0. 前言
- 参考资料:
- 《TensorFlow架构与设计:会话生命周期》,推荐
- 介绍TensorFlow Python API如何通过swig作为纽带调用 c api,最终调用c++核心代码,实现Session生命周期相关功能(创建、运行、关闭、销毁)。
- 《TensorFlow 拆包(一):Session.Run()》,推荐:详细介绍了
DirectSession
中run方法的流程。 - 《『深度长文』Tensorflow代码解析(五)》:描述了
DirectSession
与GrpcSession
调用run
方法的流程。 - 《[图解tensorflow源码] Session::Run()流程图 (单机版)》:图片描述
DirectSession
中run
方法的调用栈。 - 文章总体架构:
- 介绍Python/C++创建Session、调用Session中run方法的调用栈。
- 介绍
DirectSession
中run
方法。 - 感想:还有很多细节不清楚……总体上有个概念了……
1. Python Session相关API
- Python中使用的Session API就是
tf.Session
相关内容。
1.1. tf.Session().run() 调用栈
def run(self, fetches, feed_dict=None, options=None, run_metadata=None)
- 文件:
session.py
。 - 类:
BaseSession
。 - 概述:一般 TensorFlow Python API 调用session的入口。
def _run(self, handle, fetches, feed_dict, options, run_metadata)
- 文件:
session.py
。 - 类:
BaseSession
。 - 概述:确保当前session的状态,处理feed_dict(判断是否合法等)。
def _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
- 文件:
session.py
。 - 类:
BaseSession
。 - 概述:根据不同条件调用不同的 c api。
def _do_call(self, fn, *args)
- 文件:
session.py
。 - 类:
BaseSession
。 - 概述:用于异常处理,并调用定义在
_do_run
中的函数_run_fn
。 def _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
- 文件:
session.py
。 - 类:
BaseSession
。 - 概述:直接调用底层 c api。
- 关于如何将Python中的
tf.Graph
转换到c api中:本函数还会调用c api中的TF_ExtendGraph
。 - 一般情况下调用c api中的
TF_Run
方法。 void TF_Run(...)
- 文件:
c_api.c
。 - 概述: c api语言入口函数,将需要的数据保存到
vector
中,等待TF_Run_Helper
调用。 static void TF_Run_Helper(...)
- 文件:
c_api.c
。 - 概述:这里终于调用了
Session
的run
方法,之后就可以参考DirectSession
代码啦。
1.2. tf.Session 初始化调用栈
BaseSession
的初始化方法def __init__(self, target='', graph=None, config=None)
- 文件:
session.py
。 - 概述:调用 c api 底层
TF_NewDeprecatedSession
。 TF_CAPI_EXPORT extern TF_DeprecatedSession* TF_NewDeprecatedSession(...)
- 文件:
c_api.c
。 - 概述:一层封装,没什么花头。
Status NewSession(const SessionOptions& options, Session** out_session)
- 文件:
session.h
- 概述:调用
SessionFactory
创建Session,重点在于获取 session factory。 static Status GetFactory(const SessionOptions& options, SessionFactory** out_factory)
- 文件:
session_factory.h
- 概述:获取全局静态变量
SessionFactory
(单例模式)。 - 有两类
SessionFactory
,分别创建单机模式下的DirectSession
和集群模式下的GrpcSession
,根据BaseSession
构造函数中的target
来获取不同类型的Session。
2. C++ Session相关API
- C++中使用的Session API是
ClientSession
相关内容。
2.1. ClientSession.run() 调用栈
Status Run(...) const
- 文件:
client_session.h
- 概述:判断feed数据合法性,获取所有输入、输出tensor的名称,之后调用
ClientSession::Impl
。 ClientSession::Impl().session_.run()
方法- 文件:
client_session.cc
- 概述:本质就是调用了
Session
方法的run
方法。 - 感慨:果然C++ Client就比Python容易理解多了……
2.2. 其他内容
- 如何构建计算图:
- 在创建
ClientSession
对象时,需要调用Scope::NewRootScope
创建root scope。 - 在调用
Scope::NewRootScope
时会创建计算图对象。 - 如何创建
Session
对象 - 在
ClientSession
中会调用session.h
中的Status NewSession(...)
。 - 这一步以及之后创建Session的步骤与调用Python API中一致,所以不详细说明了。
3. DirectSession 概览
3.1. 基类Session
DirectSession
继承自Session
Session
的主要功能就是创建Graph对象(Create
方法),向Graph对象中添加节点(Extend
方法),运行计算图(Run
方法)。
class Session {
public:
Session();
virtual ~Session();
// 创建Session所需Graph对象 virtual Status Create(const GraphDef& graph) = 0;
// 向计算图中添加Ops // 要求输入参数graph中要添加的op不能在现有Graph对象中存在 virtual Status Extend(const GraphDef& graph) = 0;
// 运行计算图,获取 output_tensor_names 中各tensor的数值,保存到 outputs 中 virtual Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs) = 0;
// 添加 `RunOptions` 的Create, Extend, Run方法 virtual Status Create(const RunOptions& run_options, const GraphDef& graph) {
return errors::Unimplemented(
"Create(const RunOptions& run_options, const GraphDef& graph) is not "
"supported for this session.");
}
virtual Status Extend(const RunOptions& run_options, const GraphDef& graph) {
return errors::Unimplemented(
"Extend(const RunOptions& run_options, const GraphDef& graph) is not "
"supported for this session.");
}
virtual Status Close(const RunOptions& run_options) {
return errors::Unimplemented(
"Close(const RunOptions& run_options) is not supported for this "
"session.");
}
virtual Status Run(const RunOptions& run_options,
const std::vector<std::pair<string, Tensor> >& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs, RunMetadata* run_metadata);
// 省略其他一些方法 };
3.2. DirectSession的run方法
- 流程:
- 一些准备工作:判断session是否关闭、计算图是否构建,累加计数器,获取输入tensor名称。
- 获取Executor(如果已经存在则直接获取,不存在则创建)。
- 构建
FunctionCallFrame
,用于管理Executor的输入与输出。 - 运行计算图。
- 获取运行结果。
- 重点:Executor的获取,运行计算图。
Status DirectSession::Run(const RunOptions& run_options,
const NamedTensorList& inputs,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
std::vector<Tensor>* outputs,
RunMetadata* run_metadata) {
# 判断Session状态和计算图是否构建 TF_RETURN_IF_ERROR(CheckNotClosed());
TF_RETURN_IF_ERROR(CheckGraphCreated("Run()"));
# 计数器 direct_session_runs->GetCell()->IncrementBy(1);
// 获取所有输入tensor的名称 std::vector<string> input_tensor_names;
input_tensor_names.reserve(inputs.size());
for (const auto& it : inputs) {
input_tensor_names.push_back(it.first);
}
// 获取Executor // 如果已经存在则直接获取,不存在则创建 ExecutorsAndKeys* executors_and_keys;
RunStateArgs run_state_args(run_options.debug_options());
TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
target_nodes, &executors_and_keys,
&run_state_args));
// 构建 FunctionCallFrame,好像是用来处理 Executor 的输入与输出 FunctionCallFrame call_frame(executors_and_keys->input_types,
executors_and_keys->output_types);
gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
for (const auto& it : inputs) {
if (it.second.dtype() == DT_RESOURCE) {
Tensor tensor_from_handle;
TF_RETURN_IF_ERROR(
ResourceHandleToInputTensor(it.second, &tensor_from_handle));
feed_args[executors_and_keys->input_name_to_index[it.first]] =
tensor_from_handle;
} else {
feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
}
}
const Status s = call_frame.SetArgs(feed_args);
if (errors::IsInternal(s)) {
return errors::InvalidArgument(s.error_message());
} else if (!s.ok()) {
return s;
}
const int64 step_id = step_id_counter_.fetch_add(1);
if (LogMemory::IsEnabled()) {
LogMemory::RecordStep(step_id, run_state_args.handle);
}
// 具体运行run过程 TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame,
executors_and_keys, run_metadata));
// 获取计算图运行结果 if (outputs) {
std::vector<Tensor> sorted_outputs;
const Status s = call_frame.ConsumeRetvals(&sorted_outputs);
if (errors::IsInternal(s)) {
return errors::InvalidArgument(s.error_message());
} else if (!s.ok()) {
return s;
}
const bool unique_outputs =
output_names.size() == executors_and_keys->output_name_to_index.size();
// first_indices[i] = j implies that j is the smallest value for which // output_names[i] == output_names[j]. std::vector<int> first_indices;
if (!unique_outputs) {
first_indices.resize(output_names.size());
for (int i = 0; i < output_names.size(); ++i) {
for (int j = 0; j <= i; ++j) {
if (output_names[i] == output_names[j]) {
first_indices[i] = j;
break;
}
}
}
}
outputs->clear();
outputs->reserve(sorted_outputs.size());
for (int i = 0; i < output_names.size(); ++i) {
const string& output_name = output_names[i];
if (first_indices.empty() || first_indices[i] == i) {
outputs->emplace_back(
std::move(sorted_outputs[executors_and_keys
->output_name_to_index[output_name]]));
} else {
outputs->push_back((*outputs)[first_indices[i]]);
}
}
}
return Status::OK();
}
3.3. 构建Executor
DirectSession::GetOrCreateExecutors()
函数详解- 获取的 Executor 保存在
ExecutorsAndKeys
对象中。 DirectSession
中存在一个map,用于映射输入以及对应的ExecutorsAndKeys
对象,命名为executors_
。
Status DirectSession::GetOrCreateExecutors(
gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
RunStateArgs* run_state_args) {
int64 handle_name_counter_value = -1;
if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
handle_name_counter_value = handle_name_counter_.fetch_add(1);
}
string debug_tensor_watches_summary;
if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
debug_tensor_watches_summary = SummarizeDebugTensorWatches(
run_state_args->debug_options.debug_tensor_watch_opts());
}
// 获取 executors_ 中的 key const string key = strings::StrCat(
str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
"/", debug_tensor_watches_summary);
if (handle_name_counter_value >= 0) {
run_state_args->handle =
strings::StrCat(key, ";", handle_name_counter_value);
}
// 查看是否已经存在对应的 `ExecutorsAndKeys` 对象 {
mutex_lock l(executor_lock_); // could use reader lock auto it = executors_.find(key);
if (it != executors_.end()) {
*executors_and_keys = it->second.get();
return Status::OK();
}
}
// 与上面的类似,只不过把各个key排序 // 在后面的源码中,executors_ 的 key 是 sorted_key // 快速查找存在的意义在于:如果输入参数少,那么排不排序结果是一样的 std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
std::sort(inputs_sorted.begin(), inputs_sorted.end());
std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
std::sort(outputs_sorted.begin(), outputs_sorted.end());
std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
std::sort(tn_sorted.begin(), tn_sorted.end());
const string sorted_key = strings::StrCat(
str_util::Join(inputs_sorted, ","), "->",
str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
"/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
if (handle_name_counter_value >= 0) {
run_state_args->handle =
strings::StrCat(sorted_key, ";", handle_name_counter_value);
}
{
mutex_lock l(executor_lock_);
auto it = executors_.find(sorted_key);
if (it != executors_.end()) {
*executors_and_keys = it->second.get();
// Insert this under the original key. executors_.emplace(key, it->second);
return Status::OK();
}
}
// 没有找到对应的 ExecutorsAndKeys 对象,需要自己创建 CallableOptions callable_options;
for (const string& input : inputs_sorted) {
callable_options.add_feed(input);
}
for (const string& output : outputs_sorted) {
callable_options.add_fetch(output);
}
for (const string& target : tn_sorted) {
callable_options.add_target(target);
}
*callable_options.mutable_run_options()->mutable_debug_options() =
run_state_args->debug_options;
// 创建Executors的细节都在 CreateExecutors 函数中 // 结果都保存在 ek 和 func_info 中 std::unique_ptr<ExecutorsAndKeys> ek;
std::unique_ptr<FunctionInfo> func_info;
TF_RETURN_IF_ERROR(
CreateExecutors(callable_options, &ek, &func_info, run_state_args));
// 将创建好的 FunctionInfo 保存到 functions_ 中 mutex_lock l(executor_lock_);
functions_.push_back(std::move(func_info));
// 将创建好的 ExecutorsAndKeys 保存到 executors_ 中 auto insert_result = executors_.emplace(
sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek)));
executors_.emplace(key, insert_result.first->second);
*executors_and_keys = insert_result.first->second.get();
return Status::OK();
}
CreateExecutors
函数- 概述:创建新的
ExecutorsAndKeys
和FunctionInfo
对象。 - 其他:
- 通过
CreateGraphs
创建运行时计算图。 - 遍历所有Graph对象,进行一系列操作(分配设备、通过
GraphOptimizer
优化等、通过NewLocalExecutor
创建ExecutorImpl
)。
3.4. 具体运行计算图
DirectSession::RunInternal
函数详解- 这里只关注主要功能(运行计算图),其他一些功能(cost model等)源码就省略了。
Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
CallFrameInterface* call_frame,
ExecutorsAndKeys* executors_and_keys,
RunMetadata* run_metadata) {
const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
std::unique_ptr<DebuggerStateInterface> debugger_state;
if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
TF_RETURN_IF_ERROR(
CreateDebuggerState(executors_and_keys->callable_options,
run_options.debug_options().global_step(), step_id,
executor_step_count, &debugger_state));
}
// 构建 RunState 用于标记运行状态 // 构建 IntraProcessRendezvous 用于本地Tensor管理 RunState run_state(step_id, &devices_);
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
// 构建 ExecutorBarrier 用于协调多个 Executor 并行计算,保持 graph 一致性 const size_t num_executors = executors_and_keys->items.size();
ExecutorBarrier* barrier = new ExecutorBarrier(
num_executors, run_state.rendez, [&run_state](const Status& ret) {
{
mutex_lock l(run_state.mu_);
run_state.status.Update(ret);
}
run_state.executors_done.Notify();
});
// 构建 args Executor::Args args;
args.step_id = step_id;
args.call_frame = call_frame;
args.rendezvous = run_state.rendez;
args.collective_executor =
(run_state.collective_executor ? run_state.collective_executor->get()
: nullptr);
CancellationManager step_cancellation_manager;
args.cancellation_manager = &step_cancellation_manager;
args.session_state = &session_state_;
args.tensor_store = &run_state.tensor_store;
args.step_container = &run_state.step_container;
args.sync_on_finish = sync_on_finish_;
...
// 处理 `Session::Close()` const CancellationToken cancellation_token =
cancellation_manager_->get_cancellation_token();
const bool already_cancelled = !cancellation_manager_->RegisterCallback(
cancellation_token, [&step_cancellation_manager]() {
step_cancellation_manager.StartCancel();
});
if (already_cancelled) {
run_state.executors_done.Notify();
delete barrier;
return errors::Cancelled("Run call was cancelled");
}
// 通过线程池实际运行Executor thread::ThreadPool* pool =
thread_pools_[run_options.inter_op_thread_pool()].first;
Executor::Args::Runner default_runner = [this,
pool](Executor::Args::Closure c) {
SchedClosure(pool, std::move(c));
};
for (const auto& item : executors_and_keys->items) {
// item是 ExecutorImpl 对象 thread::ThreadPool* device_thread_pool =
item.device->tensorflow_device_thread_pool();
if (!device_thread_pool) {
args.runner = default_runner;
} else {
args.runner = [this, device_thread_pool](Executor::Args::Closure c) {
SchedClosure(device_thread_pool, std::move(c));
};
}
// 运行的重点就在这 item.executor->RunAsync(args, barrier->Get());
}
...
// 保存运行结果 if (!run_state.tensor_store.empty()) {
TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors(
{executors_and_keys->callable_options.fetch().begin(),
executors_and_keys->callable_options.fetch().end()},
&session_state_));
}
...
return Status::OK();
}
- 关于运行
item.executor->RunAsync(args, barrier->Get())
# 1. ExecutorImpl::RunAsync 的实现 # 概述:每个ExecutorImpl 对应一个 ExecutorState 对象,用于追踪输入节点状态,满足条件就可以调度到线程池中 void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
(new ExecutorState(args, this))->RunAsync(std::move(done));
}
# 2. ExecutorState::RunAsync 的实现 # 概述:获取context map上下文,初始化队列,开启线程池 void ExecutorState::RunAsync(Executor::DoneCallback done) {
const Graph* graph = impl_->graph_.get();
TaggedNodeSeq ready;
// 获取 context map,即运行时上下文 Device* device = impl_->params_.device;
const Status fill_status =
device->FillContextMap(graph, &device_context_map_);
if (!fill_status.ok()) {
done(fill_status);
return;
}
// 初始化 ready 队列,即存放入度为0的node for (const Node* n : impl_->root_nodes_) {
DCHECK_EQ(n->in_edges().size(), 0);
ready.push_back(TaggedNode{n, root_frame_, 0, false});
}
if (ready.empty()) {
done(Status::OK());
} else {
num_outstanding_ops_ = ready.size();
root_frame_->iterations[0]->outstanding_ops = ready.size();
done_cb_ = std::move(done);
// 线程池入口 ScheduleReady(ready, nullptr);
}
}
# 3. ExecutorState::ScheduleReady 的实现 # 概述:将节点分为 expensive & inexpensive 节点,将inexpensive节点放入 inline_ready 中 # 本函数重点在于运行 runner_ // Schedule all the expensive nodes in 'ready', and put all the inexpensive // nodes in 'ready' into 'inline_ready'. void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
TaggedNodeReadyQueue* inline_ready) {
if (ready.empty()) return;
int64 scheduled_usec = 0;
if (stats_collector_) {
scheduled_usec = nodestats::NowInUsec();
}
if (inline_ready == nullptr) {
// 运行所有 ready ops for (auto& tagged_node : ready) {
runner_([=]() { Process(tagged_node, scheduled_usec); });
}
return;
}
// 将节点分类,运行 expensive node const GraphView& gview = impl_->gview_;
const TaggedNode* curr_expensive_node = nullptr;
for (auto& tagged_node : ready) {
const NodeItem& item = *gview.node(tagged_node.node->id());
if (tagged_node.is_dead || !item.kernel_is_expensive) {
inline_ready->push_back(tagged_node);
} else {
if (curr_expensive_node) {
runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
scheduled_usec));
}
curr_expensive_node = &tagged_node;
}
}
if (curr_expensive_node) {
if (inline_ready->empty()) {
inline_ready->push_back(*curr_expensive_node);
} else {
runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
scheduled_usec));
}
}
}
# 4. runner_ # 概述:其本质是一个 std::function<void(Closure)> 对象,在 DirectSession 中的定义就是 SchedClosure 函数,调用了Eigen::ThreadPoolTempl void DirectSession::SchedClosure(thread::ThreadPool* pool, std::function<void()> c) {
pool->Schedule(std::move(c));
}
# 5. ExecutorState::Process 详解 # 概述:线程池中跑的内容,代码太长不贴了。 # 主要流程: # + 将当前节点添加到 inline_ready 队列中。 # + 循环从 inline_ready 队列获取节点并运行,运行完毕后执行 NodeDone(有可能会添加新节点到inline_ready队列) # + 当inline ready队列为空时,跳出循环。 # 其他重要内容: # + 运行节点通过 device 的 ComputeAsync 或 Compute 方法 # + 处理输出结果使用 ProcessOutputs 函数和 PropagateOutputs 函数 # + 计算结束后通过 NodeDone 来收尾
# 6. 其他函数总结 # + ProcessOutputs:处理output tensor。 # + PropagateOutputs:获取ready队列,即当前Node执行完成后,有其他哪些Node可以用于计算。 # + NodeDone:最终运行了ScheduleReady函数,继续运行其他Node