Tensorflow是如何注册和调用C++ New Op的

寒假回来想起来之前挖的坑,但好像并没有特别好的主题可以写,更不用说实习招聘近在眼前了,于是打算先扩展一下之前在知乎上的两个回答。

本文主要介绍动态链接的C++ New Op是如何被注册进来,又如何被Python代码调用的,也算是给自己的一个交代,毕竟本人一直不太喜欢high-level的API。本文大致分为三个模块:注册Ops,注册Kernel,调用Ops。

  • Ops的注册过程

先说一下OpRegistrationData这个东西,这个类的对象由全局注册器Registry负责分配,作用简单来说就是保存OpDef和OpShapeInferenceFn函数,前者保存有Op的各种具体信息,会由OpDefBuilder在最后的解析参数时(成员函数Finalize)放进来,后者在SetShapeFn传进来(由Wrapper转发),所谓注册就是将op name和OpRegistrationData关联起来,具体来说放进hashmap。

mutable std::unordered_map<string, const OpRegistrationData*> registry_;

还得先说一下OpDefBuilder这个类,OpDefBuilder会负责接收Op的各种属性和参数定义(就是REGISTER_OP时指定的,见下),最后统一解析(注意只是解析并不保证合法性之类的)并转给OpRegistrationData这个类(包括ShapeFn)。

我们自己注册op都会通过下面这个宏定义:

REGISTER_OP("YourOp")
  .Attr("T: {float}")
  .Input("logits: T")
  .Input("Labels: T")
  .Output("loss: T")
  .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
    c->set_output(0, c->MakeShape({1}));
    return Status::OK();
  });

细节都在REGISTER_OP那个宏定义里面,简化如下:

static OpDefBuilderReceiver register_op = OpDefBuilderWrapper('YourOp')

其中OpDefBuilderWrapper内部保存有一个OpDefBuilder成员变量,你所有对REGISTER_OP宏连续调用的操作包括op的名字最后都会一股脑转发给前面那个唯一的OpDefBuilder变量,而OpDefBuilderReceiver则拿过来BuilderWrapper交给一个负责管理所有Op注册的Registry,Registry暴露Register方法给op们注册,把官方的example摘过来示意一下:

//Example registration:
 OpRegistry::Global()->Register(
   [](OpRegistrationData* op_reg_data)->Status {
     // Populate *op_reg_data here.
     return Status::OK();
 });

(先解释下:OpRegistry::Global()简单的单例模式,返回OpRegistry的全局唯一实例,当然这里必须要感谢下新标准对static线程安全的保证。)

在那个lambda里面你就可以做任何想做的事情了,比如就像OpDefBuilderReceiver一样把BuilderWrapper拿进来,然后把wrapper去掉取出OpDefBuilder,看到上面lambda里面那个op_reg_data没,对这就是之前提到的将解析好参数及shapefn传到OpRegistrationData里,最后Register拿到op的name和OpRegistrationData组成pair放进hashmap完成注册,同时会做一些合法性检查的事情。如下:

OpRegistry::Global()->Register(
     [wrapper](OpRegistrationData* op_reg_data) -> Status {
       return wrapper.builder().Finalize(op_reg_data);
     });

其实到这里真正的注册并不一定会发生,下面会详细说。

  • Kernel的注册过程

与Ops的注册类似,也是有一个叫作KernelDefBuilder的wrapper,内部保存有KernelDef的一个指针,用于设置各种属性,最后调用Build函数可返回该指针并清空Builder,Kernel的注册主要是通过下面这个宏来实现的:

REGISTER_KERNEL_BUILDER(                                       \
     Name("PsRoiAlignGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"), \
     PSROIAlignGradOp<GPUDevice, float>);

其中Name是KernelDefBuilder的一个派生类,Name(“KernelName”)会首先创建一个KernelDefBuilder同时设置设置kernel名称,每次调用这种setter函数就会返回Builder自身从而支持连续调用,然后是设置Device,最后添加值float到属性T中。

class Name : public KernelDefBuilder {
public:
 // For system kernels, we ignore selective registration and
 // unconditionally register the kernel.
 explicit Name(const char* op) : KernelDefBuilder(op) {}
};

