預(yù)備知識(shí):
如果現(xiàn)有的庫(kù)沒(méi)有涵蓋你想要的操作, 你可以自己定制一個(gè). 為了使定制的 Op 能夠兼容原有的庫(kù) , 你必須做以下工作:
向 TensorFlow 系統(tǒng)注冊(cè)來(lái)定義 Op 的接口. 在注冊(cè)時(shí), 指定 Op 的名稱(chēng), 它的輸入(類(lèi)型和名稱(chēng)) 和輸出(類(lèi)型和名稱(chēng)), 和所需要任何 屬性的文檔說(shuō)明.
為了讓你有直觀的認(rèn)識(shí), 創(chuàng)建一個(gè)簡(jiǎn)單的 Op 作為例子. 該 Op 接受一個(gè) int32 類(lèi)型 tensor 作為
輸入, 輸出這個(gè) tensor 的一個(gè)副本, 副本與原 tensor 唯一的區(qū)別在于第一個(gè)元素被置為 0. 創(chuàng)建
文件 tensorflow/core/user_ops/zero_out.cc, 并調(diào)用 REGISTER_OP 宏來(lái)定義 Op 的接口.
#include "tensorflow/core/framework/op.h"
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32");
ZeroOut Op 接受 32 位整型的 tensor to_zero 作為輸入, 輸出 32 位整型的 tensor zeroed.
命名的注意事項(xiàng): Op 的名稱(chēng)必須是為唯一的, 并使用駝峰命名法. 以下劃線(xiàn)
_開(kāi)始的名稱(chēng)保留為內(nèi)部使用.
在定義接口之后, 提供一個(gè)或多個(gè) Op 的實(shí)現(xiàn). 為這些 kernel 的每一個(gè)創(chuàng)建一個(gè)對(duì)應(yīng)的類(lèi), 繼承
OpKernel, 覆蓋 Compute 方法. Compute 方法提供一個(gè)類(lèi)型為 OpKernelContext* 的參數(shù) context, 用于訪(fǎng)問(wèn)一些有用的信息, 例如輸入和輸出的 tensor.
將 kernel 添加到剛才創(chuàng)建的文件中, kernel 看起來(lái)和下面的代碼類(lèi)似:
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// 獲取輸入 tensor.
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
// 創(chuàng)建一個(gè)輸出 tensor.
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output = output_tensor->template flat<int32>();
// 設(shè)置 tensor 除第一個(gè)之外的元素均設(shè)為 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
output(i) = 0;
}
// 盡可能地保留第一個(gè)元素的值.
if (N > 0) output(0) = input(0);
}
};
實(shí)現(xiàn) kernel 后, 將其注冊(cè)到 TensorFlow 系統(tǒng)中. 注冊(cè)時(shí), 可以指定該 kernel 運(yùn)行時(shí)的多個(gè)約束 條件. 例如可以指定一個(gè) kernel 在 CPU 上運(yùn)行, 另一個(gè)在 GPU 上運(yùn)行.
將下列代碼加入到 zero_out.cc 中, 注冊(cè) ZeroOut op:
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
一旦創(chuàng)建和重新安裝了 TensorFlow , Tensorflow 系統(tǒng)可以在需要時(shí)引用和使用該 Op.
當(dāng)編譯 TensorFlow 時(shí), 所有放在 tensorflow/core/user_ops 目錄下
的 Op 會(huì)自動(dòng)在 bazel-genfiles/tensorflow/python/ops/gen_user_ops.py 文件
中生成 Python Op 包裝器. 通過(guò)以下聲明, 把那些 Op 引入到 tensorflow/python/user_ops/user_ops.py
中:
from tensorflow.python.ops.gen_user_ops import *
你可以選擇性將部分函數(shù)替換為自己的實(shí)現(xiàn). 為此, 首先要隱藏自動(dòng)生成的代碼,
在 tensorflow/python/BUILD
文件中, 將其名字添加到 "user_ops" 的 hidden 列表.
tf_gen_op_wrapper_py(
name = "user_ops",
hidden = [
"Fact",
],
require_shape_functions = False,
)
緊接著 "Fact" 列出自己的 Op. 然后, 在
tensorflow/python/user_ops/user_ops.py
中添加你的替代實(shí)現(xiàn)函數(shù). 通常, 替代實(shí)現(xiàn)函數(shù)也會(huì)調(diào)用自動(dòng)生成函數(shù)來(lái)真正把 Op 添加
到圖中. 被隱藏的自動(dòng)生成函數(shù)位于 gen_user_ops 包中, 名稱(chēng)多了一個(gè)下劃線(xiàn)前綴
("_"). 例如:
def my_fact():
"""覆蓋一個(gè) Op 自動(dòng)生成代碼的示例."""
return gen_user_ops._fact()
當(dāng)編譯 TensorFlow 時(shí), 所有 tensorflow/core/user_ops 文件夾
下的 Op 會(huì)自動(dòng)創(chuàng)建 C++ Op 包裝器. 例如, tensorflow/core/user_ops/zero_out.cc 中的 Op 會(huì)自動(dòng)在 bazel-genfiles/tensorflow/cc/ops/user_ops.{h,cc}
中生成包裝器.
tensorflow/cc/ops/standard_ops.h 通過(guò)下述申明,
導(dǎo)入用戶(hù)自定義 Op 自動(dòng)生成的包裝器.
#include "tensorflow/cc/ops/user_ops.h"
驗(yàn)證已經(jīng)成功實(shí)現(xiàn) Op 的方式是編寫(xiě)測(cè)試程序. 創(chuàng)建文件
tensorflow/python/kernel_tests/zero_out_op_test.py,
包含以下內(nèi)容:
import tensorflow as tf
class ZeroOutTest(tf.test.TestCase):
def testZeroOut(self):
with self.test_session():
result = tf.user_ops.zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
然后運(yùn)行測(cè)試:
$ bazel test tensorflow/python:zero_out_op_test
上述示例假定 Op 能夠應(yīng)用在任何 shape 的 tensor 上. 如果只想應(yīng)用到 vector 上 呢? 這意味需要在上述 OpKernel 實(shí)現(xiàn)中添加相關(guān)的檢查.
void Compute(OpKernelContext* context) override {
// 獲取輸入 tensor
const Tensor& input_tensor = context->input(0);
OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
errors::InvalidArgument("ZeroOut expects a 1-D vector."));
// ...
}
OP_REQUIRES 斷言的輸入是一個(gè) vector, 如果不是 vector, 將設(shè)置 InvalidArgument 狀態(tài)并返回.
OP_REQUIRES 宏 有三個(gè)參數(shù):
context: 可以是一個(gè) OpKernelContext 或 OpKernelConstruction 指針
(參見(jiàn) tensorflow/core/framework/op_kernel.h),
其 SetStatus() 方法將被使用到.tensorflow/core/public/tensor_shape.h
中有一些驗(yàn)證 tensor shape 的函數(shù).Status 對(duì)象表示, 參見(jiàn)
tensorflow/core/public/status.h.
Status 包含一個(gè)類(lèi)型 (通常是 InvalidArgument, 但也可以是任何類(lèi)型) 和一個(gè)消息. 構(gòu)造
一個(gè)錯(cuò)誤的函數(shù)位于 tensorflow/core/lib/core/errors.h 中.如果想要測(cè)試一個(gè)函數(shù)返回的 Status 對(duì)象是否是一個(gè)錯(cuò)誤, 可以使用 OP_REQUIRES_OK.
這些宏如果檢測(cè)到錯(cuò)誤, 會(huì)直接跳出函數(shù), 終止函數(shù)執(zhí)行.
Op 可以有屬性, 屬性的值在 Op 添加到圖中時(shí)被設(shè)置. 屬性值用于配置 Op, 在 kernel 實(shí)現(xiàn)中, Op 注冊(cè)的輸入和輸出類(lèi)型中, 均可訪(fǎng)問(wèn)這些屬性值. 盡可能地使用輸入代替屬性, 因?yàn)檩斎氲撵`活性更高, 例如可以在執(zhí)行步驟中 中被更改, 可以使用 feed 等等. 屬性可用于實(shí)現(xiàn)一些輸入無(wú)法做到的事情, 例如影響 Op 簽名 (即輸入輸出的數(shù)量和類(lèi)型) 的配置或只讀配置可以通過(guò)屬性實(shí)現(xiàn).
注冊(cè) Op 時(shí)可以用 Attr 方法指定屬性的名稱(chēng)和類(lèi)型, 以此來(lái)定義一個(gè)屬性, 形式如下:
<name>: <attr-type-expr>
<name> 必須以字母開(kāi)頭, 可以由數(shù)字, 字母, 下劃線(xiàn)組成. <attr-type-expr> 是一個(gè)類(lèi)型表達(dá)式,
形式如下:
例如, 如果想要 ZeroOut Op 保存一個(gè)用戶(hù)索引, 指示該 Op 不僅僅只有一個(gè)元素, 你可以注冊(cè) Op 如下:
REGISTER_OP("ZeroOut")
.Attr("preserve_index: int")
.Input("to_zero: int32")
.Output("zeroed: int32");
你的 kernel 可以在構(gòu)造函數(shù)里, 通過(guò) context 參數(shù)訪(fǎng)問(wèn)這個(gè)屬性:
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction * context) : OpKernel(context) {
// 獲取欲保存的索引值
OP_REQUIRES_OK(context,
context->GetAttr("preserve_index", &preserve_index_));
// 檢查 preserve_index 是否為正
OP_REQUIRES(context, preserve_index_ >= 0,
errors::InvalidArgument("Need preserve_index >= 0, got ",
preserve_index_));
}
void Compute(OpKernelContext* context) override {
// ...
}
private:
int preserve_index_;
};
該值可以在 Compute 方法中被使用:
void Compute(OpKernelContext* context) override {
// ...
// 檢查 preserve_index 范圍是否合法
OP_REQUIRES(context, preserve_index_ < input.dimension(0),
errors::InvalidArgument("preserve_index out of range"));
// 設(shè)置輸出 tensor 所有的元素值為 0
const int N = input.size();
for (int i = 0; i < N; i++) {
output_flat(i) = 0;
}
// 保存請(qǐng)求的輸入值
output_flat(preserve_index_) = input(preserve_index_);
}
為了維持向后兼容性, 將一個(gè)屬性添加到一個(gè)已有的 Op 時(shí), 必須指定一個(gè)默認(rèn)值:
REGISTER_OP("ZeroOut")
.Attr("preserve_index: int = 0")
.Input("to_zero: int32")
.Output("zeroed: int32");
屬性可以使用下面的類(lèi)型:
string: 任何二進(jìn)制字節(jié)流 (UTF8 不是必須的).int: 一個(gè)有型整數(shù).float: 一個(gè)浮點(diǎn)數(shù).bool: 真或假.type: DataType 非引用類(lèi)型之一.shape: 一個(gè) TensorShapeProto.tensor: 一個(gè) TensorProto.list(<type>): <type> 列表, 其中 <type> 是上述類(lèi)型之一.
注意 list(list(<type>)) 是無(wú)效的.權(quán)威的列表以 op_def_builder.cc:FinalizeAttr 為準(zhǔn).
屬性可能有默認(rèn)值, 一些類(lèi)型的屬性可以有約束條件. 為了定義一個(gè)有約束條件的屬性, 你可以使用下列的
<attr-type-expr> 形式:
{'<string1>', '<string2>'}: 屬性值必須是一個(gè)字符串, 取值可以為 <string1> 或 <string2>.
值的語(yǔ)法已經(jīng)暗示了值的類(lèi)型為 string, 已經(jīng)暗示了. 下述語(yǔ)句模擬了一個(gè)枚舉值:REGISTER_OP("EnumExample")
.Attr("e: {'apple', 'orange'}");
{<type1>, <type2>}: 值是 type 類(lèi)型, 且必須為 <type1> 或 <type2> 之一, 當(dāng)然
<type1> 和 <type2> 必須都是有效的 tensor 類(lèi)型.
你無(wú)須指定屬性的類(lèi)型為 type, 而是通過(guò) {...} 語(yǔ)句給出一個(gè)類(lèi)型列表. 例如, 在下面的例子里,
屬性 t 的類(lèi)型必須為 int32, float, 或 bool:REGISTER_OP("RestrictedTypeExample")
.Attr("t: {int32, float, bool}");
這里有一些常見(jiàn)類(lèi)型約束條件的快捷方式:
numbertype: 限制類(lèi)型為數(shù)字類(lèi)型, 即非 string 非 bool 的類(lèi)型.realnumbertype: 與 numbertype 區(qū)別是不支持復(fù)雜類(lèi)型.quantizedtype: 與 numbertype 區(qū)別是只支持量化數(shù)值 (quantized number type).這些類(lèi)型的列表在 tensorflow/core/framework/types.h
文件中通過(guò)函數(shù)定義 (如 NumberTypes()).
本例中屬性 t 必須為某種數(shù)字類(lèi)型:
REGISTER_OP("NumberType")
.Attr("t: numbertype");
對(duì)于這個(gè) Op:
tf.number_type(t=tf.int32) # 有效
tf.number_type(t=tf.bool) # 無(wú)效
int >= <n>: 值必須是一個(gè)整數(shù), 且取值大于等于 <n>, <n> 是一個(gè)自然數(shù).例如, 下列 Op 注冊(cè)操作指定了屬性 a 的取值至少為 2.
REGISTER_OP("MinIntExample")
.Attr("a: int >= 2");
list(<type>) >= <n>: 一個(gè) <type> 類(lèi)型列表, 列表長(zhǎng)度必須大于等于 <n>.例如, 下面的 Op 注冊(cè)操作指定屬性 a 是一個(gè)列表, 列表中的元素類(lèi)型是 int32 或 float列表長(zhǎng)度至少為3.
REGISTER_OP("TypeListExample")
.Attr("a: list({int32, float}) >= 3");
通過(guò)添加 = <default> 到約束條件末尾, 給一個(gè)屬性設(shè)置默認(rèn)值 (使其在自動(dòng)生成的代碼里
變成可選屬性), 如下:
REGISTER_OP("AttrDefaultExample")
.Attr("i: int = 0");
默認(rèn)值支持的語(yǔ)法將在最終 GraphDef 定義的 protobuf 表示中被使用.
下面是給所有類(lèi)型賦予默認(rèn)值的例子:
REGISTER_OP("AttrDefaultExampleForAllTypes")
.Attr("s: string = 'foo'")
.Attr("i: int = 0")
.Attr("f: float = 1.0")
.Attr("b: bool = true")
.Attr("ty: type = DT_INT32")
.Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
.Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
.Attr("l_empty: list(int) = []")
.Attr("l_int: list(int) = [2, 3, 5, 7]");
請(qǐng)?zhí)貏e注意那些類(lèi)型值里面包含的 DT_* 名稱(chēng).
對(duì)于那些可以使用不同類(lèi)型輸入或產(chǎn)生不同類(lèi)型輸出的 Op, 可以注冊(cè) Op 時(shí)為輸入/輸出類(lèi)型里指定一個(gè)屬性.
一般緊接著, 會(huì)為每一個(gè)支持的類(lèi)型注冊(cè)一個(gè) OpKernel.
例如, 除了 int32 外, 想要 ZeroOut Op 支持 float, 注冊(cè)代碼如下:
REGISTER_OP("ZeroOut")
.Attr("T: {float, int32}")
.Input("to_zero: <b>T</b>")
.Output("zeroed: <b>T</b>");
這段 Op 注冊(cè)代碼現(xiàn)在指定了輸入的類(lèi)型必須為 float 或 int32, 而且
既然輸入和輸出制定了同樣的類(lèi)型 T, 輸出也同樣如此.
一個(gè)命名建議:{#naming} 輸入, 輸出, 和屬性通常使用 snake_case 命名法. 唯一的例外是屬性被用作輸入類(lèi)型或是輸入類(lèi)型的一部分. 當(dāng)添加到圖中時(shí), 這些屬性 可以被推斷出來(lái), 因此不會(huì)出現(xiàn)在 Op 的函數(shù)里. 例如, 最后一個(gè) ZeroOut 定義 生成的 Python 函數(shù)如下:
def zero_out(to_zero, name=None):
"""...
參數(shù):
to_zero: 一個(gè) `Tensor`. 必須為下列類(lèi)型之一:
`float32`, `int32`.
name: 操作的名字 (可選).
返回值:
一個(gè) `Tensor`, 類(lèi)型和 `to_zero` 一樣.
"""
如果輸入的
to_zero是一個(gè)int32的tensor, 然后T將被自動(dòng) 設(shè)置為int32(實(shí)際上是DT_INT32). 那些推導(dǎo)出的屬性的名稱(chēng)字母全大寫(xiě) 或采用駝峰命名法.下面是一個(gè)輸出類(lèi)型自動(dòng)推斷的例子, 讀者可以對(duì)比一下:
REGISTER_OP("StringToNumber")
.Input("string_tensor: string")
.Output("output: out_type")
.Attr("out_type: {float, int32}");
.Doc(R"doc(
Converts each string in the input Tensor to the specified numeric type.
)doc");
在這種情況下, 用戶(hù)需要在生成的 Python 代碼中指定輸出類(lèi)型.
def string_to_number(string_tensor, out_type=None, name=None):
"""將輸入 Tensor 中的每一個(gè)字符串轉(zhuǎn)化成指定的數(shù)字類(lèi)型
參數(shù):
string_tensor: 一個(gè) `string` 類(lèi)型的 `Tensor`.
out_type: 一個(gè)可選的 `tf.DType`, 取值為 `tf.float32, tf.int32`.
默認(rèn)值是 `tf.float32`.
name: 操作的名稱(chēng) (可選).
返回值:
一個(gè) `out_type` 類(lèi)型的 `Tensor`.
"""
#include "tensorflow/core/framework/op_kernel.h"
class ZeroOutInt32Op : public OpKernel {
// 和之前一樣
};
class ZeroOutFloatOp : public OpKernel {
public:
explicit ZeroOutFloatOp(OpKernelConstruction * context)
: OpKernel(context) {}
void Compute(OpKernelContext * context) override {
// 獲取輸入 tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<float>();
// 創(chuàng)建一個(gè)輸出 tensor
Tensor * output = NULL;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_tensor.shape(), &output));
auto output_flat = output->template flat<float>();
// 設(shè)置輸出 tensor 的所有元素為 0
const int N = input.size();
for (int i = 0; i < N; i++) {
output_flat(i) = 0;
}<br/>
// 保留第一個(gè)輸入值
if (N > 0) output_flat(0) = input(0);
}
};
// 注意, TypeConstraint<int32>("T") 意味著屬性 "T" (在上面 Op 注冊(cè)代碼中
// 定義的) 必須是 "int32", 才能實(shí)例化.
REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<int32>("T"),
ZeroOutOpInt32);
REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T"),
ZeroOutFloatOp);
REGISTER_OP("ZeroOut")
.Attr("T: {float, int32} = DT_INT32")
.Input("to_zero: T")
.Output("zeroed: T")
如果需要添加更多類(lèi)型, 例如 double:
REGISTER_OP("ZeroOut")
.Attr("T: {float, double, int32}")
.Input("to_zero: T")
.Output("zeroed: T");
為了避免為新增的類(lèi)型寫(xiě)冗余的 OpKernel 代碼, 通??梢詫?xiě)一個(gè) C++ 模板作為替代.
當(dāng)然, 仍然需要為每一個(gè)重載版本定義一個(gè) keneral 注冊(cè) (REGISTER\_KERNEL\_BUILDER 調(diào)用).
template <typename T>;
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// 獲取輸入 tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<T>();
// 創(chuàng)建一個(gè)輸出 tensor
Tensor* output = NULL;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_tensor.shape(), &output));
auto output_flat = output->template flat<T>();
// 設(shè)置輸出 tensor 的所有元素為 0
const int N = input.size();
for (int i = 0; i < N; i++) {
output_flat(i) = 0;
}
// Preserve the first input value
if (N > 0) output_flat(0) = input(0);
}
};
};<br/>
// 注意, TypeConstraint<int32>("T") 意味著屬性 "T" (在上面 Op 注冊(cè)代碼中
// 定義的) 必須是 "int32", 才能實(shí)例化. </b>
REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<int32>("T"),
ZeroOutOp<int32>);
REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T"),
ZeroOutOp<float>);
REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<double>("T"),
ZeroOutOp<double>);
如果有很多重載版本, 可以將注冊(cè)操作通過(guò)一個(gè)宏來(lái)實(shí)現(xiàn).
#include "tensorflow/core/framework/op_kernel.h"
#define REGISTER_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
ZeroOutOp<type>)
REGISTER_KERNEL(int32);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);
#undef REGISTER_KERNEL
取決于注冊(cè) kernel 使用哪些類(lèi)型, 你可能可以使用tensorflow/core/framework/register_types.h
提供的宏:
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
REGISTER_OP("ZeroOut")
.Attr("T: realnumbertype")
.Input("to_zero: T")
.Output("zeroed: T");
template <typename T>
class ZeroOutOp : public OpKernel { ... };
#define REGISTER_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
ZeroOutOp<type>)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL
除了能夠使用不同類(lèi)型的 tensor 作為輸入或輸出, Op 還支持使用多個(gè) tensor 作為輸入或輸出.
在接下來(lái)的例子里, 屬性 T 存儲(chǔ)了一個(gè)類(lèi)型列表, 并同時(shí)作為輸入 in 和輸出 out 的類(lèi)型.
輸入和輸出均為指定類(lèi)型的 tensor 列表. 既然輸入和輸出的類(lèi)型均為 T, 它們的 tensor 數(shù)量和類(lèi)型
是一致的.
REGISTER_OP("PolymorphicListExample")
.Attr("T: list(type)")
.Input("in: T")
.Output("out: T");
可以為列表中可存放的類(lèi)型設(shè)置約束條件. 在下一個(gè)例子中, 輸入是 float 和
double 類(lèi)型的 tensor 列表. 例如, 這個(gè) Op 可接受的
輸入類(lèi)型為 (float, double, float) 的數(shù)據(jù), 且在此情況下, 輸出類(lèi)型同樣
為 (float, double, float).
REGISTER_OP("ListTypeRestrictionExample")
.Attr("T: list({float, double})")
.Input("in: T")
.Output("out: T");
如果想要一個(gè)列表中的所有 tensor 是同一類(lèi)型, 你需要寫(xiě)下列代碼:
REGISTER_OP("IntListInputExample")
.Attr("N: int")
.Input("in: N * int32")
.Output("out: int32");
這段代碼接受 int32 tensor 列表, 并用一個(gè) int 屬性 N
來(lái)指定列表的長(zhǎng)度.
這也可用于類(lèi)型推斷. 在下一個(gè)例子中,
輸入是一個(gè) tensor 列表, 長(zhǎng)度為 "N", 類(lèi)型為 "T", 輸出是單個(gè) "T" 的 tensor:
REGISTER_OP("SameListInputExample")
.Attr("N: int")
.Attr("T: type")
.Input("in: N * T")
.Output("out: T");
默認(rèn)情況下, tensor 列表的最小長(zhǎng)度為1. 這個(gè)約束條件可以通過(guò)
為指定的屬性增加一個(gè) ">=" 約束來(lái)變更:
REGISTER_OP("MinLengthIntListExample")
.Attr("N: int >= 2")
.Input("in: N * int32")
.Output("out: int32");
同樣的語(yǔ)法也適用于 "list(type)" 屬性:
REGISTER_OP("MinimumLengthPolymorphicListExample")
.Attr("T: list(type) >= 3")
.Input("in: T")
.Output("out: T");
總結(jié)一下上述內(nèi)容, 一個(gè) Op 注冊(cè)操作可以指定多個(gè)輸入和輸出:
REGISTER_OP("MultipleInsAndOuts")
.Input("y: int32")
.Input("z: float")
.Output("a: string")
.Output("b: int32");
每一個(gè)輸入或輸出形式如下:
<name>: <io-type-expr>
其中, <name> 以字母打頭, 且只能由數(shù)字, 字母和下劃線(xiàn)組成. <io-type-expr> 可以是
下列類(lèi)型表達(dá)式之一:
<type>, 一個(gè)合法的輸入類(lèi)型, 如 float, int32, string. 這可用于指定給定類(lèi)型的單個(gè) tensor.參見(jiàn)合法 Tensor 類(lèi)型列表.
REGISTER_OP("BuiltInTypesExample")
.Input("integers: int32")
.Input("complex_numbers: scomplex64");
<attr-type>, 一個(gè)屬性和一個(gè)類(lèi)型 type 或類(lèi)型列表 list(type)(可能
包含類(lèi)型限制). 該語(yǔ)法可實(shí)現(xiàn)多態(tài) Op.REGISTER_OP("PolymorphicSingleInput")
.Attr("T: type")
.Input("in: T);
REGISTER_OP("RestrictedPolymorphicSingleInput")
.Attr("T: {int32, int64}")
.Input("in: T);
將屬性的類(lèi)型設(shè)置為 list(type) 將允許你接受一個(gè)序列的 tensor.
REGISTER_OP("ArbitraryTensorSequenceExample")
.Attr("T: list(type)")
.Input("in: T")
.Output("out: T");
REGISTER_OP("RestrictedTensorSequenceExample")
.Attr("T: list({int32, int64})")
.Input("in: T")
.Output("out: T");
注意, 輸入和輸出均為 T, 意味著輸入和輸出的類(lèi)型與數(shù)量均相同.
<number> * <type>, 一組擁有相同類(lèi)型的 tensor, <number> 是一個(gè) int 類(lèi)型屬性的名稱(chēng).
<type> 可以是一個(gè)類(lèi)似于 int32 和 float 的特定類(lèi)型,
或者一個(gè) type 類(lèi)型屬性的名字. 前者的例子如下, 該例子接受一個(gè) int32 tensor 列表作為 Op 輸入:REGISTER_OP("Int32SequenceExample")
.Attr("NumTensors: int")
.Input("in: NumTensors * int32")
后者的例子如下, 該例子接受一個(gè)泛型 tensor 列表作為 Op 輸入:
REGISTER_OP("SameTypeSequenceExample")
.Attr("NumTensors: int")
.Attr("T: type")
.Input("in: NumTensors * T")
Ref(<type>), 其中 <type> 是上述類(lèi)型之一.一個(gè)命名建議: 當(dāng)使用屬性表示一個(gè)輸入的類(lèi)型時(shí), 該類(lèi)型可以被推斷出來(lái). 實(shí)現(xiàn)該特性, 將需要推斷 的類(lèi)型用大寫(xiě)名稱(chēng)表示 (如
T或N), 其它的輸入, 輸出, 和屬性像使用函數(shù)參數(shù)一樣使用這些 大寫(xiě)名稱(chēng). 參見(jiàn)之前的命名建議章節(jié)查看更多細(xì)節(jié).
更多細(xì)節(jié)參見(jiàn) tensorflow/core/framework/op_def_builder.h.
通常, 對(duì)規(guī)范的改變必須保持向后兼容性: Op 使用新規(guī)范后, 需保證使用舊規(guī)范構(gòu)造的序列化 GraphDef 仍能正確工作.
下面是幾種保持向后兼容性的方式:
REGISTER_OP("MyGeneralUnaryOp")
.Input("in: float")
.Output("out: float");
可以通過(guò)下述方式將其變?yōu)槎鄳B(tài), 且保持向后兼容性:
REGISTER_OP("MyGeneralUnaryOp")
.Input("in: T")
.Output("out: T")
.Attr("T: numerictype = float");
1.放寬一個(gè)屬性的約束條件是安全的. 例如, 你可以將 {int32, int64} 變?yōu)?{int32, int64, float},
或者, 將 {"apple", "orange"} 變?yōu)?{"apple", "banana", "orange"}.
2.通過(guò)給 Op 名稱(chēng)添加一些項(xiàng)目中唯一的標(biāo)識(shí)作為前綴, 來(lái)為新建的 Op 添加命名空間. 命名空間 可以預(yù)防你的 Op 與 TensorFlow 未來(lái)版本里的內(nèi)置 Op 產(chǎn)生命名沖突.
3.超前計(jì)劃! 嘗試著去預(yù)測(cè) Op 未來(lái)的的用途, 超前設(shè)計(jì), 畢竟, 一些簽名的變更無(wú)法保證兼容性 (例如, 增加新的輸入, 或?qū)⒃瓉?lái)的單元素輸入變成一個(gè)列表).
如果不能以兼容的方式改變一個(gè)操作, 那就創(chuàng)建一個(gè)全新的操作, 來(lái)實(shí)現(xiàn)所需功能.
你可以實(shí)現(xiàn)不同的 OpKernel, 將其中之一注冊(cè)到 GPU, 另一個(gè)注冊(cè)到 GPU, 正如為不同的類(lèi)型注冊(cè) kernel 一樣.
tensorflow/core/kernels/ 中有一些 GPU 支持的例子.
注意, 一些 kernel 的 CPU 版本位于 .cc 文件, GPU 版本位于 _gpu.cu.cc 文件, 共享的代碼位于 .h 文件.
例如, pad op 除了 GPU kernel 外的其它代碼
均在 tensorflow/core/kernels/pad_op.cc 中. GPU kernel 位于 tensorflow/core/kernels/pad_op_gpu.cu.cc,
共享的一個(gè)模板類(lèi)代碼定義在 tensorflow/core/kernels/pad_op.h.
需要注意的事情是, 即使使用 pad 的 GPU 版本時(shí), 仍然需要將 "paddings" 輸入放置到內(nèi)存中.
為了實(shí)現(xiàn)這一點(diǎn), 將輸入或輸出標(biāo)記為必須保存在內(nèi)存中, 為 kernel 注冊(cè)一個(gè) HostMemory() 調(diào)用.
如下:
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("Pad") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.HostMemory("paddings"), \
PadOp<GPUDevice, T>)
給定一個(gè) Op 組成的圖, TensorFlow 使用自動(dòng)微分 (反向傳播) 來(lái)添加新的 Op 以表示梯度運(yùn)算, 同時(shí) 不影響已有的 Op (參見(jiàn)梯度運(yùn)算). 為了使自動(dòng)微分能夠與新的 Op 協(xié)同工作, 必須注冊(cè)一個(gè)梯度函數(shù), 從 Op 的輸入計(jì)算梯度, 并返回代表 梯度值的輸出.
數(shù)學(xué)上, 如果一個(gè) Op 計(jì)算 \(y = f(x)\), 注冊(cè)的梯度 Op 通過(guò)以下鏈?zhǔn)椒▌t, 將 \(\partial / \partial y\) 的梯度運(yùn)算轉(zhuǎn)化為 \(\partial / \partial x\) 的梯度運(yùn)算.
$$\frac{\partial}{\partial x} = \frac{\partial}{\partial y} \frac{\partial y}{\partial x} = \frac{\partial}{\partial y} \frac{\partial f}{\partial x}.$$
在 ZeroOut 的例子中, 輸入中只有一個(gè)項(xiàng)會(huì)影響輸出, 所以, 代表輸入的梯度值的 tensor 也只有
一個(gè)輸入項(xiàng). 如下所示:
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
@ops.RegisterGradient("ZeroOut")
def _zero_out_grad(op, grad):
"""`zero_out` 的梯度.
參數(shù):
op: 欲進(jìn)行微分的 `zero_out` `操作`, 可以用于獲取原始 Op 的輸入和輸出.
grad: 代表 `zero_out` 輸出的梯度 Op.
返回:
代表輸入 `zero_out` 的微分.
"""
to_zero = op.inputs[0]
shape = array_ops.shape(to_zero)
index = array_ops.zeros_like(shape)
first_grad = array_ops.reshape(grad, [-1])[0]
to_zero_grad = sparse_ops.sparse_to_dense(index, shape, first_grad, 0)
return [to_zero_grad] # 單個(gè) Tensor 的列表, 既然只有一個(gè)輸入
使用 ops.RegisterGradient
注冊(cè)梯度函數(shù)需要注意的一些細(xì)節(jié):
對(duì)于僅有一個(gè)輸出的 Op, 梯度函數(shù)使用 Operation op
和一個(gè) Tensor grad 作為參數(shù), 并從
op.inputs[i],
op.outputs[i],
和 grad 構(gòu)建新的 Op. 屬性的信息可以通過(guò) op.get_attr 獲取.
如果 Op 有多個(gè)輸出, 梯度函數(shù)將使用 op 和 grads 作為參數(shù), 其中, grads 是一個(gè)
梯度 Op 的列表, 為每一個(gè)輸出計(jì)算梯度. 梯度函數(shù)的輸出必須是一個(gè) Tensor 對(duì)象列表, 對(duì)應(yīng)到
每一個(gè)輸入的梯度.
如果沒(méi)有為一些輸入定義梯度, 譬如用作索引的整型, 這些輸入返回的梯度為 None