以下几行应该得到相同的结果:
import tensorflow as tf
print(tf.random.truncated_normal(shape=[2],seed=1234))
print(tf.random.truncated_normal(shape=[2],seed=1234))
但是我得到了:
tf.Tensor([-0.12297685 -0.76935077], shape=(2,), dtype=tf.float32)
tf.Tensor([0.37034193 1.3367208 ], shape=(2,), dtype=tf.float32)
为什么?
这似乎是故意的,请参阅文档这里。特别是“示例”部分。
您需要的是
stateless_truncated_normal
:
print(tf.random.stateless_truncated_normal(shape=[2],seed=[1234, 1]))
print(tf.random.stateless_truncated_normal(shape=[2],seed=[1234, 1]))
给我
tf.Tensor([1.0721238 0.10303579], shape=(2,), dtype=float32)
tf.Tensor([1.0721238 0.10303579], shape=(2,), dtype=float32)
注意:这里的种子需要是两个数字,老实说我不知道为什么(文档没有说)。
Tensorflow 有两种类型的种子:全局种子和操作种子 - 这也是为什么您需要将两个数字传递给
stateless_truncated_normal
,如 xdurch0 在他的答案中所述。 Tensorflow 将这两个种子组合起来生成一个新的种子。
tf.random.truncated_normal(shape=[2],seed=1234) # global seed #1 & operational 1234 -> Seed A
tf.random.truncated_normal(shape=[2],seed=1234) # global seed #2 & operational 1234 -> Seed B
有多种方法可以解决您的问题。也预先设置两次全局种子。 在
@tf.functions
内部工作,这些重置全局种子并拥有自己的操作计数器。
或者使用其他答案中所写的 stateless_truncated_normal
。
正如已经链接的那样,在文档中也有描述。