REGISTER_KERNEL_BUILDER宏里面就是一些trick,实质是创建一个名称唯一的类型为OpKernelRegistrar的全局静态变量,如果你有兴趣可以看一下:

#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
 REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)

#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
 REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)

#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)        \
 constexpr bool should_register_##ctr##__flag =                      \
     SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__);                        \
 static ::tensorflow::kernel_factory::OpKernelRegistrar              \
     registrar__body__##ctr##__object(                               \
         should_register_##ctr##__flag                               \
             ? ::tensorflow::register_kernel::kernel_builder.Build() \
             : nullptr,                                              \
         #__VA_ARGS__,                                               \
         [](::tensorflow::OpKernelConstruction* context)             \
             -> ::tensorflow::OpKernel* {                            \
           return new __VA_ARGS__(context);                          \
         });

OpKernelRegistrar静态变量的构造需要三个参数,如下所示,第一个是KernelDef,第二个是定义Kernel的类名,第三个是创建kernel对象的函数,其实后面就可以知道这三个参数都会被包装到KernelRegistration这个结构体里,然后作为Kernel注册表的值。因此这个宏会首先调用KernelDefBuilder的Build函数获得对应的KernelDef;然后获取用于创建这个Kernel的C++类名称(这个类是继承自OpKernel的);最后包装一个factory函数用来接收传进来的OpKernelConstruction*,创建对应的Kernel类对象,并返回其指针。

class OpKernelRegistrar {
public:
 typedef OpKernel* (*Factory)(OpKernelConstruction*);

 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
                   Factory factory) {
   if (kernel_def != nullptr) {
     InitInternal(kernel_def, kernel_class_name, factory);
   }
 }
};

这里是InitInternal的细节

void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
                                    StringPiece kernel_class_name,
                                    Factory factory) {
 // See comments in register_kernel::Name in header for info on _no_register.
 if (kernel_def->op() != "_no_register") {
   const string key =
       Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
           kernel_def->label());
   GlobalKernelRegistryTyped()->insert(std::make_pair(
       key, KernelRegistration(*kernel_def, kernel_class_name, factory)));
 }
 delete kernel_def;
}

可以看到OpKernelRegistrar这个类主要是负责根据传进来的KernelDef和KernelFactory,首先依据一定规则生成一个适当的key,并插入到一个全局唯一的Kernel注册表里,注册表当然是一个map但是值得注意的是它是multimap因此支持一个键对应多个kernel副本。

typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;
  • OpKernel的创建与调用

如果你还记得的话,前面还有一个全局的OpRegistry,这样根据NodeDef里的Op名称就可以获得Op对应的信息,再结合设备类型也就可以获得Kernel对应的信息了,而NodeDef是在Python创建Operation之前创建的,可以看这里create_op,后面会提到调用这个函数的地方。

然后就可以根据一个NodeDef和当前的设备类型在运行时创建一个OpKernel了,每个被创建的OpKernel都会被自动地管理生命周期。在Device类中会有一个OpSegment对象,OpSegment会管理一个sessions中用到的kernel,根据情况来决定是创建新的还是复用之前的OpKernel,具体来说是有两个嵌套的hashmap,第一个将session handle映射到一个KernelMap,然后在KernelMap就可以去查找是否有对应Op名的OpKernel,如果没有就调用一个create_fn函数进行创建。

那么问题来了,这背后的原动力在哪?事实上Session在第一次为某个Node创建Executor的时候这一切就发生了(后面会再说到Executor的):DirectSession::GetOrCreateExecutors,更直接地可以看查找失败后第一次创建Executor的地方,代码片段如下:

