tensorflow自定义算子开发1:CPU实例

人工智能109

本文将介绍如果用C++在tensorflow中新建一个算子,参考官方文档通过一个简单的例子来说明。操作系统是Ubuntu,且系统已经安装tensorflow。

首先,创建一个名为 zero_out.cc 的文件,所有内容均在本文件中实现

定义运算接口

对于一个新的操作,首先要在C++中定义这个操作,通过将 接口注册到 TensorFlow 系统来定义运算的接口。注册中需要指定该运算的 名称、输入 类型名称以及输出 类型名称,还有 文档字符串和该运算可能需要的任意特性。下面给出示例:

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

这里给出了具体的注册过程,这里调用REGISTER_OP宏注册了一个ZeroOut的操作,输入命名为to_zero,类型为int32,输出命名zeroed,类型为int32,最后set_output用来保证输入输出的维度是一致的。

实现运算内核

定义完接口之后,可以为此操作定义一个或多个 内核实现,内核的实现需要继承 <strong>OpKernel</strong>&#x7C7B;&#xFF0C;&#x5E76;&#x4E14;&#x91CD;&#x8F7D;<strong>Compute</strong>&#x65B9;&#x6CD5;&#xFF0C;Compute&#x53C2;&#x6570;&#x4E2D;&#x7C7B;&#x578B;&#x4E3A;<strong>OpKernelContext*</strong>的参数 context,从中可以访问输入输出张量等有用信息。内核代码如下:

#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // 得到输入张量
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat();

    // 创建输出张量
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat();

    // 除了第一个元素,其他元素全部置为0
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // 如果输入维度大于0,输出第1维度等于输入第1维度
    if (N > 0) output_flat(0) = input(0);
  }
};

内核注册

内核实现后,需要将其注册到tensorflow系统,需要指定内核执行时的不同约束条件,例如针对CPU和GPU通常会有两个内核。

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

需要注意 <strong>OpKernel</strong>&#x53EF;&#x80FD;&#x5E76;&#x884C;&#x65B9;&#x4F4D;&#xFF0C;&#x9700;&#x8981;&#x4FDD;&#x8BC1;&#x7EBF;&#x7A0B;&#x5B89;&#x5168;&#x3002;&#x8FD9;&#x91CC;&#x53EA;&#x7ED9;&#x51FA;&#x4E86;CPU&#x5185;&#x6838;&#x7684;&#x5B9E;&#x73B0;&#xFF0C;GPU&#x5185;&#x6838;&#x5C06;&#x5728;&#x4E0B;&#x4E00;&#x8282;&#x4ECB;&#x7ECD;&#x3002;

完整代码

给出完整代码:

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat();

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat();

    // Set all but the first element of the output tensor to 0.

    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.

    if (N > 0) output_flat(0) = input(0);
  }
};

构建运算库

这里采用系统编译库来实现,采用g++编译器,且系统已经安装二进制tensorflow, PIP 包管理器来安装二进制 TensorFlow 时,已经包含了编译操作所需的头文件和库文件。但是,TensorFlow Python 库已经提供了 get_include 函数来获取头文件目录,以及 get_lib 函数来回去库目录。 U可以测试如下函数的输出

$ python
>>> import tensorflow as tf
>>> tf.sysconfig.get_include()
/usr/local/lib/python3.5/dist-packages/tensorflow/include
>>> tf.sysconfig.get_lib()
'/usr/local/lib/python3.6/site-packages/tensorflow'

构建

基于tensorflow的两个接口函数,我们用g++来编译新的算子

TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
g++ -std=c++11 -shared zero_out.cc -o zero_out.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2

正常情况下会生成动态库:

tensorflow自定义算子开发1:CPU实例

python中使用

tensorflow的python接口,提供了函数 tf.load_op_library 来加载动态库并向tensorflow 框架注册运算。 load_op_library 会返回一个 python 模块,其中包含运算和内核的 Python 封装容器。因此,在构建此运算后,就可以执行刚才定义的 算子

import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.so')
sess = tf.Session()
result = zero_out_module.zero_out([[1, 2], [3, 4]])
print("****************")
print(sess.run(result))

执行结果如下:

[[1 0]
 [0 0]]

Original: https://blog.csdn.net/fangfanglovezhou/article/details/124573170
Author: I_belong_to_jesus
Title: tensorflow自定义算子开发1:CPU实例