我正在尝试构建一个风格迁移模型,这些是我使用的函数
def style_transfer(model, style_path, content_path, learning_rate,style_weight, content_weight, epochs):
history = {"content_loss":[], "style_loss":[], "total_loss":[]}
style_image = img_preprocess(style_path)
content_image = img_preprocess(content_path)
# getting content and style features
style_features = get_features(model, style_image)["style"]
content_features = get_features(model, content_image)["content"]
# adding some noise to the content_image
content_img_shape = tf.shape(content_image)
noise = tf.random.uniform(content_img_shape, minval=0, maxval=0.5)
gen_image = tf.add(content_image, noise)
gen_image = tf.Variable(gen_image, tf.float32)
# changing the gen_image
opt = tf.keras.optimizers.Adam(learning_rate, beta_1=0.99, epsilon=1e-8)
for i in range(epochs):
all_losses = gradient_descent(model,
opt,
gen_image,
style_features,
content_features,
style_weight,
content_weight)
history["total_loss"].append(all_losses[0].numpy())
history["style_loss"].append(all_losses[1].numpy())
history["content_loss"].append(all_losses[2].numpy())
if i % 100 == 0:
img = PIL.Image.fromarray(deprocess_img(gen_image.numpy()))
display.clear_output(wait=True)
display.display_png(img)
print(f"epoch: {i}\ntotal_loss: {all_losses[0]}")
return history
还有宏伟下降函数
@tf.function()
def gradient_descent(model,optimizer, gen_image, style_features,
content_features, style_weight, content_weight):
with tf.GradientTape() as g:
g.watch(gen_image)
gen_image_sfeatures , gen_image_cfeatures = get_features(model, gen_image).values()
j_content = content_loss(content_features , gen_image_cfeatures)
j_style = style_loss(style_features , gen_image_sfeatures)
j_total = total_loss(j_style , j_content, style_weight, content_weight)
norm_means = np.array([103.939, 116.779, 123.68])
min_vals = -norm_means
max_vals = 255 - norm_means
grad = g.gradient(j_total , gen_image)
optimizer.apply_gradients([(grad, gen_image)])
gen_image.assign(clip(gen_image, min_vals, max_vals))
return j_total, j_style, j_content
第一次执行程序时。它运行没有问题。但是,当我第二次尝试执行 style_transfer 函数时,出现以下错误,为了使其再次工作,我应该重新执行梯度下降函数。所以每当我运行风格传递函数时,我应该先运行梯度下降。
File /opt/conda/lib/python3.10/site-packages/keras/optimizers/optimizer.py:512, in _BaseOptimizer.add_variable_from_reference(self, model_variable, variable_name, shape, initial_value)
510 else:
511 initial_value = tf.zeros(shape, dtype=model_variable.dtype)
512 variable = tf.Variable(
513 initial_value=initial_value,
514 name=f"{variable_name}/{model_variable._shared_name}",
515 dtype=model_variable.dtype,
516 trainable=False,
517 )
518 self._variables.append(variable)
519 return variable
ValueError: in user code:
File "/tmp/ipykernel_29/1467961479.py", line 16, in gradient_descent *
optimizer.apply_gradients([(grad, gen_image)])
File "/opt/conda/lib/python3.10/site-packages/keras/optimizers/optimizer.py", line 1174, in apply_gradients **
return super().apply_gradients(grads_and_vars, name=name)
File "/opt/conda/lib/python3.10/site-packages/keras/optimizers/optimizer.py", line 637, in apply_gradients
self.build(trainable_variables)
File "/opt/conda/lib/python3.10/site-packages/keras/optimizers/adam.py", line 139, in build
self.add_variable_from_reference(
File "/opt/conda/lib/python3.10/site-packages/keras/optimizers/optimizer.py", line 1106, in add_variable_from_reference
return super().add_variable_from_reference(
File "/opt/conda/lib/python3.10/site-packages/keras/optimizers/optimizer.py", line 512, in add_variable_from_reference
variable = tf.Variable(
ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
但问题是我在外面使用过
tf.Variable
一次tf.function
,你能帮我解决这个问题吗?
我认为问题可能出在 Adam 优化器上。该优化器记录每个
Variable
的动量,并为此创建与每个要优化的 Variable
关联的新 Variable
(即通过梯度磁带观看)。那些与 Variable
相关的新 gen_image
正在给 tf.function
带来麻烦。
也许最直接的解决方案是删除
tf.function
装饰器并看看有多少速度减慢。最终,将优化循环包装在 keras.Model
内并通过 model.compile(optimizer=opt)
关联优化器可能会更规范。然后,通过创建 Model
的不同实例来隔离不同图像上的优化会更容易。