TensorFlow 源码大坑(2) Session

0. 前言

1. Python Session相关API

  • Python中使用的Session API就是tf.Session相关内容。

1.1. tf.Session().run() 调用栈

  • def run(self, fetches, feed_dict=None, options=None, run_metadata=None)
    • 文件:session.py
    • 类:BaseSession
    • 概述:一般 TensorFlow Python API 调用session的入口。
  • def _run(self, handle, fetches, feed_dict, options, run_metadata)
    • 文件:session.py
    • 类:BaseSession
    • 概述:确保当前session的状态,处理feed_dict(判断是否合法等)。
  • def _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
    • 文件:session.py
    • 类:BaseSession
    • 概述:根据不同条件调用不同的 c api。
  • def _do_call(self, fn, *args)
    • 文件:session.py
    • 类:BaseSession
    • 概述:用于异常处理,并调用定义在_do_run中的函数 _run_fn
  • def _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
    • 文件:session.py
    • 类:BaseSession
    • 概述:直接调用底层 c api。
    • 关于如何将Python中的 tf.Graph 转换到c api中:本函数还会调用c api中的 TF_ExtendGraph
    • 一般情况下调用c api中的 TF_Run 方法。
  • void TF_Run(...)
    • 文件:c_api.c
    • 概述: c api语言入口函数,将需要的数据保存到 vector 中,等待 TF_Run_Helper 调用。
  • static void TF_Run_Helper(...)
    • 文件:c_api.c
    • 概述:这里终于调用了 Sessionrun方法,之后就可以参考 DirectSession 代码啦。

1.2. tf.Session 初始化调用栈

  • BaseSession的初始化方法 def __init__(self, target='', graph=None, config=None)
    • 文件:session.py
    • 概述:调用 c api 底层 TF_NewDeprecatedSession
  • TF_CAPI_EXPORT extern TF_DeprecatedSession* TF_NewDeprecatedSession(...)
    • 文件:c_api.c
    • 概述:一层封装,没什么花头。
  • Status NewSession(const SessionOptions& options, Session** out_session)
    • 文件:session.h
    • 概述:调用 SessionFactory 创建Session,重点在于获取 session factory。
  • static Status GetFactory(const SessionOptions& options, SessionFactory** out_factory)
    • 文件:session_factory.h
    • 概述:获取全局静态变量 SessionFactory(单例模式)。
    • 有两类SessionFactory,分别创建单机模式下的DirectSession和集群模式下的GrpcSession,根据BaseSession构造函数中的target来获取不同类型的Session。

2. C++ Session相关API

  • C++中使用的Session API是ClientSession相关内容。

2.1. ClientSession.run() 调用栈

  • Status Run(...) const
    • 文件:client_session.h
    • 概述:判断feed数据合法性,获取所有输入、输出tensor的名称,之后调用ClientSession::Impl
  • ClientSession::Impl().session_.run()方法
    • 文件:client_session.cc
    • 概述:本质就是调用了Session方法的run方法。
    • 感慨:果然C++ Client就比Python容易理解多了……

2.2. 其他内容

  • 如何构建计算图:
    • 在创建ClientSession对象时,需要调用Scope::NewRootScope创建root scope。
    • 在调用Scope::NewRootScope时会创建计算图对象。
  • 如何创建Session对象
    • ClientSession中会调用session.h中的Status NewSession(...)
    • 这一步以及之后创建Session的步骤与调用Python API中一致,所以不详细说明了。

3. DirectSession 概览

3.1. 基类Session

  • DirectSession继承自Session
    • Session的主要功能就是创建Graph对象(Create方法),向Graph对象中添加节点(Extend方法),运行计算图(Run方法)。
class Session {
 public:
  Session();
  virtual ~Session();

  // 创建Session所需Graph对象   virtual Status Create(const GraphDef& graph) = 0;

  // 向计算图中添加Ops   // 要求输入参数graph中要添加的op不能在现有Graph对象中存在   virtual Status Extend(const GraphDef& graph) = 0;

