如何在 TF 2.6.0 / Python 3.9.7 中保存并重新加载子类模型而不导致性能下降?

问题描述 投票:0回答:1

看起来像是百万美元的问题。我通过 Keras 中的子类

Model
构建了下面的模型。

模型训练得很好并且具有良好的性能,但我找不到一种方法来保存和恢复模型而不导致显着的性能损失。 我在 ROC 曲线上跟踪 AUC 以进行异常检测,加载模型后的 ROC 曲线比之前更差,使用完全相同的验证数据集。

我怀疑问题来自 BatchNormalization,但我可能是错的。

我尝试了几种选择:

这可行,但会导致性能下降。

model.save() / tf.keras.models.load()

这有效,但也会导致性能下降:

model.save_weights() / model.load_weights()

这不起作用,我收到以下错误:

tf.saved_model.save() / tf.saved_model.load()

AttributeError: '_UserObject' object has no attribute 'predict'

这也不起作用,因为子类模型不支持 json 导出:

model.to_json()

这是模型

class Deep_Seq2Seq_Detector(Model):
  def __init__(self, flight_len, param_len, hidden_state=16):
    super(Deep_Seq2Seq_Detector, self).__init__()
    self.input_dim = (None, flight_len, param_len)
    self._name_ = "LSTM"
    self.units = hidden_state
    
    self.regularizer0 = tf.keras.Sequential([
        layers.BatchNormalization()
        ])
    
    self.encoder1 = layers.LSTM(self.units,
                  return_state=False,
                  return_sequences=True,
                  #activation="tanh",
                  name='encoder1',
                  input_shape=self.input_dim)#,
                  #kernel_regularizer= tf.keras.regularizers.l1(),
                  #)
    
    self.regularizer1 = tf.keras.Sequential([
        layers.BatchNormalization(),
        layers.Activation("tanh")
        ])
    
    self.encoder2 = layers.LSTM(self.units,
                  return_state=False,
                  return_sequences=True,
                  #activation="tanh",
                  name='encoder2')#,
                  #kernel_regularizer= tf.keras.regularizers.l1()
                  #) #                    input_shape=(None, self.input_dim[1],self.units),
    
    self.regularizer2 = tf.keras.Sequential([
        layers.BatchNormalization(),
        layers.Activation("tanh")
        ])
    
    self.encoder3 = layers.LSTM(self.units,
                  return_state=True,
                  return_sequences=False,
                  activation="tanh",
                  name='encoder3')#,
                  #kernel_regularizer= tf.keras.regularizers.l1(),
                  #) #                   input_shape=(None, self.input_dim[1],self.units),
    
    self.repeat = layers.RepeatVector(self.input_dim[1])
    
    self.decoder = layers.LSTM(self.units,
                  return_sequences=True,
                  activation="tanh",
                  name="decoder",
                  input_shape=(self.input_dim[1],self.units))
    
    self.dense = layers.TimeDistributed(layers.Dense(self.input_dim[2]))

  @tf.function 
  def call(self, x):
    
    # Encoder
    x0 = self.regularizer0(x)
    x1 = self.encoder1(x0)
    x11 = self.regularizer1(x1)
    
    x2 = self.encoder2(x11)
    x22 = self.regularizer2(x2)
    
    output, hs, cs = self.encoder3(x22)
    
    # see https://www.tensorflow.org/guide/keras/rnn 
    encoded_state = [hs, cs] 
    repeated_vec = self.repeat(output)
    
    # Decoder
    decoded = self.decoder(repeated_vec, initial_state=encoded_state)
    output_decoder = self.dense(decoded)

    return output_decoder

我见过 Git 线程,但没有直接答案: https://github.com/keras-team/keras/issues/4875

有人找到解决办法了吗?我必须使用函数式 API 或顺序式 API 吗?

tensorflow keras tensorflow2.0 tf.keras
1个回答
0
投票

问题似乎出在

Subclassing API

我使用

Functional API
重建了完全相同的模型,现在
model.save
/
model.load
产生了相似的结果。

© www.soinside.com 2019 - 2024. All rights reserved.