LocalExecutorParams params;
params.device = device;
params.function_library = lib;
auto opseg = device->op_segment();
params.create_kernel = [this, lib, opseg](const NodeDef& ndef, OpKernel** kernel) {
 // We do not share the kernel via the OpSegment if the node is
 // stateless, or a function.
 // NOTE(mrry): We must not share function kernels (implemented
 // using `CallOp`) between subgraphs, because `CallOp::handle_`
 // is tied to a particular subgraph. Even if the function itself
 // is stateful, the `CallOp` that invokes it is not.
 if (!lib->IsStateful(ndef.op()) ||
     lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
   return lib->CreateKernel(ndef, kernel);
 }
 auto create_fn = [lib, &ndef](OpKernel** kernel) {
   return lib->CreateKernel(ndef, kernel);
 };
 // Kernels created for subgraph nodes need to be cached.  On
 // cache miss, create_fn() is invoked to create a kernel based
 // on the function library here + global op registry.
 return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
                            create_fn);
};

可以看到取出OpSegment,构造create_fn并调用FindOrCreate的过程。其中create_fn内部调用的FunctionLibraryRuntime的CreateKernel函数可以看这里:FunctionLibraryRuntimeImpl::CreateKernel,再往下CreateNonCachedKernel

Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
                            const NodeDef& ndef, int graph_def_version,
                            OpKernel** kernel) {
 const auto device_type = DeviceType(device->attributes().device_type());
 auto allocator = device->GetAllocator(AllocatorAttributes());
 return CreateOpKernel(device_type, device, allocator, flib, ndef,
                       graph_def_version, kernel);
}

看到了CreateOpKernel的调用,这下总算回到了我们最开始的地方CreateOpKernel

Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
                     Allocator* allocator, FunctionLibraryRuntime* flib,
                     const NodeDef& node_def, int graph_def_version,
                     OpKernel** kernel)

这个核心函数主要是做一下以下几件事情:根据node_def取出op名,去查OpRegistry,并与node_def的信息进行校验,比如接口是否一致,node_def中是否包含所有op_def中的信息等,然后根据device_type和op名去查KernelRegistry获取KernelRegistration,就是map中的值,包含之前提到的三项。接着是确定输入输出类型及其存储位置,最后是创建一个OpKernelConstruction对象,并传给Kernel的factory函数函数,这就到了用户自己写的函数这边了:

// Everything needed for OpKernel construction.
OpKernelConstruction context(
 device_type, device, allocator, &node_def, op_def, flib, inputs,
 input_memory_types, outputs, output_memory_types, graph_def_version, &s);
*kernel = (*registration->factory)(&context);

Kernel创建完了,那么它什么时候被执行呢?前面说到第一次创建executor的时候会创建OpKernel,其实每次Session调用Run的时候最终也是转到executor这边来执行的,包括根据当前的运行时环境创建OpKernelContext以及OpKernel::Compute的调用

// Synchronous computes.
OpKernelContext ctx(&params, item.num_outputs);
nodestats::SetOpStart(stats);
device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
nodestats::SetOpEnd(stats);

其中device->Compute这一步通过查看基类的实现就大概能知道所有细节了Device::Compute

// Performs the actual compute function.
//
// Subclasses may override this function if they wish to perform
// some initialization before each compute.
virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) {
 op_kernel->Compute(context);
}

可以发现,我们写的Compute方法在这里就被调用了。至此故事好像可以告一段落了,不过说了半天好像一直在C++这边啊,那Python代码怎么调用的呢?

  • 注册Ops和Kernel后传

根据上面REGISTER_KERNEL_BUILDER所展开的两段程序很容易就判断出如果动态库被加载进来的话,Kernel就会自动完成注册,这跟Ops的注册基本是一样的,不同之处在于动态链接进来的Ops会在加载库之前设置延迟注册的标记,并添加一个Watcher,然后手动调用注册,这主要是为了通过Watcher获取注册过程中从OpRegistrationData(就是注册表的值)中取出的OpDef,这一点可以在后面的LoadLibrary中看到。这个过程很重要,通过获得的OpDef组成的OpList并序列化后,Python端就可以解析出这些OpDef,同时调用C++这边利用这些OpDef生成对应的ApiDef,二者结合就可以动态生成定义这个Op的Python代码,然后返回到Python端执行这些代码,注意这些代码的执行并不包括创建Op并添加到Graph这个过程,只包括定义相关代码段的函数,下面是从Python端load_op_library一直到生成Python代码的过程:load_op_library->GetPythonWrappers->GetPythonOps->GetPythonOp->GenPythonOp::Code()。还有从OpList生成ApiDef的地方ApiDefMap::ApiDefMap(const OpList& op_list)。如果你有兴趣的话可以去看一下我之前写的一个Op自动生成的代码,我附在了本文最后,生成代码中的apply_op就是添加Op到Graph的代码,可以看这里apply_op,这个函数的最后面就是前面提到的调用Graph的create_op。