  // 运行计算图,获取 output_tensor_names 中各tensor的数值,保存到 outputs 中   virtual Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
                     const std::vector<string>& output_tensor_names,
                     const std::vector<string>& target_node_names,
                     std::vector<Tensor>* outputs) = 0;

  // 添加 `RunOptions` 的Create, Extend, Run方法   virtual Status Create(const RunOptions& run_options, const GraphDef& graph) {
    return errors::Unimplemented(
        "Create(const RunOptions& run_options, const GraphDef& graph) is not "
        "supported for this session.");
  }
  virtual Status Extend(const RunOptions& run_options, const GraphDef& graph) {
    return errors::Unimplemented(
        "Extend(const RunOptions& run_options, const GraphDef& graph) is not "
        "supported for this session.");
  }
  virtual Status Close(const RunOptions& run_options) {
    return errors::Unimplemented(
        "Close(const RunOptions& run_options) is not supported for this "
        "session.");
  }
  virtual Status Run(const RunOptions& run_options,
                     const std::vector<std::pair<string, Tensor> >& inputs,
                     const std::vector<string>& output_tensor_names,
                     const std::vector<string>& target_node_names,
                     std::vector<Tensor>* outputs, RunMetadata* run_metadata);

  // 省略其他一些方法 };

3.2. DirectSession的run方法

  • 流程:
    • 一些准备工作:判断session是否关闭、计算图是否构建,累加计数器,获取输入tensor名称。
    • 获取Executor(如果已经存在则直接获取,不存在则创建)。
    • 构建FunctionCallFrame,用于管理Executor的输入与输出。
    • 运行计算图。
    • 获取运行结果。
  • 重点:Executor的获取,运行计算图。
Status DirectSession::Run(const RunOptions& run_options,
                          const NamedTensorList& inputs,
                          const std::vector<string>& output_names,
                          const std::vector<string>& target_nodes,
                          std::vector<Tensor>* outputs,
                          RunMetadata* run_metadata) {
  # 判断Session状态和计算图是否构建   TF_RETURN_IF_ERROR(CheckNotClosed());
  TF_RETURN_IF_ERROR(CheckGraphCreated("Run()"));

  # 计数器   direct_session_runs->GetCell()->IncrementBy(1);

  // 获取所有输入tensor的名称   std::vector<string> input_tensor_names;
  input_tensor_names.reserve(inputs.size());
  for (const auto& it : inputs) {
    input_tensor_names.push_back(it.first);
  }

  // 获取Executor   // 如果已经存在则直接获取,不存在则创建   ExecutorsAndKeys* executors_and_keys;
  RunStateArgs run_state_args(run_options.debug_options());
  TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
                                          target_nodes, &executors_and_keys,
                                          &run_state_args));

  // 构建 FunctionCallFrame,好像是用来处理 Executor 的输入与输出   FunctionCallFrame call_frame(executors_and_keys->input_types,
                               executors_and_keys->output_types);
  gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
  for (const auto& it : inputs) {
    if (it.second.dtype() == DT_RESOURCE) {
      Tensor tensor_from_handle;
      TF_RETURN_IF_ERROR(
          ResourceHandleToInputTensor(it.second, &tensor_from_handle));
      feed_args[executors_and_keys->input_name_to_index[it.first]] =
          tensor_from_handle;
    } else {
      feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
    }
  }
  const Status s = call_frame.SetArgs(feed_args);
  if (errors::IsInternal(s)) {
    return errors::InvalidArgument(s.error_message());
  } else if (!s.ok()) {
    return s;
  }

  const int64 step_id = step_id_counter_.fetch_add(1);

  if (LogMemory::IsEnabled()) {
    LogMemory::RecordStep(step_id, run_state_args.handle);
  }

  // 具体运行run过程   TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame,
                                 executors_and_keys, run_metadata));

  // 获取计算图运行结果   if (outputs) {
    std::vector<Tensor> sorted_outputs;
    const Status s = call_frame.ConsumeRetvals(&sorted_outputs);
    if (errors::IsInternal(s)) {
      return errors::InvalidArgument(s.error_message());
    } else if (!s.ok()) {
      return s;
    }
    const bool unique_outputs =
        output_names.size() == executors_and_keys->output_name_to_index.size();
    // first_indices[i] = j implies that j is the smallest value for which     // output_names[i] == output_names[j].     std::vector<int> first_indices;
    if (!unique_outputs) {
      first_indices.resize(output_names.size());
      for (int i = 0; i < output_names.size(); ++i) {
        for (int j = 0; j <= i; ++j) {
          if (output_names[i] == output_names[j]) {
            first_indices[i] = j;
            break;
          }
        }
      }
    }
    outputs->clear();
    outputs->reserve(sorted_outputs.size());
    for (int i = 0; i < output_names.size(); ++i) {
      const string& output_name = output_names[i];
      if (first_indices.empty() || first_indices[i] == i) {
        outputs->emplace_back(
            std::move(sorted_outputs[executors_and_keys
                                         ->output_name_to_index[output_name]]));
      } else {
        outputs->push_back((*outputs)[first_indices[i]]);
      }
    }
  }

  return Status::OK();
}

