在 TensorFlow 2.0 中使用神经网络学习聚合器

问题描述 投票:0回答:0

我正在尝试使用 Tensorflow 设计以下神经网络模型:

模型的一个输入是 X,一个 n 个维度为 3 的向量的列表。模型的第二个输入是 Y,一个从 0 开始按升序排列的 n 个自然数的列表。模型的输出是 Z,一个 m 的列表维度 3.

的向量

Y中有m个唯一的数,代表3维输入向量的类别。不同类别的输入向量个数可能不同。

模型的第一层将 X 中的每个向量转换为维度为 2 的向量,并应用 'gelu' 激活函数。第二层执行“segment_sum”,使用 Y 将 n 个 2 维向量减少为 m 个 2 维向量。第三层将 m 个 2 维向量转换为 3 维,这是模型的输出。

我使用余弦差异损失和 Adam 优化器来训练模型。

这是我为此编写的代码:

import numpy as np
import tensorflow as tf
from tensorflow import keras

# Prepare the input and output data (example)
n = 10
m = 4
X = np.random.random((n, 3)).astype('float32')
Y = np.array([0, 0, 1, 1, 2, 2, 2, 3, 3, 3]).astype('int32')
Z = np.random.random((m, 3)).astype('float32')

class CustomModel(tf.keras.Model):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.dense1 = keras.layers.Dense(2, activation='gelu')
        self.dense2 = keras.layers.Dense(3)

    def call(self, inputs):
        X, Y = inputs
        X = self.dense1(X)
        X = tf.math.segment_sum(X, Y)
        Z = self.dense2(X)
        return Z


model = CustomModel()

model.compile(loss=tf.keras.losses.CosineSimilarity(axis=1), optimizer=tf.keras.optimizers.Adam())

model.fit([X, Y], Z, epochs=10)

该模型旨在学习聚合函数。但是,我收到以下错误:

Traceback (most recent call last):
  File "/home/nitesh/PycharmProjects1/pythonProject/research/reasoning_with_vectors/custom_model.py", line 31, in <module>
    model.fit([X, Y], Z, epochs=10)
  File "/home/nitesh/miniconda3/envs/relbert/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/nitesh/miniconda3/envs/relbert/lib/python3.10/site-packages/keras/engine/data_adapter.py", line 1852, in _check_data_cardinality
    raise ValueError(msg)
ValueError: Data cardinality is ambiguous:
  x sizes: 10, 10
  y sizes: 4
Make sure all arrays contain the same number of samples.

我尝试了很多,但我没有找到任何方法来使用 TensorFlow 2.0 对模型进行编码。我尝试询问 GPT4 和 Bard,但没有得到任何满意的答案。

有人能帮忙吗?提前致谢。

tensorflow deep-learning aggregation max-pooling
© www.soinside.com 2019 - 2024. All rights reserved.