在线观看不卡亚洲电影_亚洲妓女99综合网_91青青青亚洲娱乐在线观看_日韩无码高清综合久久

鍍金池/ 教程/ 人工智能/ 增加一個(gè)新 Op <a class="md-anchor" id="AUTOGENERATED-adding-a-new-op"
BibTex 引用<a class="md-anchor" id="AUTOGENERATED-bibtex-citation"
術(shù)語(yǔ)表
自定義數(shù)據(jù)讀取 <a class="md-anchor" id="AUTOGENERATED-custom-data-reade
使用 GPUs <a class="md-anchor" id="AUTOGENERATED-using-gpus"></a>
Vector Representations of Words <a class="md-anchor" id="AUTOGEN
TensorFlow 個(gè)人學(xué)習(xí)心得
共享變量<a class="md-anchor" id="AUTOGENERATED-sharing-variables"></
應(yīng)用實(shí)例 <a class="md-anchor" id="AUTOGENERATED-example-uses"></a>
其他資源 <a class="md-anchor" id="AUTOGENERATED-additional-resources
偏微分方程 <a class="md-anchor" id="AUTOGENERATED-partial-differentia
TensorBoard:可視化學(xué)習(xí) <a class="md-anchor" id="AUTOGENERATED-tensorb
TensorFlow運(yùn)作方式入門(mén) <a class="md-anchor" id="AUTOGENERATED-tensorfl
常見(jiàn)問(wèn)題 <a class="md-anchor" id="AUTOGENERATED-frequently-asked-que
MNIST機(jī)器學(xué)習(xí)入門(mén) <a class="md-anchor" id="AUTOGENERATED-mnist-for-ml-
曼德布洛特(Mandelbrot)集合 <a class="md-anchor" id="AUTOGENERATED-mande
變量:創(chuàng)建、初始化、保存和加載
TensorBoard: 圖表可視化 <a class="md-anchor" id="AUTOGENERATED-tensor
簡(jiǎn)介 <a class="md-anchor" id="AUTOGENERATED-introduction"></a>
張量的階、形狀、數(shù)據(jù)類(lèi)型<a class="md-anchor" id="AUTOGENERATED-tensor-ranks-
線(xiàn)程和隊(duì)列 <a class="md-anchor" id="AUTOGENERATED-threading-and-queue
下載與安裝 <a class="md-anchor" id="AUTOGENERATED-download-and-setup"
常見(jiàn)問(wèn)題匯總
綜述
綜述 Overview
TensorFlow 相關(guān)資源
數(shù)據(jù)讀取 <a class="md-anchor" id="AUTOGENERATED-reading-data"></a>
遞歸神經(jīng)網(wǎng)絡(luò) <a class="md-anchor" id="AUTOGENERATED-recurrent-neural-n
深入MNIST <a class="md-anchor" id="AUTOGENERATED-deep-mnist-for-ex
增加一個(gè)新 Op <a class="md-anchor" id="AUTOGENERATED-adding-a-new-op"
卷積神經(jīng)網(wǎng)絡(luò) <a class="md-anchor" id="AUTOGENERATED-convolutional-neur
基本使用 <a class="md-anchor" id="AUTOGENERATED-basic-usage"></a>
MNIST 數(shù)據(jù)下載 <a class="md-anchor" id="AUTOGENERATED-mnist-data-dow

增加一個(gè)新 Op <a class="md-anchor" id="AUTOGENERATED-adding-a-new-op"

預(yù)備知識(shí):

如果現(xiàn)有的庫(kù)沒(méi)有涵蓋你想要的操作, 你可以自己定制一個(gè). 為了使定制的 Op 能夠兼容原有的庫(kù) , 你必須做以下工作:

  • 在一個(gè) C++ 文件中注冊(cè)新 Op. Op 的注冊(cè)與實(shí)現(xiàn)是相互獨(dú)立的. 在其注冊(cè)時(shí)描述了 Op 該如何執(zhí)行. 例如, 注冊(cè) Op 時(shí)定義了 Op 的名字, 并指定了它的輸入和輸出.
  • 使用 C++ 實(shí)現(xiàn) Op. 每一個(gè)實(shí)現(xiàn)稱(chēng)之為一個(gè) "kernel", 可以存在多個(gè) kernel, 以適配不同的架構(gòu) (CPU, GPU 等)或不同的輸入/輸出類(lèi)型.
  • 創(chuàng)建一個(gè) Python 包裝器(wrapper). 這個(gè)包裝器是創(chuàng)建 Op 的公開(kāi) API. 當(dāng)注冊(cè) Op 時(shí), 會(huì)自動(dòng)生成一個(gè)默認(rèn) 默認(rèn)的包裝器. 既可以直接使用默認(rèn)包裝器, 也可以添加一個(gè)新的包裝器.
  • (可選) 寫(xiě)一個(gè)函數(shù)計(jì)算 Op 的梯度.
  • (可選) 寫(xiě)一個(gè)函數(shù), 描述 Op 的輸入和輸出 shape. 該函數(shù)能夠允許從 Op 推斷 shape.
  • 測(cè)試 Op, 通常使用 Pyhton。如果你定義了梯度,你可以使用Python的GradientChecker來(lái)測(cè)試它。