3.3. 构建Executor

  • DirectSession::GetOrCreateExecutors()函数详解
    • 获取的 Executor 保存在 ExecutorsAndKeys 对象中。
    • DirectSession中存在一个map,用于映射输入以及对应的 ExecutorsAndKeys 对象,命名为executors_
Status DirectSession::GetOrCreateExecutors(
    gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
    gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
    RunStateArgs* run_state_args) {
  int64 handle_name_counter_value = -1;
  if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
    handle_name_counter_value = handle_name_counter_.fetch_add(1);
  }

  string debug_tensor_watches_summary;
  if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
    debug_tensor_watches_summary = SummarizeDebugTensorWatches(
        run_state_args->debug_options.debug_tensor_watch_opts());
  }

  // 获取 executors_ 中的 key   const string key = strings::StrCat(
      str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
      str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
      "/", debug_tensor_watches_summary);
  if (handle_name_counter_value >= 0) {
    run_state_args->handle =
        strings::StrCat(key, ";", handle_name_counter_value);
  }

  // 查看是否已经存在对应的 `ExecutorsAndKeys` 对象   {
    mutex_lock l(executor_lock_);  // could use reader lock     auto it = executors_.find(key);
    if (it != executors_.end()) {
      *executors_and_keys = it->second.get();
      return Status::OK();
    }
  }

  // 与上面的类似,只不过把各个key排序   // 在后面的源码中,executors_ 的 key 是 sorted_key   // 快速查找存在的意义在于:如果输入参数少,那么排不排序结果是一样的   std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
  std::sort(inputs_sorted.begin(), inputs_sorted.end());
  std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
  std::sort(outputs_sorted.begin(), outputs_sorted.end());
  std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
  std::sort(tn_sorted.begin(), tn_sorted.end());
  const string sorted_key = strings::StrCat(
      str_util::Join(inputs_sorted, ","), "->",
      str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
      "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
  if (handle_name_counter_value >= 0) {
    run_state_args->handle =
        strings::StrCat(sorted_key, ";", handle_name_counter_value);
  }
  {
    mutex_lock l(executor_lock_);
    auto it = executors_.find(sorted_key);
    if (it != executors_.end()) {
      *executors_and_keys = it->second.get();
      // Insert this under the original key.       executors_.emplace(key, it->second);
      return Status::OK();
    }
  }

  // 没有找到对应的 ExecutorsAndKeys 对象,需要自己创建   CallableOptions callable_options;
  for (const string& input : inputs_sorted) {
    callable_options.add_feed(input);
  }
  for (const string& output : outputs_sorted) {
    callable_options.add_fetch(output);
  }
  for (const string& target : tn_sorted) {
    callable_options.add_target(target);
  }
  *callable_options.mutable_run_options()->mutable_debug_options() =
      run_state_args->debug_options;

  // 创建Executors的细节都在 CreateExecutors 函数中   // 结果都保存在 ek 和 func_info 中   std::unique_ptr<ExecutorsAndKeys> ek;
  std::unique_ptr<FunctionInfo> func_info;
  TF_RETURN_IF_ERROR(
      CreateExecutors(callable_options, &ek, &func_info, run_state_args));

  // 将创建好的 FunctionInfo 保存到 functions_ 中   mutex_lock l(executor_lock_);
  functions_.push_back(std::move(func_info));

  // 将创建好的 ExecutorsAndKeys 保存到 executors_ 中   auto insert_result = executors_.emplace(
      sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek)));
  executors_.emplace(key, insert_result.first->second);
  *executors_and_keys = insert_result.first->second.get();

  return Status::OK();
}
  • CreateExecutors函数
    • 概述:创建新的 ExecutorsAndKeysFunctionInfo 对象。
    • 其他:
      • 通过 CreateGraphs 创建运行时计算图。
      • 遍历所有Graph对象,进行一系列操作(分配设备、通过GraphOptimizer优化等、通过NewLocalExecutor创建ExecutorImpl)。

