我正在尝试创建一个模型子类,该子类具有可变数量的层和隐藏层的大小。
由于隐藏层的数量和大小不固定,因此我根据构造函数参数将实例化的Keras层附加到列表中。但是我不明白为什么使用列表self.W保留Keras图层时,模型会忽略它们的权重。
class MLP(keras.Model):
def __init__(self, first_size, num_hidden_layers, hidden_activation, num_classes, **kwargs):
super(MLP, self).__init__()
self.W = [Dense(units=first_size//(2**i), activation=hidden_activation) for i in range(num_hidden_layers)]
# Regression task
if num_classes == 0:
self.W.append(Dense(units=1, activation='linear'))
# Classification task
else:
self.W.append(Dense(units=num_classes, activation='softmax'))
def call(self, inputs):
x = inputs
for w in self.W:
x = w(x)
return x
model = MLP(first_size=128, num_hidden_layers=4, hidden_activation='relu', num_classes=10)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['acc'])
model.fit(x_train, y_train, batch_size=64, epochs=20, validation_data=(x_val, y_val))
model.summary()
型号:“ mlp_23”_________________________________________________________________层(类型)输出形状参数#================================================== ==============总参数:0可训练的参数:0不可训练的参数:0_________________________________________________________________
我认为您可以轻松地做到这一点。例如,我使用了sklearn的虹膜数据集。
from keras.models import Model
from keras.layers import Input, Dense
import sklearn.datasets
iris_dataset = sklearn.datasets.load_iris()
x_train = iris_dataset["data"]
y_train = iris_dataset["target"]
inputs = Input(shape=x_train[0].shape)
x = inputs
num_hidden_layers=4
num_classes=10
hidden_activation='relu'
first_size=128
for i in range(num_hidden_layers):
x=Dense(units=first_size//(2**i), activation=hidden_activation)(x)
outputs=Dense(units=num_classes, activation='softmax')(x)
model = Model(inputs=inputs,outputs=outputs)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['acc'])
model.summary()
输出
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 4) 0
_________________________________________________________________
dense_1 (Dense) (None, 128) 640
_________________________________________________________________
dense_2 (Dense) (None, 64) 8256
_________________________________________________________________
dense_3 (Dense) (None, 32) 2080
_________________________________________________________________
dense_4 (Dense) (None, 16) 528
_________________________________________________________________
dense_5 (Dense) (None, 10) 170
=================================================================
Total params: 11,674
Trainable params: 11,674
Non-trainable params: 0