将keras优化器作为字符串参数传递给keras优化器函数

问题描述 投票:0回答:1

我正在使用包含超参数的keras文件来调整config.json深度学习模型的超参数。

    { “opt: “Adam”,
      “lr”: 0.01,
       “grad_clip”: 0.5
    }

Keras允许通过两种方式指定优化器:

  1. 作为字符串参数,在调用函数时没有附加参数。
model.compile(loss='categorical_crossentropy’,
              optimizer=’Adam’, 
              metrics=['mse'])
  1. 具有附加参数的同名功能。
model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.Adam(lr=0.01, clipvalue=0.5), 
              metrics=['mse'])

我的问题是:如何将优化器(SGD,Adam等)作为配置文件中的参数以及子参数传递,并像[2]中那样使用keras.optimizers.optimizer()函数调用?

from keras.models import Sequential
from keras.layers import LSTM, Dense, TimeDistributed, Bidirectional
from keras import optimizers

def train(X,y, opt, lr, clip):

   model = Sequential()
   model.add(Bidirectional(LSTM(100, return_sequences=True), input_shape=(500, 300)))    
   model.add(TimeDistributed(Dense(5, activation='sigmoid')))

   model.compile(loss='categorical_crossentropy',
                  optimizer=optimizers.opt(lr=lr, clipvalue=clip), 
                  metrics=['mse'])

   model.fit(X, y, epochs=100, batch_size=1, verbose=2)

   return(model)

当我尝试将参数从配置文件传递到上述train()函数时,出现以下错误:

AttributeError: module 'keras.optimizers' has no attribute 'opt'

如何从函数中解析字符串中的优化器?

python optimization keras deep-learning parameter-passing
1个回答
1
投票

您可以使用像这样构造优化器的类:

class Optimizer(object):
    def get_opt(self, opt, lr, clip):
        """Dispatch method"""
        method_name = 'opt_' + str(opt)
        # Get the method from 'self'. Default to a lambda.
        method = getattr(self, method_name, lambda: "Invalid optimizier")
        # Call the method as we return it
        return method()

    def opt_Adam(self):
        return optimizer.Adam(lr=lr,clipvalue=clip)

    def opt_example(self):
        return  optimizer.example(lr=lr,clipvalue=clip)

    #and so on for how many cases you would need

然后您可以将其称为:

a=Optimizer()
model.compile(loss='categorical_crossentropy',
              optimizer=a.get_opt(opt=opt, lr=lr, clip=clip), 
              metrics=['mse'])
© www.soinside.com 2019 - 2024. All rights reserved.