假设我正在构建一个维度部分已知的模型:

# Multilayer Perceptron
from keras.layers import Input
from keras.layers import Dense
import tensorflow as tf

inputs = Input(shape=(10,))

hidden1 = Dense(10, activation='relu')(inputs)
hidden2 = tf.math.add(hidden1, 5)
hidden3 = tf.math.add(hidden2, 5)
hidden4 = my_custom_op(hidden2, hidden3)
output = Dense(10, activation='linear')(hidden4)

my_custom_op是一个大而复杂的张量函数,它在不同的地方使用断言来确保满足关于形状和秩的假设.为了再现这个问题,我将只做如下处理:

def my_custom_op(hidden_x, hidden_y):
    tf.assert_equal(tf.shape(hidden_x), tf.shape(hidden_y))    
    return hidden

当我运行此命令时,会出现以下错误:

TypeError:无法为名称生成TypeSpec:

我不明白这个错误消息告诉我什么.如果我运行tf.assert_equal(2, 2),我没有得到异常,所以我假设这与维度还未知有关.

但是,当维度已知时,它们运行的不是断言的要点吗?如果不是,这是否意味着我不能在my_custom_op中使用断言,因为它们在构建图时会导致这些错误?

以下是完整的错误消息:

TypeError: Could not build a TypeSpec for name: "tf.debugging.assert_equal_1/assert_equal_1/Assert/Assert"
op: "Assert"
input: "tf.debugging.assert_equal_1/assert_equal_1/All"
input: "tf.debugging.assert_equal_1/assert_equal_1/Assert/Assert/data_0"
input: "tf.debugging.assert_equal_1/assert_equal_1/Assert/Assert/data_1"
input: "Shape"
input: "tf.debugging.assert_equal_1/assert_equal_1/Assert/Assert/data_3"
input: "Shape_1"
attr {
  key: "T"
  value {
    list {
      type: DT_STRING
      type: DT_STRING
      type: DT_INT32
      type: DT_STRING
      type: DT_INT32
    }
  }
}
attr {
  key: "summarize"
  value {
    i: 3
  }
}
 of unsupported type <class 'tensorflow.python.framework.ops.Operation'>.

推荐答案

问题是,您无法将Keras个符号张量提供给某些Tensorflow个API.只需将函数my_custom_op包装在Lambda或自定义层中,它就可以工作了:

import tensorflow as tf

def my_custom_op(x):
    hidden_x, hidden_y = x
    tf.assert_equal(tf.shape(hidden_x), tf.shape(hidden_y))    
    return hidden_x

inputs = tf.keras.layers.Input(shape=(10,))

hidden1 = tf.keras.layers.Dense(10, activation='relu')(inputs)
hidden2 = tf.math.add(hidden1, 5)
hidden3 = tf.math.add(hidden2, 5)
hidden4 = tf.keras.layers.Lambda(my_custom_op)([hidden2, hidden3])
output = tf.keras.layers.Dense(10, activation='linear')(hidden4)

Python相关问答推荐

从管道将Python应用程序部署到Azure Web应用程序,不包括需求包

Pandas 填充条件是另一列

根据另一列中的nan重置值后重新加权Pandas列

如何从具有不同len的列表字典中创建摘要表?

如何找到满足各组口罩条件的第一行?

按顺序合并2个词典列表

如何在虚拟Python环境中运行Python程序?

修复mypy错误-赋值中的类型不兼容(表达式具有类型xxx,变量具有类型yyy)

我如何使法国在 map 中完全透明的代码?

Pandas—合并数据帧,在公共列上保留非空值,在另一列上保留平均值

计算天数

为什么'if x is None:pass'比'x is None'单独使用更快?

交替字符串位置的正则表达式

如何在海上配对图中使某些标记周围的黑色边框

如何在Python中自动创建数字文件夹和正在进行的文件夹?

对数据帧进行分组,并按组间等概率抽样n行

Polars定制函数返回多列

如何在python tkinter中绑定键盘上的另一个回车?

大型稀疏CSR二进制矩阵乘法结果中的错误

将数据从一个单元格保存到Jupyter笔记本中的下一个单元格