class_id 支持使用 metric_recall_at_ precision 或 metric_ precision_at_recall 的分类深度学习问题

问题描述 投票:0回答:1

在支持 GPU 的 Windows 10 计算机上使用 R 3.6.3、keras 2.9.0 和 tensorflow 2.9.0(网状点指向 python 3.6.10)
我无法使用指标

class_id
metric_recall_at_precision
的可选
metric_precision_at_recall
参数来编译模型(3 个分类类)。 产生以下错误:

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  TypeError: __init__() got an unexpected keyword argument 'class_id'

这些指标的 keras 文档明确指出“class_id”是一个可选参数...... 该模型使用

metric_sparse_categorical_accuracy
正确编译,或者如果我将模型转换为二元分类(S形输出)并使用
metric_recall_at_precision
metric_precision_at_recall

这是生成错误的(简化)模型的代码:

model <- keras_model_sequential() %>% 
     layer_conv_1d(filters = 64, kernel_size = 11, strides = 5, activation = "relu", input_shape = c(446,3)) %>% 
     layer_max_pooling_1d(pool_size = 5) 

model %>% 
    layer_dropout(rate = 0.1) %>%
    layer_flatten() %>% 
    layer_dense(units = 64, activation = "relu") %>%
    layer_dense(units = 3, activation = "softmax")  

model %>% compile(
    optimizer = "adam",
    loss = "sparse_categorical_crossentropy",   
    metrics =  metric_recall_at_precision(precision=precision, class_id=0))

知道如何使用 class_id 参数编译此模型吗?

r tensorflow keras metrics
1个回答
0
投票

升级我的python虚拟环境中的Tensorflow版本解决了问题! 我升级到

Tensorflow v2.6.0
,现在可以编译我的模型了!

感谢t-kalinowskiQuinten的指点!

© www.soinside.com 2019 - 2024. All rights reserved.