我有一个
Keras TextVectorization
层,它使用自定义标准化函数。
def custom_standardization(input_string, preserve=['[', ']'], add=['¿']):
strip_chars = string.punctuation
for item in add:
strip_chars += item
for item in preserve:
strip_chars = strip_chars.replace(item, '')
lowercase = tf.strings.lower(input_string)
output = tf.strings.regex_replace(lowercase, f'[{re.escape(strip_chars)}]', '')
return output
target_vectorization = keras.layers.TextVectorization(max_tokens=vocab_size,
output_mode='int',
output_sequence_length=sequence_length + 1,
standardize=custom_standardization)
target_vectorization.adapt(train_spanish_texts)
我想保存调整后的配置以供推理模型使用。
一种方法,如here所述,是将
weights
和config
分别保存为pickle文件并重新加载它们。
然而,
target_vectorization.get_config()
回归
{'name': 'text_vectorization_5',
'trainable': True,
...
'standardize': <function __main__.custom_standardization(input_string, preserve=['[', ']'], add=['¿'])>,
...
'vocabulary_size': 15000}
正在保存到 pickle 文件中。
尝试使用
keras.layers.TextVectorization.from_config(pickle.load(open('ckpts/spanish_vectorization.pkl', 'rb'))['config'])
加载此配置会导致 TypeError: Could not parse config: <function custom_standardization at 0x2a1973a60>
,因为该文件没有有关此自定义标准化函数的任何信息。
在这种情况下,保存 TextVectorization 权重和配置以供推理模型使用的好方法是什么?
问题
这似乎是与序列化自定义标准化可调用相关的问题。请参阅此处的文档:(tf.keras.layers.TextVectorization)。
解决方案
文档指出,您应该使用带有以下装饰器的包装类将该层注册为 keras 可序列化对象(tf.keras. saving.register_keras_serialized)。
我已经使用您的自定义函数测试了一个最小的工作示例,该示例适用于 python 3.9.12 和 keras/tensorflow 2.15:
import tensorflow as tf
from tensorflow import keras
import string
import re
import pickle
def custom_standardization(input_string, preserve=['[', ']'], add=['¿']):
strip_chars = string.punctuation
for item in add:
strip_chars += item
for item in preserve:
strip_chars = strip_chars.replace(item, '')
lowercase = tf.strings.lower(input_string)
output = tf.strings.regex_replace(lowercase, f'[{re.escape(strip_chars)}]', '')
return output
@keras.utils.register_keras_serializable(package='custom_layers', name='TextVectorizer')
class TextVectorizer(keras.layers.Layer):
def __init__(self, custom_standardization, **kwargs):
super(TextVectorizer, self).__init__(**kwargs)
self.custom_standardization = custom_standardization
self.vectorizer = tf.keras.layers.TextVectorization(
standardize=self.custom_standardization,
max_tokens=1000, # You can adjust this parameter based on your dataset
output_mode='int'
)
def call(self, inputs):
return self.vectorizer(inputs)
def get_config(self):
config = super(TextVectorizer, self).get_config()
return config
# Example usage
text_vectorizer = TextVectorizer(custom_standardization)
# Adapt the TextVectorization layer to your training data
train_data = tf.constant(['Hello [World]!', 'Another [example].'])
text_vectorizer.vectorizer.adapt(train_data)
# Build the layer to initialize the TextVectorization layer
text_vectorizer.build(input_shape=(None,))
# Create a model to include the TextVectorization layer
model = tf.keras.Sequential([text_vectorizer])
model.build(input_shape=())
# Save the weights of the model
model.save_weights('text_vectorizer_weights.tf')
# Load the weights into a new instance of TextVectorization
loaded_text_vectorizer = TextVectorizer(custom_standardization)
loaded_text_vectorizer.build(input_shape=(None,))
# Create a model to include the loaded TextVectorization layer
loaded_model = tf.keras.Sequential([loaded_text_vectorizer])
# Adapt the TextVectorization layer to the same training data
loaded_text_vectorizer.vectorizer.adapt(train_data)
# Load the weights into the new model
loaded_model.load_weights('text_vectorizer_weights.tf')
# Compile the model after loading the weights
loaded_model.compile()
# Test the loaded layer
text_input = tf.constant(['Hello [World]!'])
output = loaded_model(text_input)
print(output)