我使用 Jax 后端在 Keras 中完全子类化了 llama2 模型。 我第一次运行代码时,它训练得很好(在莎士比亚数据集上),并且预测也很好。
但是下次我运行它时,模型不会收敛,并且损失总是在第 6 纪元上升并振荡。与上次正确训练相比,它的发生并没有发生任何变化。即使我会在数据集中设置一个用于洗牌的种子
以下是训练参数:
@dataclass
class TArgs:
checkpoint:str = "weights/llama_shakespeare/Epoch{epoch}.weights.h5"
steps_per_epoch:int = 100
batch_size:int = 16
num_steps:int = 3000
epochs:int = num_steps//steps_per_epoch
# cosine decay with warmup
init_lr:float = 8e-4
max_lr:float = 1e-3
min_lr:float = 5e-4
alpha:float = min_lr/init_lr
warmup_steps:int = 100
decay_steps:int = 85*num_steps//100
# adamw
beta1:float = 0.9
beta2:float = 0.95
clipvalue:float = 1e0
weight_decay:float = 1e-1
我陷入困境,请帮助我并告诉我是否需要任何其他信息。谢谢!
您的网络从不同的初始权重开始。
最好消除所有随机性并将种子设置为 numpy、keras 等任何相关的地方。
就您而言,
keras.utils.set_random_seed(812)
可能会有所帮助。来源:
https://keras.io/examples/keras_recipes/reproducibility_recipes/
为了确保在导入后立即添加此内容。然后运行两次训练,每次在 1 个时期后停止 - 如果它给你相同的模型,那就好了。现在您需要找到导致训练收敛的种子。
旁注:您的批量较小。也许如果你增加它,所有的训练看起来都会更加相似。一切都取决于数据,但这对我来说是一个“怀疑”。