在 TensorFlow 中使用 Mac M1/M2 时,有没有办法将 Adam 更改为旧版?

问题描述 投票:0回答:1

我正在运行深度学习模型,但有时我使用工作 PC,有时使用 Mac。问题是,每当使用 Mac(带有 M2 芯片)时,我都会收到此警告消息:

警告:absl:目前,v2.11+ 优化器

tf.keras.optimizers.Adam
在 M1/M2 Mac 上运行缓慢,请改用旧版 Keras 优化器,位于
tf.keras.optimizers.legacy.Adam

我想知道是否有办法在我的 Mac 中切换到

tf.keras.optimizers.legacy.Adam
。作为一个附带问题,它有什么好处吗?我想是这样,因为考虑到问题的简单性,我的训练比我预期的要多。

作为最小可行示例,请查看我从 TensorFlow 网站改编的示例

import tensorflow as tf

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# If-else should go here:

# If in a Mac with M1 / M2:
# opt = tf.keras.optimizers.legacy.Adam(learning_rate=0.0005)

# Else:
opt = tf.keras.optimizers.Adam(learning_rate=0.0005)

model.compile(optimizer=opt,
              loss=loss_fn,
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)

提前致谢,如果已经回答了这个问题,我们深表歉意。我找不到类似的问题。

python macos tensorflow deep-learning tensorflow2.0
1个回答
0
投票

我相信我通过查看 keras 优化器代码找到了解决第一个问题的方法。

解决方案的示例如下:

import tensorflow as tf
import platform

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

if platform.system() == "Darwin" and platform.processor() == "arm":
    opt = tf.keras.optimizers.legacy.Adam(learning_rate=0.0005)
else:
    opt = tf.keras.optimizers.Adam(learning_rate=0.0005)

model.compile(optimizer=opt,
              loss=loss_fn,
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)

我解决的方法:

  1. 我在代码中找到了抛出错误的地方(代码中的第 69 行here,截至 2023 年 10 月 3 日):

if platform.system() == "Darwin" and platform.processor() == "arm":

  1. 为此,需要

    platform
    模块。我在上面的 MVE 的第二行中导入了它。

  2. 然后,我添加了优化器来匹配所需的条件。

我希望这有帮助。

© www.soinside.com 2019 - 2024. All rights reserved.