下面是LoadLibrary的代码段,可以对照一下:

Status LoadLibrary(const char* library_filename, void** result,
                  const void** buf, size_t* len) {
 static mutex mu;
 static std::unordered_map<string, Library> loaded_libs;
 Env* env = Env::Default();
 Library library;
 std::unordered_set<string> seen_op_names;
 {
   mutex_lock lock(mu);
   if (loaded_libs.find(library_filename) != loaded_libs.end()) {
     library = loaded_libs[library_filename];
   } else {
     Status s = OpRegistry::Global()->ProcessRegistrations();
     if (!s.ok()) {
       return s;
     }
     TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(
         [&library, &seen_op_names](const Status& s,
                                    const OpDef& opdef) -> Status {
           if (errors::IsAlreadyExists(s)) {
             if (seen_op_names.find(opdef.name()) == seen_op_names.end()) {
               // Over writing a registration of an op not in this custom op
               // library. Treat this as not an error.
               return Status::OK();
             }
           }
           if (s.ok()) {
             *library.op_list.add_op() = opdef;
             seen_op_names.insert(opdef.name());
           }
           return s;
         }));
     OpRegistry::Global()->DeferRegistrations();
     s = env->LoadLibrary(library_filename, &library.handle);
     if (s.ok()) {
       s = OpRegistry::Global()->ProcessRegistrations();
     }
     if (!s.ok()) {
       OpRegistry::Global()->ClearDeferredRegistrations();
       TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr));
       return s;
     }
     TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr));

     loaded_libs[library_filename] = library;
   }
 }
 string str;
 library.op_list.SerializeToString(&str);
 char* str_buf = reinterpret_cast<char*>(port::Malloc(str.length()));
 memcpy(str_buf, str.data(), str.length());
 *buf = str_buf;
 *len = str.length();

 *result = library.handle;
 return Status::OK();
}

自动生成的Python代码,这里是对应的C++ Op

"""Python wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit. """
import collections as _collections
from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
# Needed to trigger the call to _set_call_cpp_shape_fn.
from tensorflow.python.framework import common_shapes as _common_shapes
from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
from tensorflow.python.util.tf_export import tf_export

_ps_roi_align_outputs = ["pooled_features", "pooled_index"]
_PsRoiAlignOutput = _collections.namedtuple(
   "PsRoiAlign", _ps_roi_align_outputs)

@tf_export('ps_roi_align')
def ps_roi_align(inputs, rois, grid_dim_width, grid_dim_height, name=None):
 r""" PsRoiAlign is a new PsRoiPooling method without align problems.   The input rois to be pooled must in format [center_y, center_x, h, w] and each element must be in range [0, 1.].  The caller must make sure that all rois is valid (has a intersect region (one pixel at least) with the window [0.5, 0.5, 1., 1.]).   Args:  inputs: A `Tensor`. Must be one of the following types: `float32`.  rois: A `Tensor`. Must have the same type as `inputs`.  grid_dim_width: An `int`.  grid_dim_height: An `int`.  name: A name for the operation (optional).   Returns:  A tuple of `Tensor` objects (pooled_features, pooled_index).   pooled_features: A `Tensor`. Has the same type as `inputs`.  pooled_index: A `Tensor` of type `int32`.  """
 _result = _op_def_lib.apply_op("PsRoiAlign", inputs=inputs, rois=rois,
                                grid_dim_width=grid_dim_width,
                                grid_dim_height=grid_dim_height, name=name)
 _result = _PsRoiAlignOutput._make(_result)
 return _result

_ops.RegisterShape("PsRoiAlign")(None)