內(nèi)容

增加一個(gè)新 Op

定義 Op 的接口

向 TensorFlow 系統(tǒng)注冊(cè)來(lái)定義 Op 的接口. 在注冊(cè)時(shí), 指定 Op 的名稱(chēng), 它的輸入(類(lèi)型和名稱(chēng)) 和輸出(類(lèi)型和名稱(chēng)), 和所需要任何 屬性的文檔說(shuō)明.

為了讓你有直觀的認(rèn)識(shí), 創(chuàng)建一個(gè)簡(jiǎn)單的 Op 作為例子. 該 Op 接受一個(gè) int32 類(lèi)型 tensor 作為 輸入, 輸出這個(gè) tensor 的一個(gè)副本, 副本與原 tensor 唯一的區(qū)別在于第一個(gè)元素被置為 0. 創(chuàng)建 文件 tensorflow/core/user_ops/zero_out.cc, 并調(diào)用 REGISTER_OP 宏來(lái)定義 Op 的接口.

 #include "tensorflow/core/framework/op.h"
REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32");

ZeroOut Op 接受 32 位整型的 tensor to_zero 作為輸入, 輸出 32 位整型的 tensor zeroed.

命名的注意事項(xiàng): Op 的名稱(chēng)必須是為唯一的, 并使用駝峰命名法. 以下劃線(xiàn) _ 開(kāi)始的名稱(chēng)保留為內(nèi)部使用.

為 Op 實(shí)現(xiàn) kernel

在定義接口之后, 提供一個(gè)或多個(gè) Op 的實(shí)現(xiàn). 為這些 kernel 的每一個(gè)創(chuàng)建一個(gè)對(duì)應(yīng)的類(lèi), 繼承 OpKernel, 覆蓋 Compute 方法. Compute 方法提供一個(gè)類(lèi)型為 OpKernelContext* 的參數(shù) context, 用于訪(fǎng)問(wèn)一些有用的信息, 例如輸入和輸出的 tensor.

將 kernel 添加到剛才創(chuàng)建的文件中, kernel 看起來(lái)和下面的代碼類(lèi)似:

 #include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
  void Compute(OpKernelContext* context) override {
    // 獲取輸入 tensor.
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();
   // 創(chuàng)建一個(gè)輸出 tensor.
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output = output_tensor->template flat<int32>();
    // 設(shè)置 tensor 除第一個(gè)之外的元素均設(shè)為 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output(i) = 0;
    }
    // 盡可能地保留第一個(gè)元素的值.
    if (N > 0) output(0) = input(0);
  }
};

實(shí)現(xiàn) kernel 后, 將其注冊(cè)到 TensorFlow 系統(tǒng)中. 注冊(cè)時(shí), 可以指定該 kernel 運(yùn)行時(shí)的多個(gè)約束 條件. 例如可以指定一個(gè) kernel 在 CPU 上運(yùn)行, 另一個(gè)在 GPU 上運(yùn)行.

將下列代碼加入到 zero_out.cc 中, 注冊(cè) ZeroOut op:

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

一旦創(chuàng)建和重新安裝了 TensorFlow , Tensorflow 系統(tǒng)可以在需要時(shí)引用和使用該 Op.

生成客戶(hù)端包裝器

Python Op 包裝器

當(dāng)編譯 TensorFlow 時(shí), 所有放在 tensorflow/core/user_ops 目錄下 的 Op 會(huì)自動(dòng)在 bazel-genfiles/tensorflow/python/ops/gen_user_ops.py 文件 中生成 Python Op 包裝器. 通過(guò)以下聲明, 把那些 Op 引入到 tensorflow/python/user_ops/user_ops.py 中:

from tensorflow.python.ops.gen_user_ops import *

你可以選擇性將部分函數(shù)替換為自己的實(shí)現(xiàn). 為此, 首先要隱藏自動(dòng)生成的代碼, 在 tensorflow/python/BUILD 文件中, 將其名字添加到 "user_ops"hidden 列表.

tf_gen_op_wrapper_py(
    name = "user_ops",
    hidden = [
        "Fact",
    ],
    require_shape_functions = False,
)

緊接著 "Fact" 列出自己的 Op. 然后, 在 tensorflow/python/user_ops/user_ops.py 中添加你的替代實(shí)現(xiàn)函數(shù). 通常, 替代實(shí)現(xiàn)函數(shù)也會(huì)調(diào)用自動(dòng)生成函數(shù)來(lái)真正把 Op 添加 到圖中. 被隱藏的自動(dòng)生成函數(shù)位于 gen_user_ops 包中, 名稱(chēng)多了一個(gè)下劃線(xiàn)前綴 ("_"). 例如:

