我已经在这里问过这个问题,但我认为 StackOverflow 会有更多的流量/可能知道答案的人。
我正在构建一个自定义 keras 层,类似于此处找到的示例。 我希望类内的 call
方法能够知道流经该方法的
batch_size
数据的
inputs
是什么,但是
inputs.shape
在模型预测期间显示为
(None, 3)
。 这是一个具体的例子:我初始化一个简单的数据集,如下所示:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
# Create fake data to use for model testing
n = 1000
np.random.seed(123)
x1 = np.random.random(n)
x2 = np.random.normal(0, 1, size=n)
x3 = np.random.lognormal(0, 1, size=n)
X = pd.DataFrame(np.concatenate([
np.reshape(x1, (-1, 1)),
np.reshape(x2, (-1, 1)),
np.reshape(x3, (-1, 1)),
], axis=1))
然后我定义一个自定义类来测试/显示我在说什么:
class TestClass(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(TestClass, self).__init__(**kwargs)
def get_config(self):
config = super(TestClass, self).get_config()
return config
def call(self, inputs: tf.Tensor):
if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
print(inputs)
record_count, n = inputs.shape
print(f'inputs.shape = {inputs.shape}')
return inputs
然后,当我创建一个简单的模型并强制它进行前向传递时......
input_layer = layers.Input(3)
test = TestClass()(input_layer)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00025)
model = Model(input_layer, test)
model.compile(loss='mse', optimizer=optimizer, metrics=['mae', 'mse'])
model.predict(X.loc[:9, :])
...我将此输出打印到屏幕上
model.predict(X.loc[:9, :])
Tensor("model_1/Cast:0", shape=(None, 3), dtype=float32)
inputs.shape = (None, 3)
1/1 [==============================] - 0s 28ms/step
Out[34]:
array([[ 0.5335418 , 0.7788839 , 0.64132416],
[ 0.2924202 , -0.08321562, 0.412311 ],
[ 0.5118007 , -0.6822934 , 1.1782378 ],
[ 0.03780456, -0.19350041, 0.7637337 ],
[ 0.86494124, -3.196387 , 4.8535166 ],
[ 0.26708454, -0.49397194, 0.91296834],
[ 0.49734482, -1.6618049 , 0.50054324],
[ 0.8563762 , 0.7956695 , 0.29466265],
[ 0.7682351 , 0.86538637, 0.6633331 ],
[ 0.85322225, 0.868021 , 0.1776046 ]], dtype=float32)
您可以看到,在 model.predict
调用期间,
inputs.shape
打印出
(None, 3)
的值,但显然这不是真的,因为
call
方法返回形状为
(10, 3)
的输出。 如何在
10
方法中捕获此示例中的
call
值?更新1
tf.shape
时,我可以将值打印到屏幕上,但是当我尝试在变量中捕获该值时出现错误。
class TestClass(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(TestClass, self).__init__(**kwargs)
def get_config(self):
config = super(TestClass, self).get_config()
return config
def call(self, inputs: tf.Tensor):
if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
record_count, n = tf.shape(inputs)
tf.print("Dynamic batch size", tf.shape(inputs)[0])
return inputs
此代码会导致 record_count, ...
行出现错误。
Traceback (most recent call last):
File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3378, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-22-104d812c32e6>", line 1, in <module>
test = TestClass()(input_layer)
File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 692, in wrapper
raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: Exception encountered when calling layer "test_class_4" (type TestClass).
in user code:
File "<ipython-input-21-2dec1d5b9547>", line 12, in call *
record_count, n = tf.shape(inputs)
OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
Call arguments received by layer "test_class_4" (type TestClass):
• inputs=tf.Tensor(shape=(None, 3), dtype=float32)
我尝试用 call
装饰
@tf.function
方法,但我得到了同样的错误。更新2
class TestClass(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(TestClass, self).__init__(**kwargs)
def get_config(self):
config = super(TestClass, self).get_config()
return config
def call(self, inputs: tf.Tensor):
if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
shape = tf.shape(inputs)
record_count = shape[0]
n = shape[1]
tf.print("Dynamic batch size", tf.shape(inputs)[0])
return inputs
TL;DR --> 如果您想在 tf.shape(inputs)[0]
方法中捕获动态批量大小,请使用
call
,或者您可以只使用可以在模型创建中指定的静态批量大小。
call
装饰
__call__
和
call
(这就是
tf.function
方法调用)方法。使用
print
和
.shape
将无法按预期工作。使用
tf.function
,Python 代码将被跟踪并转换为原生 TensorFlow 操作。之后,创建一个静态图,这只是tf.Graph的一个实例。最后,操作在该图中执行。 Python 的
print
函数仅在第一步中考虑,因此这不是以图形模式打印内容的正确方法(用
tf.function
装饰)。张量形状在运行时是动态的,因此您需要使用
tf.shape(inputs)[0]
这将为您提供该批次的批次大小。如果你真的想在
10
中看到
call
:
class TestClass(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(TestClass, self).__init__(**kwargs)
def get_config(self):
config = super(TestClass, self).get_config()
return config
def call(self, inputs: tf.Tensor):
if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
tf.print("Dynamic batch size", tf.shape(inputs)[0])
return inputs
跑步:
input_layer = layers.Input(3)
test = TestClass()(input_layer)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00025)
model = Model(input_layer, test)
model.compile(loss='mse', optimizer=optimizer, metrics=['mae', 'mse'])
model.predict(X.loc[:9, :])
会回来:
Dynamic batch size 10
1/1 [==============================] - 0s 65ms/step
array([[ 6.9646919e-01, -1.0032653e-02, 3.7556963e+00],
[ 2.8613934e-01, -8.4564441e-01, 9.9685013e-01],
[ 2.2685145e-01, 9.1146064e-01, 6.5008003e-01],
[ 5.5131477e-01, -1.3744969e+00, 8.6379850e-01],
[ 7.1946895e-01, -5.4706562e-01, 3.1904945e+00],
[ 4.2310646e-01, -7.5526608e-05, 5.2649558e-01],
[ 9.8076421e-01, -1.2116680e-01, 7.4064606e-01],
[ 6.8482971e-01, -2.0085855e+00, 5.3138912e-01],
[ 4.8093191e-01, -9.2064655e-01, 8.1520426e-01],
[ 3.9211753e-01, 1.6823435e-01, 1.2382457e+00]], dtype=float32)
然而,这已经贬值并且不再有效。我能找到的唯一选择是 Tf.shape(密集层输出) 但出现错误 kerastensor 不能用作张量流函数的输入 所以我首先尝试 covert_to_tensor ,但得到了同样的错误。 我已阅读上面的内容,但看不出它如何修复我的应用程序。 tf.shape 可以做到这一点吗?如果没有,除了折旧的 kerasbackendshape 是否没有其他选择?我已经尝试了几天来找到解决方案