在 tensorflow-addon 中实现“WeightNormalization”层时发现类型错误

问题描述 投票:0回答:1

当我按照在 tensorflow-addons 中实现的“WeightNormalization”包装器的示例代码进行操作时,我发现了以下错误。

环境是

  • 操作系统:Ubuntu 18.04LTS
  • Tensorflow: 2.11 (latest-gpu), typeguard==3.0.1
  • 蟒蛇:3.8.10

我发现以下代码在 Colab 笔记本中不起作用。

麻烦的代码如下,

import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
from matplotlib import pyplot as plt

# Hyper Parameters
batch_size = 32
epochs = 10
num_classes=10

# WeightNorm ConvNet
wn_model = tf.keras.Sequential([
    tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(6, 5, activation='relu')),
    tf.keras.layers.MaxPooling2D(2, 2),
    tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(16, 5, activation='relu')),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(120, activation='relu')),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(84, activation='relu')),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(num_classes, activation='softmax')),
])

为 tfa.layers.WeightNormalization() 引发的 TypeError()

TypeError                                 Traceback (most recent call last)
<ipython-input-5-973cec7abd8b> in <module>
      1 # WeightNorm ConvNet
      2 wn_model = tf.keras.Sequential([
----> 3     tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(6, 5, activation='relu')),
      4     tf.keras.layers.MaxPooling2D(2, 2),
      5     tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(16, 5, activation='relu')),

2 frames
/usr/local/lib/python3.9/dist-packages/tensorflow_addons/layers/wrappers.py in __init__(self, layer, data_init, **kwargs)
     57 
     58     @typechecked
---> 59     def __init__(self, layer: tf.keras.layers, data_init: bool = True, **kwargs):
     60         super().__init__(layer, **kwargs)
     61         self.data_init = data_init

/usr/local/lib/python3.9/dist-packages/typeguard/_functions.py in check_argument_types(memo)
    111             value = memo.arguments[argname]
    112             try:
--> 113                 check_type_internal(value, expected_type, memo=memo)
    114             except TypeCheckError as exc:
    115                 qualname = qualified_name(value, add_class_prefix=True)

/usr/local/lib/python3.9/dist-packages/typeguard/_checkers.py in check_type_internal(value, annotation, memo)
    668             return
    669 
--> 670     if not isinstance(value, origin_type):
    671         raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}")
    672 

TypeError: isinstance() arg 2 must be a type or tuple of types

我认为'typeguard'包很麻烦,但我无法解决这个问题。你能给我一个答案吗?

python-3.x deep-learning tensorflow2.0 tensorflow-addons
1个回答
0
投票

你应该使用 3.0 以下的 typeguard 版本:

pip install typeguard<3.0
© www.soinside.com 2019 - 2024. All rights reserved.