def my_fact():
    """覆蓋一個(gè) Op 自動(dòng)生成代碼的示例."""
    return gen_user_ops._fact()

C++ Op 包裝器

當(dāng)編譯 TensorFlow 時(shí), 所有 tensorflow/core/user_ops 文件夾 下的 Op 會(huì)自動(dòng)創(chuàng)建 C++ Op 包裝器. 例如, tensorflow/core/user_ops/zero_out.cc 中的 Op 會(huì)自動(dòng)在 bazel-genfiles/tensorflow/cc/ops/user_ops.{h,cc} 中生成包裝器.

tensorflow/cc/ops/standard_ops.h 通過(guò)下述申明, 導(dǎo)入用戶(hù)自定義 Op 自動(dòng)生成的包裝器.

 #include "tensorflow/cc/ops/user_ops.h"

檢查 Op 能否正常工作

驗(yàn)證已經(jīng)成功實(shí)現(xiàn) Op 的方式是編寫(xiě)測(cè)試程序. 創(chuàng)建文件 tensorflow/python/kernel_tests/zero_out_op_test.py, 包含以下內(nèi)容:

import tensorflow as tf
class ZeroOutTest(tf.test.TestCase):
  def testZeroOut(self):
    with self.test_session():
      result = tf.user_ops.zero_out([5, 4, 3, 2, 1])
      self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])

然后運(yùn)行測(cè)試:

$ bazel test tensorflow/python:zero_out_op_test

驗(yàn)證條件

上述示例假定 Op 能夠應(yīng)用在任何 shape 的 tensor 上. 如果只想應(yīng)用到 vector 上 呢? 這意味需要在上述 OpKernel 實(shí)現(xiàn)中添加相關(guān)的檢查.

  void Compute(OpKernelContext* context) override {
   // 獲取輸入 tensor
    const Tensor& input_tensor = context->input(0);
    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
                errors::InvalidArgument("ZeroOut expects a 1-D vector."));
    // ...
  }

OP_REQUIRES 斷言的輸入是一個(gè) vector, 如果不是 vector, 將設(shè)置 InvalidArgument 狀態(tài)并返回. OP_REQUIRES 有三個(gè)參數(shù):

如果想要測(cè)試一個(gè)函數(shù)返回的 Status 對(duì)象是否是一個(gè)錯(cuò)誤, 可以使用 OP_REQUIRES_OK. 這些宏如果檢測(cè)到錯(cuò)誤, 會(huì)直接跳出函數(shù), 終止函數(shù)執(zhí)行.

Op 注冊(cè)

屬性

