我有一个名为
ControlledBackward
的自定义图层,它接受 prev_layer
和 input_layer_of_mask
:
class ControlledBackward(Layer):
def __init__(self, **kwargs):
super(ControlledBackward, self).__init__(**kwargs)
def call(self, inputs):
mask, prev_layer = inputs
# Cast the mask to the same type as prev_layer
mask = Lambda(lambda x: ops.cast(x, dtype=prev_layer.dtype))(mask)
mask_h = Lambda(lambda x: ops.logical_not(x))(mask)
# Apply the stop_gradient function on the masked parts
stopped_gradient_part = Lambda(lambda x: ops.stop_gradient(x))(prev_layer)
# Multiply stopped gradient part with mask_h
stopped_gradient_masked = Multiply()([stopped_gradient_part, mask_h])
# Multiply normal (non-stopped) part with mask
non_stopped_gradient_part = Multiply()([prev_layer, mask])
# Add the stopped and non-stopped parts
return Add()([stopped_gradient_masked, non_stopped_gradient_part])
这对于隐藏层工作正常,但不适用于输出层。考虑这个测试代码,其中
X_mask
是某些样本的梯度阻挡器的标记:
import numpy as np
input_layer = Input(shape=(1,), name='input_layer')
gradient_blocker_mask = Input(shape=(1,), dtype='bool', name='cg')
hidden_a = Dense(1, name='hidden_a')(input_layer)
controlled_hidden_a = ControlledBackward(name='gradient_blocker')([gradient_blocker_mask, hidden_a])
output_a = Dense(1, name='output_a')(controlled_hidden_a)
model = Model(inputs=[input_layer, gradient_blocker_mask], outputs=[output_a])
# print weights
print('hidden_weights', model.get_layer('hidden_a').get_weights())
print('output_weights',model.get_layer('output_a').get_weights())
# dummy data
X = np.array([[42], [3]])
X_mask = np.array([[False], [False]])
y = np.array([[7], [5]])
# pred
y_pred = model.predict([X, X_mask], verbose=0)
print(y_pred)
print('')
# fit
model.compile(optimizer='adam', loss='mse')
model.fit([X, X_mask], y, epochs=100, verbose=0)
# print weights
print('hidden_weights', model.get_layer('hidden_a').get_weights())
print('output_weights',model.get_layer('output_a').get_weights())
# predict
y_pred = model.predict([X, X_mask], verbose=0)
print(y_pred)
# display(plot_model(model, show_shapes=True, show_layer_names=True))
输出:
# Before Train
hidden_weights [array([[1.2073823]], dtype=float32), array([0.], dtype=float32)]
output_weights [array([[0.7683755]], dtype=float32), array([0.], dtype=float32)]
[[38.964367]
[ 2.783169]]
# After Train
hidden_weights [array([[1.2073823]], dtype=float32), array([0.], dtype=float32)]
output_weights [array([[0.6712855]], dtype=float32), array([-0.09659182], dtype=float32)]
[[33.944336 ]
[ 2.3349032]]
如您所见,
hidden layer
的权重符合预期;它没有改变权重,但 output layer
仍然更新了它的权重。
那么,如何防止
output layer
的权重根据X_mask
更新其权重呢?
原来我只需要将
output_layer
参数设置为 prev_layer
:
import numpy as np
input_layer = Input(shape=(1,), name='input_layer')
gradient_blocker_mask = Input(shape=(1,), dtype='bool', name='cg')
hidden_a = Dense(1, name='hidden_a')(input_layer)
output_a = Dense(1, name='output_a')(hidden_a)
controlled_output_a = ControlledBackward(name='gradient_blocker')([gradient_blocker_mask, output_a])
model = Model(inputs=[input_layer, gradient_blocker_mask], outputs=[controlled_output_a])
# print weights
print('hidden_weights', model.get_layer('hidden_a').get_weights())
print('output_weights',model.get_layer('output_a').get_weights())
# dummy data
X = np.array([[42], [3]])
X_mask = np.array([[False], [False]])
y = np.array([[7], [5]])
# pred
y_pred = model.predict([X, X_mask], verbose=0)
print(y_pred)
print('')
# fit
model.compile(optimizer='adam', loss='mse')
model.fit([X, X_mask], y, epochs=100, verbose=0)
# print weights
print('hidden_weights', model.get_layer('hidden_a').get_weights())
print('output_weights',model.get_layer('output_a').get_weights())
# predict
y_pred = model.predict([X, X_mask], verbose=0)
print(y_pred)
display(plot_model(model, show_shapes=True, show_layer_names=True))
输出:
hidden_weights [array([[-0.74082583]], dtype=float32), array([0.], dtype=float32)]
output_weights [array([[1.0905765]], dtype=float32), array([0.], dtype=float32)]
[[-33.932945 ]
[ -2.4237816]]
hidden_weights [array([[-0.74082583]], dtype=float32), array([0.], dtype=float32)]
output_weights [array([[1.0905765]], dtype=float32), array([0.], dtype=float32)]
[[-33.932945 ]
[ -2.4237816]]