我有一个关于 JAX 中浮点精度的问题。对于以下代码,
import numpy as np
import jax.numpy as jnp
print('jnp.arctan(10) is:','%.60f' % jnp.arctan(10))
print('np.arctan(10) is:','%.60f' % np.arctan(10))
jnp.arctan(10) is: 1.471127629280090332031250000000000000000000000000000000000000
np.arctan(10) is: 1.471127674303734700345103192375972867012023925781250000000000
print('jnp.arctan(10+1e-7) is:','%.60f' % jnp.arctan(10+1e-7))
print('np.arctan(10+1e-7) is:','%.60f' % np.arctan(10+1e-7))
jnp.arctan(10+1e-7) is: 1.471127629280090332031250000000000000000000000000000000000000
np.arctan(10+1e-7) is: 1.471127675293833592107262120407540351152420043945312500000000
对于输入变量 (1e-7) 的微小变化,jnp 对于 arctan(x) 给出了相同的结果,但 np 没有。我的问题是如何让 jax.numpy 对 x 的微小变化获得正确的数字?
欢迎任何评论。
JAX 默认使用 float32 计算,其相对精度约为
1E-7
。这意味着您的两个输入实际上是相同的:
>>> np.float32(10) == np.float32(10 + 1E-7)
True
如果您想要像 NumPy 一样的 64 位精度,您可以按照JAX 锐位:双精度中的讨论启用它,然后结果将匹配 64 位精度:
import jax
jax.config.update('jax_enable_x64', True)
import jax.numpy as jnp
import numpy as np
print('jnp.arctan(10) is:','%.60f' % jnp.arctan(10))
print('np.arctan(10) is: ','%.60f' % np.arctan(10))
print('jnp.arctan(10+1e-7) is:','%.60f' % jnp.arctan(10+1e-7))
print('np.arctan(10+1e-7) is: ','%.60f' % np.arctan(10+1e-7))
jnp.arctan(10) is: 1.471127674303734700345103192375972867012023925781250000000000
np.arctan(10) is: 1.471127674303734700345103192375972867012023925781250000000000
jnp.arctan(10+1e-7) is: 1.471127675293833592107262120407540351152420043945312500000000
np.arctan(10+1e-7) is: 1.471127675293833592107262120407540351152420043945312500000000
(但请注意,即使 Python 和 NumPy 使用的 64 位精度也只能精确到 10^16 分之一,因此与真正的反正切值相比,您打印的表示形式中的大多数数字都不准确)。