我计划使用 Google TPU 进行科学数值模拟(有限元分析)。
也就是说,TPU 支持
float32
有一个通知这里,但它说XLA会自动转换它。
是的。这不是默认设置,但它可用。
您可以在特定运算上设置
precision=jax.lax.Precision.HIGHEST
它“使用更多的 MXU 通道来实现完整的 float32 精度”。