Op 可以有屬性, 屬性的值在 Op 添加到圖中時(shí)被設(shè)置. 屬性值用于配置 Op, 在 kernel 實(shí)現(xiàn)中, Op 注冊(cè)的輸入和輸出類(lèi)型中, 均可訪(fǎng)問(wèn)這些屬性值. 盡可能地使用輸入代替屬性, 因?yàn)檩斎氲撵`活性更高, 例如可以在執(zhí)行步驟中 中被更改, 可以使用 feed 等等. 屬性可用于實(shí)現(xiàn)一些輸入無(wú)法做到的事情, 例如影響 Op 簽名 (即輸入輸出的數(shù)量和類(lèi)型) 的配置或只讀配置可以通過(guò)屬性實(shí)現(xiàn).

注冊(cè) Op 時(shí)可以用 Attr 方法指定屬性的名稱(chēng)和類(lèi)型, 以此來(lái)定義一個(gè)屬性, 形式如下:

<name>: <attr-type-expr>

<name> 必須以字母開(kāi)頭, 可以由數(shù)字, 字母, 下劃線(xiàn)組成. <attr-type-expr> 是一個(gè)類(lèi)型表達(dá)式, 形式如下:

例如, 如果想要 ZeroOut Op 保存一個(gè)用戶(hù)索引, 指示該 Op 不僅僅只有一個(gè)元素, 你可以注冊(cè) Op 如下:

REGISTER_OP("ZeroOut")
    .Attr("preserve_index: int")
    .Input("to_zero: int32")
    .Output("zeroed: int32");

你的 kernel 可以在構(gòu)造函數(shù)里, 通過(guò) context 參數(shù)訪(fǎng)問(wèn)這個(gè)屬性:

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction * context) : OpKernel(context) {
   // 獲取欲保存的索引值
    OP_REQUIRES_OK(context,
                   context->GetAttr("preserve_index", &preserve_index_));
    // 檢查 preserve_index 是否為正
    OP_REQUIRES(context, preserve_index_ >= 0,
                errors::InvalidArgument("Need preserve_index >= 0, got ",
                                        preserve_index_));
  }
  void Compute(OpKernelContext* context) override {
    // ...
}
 private:
  int preserve_index_;
};

該值可以在 Compute 方法中被使用:

void Compute(OpKernelContext* context) override {
    // ...
   // 檢查 preserve_index 范圍是否合法
OP_REQUIRES(context, preserve_index_ < input.dimension(0),
                errors::InvalidArgument("preserve_index out of range"));
    // 設(shè)置輸出 tensor 所有的元素值為 0
   const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }
    // 保存請(qǐng)求的輸入值
   output_flat(preserve_index_) = input(preserve_index_);
  }

為了維持向后兼容性, 將一個(gè)屬性添加到一個(gè)已有的 Op 時(shí), 必須指定一個(gè)默認(rèn)值:

REGISTER_OP("ZeroOut")
     .Attr("preserve_index: int = 0")
     .Input("to_zero: int32")
     .Output("zeroed: int32");

屬性類(lèi)型

屬性可以使用下面的類(lèi)型:

  • string: 任何二進(jìn)制字節(jié)流 (UTF8 不是必須的).
  • int: 一個(gè)有型整數(shù).
  • float: 一個(gè)浮點(diǎn)數(shù).
  • bool: 真或假.
  • type: DataType 非引用類(lèi)型之一.
  • shape: 一個(gè) TensorShapeProto.
  • tensor: 一個(gè) TensorProto.
  • list(<type>): <type> 列表, 其中 <type> 是上述類(lèi)型之一. 注意 list(list(<type>)) 是無(wú)效的.

權(quán)威的列表以 op_def_builder.cc:FinalizeAttr 為準(zhǔn).

默認(rèn)值和約束條件

屬性可能有默認(rèn)值, 一些類(lèi)型的屬性可以有約束條件. 為了定義一個(gè)有約束條件的屬性, 你可以使用下列的 <attr-type-expr> 形式:

  • {'<string1>', '<string2>'}: 屬性值必須是一個(gè)字符串, 取值可以為 <string1><string2>. 值的語(yǔ)法已經(jīng)暗示了值的類(lèi)型為 string, 已經(jīng)暗示了. 下述語(yǔ)句模擬了一個(gè)枚舉值:
REGISTER_OP("EnumExample")
      .Attr("e: {'apple', 'orange'}");
  • {<type1>, <type2>}: 值是 type 類(lèi)型, 且必須為 <type1><type2> 之一, 當(dāng)然 <type1><type2> 必須都是有效的 tensor 類(lèi)型. 你無(wú)須指定屬性的類(lèi)型為 type, 而是通過(guò) {...} 語(yǔ)句給出一個(gè)類(lèi)型列表. 例如, 在下面的例子里, 屬性 t 的類(lèi)型必須為 int32, float, 或 bool:
REGISTER_OP("RestrictedTypeExample")
      .Attr("t: {int32, float, bool}");
  • 這里有一些常見(jiàn)類(lèi)型約束條件的快捷方式:

    • numbertype: 限制類(lèi)型為數(shù)字類(lèi)型, 即非 string 非 bool 的類(lèi)型.
    • realnumbertype: 與 numbertype 區(qū)別是不支持復(fù)雜類(lèi)型.
    • quantizedtype: 與 numbertype 區(qū)別是只支持量化數(shù)值 (quantized number type).

這些類(lèi)型的列表在 tensorflow/core/framework/types.h 文件中通過(guò)函數(shù)定義 (如 NumberTypes()). 本例中屬性 t 必須為某種數(shù)字類(lèi)型:

REGISTER_OP("NumberType")
        .Attr("t: numbertype");

對(duì)于這個(gè) Op:

tf.number_type(t=tf.int32)  # 有效
tf.number_type(t=tf.bool)   # 無(wú)效
  • int >= <n>: 值必須是一個(gè)整數(shù), 且取值大于等于 <n>, <n> 是一個(gè)自然數(shù).

例如, 下列 Op 注冊(cè)操作指定了屬性 a 的取值至少為 2.

REGISTER_OP("MinIntExample")
      .Attr("a: int >= 2");
  • list(<type>) >= <n>: 一個(gè) <type> 類(lèi)型列表, 列表長(zhǎng)度必須大于等于 <n>.

例如, 下面的 Op 注冊(cè)操作指定屬性 a 是一個(gè)列表, 列表中的元素類(lèi)型是 int32float列表長(zhǎng)度至少為3.

REGISTER_OP("TypeListExample")
      .Attr("a: list({int32, float}) >= 3");

通過(guò)添加 = <default> 到約束條件末尾, 給一個(gè)屬性設(shè)置默認(rèn)值 (使其在自動(dòng)生成的代碼里 變成可選屬性), 如下:

REGISTER_OP("AttrDefaultExample")
    .Attr("i: int = 0");

默認(rèn)值支持的語(yǔ)法將在最終 GraphDef 定義的 protobuf 表示中被使用.

下面是給所有類(lèi)型賦予默認(rèn)值的例子:

REGISTER_OP("AttrDefaultExampleForAllTypes")
   .Attr("s: string = 'foo'")
   .Attr("i: int = 0")
   .Attr("f: float = 1.0")
   .Attr("b: bool = true")
   .Attr("ty: type = DT_INT32")
   .Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
   .Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
   .Attr("l_empty: list(int) = []")
   .Attr("l_int: list(int) = [2, 3, 5, 7]");

請(qǐng)?zhí)貏e注意那些類(lèi)型值里面包含的 DT_* 名稱(chēng).

多態(tài)

Type Polymorphism

對(duì)于那些可以使用不同類(lèi)型輸入或產(chǎn)生不同類(lèi)型輸出的 Op, 可以注冊(cè) Op 時(shí)為輸入/輸出類(lèi)型里指定一個(gè)屬性. 一般緊接著, 會(huì)為每一個(gè)支持的類(lèi)型注冊(cè)一個(gè) OpKernel.

例如, 除了 int32 外, 想要 ZeroOut Op 支持 float, 注冊(cè)代碼如下:

REGISTER_OP("ZeroOut")
    .Attr("T: {float, int32}")
    .Input("to_zero: <b>T</b>")
    .Output("zeroed: <b>T</b>");

這段 Op 注冊(cè)代碼現(xiàn)在指定了輸入的類(lèi)型必須為 floatint32, 而且 既然輸入和輸出制定了同樣的類(lèi)型 T, 輸出也同樣如此.

一個(gè)命名建議:{#naming} 輸入, 輸出, 和屬性通常使用 snake_case 命名法. 唯一的例外是屬性被用作輸入類(lèi)型或是輸入類(lèi)型的一部分. 當(dāng)添加到圖中時(shí), 這些屬性 可以被推斷出來(lái), 因此不會(huì)出現(xiàn)在 Op 的函數(shù)里. 例如, 最后一個(gè) ZeroOut 定義 生成的 Python 函數(shù)如下:

def zero_out(to_zero, name=None):
   """...
   參數(shù):
     to_zero: 一個(gè) `Tensor`. 必須為下列類(lèi)型之一:
         `float32`, `int32`.
     name: 操作的名字 (可選).

   返回值:
     一個(gè) `Tensor`, 類(lèi)型和 `to_zero` 一樣.
   """

如果輸入的 to_zero 是一個(gè) int32 的tensor, 然后 T 將被自動(dòng) 設(shè)置為 int32 (實(shí)際上是 DT_INT32). 那些推導(dǎo)出的屬性的名稱(chēng)字母全大寫(xiě) 或采用駝峰命名法.

下面是一個(gè)輸出類(lèi)型自動(dòng)推斷的例子, 讀者可以對(duì)比一下:

REGISTER_OP("StringToNumber")
     .Input("string_tensor: string")
     .Output("output: out_type")
     .Attr("out_type: {float, int32}");
     .Doc(R"doc(
 Converts each string in the input Tensor to the specified numeric type.
 )doc");

在這種情況下, 用戶(hù)需要在生成的 Python 代碼中指定輸出類(lèi)型.

def string_to_number(string_tensor, out_type=None, name=None):
   """將輸入 Tensor 中的每一個(gè)字符串轉(zhuǎn)化成指定的數(shù)字類(lèi)型

   參數(shù):
     string_tensor: 一個(gè) `string` 類(lèi)型的 `Tensor`.
     out_type: 一個(gè)可選的 `tf.DType`, 取值為 `tf.float32, tf.int32`.
       默認(rèn)值是 `tf.float32`.
     name: 操作的名稱(chēng) (可選).

   返回值:
     一個(gè) `out_type` 類(lèi)型的 `Tensor`.
   """
 #include "tensorflow/core/framework/op_kernel.h"
class ZeroOutInt32Op : public OpKernel {
  // 和之前一樣
};
class ZeroOutFloatOp : public OpKernel {
 public:
  explicit ZeroOutFloatOp(OpKernelConstruction * context)
      : OpKernel(context) {}
  void Compute(OpKernelContext * context) override {
    // 獲取輸入 tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<float>();
    // 創(chuàng)建一個(gè)輸出 tensor
    Tensor * output = NULL;
    OP_REQUIRES_OK(context,
                    context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<float>();
    // 設(shè)置輸出 tensor 的所有元素為 0
    const int N = input.size();
    for (int i = 0; i &lt; N; i++) {
      output_flat(i) = 0;
    }<br/>
    // 保留第一個(gè)輸入值
    if (N &gt; 0) output_flat(0) = input(0);
  }
};
// 注意, TypeConstraint<int32>("T") 意味著屬性 "T" (在上面 Op 注冊(cè)代碼中
// 定義的) 必須是 "int32", 才能實(shí)例化. 
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint&lt;int32&gt;("T"),
    ZeroOutOpInt32);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutFloatOp);

為了保持向后兼容性, 你在為一個(gè) 已有的 op 添加屬性時(shí), 必須指定一個(gè)默認(rèn)值:

REGISTER_OP("ZeroOut")
  .Attr("T: {float, int32} = DT_INT32")
  .Input("to_zero: T")
  .Output("zeroed: T")

如果需要添加更多類(lèi)型, 例如 double:

REGISTER_OP("ZeroOut")
    .Attr("T: {float, double, int32}")
    .Input("to_zero: T")
    .Output("zeroed: T");

為了避免為新增的類(lèi)型寫(xiě)冗余的 OpKernel 代碼, 通??梢詫?xiě)一個(gè) C++ 模板作為替代. 當(dāng)然, 仍然需要為每一個(gè)重載版本定義一個(gè) keneral 注冊(cè) (REGISTER\_KERNEL\_BUILDER 調(diào)用).

template <typename T>;
class ZeroOutOp : public OpKernel {
 public:
    explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
  void Compute(OpKernelContext* context) override {
    // 獲取輸入 tensor
     const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<T>();
    // 創(chuàng)建一個(gè)輸出 tensor
      Tensor* output = NULL;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<T>();
    // 設(shè)置輸出 tensor 的所有元素為 0
   const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }
    // Preserve the first input value
    if (N > 0) output_flat(0) = input(0);
  }
};
};<br/>
// 注意, TypeConstraint<int32>("T") 意味著屬性 "T" (在上面 Op 注冊(cè)代碼中
// 定義的) 必須是 "int32", 才能實(shí)例化. </b>
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<int32>("T"),
    ZeroOutOp<int32>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutOp<float>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<double>("T"),
    ZeroOutOp<double>);

如果有很多重載版本, 可以將注冊(cè)操作通過(guò)一個(gè)宏來(lái)實(shí)現(xiàn).

 #include "tensorflow/core/framework/op_kernel.h"
 #define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)
REGISTER_KERNEL(int32);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);
 #undef REGISTER_KERNEL

取決于注冊(cè) kernel 使用哪些類(lèi)型, 你可能可以使用tensorflow/core/framework/register_types.h 提供的宏:

 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
REGISTER_OP("ZeroOut")
    .Attr("T: realnumbertype")
    .Input("to_zero: T")
    .Output("zeroed: T");
template <typename T>
class ZeroOutOp : public OpKernel { ... };
 #define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
 #undef REGISTER_KERNEL

列表輸入和輸出

除了能夠使用不同類(lèi)型的 tensor 作為輸入或輸出, Op 還支持使用多個(gè) tensor 作為輸入或輸出.

在接下來(lái)的例子里, 屬性 T 存儲(chǔ)了一個(gè)類(lèi)型列表, 并同時(shí)作為輸入 in 和輸出 out 的類(lèi)型. 輸入和輸出均為指定類(lèi)型的 tensor 列表. 既然輸入和輸出的類(lèi)型均為 T, 它們的 tensor 數(shù)量和類(lèi)型 是一致的.

REGISTER_OP("PolymorphicListExample")
    .Attr("T: list(type)")
    .Input("in: T")
    .Output("out: T");

可以為列表中可存放的類(lèi)型設(shè)置約束條件. 在下一個(gè)例子中, 輸入是 floatdouble 類(lèi)型的 tensor 列表. 例如, 這個(gè) Op 可接受的 輸入類(lèi)型為 (float, double, float) 的數(shù)據(jù), 且在此情況下, 輸出類(lèi)型同樣 為 (float, double, float).

REGISTER_OP("ListTypeRestrictionExample")
    .Attr("T: list({float, double})")
    .Input("in: T")
    .Output("out: T");

如果想要一個(gè)列表中的所有 tensor 是同一類(lèi)型, 你需要寫(xiě)下列代碼:

REGISTER_OP("IntListInputExample")
    .Attr("N: int")
    .Input("in: N * int32")
    .Output("out: int32");

這段代碼接受 int32 tensor 列表, 并用一個(gè) int 屬性 N 來(lái)指定列表的長(zhǎng)度.

這也可用于類(lèi)型推斷. 在下一個(gè)例子中, 輸入是一個(gè) tensor 列表, 長(zhǎng)度為 "N", 類(lèi)型為 "T", 輸出是單個(gè) "T" 的 tensor:

REGISTER_OP("SameListInputExample")
    .Attr("N: int")
    .Attr("T: type")
    .Input("in: N * T")
    .Output("out: T");

默認(rèn)情況下, tensor 列表的最小長(zhǎng)度為1. 這個(gè)約束條件可以通過(guò) 為指定的屬性增加一個(gè) ">=" 約束來(lái)變更:

REGISTER_OP("MinLengthIntListExample")
    .Attr("N: int >= 2")
    .Input("in: N * int32")
    .Output("out: int32");

同樣的語(yǔ)法也適用于 "list(type)" 屬性:

REGISTER_OP("MinimumLengthPolymorphicListExample")
    .Attr("T: list(type) >= 3")
    .Input("in: T")
    .Output("out: T");

輸入和輸出

總結(jié)一下上述內(nèi)容, 一個(gè) Op 注冊(cè)操作可以指定多個(gè)輸入和輸出:

REGISTER_OP("MultipleInsAndOuts")
    .Input("y: int32")
    .Input("z: float")
    .Output("a: string")
    .Output("b: int32");

每一個(gè)輸入或輸出形式如下:

<name>: <io-type-expr>

其中, <name> 以字母打頭, 且只能由數(shù)字, 字母和下劃線(xiàn)組成. <io-type-expr> 可以是 下列類(lèi)型表達(dá)式之一:

  • <type>, 一個(gè)合法的輸入類(lèi)型, 如 float, int32, string. 這可用于指定給定類(lèi)型的單個(gè) tensor.

參見(jiàn)合法 Tensor 類(lèi)型列表.

REGISTER_OP("BuiltInTypesExample")
      .Input("integers: int32")
      .Input("complex_numbers: scomplex64");
  • <attr-type>, 一個(gè)屬性和一個(gè)類(lèi)型 type 或類(lèi)型列表 list(type)(可能 包含類(lèi)型限制). 該語(yǔ)法可實(shí)現(xiàn)多態(tài) Op.
REGISTER_OP("PolymorphicSingleInput")
      .Attr("T: type")
      .Input("in: T);
REGISTER_OP("RestrictedPolymorphicSingleInput")
      .Attr("T: {int32, int64}")
      .Input("in: T);

將屬性的類(lèi)型設(shè)置為 list(type) 將允許你接受一個(gè)序列的 tensor.

REGISTER_OP("ArbitraryTensorSequenceExample")
      .Attr("T: list(type)")
      .Input("in: T")
      .Output("out: T");
REGISTER_OP("RestrictedTensorSequenceExample")
      .Attr("T: list({int32, int64})")
      .Input("in: T")
      .Output("out: T");

注意, 輸入和輸出均為 T, 意味著輸入和輸出的類(lèi)型與數(shù)量均相同.

  • <number> * <type>, 一組擁有相同類(lèi)型的 tensor, <number> 是一個(gè) int 類(lèi)型屬性的名稱(chēng). <type> 可以是一個(gè)類(lèi)似于 int32float 的特定類(lèi)型, 或者一個(gè) type 類(lèi)型屬性的名字. 前者的例子如下, 該例子接受一個(gè) int32 tensor 列表作為 Op 輸入:
REGISTER_OP("Int32SequenceExample")
      .Attr("NumTensors: int")
      .Input("in: NumTensors * int32")

后者的例子如下, 該例子接受一個(gè)泛型 tensor 列表作為 Op 輸入:

REGISTER_OP("SameTypeSequenceExample")
      .Attr("NumTensors: int")
      .Attr("T: type")
      .Input("in: NumTensors * T")
  • Tensor 的引用表示為 Ref(<type>), 其中 <type> 是上述類(lèi)型之一.

一個(gè)命名建議: 當(dāng)使用屬性表示一個(gè)輸入的類(lèi)型時(shí), 該類(lèi)型可以被推斷出來(lái). 實(shí)現(xiàn)該特性, 將需要推斷 的類(lèi)型用大寫(xiě)名稱(chēng)表示 (如 TN), 其它的輸入, 輸出, 和屬性像使用函數(shù)參數(shù)一樣使用這些 大寫(xiě)名稱(chēng). 參見(jiàn)之前的命名建議章節(jié)查看更多細(xì)節(jié).

更多細(xì)節(jié)參見(jiàn) tensorflow/core/framework/op_def_builder.h.

向后兼容性

通常, 對(duì)規(guī)范的改變必須保持向后兼容性: Op 使用新規(guī)范后, 需保證使用舊規(guī)范構(gòu)造的序列化 GraphDef 仍能正確工作.

下面是幾種保持向后兼容性的方式:

  1. 任何添加到 Op 的新屬性必須有默認(rèn)值, 且默認(rèn)值下的行為有明確定義. 將一個(gè)非多態(tài)的操作變?yōu)槎鄳B(tài)操作, 你必須為新的類(lèi)型屬性賦予默認(rèn)值, 以保持原始的函數(shù)簽名. 例如, 有如下操作:
REGISTER_OP("MyGeneralUnaryOp")
       .Input("in: float")
       .Output("out: float");

可以通過(guò)下述方式將其變?yōu)槎鄳B(tài), 且保持向后兼容性:

REGISTER_OP("MyGeneralUnaryOp")
       .Input("in: T")
       .Output("out: T")
       .Attr("T: numerictype = float");

1.放寬一個(gè)屬性的約束條件是安全的. 例如, 你可以將 {int32, int64} 變?yōu)?{int32, int64, float}, 或者, 將 {"apple", "orange"} 變?yōu)?{"apple", "banana", "orange"}.

2.通過(guò)給 Op 名稱(chēng)添加一些項(xiàng)目中唯一的標(biāo)識(shí)作為前綴, 來(lái)為新建的 Op 添加命名空間. 命名空間 可以預(yù)防你的 Op 與 TensorFlow 未來(lái)版本里的內(nèi)置 Op 產(chǎn)生命名沖突.

3.超前計(jì)劃! 嘗試著去預(yù)測(cè) Op 未來(lái)的的用途, 超前設(shè)計(jì), 畢竟, 一些簽名的變更無(wú)法保證兼容性 (例如, 增加新的輸入, 或?qū)⒃瓉?lái)的單元素輸入變成一個(gè)列表).

如果不能以兼容的方式改變一個(gè)操作, 那就創(chuàng)建一個(gè)全新的操作, 來(lái)實(shí)現(xiàn)所需功能.

GPU 支持

你可以實(shí)現(xiàn)不同的 OpKernel, 將其中之一注冊(cè)到 GPU, 另一個(gè)注冊(cè)到 GPU, 正如為不同的類(lèi)型注冊(cè) kernel 一樣. tensorflow/core/kernels/ 中有一些 GPU 支持的例子. 注意, 一些 kernel 的 CPU 版本位于 .cc 文件, GPU 版本位于 _gpu.cu.cc 文件, 共享的代碼位于 .h 文件.

例如, pad op 除了 GPU kernel 外的其它代碼 均在 tensorflow/core/kernels/pad_op.cc 中. GPU kernel 位于 tensorflow/core/kernels/pad_op_gpu.cu.cc, 共享的一個(gè)模板類(lèi)代碼定義在 tensorflow/core/kernels/pad_op.h. 需要注意的事情是, 即使使用 pad 的 GPU 版本時(shí), 仍然需要將 "paddings" 輸入放置到內(nèi)存中. 為了實(shí)現(xiàn)這一點(diǎn), 將輸入或輸出標(biāo)記為必須保存在內(nèi)存中, 為 kernel 注冊(cè)一個(gè) HostMemory() 調(diào)用. 如下:

 #define REGISTER_GPU_KERNEL(T)                         \
REGISTER_KERNEL_BUILDER(Name("Pad")                  \
                              .Device(DEVICE_GPU)      \
                              .TypeConstraint<T>("T")  \
                              .HostMemory("paddings"), \
                          PadOp<GPUDevice, T>)

使用 Python 實(shí)現(xiàn)梯度

給定一個(gè) Op 組成的圖, TensorFlow 使用自動(dòng)微分 (反向傳播) 來(lái)添加新的 Op 以表示梯度運(yùn)算, 同時(shí) 不影響已有的 Op (參見(jiàn)梯度運(yùn)算). 為了使自動(dòng)微分能夠與新的 Op 協(xié)同工作, 必須注冊(cè)一個(gè)梯度函數(shù), 從 Op 的輸入計(jì)算梯度, 并返回代表 梯度值的輸出.

數(shù)學(xué)上, 如果一個(gè) Op 計(jì)算 \(y = f(x)\), 注冊(cè)的梯度 Op 通過(guò)以下鏈?zhǔn)椒▌t, 將 \(\partial / \partial y\) 的梯度運(yùn)算轉(zhuǎn)化為 \(\partial / \partial x\) 的梯度運(yùn)算.

$$\frac{\partial}{\partial x} = \frac{\partial}{\partial y} \frac{\partial y}{\partial x} = \frac{\partial}{\partial y} \frac{\partial f}{\partial x}.$$

ZeroOut 的例子中, 輸入中只有一個(gè)項(xiàng)會(huì)影響輸出, 所以, 代表輸入的梯度值的 tensor 也只有 一個(gè)輸入項(xiàng). 如下所示:

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops

@ops.RegisterGradient("ZeroOut")
def _zero_out_grad(op, grad):
  """`zero_out` 的梯度.

  參數(shù):
    op: 欲進(jìn)行微分的 `zero_out` `操作`, 可以用于獲取原始 Op 的輸入和輸出.
    grad: 代表 `zero_out` 輸出的梯度 Op.

  返回:
    代表輸入 `zero_out` 的微分.
  """
  to_zero = op.inputs[0]
  shape = array_ops.shape(to_zero)
  index = array_ops.zeros_like(shape)
  first_grad = array_ops.reshape(grad, [-1])[0]
  to_zero_grad = sparse_ops.sparse_to_dense(index, shape, first_grad, 0)
  return [to_zero_grad]  # 單個(gè) Tensor 的列表, 既然只有一個(gè)輸入

使用 ops.RegisterGradient 注冊(cè)梯度函數(shù)需要注意的一些細(xì)節(jié):