JAX 的精度

问题描述 投票:0回答:1

我有一个关于 JAX 中浮点精度的问题。对于以下代码,

import numpy as np
import jax.numpy as jnp

print('jnp.arctan(10) is:','%.60f' % jnp.arctan(10))
print('np.arctan(10) is:','%.60f' % np.arctan(10))

jnp.arctan(10) is: 1.471127629280090332031250000000000000000000000000000000000000
np.arctan(10) is: 1.471127674303734700345103192375972867012023925781250000000000


print('jnp.arctan(10+1e-7) is:','%.60f' % jnp.arctan(10+1e-7))
print('np.arctan(10+1e-7) is:','%.60f' % np.arctan(10+1e-7))

jnp.arctan(10+1e-7) is: 1.471127629280090332031250000000000000000000000000000000000000
np.arctan(10+1e-7) is: 1.471127675293833592107262120407540351152420043945312500000000
对于输入变量 (1e-7) 的微小变化,

jnp 对于 arctan(x) 给出了相同的结果,但 np 没有。我的问题是如何让 jax.numpy 对 x 的微小变化获得正确的数字?

欢迎任何评论。

python jax
1个回答
1
投票

JAX 默认使用 float32 计算,其相对精度约为

1E-7
。这意味着您的两个输入实际上是相同的:

>>> np.float32(10) == np.float32(10 + 1E-7)
True

如果您想要像 NumPy 一样的 64 位精度,您可以按照JAX 锐位:双精度中的讨论启用它,然后结果将匹配 64 位精度:

import jax
jax.config.update('jax_enable_x64', True)

import jax.numpy as jnp
import numpy as np

print('jnp.arctan(10) is:','%.60f' % jnp.arctan(10))
print('np.arctan(10) is: ','%.60f' % np.arctan(10))

print('jnp.arctan(10+1e-7) is:','%.60f' % jnp.arctan(10+1e-7))
print('np.arctan(10+1e-7) is: ','%.60f' % np.arctan(10+1e-7))
jnp.arctan(10) is: 1.471127674303734700345103192375972867012023925781250000000000
np.arctan(10) is:  1.471127674303734700345103192375972867012023925781250000000000
jnp.arctan(10+1e-7) is: 1.471127675293833592107262120407540351152420043945312500000000
np.arctan(10+1e-7) is:  1.471127675293833592107262120407540351152420043945312500000000

(但请注意,即使 Python 和 NumPy 使用的 64 位精度也只能精确到 10^16 分之一,因此与真正的反正切值相比,您打印的表示形式中的大多数数字都不准确)。

© www.soinside.com 2019 - 2024. All rights reserved.