列出 keras 模型中编译的指标?

问题描述 投票:0回答:2

我已经加载了一个 Keras 模型(或者刚刚创建并编译了它)。如何访问编译模型所用的指标对象列表? 我可以使用以下方法访问损失和优化器:

model.loss
model.optimizer
。 因此,我假设我会在
model.metrics
中找到指标列表,但这只会返回一个空列表。

python tensorflow keras
2个回答
3
投票

您必须运行模型至少 1 个时期才能使指标名称可用:

import numpy as np
import tensorflow as tf
x = np.random.uniform(0,1, (37432,512))
y = np.random.randint(0,2, (37432,1))
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(256, activation = 'relu'))
model.add(tf.keras.layers.Dense(2, activation='softmax'))

model.compile(loss="sparse_categorical_crossentropy",
              optimizer='adam',
              metrics=['accuracy'])
print(model.metrics_names)
_ = model.fit(x= x, y = y, validation_split=0.2, verbose = 0)
print(model.metrics_names)

输出:

[]
['loss', 'accuracy']

对于度量对象:

model.metrics[1:]

输出:

[<tensorflow.python.keras.metrics.MeanMetricWrapper at 0x7fbe702aee50>]

2
投票

您可以在

model.fit()
之前从
model.compiled_metrics
属性获取它们,该属性是在
model.compile()
中创建的 MetricGenerator 对象。拟合前后的内存地址是相同的,所以我假设它是同一个对象。这适用于 tf 2.6.0。

>>> model.compile(metrics=[tf.keras.losses.sparse_categorical_crossentropy])
>>> model.metrics
[]
>>> model.compiled_metrics
<keras.engine.compile_utils.MetricsContainer at 0x7f701c7ed4a8>
>>> model.compiled_metrics._metrics
[<keras.metrics.SparseCategoricalCrossentropy object at 0x7facf8109b00>]
>>> model.fit(x)
...
>>> model.metrics
[<keras.metrics.Mean object at 0x7facf81099e8>, 
<keras.metrics.SparseCategoricalCrossentropy object at 0x7facf8109b00>]
© www.soinside.com 2019 - 2024. All rights reserved.