Tensorflow 的 random.truncated_normal 使用相同的种子返回不同的结果

问题描述 投票:0回答:2

以下几行应该得到相同的结果:

import tensorflow as tf
print(tf.random.truncated_normal(shape=[2],seed=1234))
print(tf.random.truncated_normal(shape=[2],seed=1234))

但是我得到了:

tf.Tensor([-0.12297685 -0.76935077], shape=(2,), dtype=tf.float32)
tf.Tensor([0.37034193 1.3367208 ], shape=(2,), dtype=tf.float32)

为什么?

python tensorflow machine-learning tensor random-seed
2个回答
2
投票

这似乎是故意的,请参阅文档这里。特别是“示例”部分。

您需要的是

stateless_truncated_normal

print(tf.random.stateless_truncated_normal(shape=[2],seed=[1234, 1]))
print(tf.random.stateless_truncated_normal(shape=[2],seed=[1234, 1]))

给我

tf.Tensor([1.0721238  0.10303579], shape=(2,), dtype=float32)
tf.Tensor([1.0721238  0.10303579], shape=(2,), dtype=float32)

注意:这里的种子需要是两个数字,老实说我不知道为什么(文档没有说)。


2
投票

Tensorflow 有两种类型的种子:全局种子和操作种子 - 这也是为什么您需要将两个数字传递给

stateless_truncated_normal
,如 xdurch0 在他的答案中所述。 Tensorflow 将这两个种子组合起来生成一个新的种子。

tf.random.truncated_normal(shape=[2],seed=1234) # global seed #1 & operational 1234 -> Seed A
tf.random.truncated_normal(shape=[2],seed=1234) # global seed #2 & operational 1234 -> Seed B

有多种方法可以解决您的问题。也预先设置两次全局种子。 在

@tf.functions
内部工作,这些重置全局种子并拥有自己的操作计数器。 或者使用其他答案中所写的
stateless_truncated_normal

正如已经链接的那样,在文档中也有描述。

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.