3.4. 具体运行计算图

  • DirectSession::RunInternal函数详解
    • 这里只关注主要功能(运行计算图),其他一些功能(cost model等)源码就省略了。
Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
                                  CallFrameInterface* call_frame,
                                  ExecutorsAndKeys* executors_and_keys,
                                  RunMetadata* run_metadata) {
  const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);

  std::unique_ptr<DebuggerStateInterface> debugger_state;
  if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
    TF_RETURN_IF_ERROR(
        CreateDebuggerState(executors_and_keys->callable_options,
                            run_options.debug_options().global_step(), step_id,
                            executor_step_count, &debugger_state));
  }

  // 构建 RunState 用于标记运行状态   // 构建 IntraProcessRendezvous 用于本地Tensor管理   RunState run_state(step_id, &devices_);
  run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());

  // 构建 ExecutorBarrier 用于协调多个 Executor 并行计算,保持 graph 一致性   const size_t num_executors = executors_and_keys->items.size();
  ExecutorBarrier* barrier = new ExecutorBarrier(
      num_executors, run_state.rendez, [&run_state](const Status& ret) {
        {
          mutex_lock l(run_state.mu_);
          run_state.status.Update(ret);
        }
        run_state.executors_done.Notify();
      });

  // 构建 args   Executor::Args args;
  args.step_id = step_id;
  args.call_frame = call_frame;
  args.rendezvous = run_state.rendez;
  args.collective_executor =
      (run_state.collective_executor ? run_state.collective_executor->get()
                                     : nullptr);
  CancellationManager step_cancellation_manager;
  args.cancellation_manager = &step_cancellation_manager;
  args.session_state = &session_state_;
  args.tensor_store = &run_state.tensor_store;
  args.step_container = &run_state.step_container;
  args.sync_on_finish = sync_on_finish_;

  ...

  // 处理 `Session::Close()`   const CancellationToken cancellation_token =
      cancellation_manager_->get_cancellation_token();
  const bool already_cancelled = !cancellation_manager_->RegisterCallback(
      cancellation_token, [&step_cancellation_manager]() {
        step_cancellation_manager.StartCancel();
      });
  if (already_cancelled) {
    run_state.executors_done.Notify();
    delete barrier;
    return errors::Cancelled("Run call was cancelled");
  }

  // 通过线程池实际运行Executor   thread::ThreadPool* pool =
      thread_pools_[run_options.inter_op_thread_pool()].first;
  Executor::Args::Runner default_runner = [this,
                                           pool](Executor::Args::Closure c) {
    SchedClosure(pool, std::move(c));
  };
  for (const auto& item : executors_and_keys->items) {
    // item是 ExecutorImpl 对象     thread::ThreadPool* device_thread_pool =
        item.device->tensorflow_device_thread_pool();
    if (!device_thread_pool) {
      args.runner = default_runner;
    } else {
      args.runner = [this, device_thread_pool](Executor::Args::Closure c) {
        SchedClosure(device_thread_pool, std::move(c));
      };
    }
    // 运行的重点就在这     item.executor->RunAsync(args, barrier->Get());
  }

  ...

  // 保存运行结果   if (!run_state.tensor_store.empty()) {
    TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors(
        {executors_and_keys->callable_options.fetch().begin(),
         executors_and_keys->callable_options.fetch().end()},
        &session_state_));
  }

  ...

  return Status::OK();
}
  • 关于运行 item.executor->RunAsync(args, barrier->Get())
