寒假回来想起来之前挖的坑,但好像并没有特别好的主题可以写,更不用说实习招聘近在眼前了,于是打算先扩展一下之前在知乎上的两个回答。
本文主要介绍动态链接的C++ New Op是如何被注册进来,又如何被Python代码调用的,也算是给自己的一个交代,毕竟本人一直不太喜欢high-level的API。本文大致分为三个模块:注册Ops,注册Kernel,调用Ops。
- Ops的注册过程
先说一下OpRegistrationData这个东西,这个类的对象由全局注册器Registry负责分配,作用简单来说就是保存OpDef和OpShapeInferenceFn函数,前者保存有Op的各种具体信息,会由OpDefBuilder在最后的解析参数时(成员函数Finalize)放进来,后者在SetShapeFn传进来(由Wrapper转发),所谓注册就是将op name和OpRegistrationData关联起来,具体来说放进hashmap。
mutable std::unordered_map<string, const OpRegistrationData*> registry_;
还得先说一下OpDefBuilder这个类,OpDefBuilder会负责接收Op的各种属性和参数定义(就是REGISTER_OP时指定的,见下),最后统一解析(注意只是解析并不保证合法性之类的)并转给OpRegistrationData这个类(包括ShapeFn)。
我们自己注册op都会通过下面这个宏定义:
REGISTER_OP("YourOp")
.Attr("T: {float}")
.Input("logits: T")
.Input("Labels: T")
.Output("loss: T")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->MakeShape({1}));
return Status::OK();
});
细节都在REGISTER_OP那个宏定义里面,简化如下:
static OpDefBuilderReceiver register_op = OpDefBuilderWrapper('YourOp')
其中OpDefBuilderWrapper内部保存有一个OpDefBuilder成员变量,你所有对REGISTER_OP宏连续调用的操作包括op的名字最后都会一股脑转发给前面那个唯一的OpDefBuilder变量,而OpDefBuilderReceiver则拿过来BuilderWrapper交给一个负责管理所有Op注册的Registry,Registry暴露Register方法给op们注册,把官方的example摘过来示意一下:
//Example registration:
OpRegistry::Global()->Register(
[](OpRegistrationData* op_reg_data)->Status {
// Populate *op_reg_data here.
return Status::OK();
});
(先解释下:OpRegistry::Global()简单的单例模式,返回OpRegistry的全局唯一实例,当然这里必须要感谢下新标准对static线程安全的保证。)
在那个lambda里面你就可以做任何想做的事情了,比如就像OpDefBuilderReceiver一样把BuilderWrapper拿进来,然后把wrapper去掉取出OpDefBuilder,看到上面lambda里面那个op_reg_data没,对这就是之前提到的将解析好参数及shapefn传到OpRegistrationData里,最后Register拿到op的name和OpRegistrationData组成pair放进hashmap完成注册,同时会做一些合法性检查的事情。如下:
OpRegistry::Global()->Register(
[wrapper](OpRegistrationData* op_reg_data) -> Status {
return wrapper.builder().Finalize(op_reg_data);
});
其实到这里真正的注册并不一定会发生,下面会详细说。
- Kernel的注册过程
与Ops的注册类似,也是有一个叫作KernelDefBuilder的wrapper,内部保存有KernelDef的一个指针,用于设置各种属性,最后调用Build函数可返回该指针并清空Builder,Kernel的注册主要是通过下面这个宏来实现的:
REGISTER_KERNEL_BUILDER( \
Name("PsRoiAlignGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"), \
PSROIAlignGradOp<GPUDevice, float>);
其中Name是KernelDefBuilder的一个派生类,Name(“KernelName”)会首先创建一个KernelDefBuilder同时设置设置kernel名称,每次调用这种setter函数就会返回Builder自身从而支持连续调用,然后是设置Device,最后添加值float到属性T中。
class Name : public KernelDefBuilder {
public:
// For system kernels, we ignore selective registration and
// unconditionally register the kernel.
explicit Name(const char* op) : KernelDefBuilder(op) {}
};
REGISTER_KERNEL_BUILDER宏里面就是一些trick,实质是创建一个名称唯一的类型为OpKernelRegistrar的全局静态变量,如果你有兴趣可以看一下:
#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
constexpr bool should_register_##ctr##__flag = \
SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); \
static ::tensorflow::kernel_factory::OpKernelRegistrar \
registrar__body__##ctr##__object( \
should_register_##ctr##__flag \
? ::tensorflow::register_kernel::kernel_builder.Build() \
: nullptr, \
#__VA_ARGS__, \
[](::tensorflow::OpKernelConstruction* context) \
-> ::tensorflow::OpKernel* { \
return new __VA_ARGS__(context); \
});
OpKernelRegistrar静态变量的构造需要三个参数,如下所示,第一个是KernelDef,第二个是定义Kernel的类名,第三个是创建kernel对象的函数,其实后面就可以知道这三个参数都会被包装到KernelRegistration这个结构体里,然后作为Kernel注册表的值。因此这个宏会首先调用KernelDefBuilder的Build函数获得对应的KernelDef;然后获取用于创建这个Kernel的C++类名称(这个类是继承自OpKernel的);最后包装一个factory函数用来接收传进来的OpKernelConstruction*,创建对应的Kernel类对象,并返回其指针。
class OpKernelRegistrar {
public:
typedef OpKernel* (*Factory)(OpKernelConstruction*);
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
Factory factory) {
if (kernel_def != nullptr) {
InitInternal(kernel_def, kernel_class_name, factory);
}
}
};
这里是InitInternal的细节
void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
StringPiece kernel_class_name,
Factory factory) {
// See comments in register_kernel::Name in header for info on _no_register.
if (kernel_def->op() != "_no_register") {
const string key =
Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
kernel_def->label());
GlobalKernelRegistryTyped()->insert(std::make_pair(
key, KernelRegistration(*kernel_def, kernel_class_name, factory)));
}
delete kernel_def;
}
可以看到OpKernelRegistrar这个类主要是负责根据传进来的KernelDef和KernelFactory,首先依据一定规则生成一个适当的key,并插入到一个全局唯一的Kernel注册表里,注册表当然是一个map但是值得注意的是它是multimap因此支持一个键对应多个kernel副本。
typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;
- OpKernel的创建与调用
如果你还记得的话,前面还有一个全局的OpRegistry,这样根据NodeDef里的Op名称就可以获得Op对应的信息,再结合设备类型也就可以获得Kernel对应的信息了,而NodeDef是在Python创建Operation之前创建的,可以看这里create_op,后面会提到调用这个函数的地方。
然后就可以根据一个NodeDef和当前的设备类型在运行时创建一个OpKernel了,每个被创建的OpKernel都会被自动地管理生命周期。在Device类中会有一个OpSegment对象,OpSegment会管理一个sessions中用到的kernel,根据情况来决定是创建新的还是复用之前的OpKernel,具体来说是有两个嵌套的hashmap,第一个将session handle映射到一个KernelMap,然后在KernelMap就可以去查找是否有对应Op名的OpKernel,如果没有就调用一个create_fn函数进行创建。
那么问题来了,这背后的原动力在哪?事实上Session在第一次为某个Node创建Executor的时候这一切就发生了(后面会再说到Executor的):DirectSession::GetOrCreateExecutors,更直接地可以看查找失败后第一次创建Executor的地方,代码片段如下:
LocalExecutorParams params;
params.device = device;
params.function_library = lib;
auto opseg = device->op_segment();
params.create_kernel = [this, lib, opseg](const NodeDef& ndef, OpKernel** kernel) {
// We do not share the kernel via the OpSegment if the node is
// stateless, or a function.
// NOTE(mrry): We must not share function kernels (implemented
// using `CallOp`) between subgraphs, because `CallOp::handle_`
// is tied to a particular subgraph. Even if the function itself
// is stateful, the `CallOp` that invokes it is not.
if (!lib->IsStateful(ndef.op()) ||
lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
return lib->CreateKernel(ndef, kernel);
}
auto create_fn = [lib, &ndef](OpKernel** kernel) {
return lib->CreateKernel(ndef, kernel);
};
// Kernels created for subgraph nodes need to be cached. On
// cache miss, create_fn() is invoked to create a kernel based
// on the function library here + global op registry.
return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
create_fn);
};
可以看到取出OpSegment,构造create_fn并调用FindOrCreate的过程。其中create_fn内部调用的FunctionLibraryRuntime的CreateKernel函数可以看这里:FunctionLibraryRuntimeImpl::CreateKernel,再往下CreateNonCachedKernel:
Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
const NodeDef& ndef, int graph_def_version,
OpKernel** kernel) {
const auto device_type = DeviceType(device->attributes().device_type());
auto allocator = device->GetAllocator(AllocatorAttributes());
return CreateOpKernel(device_type, device, allocator, flib, ndef,
graph_def_version, kernel);
}
看到了CreateOpKernel的调用,这下总算回到了我们最开始的地方CreateOpKernel:
Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
Allocator* allocator, FunctionLibraryRuntime* flib,
const NodeDef& node_def, int graph_def_version,
OpKernel** kernel)
这个核心函数主要是做一下以下几件事情:根据node_def取出op名,去查OpRegistry,并与node_def的信息进行校验,比如接口是否一致,node_def中是否包含所有op_def中的信息等,然后根据device_type和op名去查KernelRegistry获取KernelRegistration,就是map中的值,包含之前提到的三项。接着是确定输入输出类型及其存储位置,最后是创建一个OpKernelConstruction对象,并传给Kernel的factory函数函数,这就到了用户自己写的函数这边了:
// Everything needed for OpKernel construction.
OpKernelConstruction context(
device_type, device, allocator, &node_def, op_def, flib, inputs,
input_memory_types, outputs, output_memory_types, graph_def_version, &s);
*kernel = (*registration->factory)(&context);
Kernel创建完了,那么它什么时候被执行呢?前面说到第一次创建executor的时候会创建OpKernel,其实每次Session调用Run的时候最终也是转到executor这边来执行的,包括根据当前的运行时环境创建OpKernelContext以及OpKernel::Compute的调用:
// Synchronous computes.
OpKernelContext ctx(¶ms, item.num_outputs);
nodestats::SetOpStart(stats);
device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
nodestats::SetOpEnd(stats);
其中device->Compute这一步通过查看基类的实现就大概能知道所有细节了Device::Compute:
// Performs the actual compute function.
//
// Subclasses may override this function if they wish to perform
// some initialization before each compute.
virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) {
op_kernel->Compute(context);
}
可以发现,我们写的Compute方法在这里就被调用了。至此故事好像可以告一段落了,不过说了半天好像一直在C++这边啊,那Python代码怎么调用的呢?
- 注册Ops和Kernel后传
根据上面REGISTER_KERNEL_BUILDER所展开的两段程序很容易就判断出如果动态库被加载进来的话,Kernel就会自动完成注册,这跟Ops的注册基本是一样的,不同之处在于动态链接进来的Ops会在加载库之前设置延迟注册的标记,并添加一个Watcher,然后手动调用注册,这主要是为了通过Watcher获取注册过程中从OpRegistrationData(就是注册表的值)中取出的OpDef,这一点可以在后面的LoadLibrary中看到。这个过程很重要,通过获得的OpDef组成的OpList并序列化后,Python端就可以解析出这些OpDef,同时调用C++这边利用这些OpDef生成对应的ApiDef,二者结合就可以动态生成定义这个Op的Python代码,然后返回到Python端执行这些代码,注意这些代码的执行并不包括创建Op并添加到Graph这个过程,只包括定义相关代码段的函数,下面是从Python端load_op_library一直到生成Python代码的过程:load_op_library->GetPythonWrappers->GetPythonOps->GetPythonOp->GenPythonOp::Code()。还有从OpList生成ApiDef的地方ApiDefMap::ApiDefMap(const OpList& op_list)。如果你有兴趣的话可以去看一下我之前写的一个Op自动生成的代码,我附在了本文最后,生成代码中的apply_op就是添加Op到Graph的代码,可以看这里apply_op,这个函数的最后面就是前面提到的调用Graph的create_op。
下面是LoadLibrary的代码段,可以对照一下:
Status LoadLibrary(const char* library_filename, void** result,
const void** buf, size_t* len) {
static mutex mu;
static std::unordered_map<string, Library> loaded_libs;
Env* env = Env::Default();
Library library;
std::unordered_set<string> seen_op_names;
{
mutex_lock lock(mu);
if (loaded_libs.find(library_filename) != loaded_libs.end()) {
library = loaded_libs[library_filename];
} else {
Status s = OpRegistry::Global()->ProcessRegistrations();
if (!s.ok()) {
return s;
}
TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(
[&library, &seen_op_names](const Status& s,
const OpDef& opdef) -> Status {
if (errors::IsAlreadyExists(s)) {
if (seen_op_names.find(opdef.name()) == seen_op_names.end()) {
// Over writing a registration of an op not in this custom op
// library. Treat this as not an error.
return Status::OK();
}
}
if (s.ok()) {
*library.op_list.add_op() = opdef;
seen_op_names.insert(opdef.name());
}
return s;
}));
OpRegistry::Global()->DeferRegistrations();
s = env->LoadLibrary(library_filename, &library.handle);
if (s.ok()) {
s = OpRegistry::Global()->ProcessRegistrations();
}
if (!s.ok()) {
OpRegistry::Global()->ClearDeferredRegistrations();
TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr));
return s;
}
TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr));
loaded_libs[library_filename] = library;
}
}
string str;
library.op_list.SerializeToString(&str);
char* str_buf = reinterpret_cast<char*>(port::Malloc(str.length()));
memcpy(str_buf, str.data(), str.length());
*buf = str_buf;
*len = str.length();
*result = library.handle;
return Status::OK();
}
自动生成的Python代码,这里是对应的C++ Op:
"""Python wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit. """
import collections as _collections
from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
# Needed to trigger the call to _set_call_cpp_shape_fn.
from tensorflow.python.framework import common_shapes as _common_shapes
from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
from tensorflow.python.util.tf_export import tf_export
_ps_roi_align_outputs = ["pooled_features", "pooled_index"]
_PsRoiAlignOutput = _collections.namedtuple(
"PsRoiAlign", _ps_roi_align_outputs)
@tf_export('ps_roi_align')
def ps_roi_align(inputs, rois, grid_dim_width, grid_dim_height, name=None):
r""" PsRoiAlign is a new PsRoiPooling method without align problems. The input rois to be pooled must in format [center_y, center_x, h, w] and each element must be in range [0, 1.]. The caller must make sure that all rois is valid (has a intersect region (one pixel at least) with the window [0.5, 0.5, 1., 1.]). Args: inputs: A `Tensor`. Must be one of the following types: `float32`. rois: A `Tensor`. Must have the same type as `inputs`. grid_dim_width: An `int`. grid_dim_height: An `int`. name: A name for the operation (optional). Returns: A tuple of `Tensor` objects (pooled_features, pooled_index). pooled_features: A `Tensor`. Has the same type as `inputs`. pooled_index: A `Tensor` of type `int32`. """
_result = _op_def_lib.apply_op("PsRoiAlign", inputs=inputs, rois=rois,
grid_dim_width=grid_dim_width,
grid_dim_height=grid_dim_height, name=name)
_result = _PsRoiAlignOutput._make(_result)
return _result
_ops.RegisterShape("PsRoiAlign")(None)
@tf_export('ps_roi_align_grad')
def ps_roi_align_grad(inputs, rois, pooled_features_grad, pooled_index, grid_dim_width, grid_dim_height, name=None):
r""" PsRoiAlignGrad is the Gradient op of PsRoiAlign. The input rois to be pooled must in format [center_y, center_x, h, w] and each element must be in range [0, 1.]. The caller must make sure that all rois is valid (has a intersect region (one pixel at least) with the window [0.5, 0.5, 1., 1.]). Args: inputs: A `Tensor`. Must be one of the following types: `float32`. rois: A `Tensor`. Must have the same type as `inputs`. pooled_features_grad: A `Tensor`. Must have the same type as `inputs`. pooled_index: A `Tensor` of type `int32`. grid_dim_width: An `int`. grid_dim_height: An `int`. name: A name for the operation (optional). Returns: A `Tensor`. Has the same type as `inputs`. """
_result = _op_def_lib.apply_op("PsRoiAlignGrad", inputs=inputs, rois=rois,
pooled_features_grad=pooled_features_grad,
pooled_index=pooled_index,
grid_dim_width=grid_dim_width,
grid_dim_height=grid_dim_height, name=name)
return _result
_ops.RegisterShape("PsRoiAlignGrad")(None)
def _InitOpDefLibrary(op_list_proto_bytes):
op_list = _op_def_pb2.OpList()
op_list.ParseFromString(op_list_proto_bytes)
_op_def_registry.register_op_list(op_list)
op_def_lib = _op_def_library.OpDefLibrary()
op_def_lib.add_op_list(op_list)
return op_def_lib
# op {
# name: "PsRoiAlign"
# input_arg {
# name: "inputs"
# type_attr: "T"
# }
# input_arg {
# name: "rois"
# type_attr: "T"
# }
# output_arg {
# name: "pooled_features"
# type_attr: "T"
# }
# output_arg {
# name: "pooled_index"
# type: DT_INT32
# }
# attr {
# name: "T"
# type: "type"
# allowed_values {
# list {
# type: DT_FLOAT
# }
# }
# }
# attr {
# name: "grid_dim_width"
# type: "int"
# }
# attr {
# name: "grid_dim_height"
# type: "int"
# }
# }
# op {
# name: "PsRoiAlignGrad"
# input_arg {
# name: "inputs"
# type_attr: "T"
# }
# input_arg {
# name: "rois"
# type_attr: "T"
# }
# input_arg {
# name: "pooled_features_grad"
# type_attr: "T"
# }
# input_arg {
# name: "pooled_index"
# type: DT_INT32
# }
# output_arg {
# name: "grad_output"
# type_attr: "T"
# }
# attr {
# name: "T"
# type: "type"
# allowed_values {
# list {
# type: DT_FLOAT
# }
# }
# }
# attr {
# name: "grid_dim_width"
# type: "int"
# }
# attr {
# name: "grid_dim_height"
# type: "int"
# }
# }
_op_def_lib = _InitOpDefLibrary(b"\\n\\215\\001\\n\\nPsRoiAlign\\022\\013\\n\\006inputs\\"\\001T\\022\\t\\n\\004rois\\"\\001T\\032\\024\\n\\017pooled_features\\"\\001T\\032\\020\\n\\014pooled_index\\030\\003\\"\\020\\n\\001T\\022\\004type:\\005\\n\\0032\\001\\001\\"\\025\\n\\016grid_dim_width\\022\\003int\\"\\026\\n\\017grid_dim_height\\022\\003int\\n\\250\\001\\n\\016PsRoiAlignGrad\\022\\013\\n\\006inputs\\"\\001T\\022\\t\\n\\004rois\\"\\001T\\022\\031\\n\\024pooled_features_grad\\"\\001T\\022\\020\\n\\014pooled_index\\030\\003\\032\\020\\n\\013grad_output\\"\\001T\\"\\020\\n\\001T\\022\\004type:\\005\\n\\0032\\001\\001\\"\\025\\n\\016grid_dim_width\\022\\003int\\"\\026\\n\\017grid_dim_height\\022\\003int")