@tf_export('ps_roi_align_grad')
def ps_roi_align_grad(inputs, rois, pooled_features_grad, pooled_index, grid_dim_width, grid_dim_height, name=None):
 r""" PsRoiAlignGrad is the Gradient op of PsRoiAlign.   The input rois to be pooled must in format [center_y, center_x, h, w] and each element must be in range [0, 1.].  The caller must make sure that all rois is valid (has a intersect region (one pixel at least) with the window [0.5, 0.5, 1., 1.]).   Args:  inputs: A `Tensor`. Must be one of the following types: `float32`.  rois: A `Tensor`. Must have the same type as `inputs`.  pooled_features_grad: A `Tensor`. Must have the same type as `inputs`.  pooled_index: A `Tensor` of type `int32`.  grid_dim_width: An `int`.  grid_dim_height: An `int`.  name: A name for the operation (optional).   Returns:  A `Tensor`. Has the same type as `inputs`.  """
 _result = _op_def_lib.apply_op("PsRoiAlignGrad", inputs=inputs, rois=rois,
                                pooled_features_grad=pooled_features_grad,
                                pooled_index=pooled_index,
                                grid_dim_width=grid_dim_width,
                                grid_dim_height=grid_dim_height, name=name)
 return _result

_ops.RegisterShape("PsRoiAlignGrad")(None)
def _InitOpDefLibrary(op_list_proto_bytes):
 op_list = _op_def_pb2.OpList()
 op_list.ParseFromString(op_list_proto_bytes)
 _op_def_registry.register_op_list(op_list)
 op_def_lib = _op_def_library.OpDefLibrary()
 op_def_lib.add_op_list(op_list)
 return op_def_lib

# op {
# name: "PsRoiAlign"
# input_arg {
# name: "inputs"
# type_attr: "T"
# }
# input_arg {
# name: "rois"
# type_attr: "T"
# }
# output_arg {
# name: "pooled_features"
# type_attr: "T"
# }
# output_arg {
# name: "pooled_index"
# type: DT_INT32
# }
# attr {
# name: "T"
# type: "type"
# allowed_values {
# list {
# type: DT_FLOAT
# }
# }
# }
# attr {
# name: "grid_dim_width"
# type: "int"
# }
# attr {
# name: "grid_dim_height"
# type: "int"
# }
# }
# op {
# name: "PsRoiAlignGrad"
# input_arg {
# name: "inputs"
# type_attr: "T"
# }
# input_arg {
# name: "rois"
# type_attr: "T"
# }
# input_arg {
# name: "pooled_features_grad"
# type_attr: "T"
# }
# input_arg {
# name: "pooled_index"
# type: DT_INT32
# }
# output_arg {
# name: "grad_output"
# type_attr: "T"
# }
# attr {
# name: "T"
# type: "type"
# allowed_values {
# list {
# type: DT_FLOAT
# }
# }
# }
# attr {
# name: "grid_dim_width"
# type: "int"
# }
# attr {
# name: "grid_dim_height"
# type: "int"
# }
# }
_op_def_lib = _InitOpDefLibrary(b"\\n\\215\\001\\n\\nPsRoiAlign\\022\\013\\n\\006inputs\\"\\001T\\022\\t\\n\\004rois\\"\\001T\\032\\024\\n\\017pooled_features\\"\\001T\\032\\020\\n\\014pooled_index\\030\\003\\"\\020\\n\\001T\\022\\004type:\\005\\n\\0032\\001\\001\\"\\025\\n\\016grid_dim_width\\022\\003int\\"\\026\\n\\017grid_dim_height\\022\\003int\\n\\250\\001\\n\\016PsRoiAlignGrad\\022\\013\\n\\006inputs\\"\\001T\\022\\t\\n\\004rois\\"\\001T\\022\\031\\n\\024pooled_features_grad\\"\\001T\\022\\020\\n\\014pooled_index\\030\\003\\032\\020\\n\\013grad_output\\"\\001T\\"\\020\\n\\001T\\022\\004type:\\005\\n\\0032\\001\\001\\"\\025\\n\\016grid_dim_width\\022\\003int\\"\\026\\n\\017grid_dim_height\\022\\003int")

    原文作者:Kapok
    原文地址: https://zhuanlan.zhihu.com/p/34168765
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