# 1. ExecutorImpl::RunAsync 的实现 # 概述:每个ExecutorImpl 对应一个 ExecutorState 对象,用于追踪输入节点状态,满足条件就可以调度到线程池中 void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
  (new ExecutorState(args, this))->RunAsync(std::move(done));
}

# 2. ExecutorState::RunAsync 的实现 # 概述:获取context map上下文,初始化队列,开启线程池 void ExecutorState::RunAsync(Executor::DoneCallback done) {
  const Graph* graph = impl_->graph_.get();
  TaggedNodeSeq ready;

  // 获取 context map,即运行时上下文   Device* device = impl_->params_.device;
  const Status fill_status =
      device->FillContextMap(graph, &device_context_map_);
  if (!fill_status.ok()) {
    done(fill_status);
    return;
  }

  // 初始化 ready 队列,即存放入度为0的node   for (const Node* n : impl_->root_nodes_) {
    DCHECK_EQ(n->in_edges().size(), 0);
    ready.push_back(TaggedNode{n, root_frame_, 0, false});
  }
  if (ready.empty()) {
    done(Status::OK());
  } else {
    num_outstanding_ops_ = ready.size();
    root_frame_->iterations[0]->outstanding_ops = ready.size();
    done_cb_ = std::move(done);
    // 线程池入口     ScheduleReady(ready, nullptr);
  }
}

# 3. ExecutorState::ScheduleReady 的实现 # 概述:将节点分为 expensive & inexpensive 节点,将inexpensive节点放入 inline_ready 中 # 本函数重点在于运行 runner_ // Schedule all the expensive nodes in 'ready', and put all the inexpensive // nodes in 'ready' into 'inline_ready'. void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
                                  TaggedNodeReadyQueue* inline_ready) {
  if (ready.empty()) return;

  int64 scheduled_usec = 0;
  if (stats_collector_) {
    scheduled_usec = nodestats::NowInUsec();
  }

  if (inline_ready == nullptr) {
    // 运行所有 ready ops     for (auto& tagged_node : ready) {
      runner_([=]() { Process(tagged_node, scheduled_usec); });
    }
    return;
  }

  // 将节点分类,运行 expensive node   const GraphView& gview = impl_->gview_;
  const TaggedNode* curr_expensive_node = nullptr;
  for (auto& tagged_node : ready) {
    const NodeItem& item = *gview.node(tagged_node.node->id());
    if (tagged_node.is_dead || !item.kernel_is_expensive) {
      inline_ready->push_back(tagged_node);
    } else {
      if (curr_expensive_node) {
        runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
                          scheduled_usec));
      }
      curr_expensive_node = &tagged_node;
    }
  }

  if (curr_expensive_node) {
    if (inline_ready->empty()) {
      inline_ready->push_back(*curr_expensive_node);
    } else {
      runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
                        scheduled_usec));
    }
  }
}

# 4. runner_ # 概述:其本质是一个 std::function<void(Closure)> 对象,在 DirectSession 中的定义就是 SchedClosure 函数,调用了Eigen::ThreadPoolTempl void DirectSession::SchedClosure(thread::ThreadPool* pool, std::function<void()> c) {
  pool->Schedule(std::move(c));
}

# 5. ExecutorState::Process 详解 # 概述:线程池中跑的内容,代码太长不贴了。 # 主要流程: # + 将当前节点添加到 inline_ready 队列中。 # + 循环从 inline_ready 队列获取节点并运行,运行完毕后执行 NodeDone(有可能会添加新节点到inline_ready队列) # + 当inline ready队列为空时,跳出循环。 # 其他重要内容: # + 运行节点通过 device 的 ComputeAsync 或 Compute 方法 # + 处理输出结果使用 ProcessOutputs 函数和 PropagateOutputs 函数 # + 计算结束后通过 NodeDone 来收尾 
# 6. 其他函数总结 # + ProcessOutputs:处理output tensor。 # + PropagateOutputs:获取ready队列,即当前Node执行完成后,有其他哪些Node可以用于计算。 # + NodeDone:最终运行了ScheduleReady函数,继续运行其他Node 

    原文作者:清欢守护者
    原文地址: https://zhuanlan.zhihu.com/p/42187985
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