我正在使用slim walkthough notebook中的flowers数据集示例,并尝试重用模型的共享权重
def my_cnn(images, num_classes, is_training): # is_training is not used...
with slim.arg_scope([slim.max_pool2d], kernel_size=[3, 3], stride=2):
net = slim.conv2d(images, 64, [5, 5])
net = slim.max_pool2d(net)
net = slim.conv2d(net, 64, [5, 5])
net = slim.max_pool2d(net)
net = slim.flatten(net)
net = slim.fully_connected(net, 192)
net = slim.fully_connected(net, num_classes, activation_fn=None)
return net
...
with tf.variable_scope("model") as scope:
logits = my_cnn(images, num_classes=dataset.num_classes, is_training=True)
scope.reuse_variables()
val_logits = my_cnn(val_images, num_classes=dataset.num_classes, is_training=False)
但是当我尝试运行此会话时,我仍然收到此错误:
<ipython-input-49-15390a9fff86> in <module>()
21 logits = my_cnn(images, num_classes=dataset.num_classes, is_training=True)
22 scope.reuse_variables()
---> 23 val_logits = my_cnn(val_images, num_classes=dataset.num_classes, is_training=False)
24
25 # Specify the `train` loss function:
...
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variable_scope.py in _get_single_variable(self, name, shape, dtype, initializer, regularizer, partition_info, reuse, trainable, collections, caching_device, validate_shape, use_resource, constraint)
763 raise ValueError("Variable %s does not exist, or was not created with "
764 "tf.get_variable(). Did you mean to set "
--> 765 "reuse=tf.AUTO_REUSE in VarScope?" % name)
766 if not shape.is_fully_defined() and not initializing_from_value:
767 raise ValueError("Shape of a new variable (%s) must be fully defined, "
ValueError: Variable model/Conv_2/weights does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?
reuse=tf.AUTO_REUSE
似乎可以做到这一点。请务必使用权重向图层添加范围或名称
def my_cnn(images, num_classes, is_training): # is_training is not used...
with slim.arg_scope([slim.conv2d, slim.fully_connected], reuse=tf.AUTO_REUSE):
with slim.arg_scope([slim.max_pool2d], kernel_size=[3, 3], stride=2):
net = slim.conv2d(images, 64, [5, 5], scope="conv1")
net = slim.max_pool2d(net)
net = slim.conv2d(net, 64, [5, 5], scope="conv2")
net = slim.max_pool2d(net)
net = slim.flatten(net)
net = slim.fully_connected(net, 192, scope="fc1")
net = slim.fully_connected(net, num_classes, activation_fn=None, scope="fc2")
return net