生成人工智能中的 Synthcity DECAF 形状错误

问题描述 投票:0回答:1

我正在尝试利用 DECAF 生成器生成新数据,但出现无法解决的错误。

我使用的代码与主存储库文档中提到的完全相同(链接Synthcity Docs):

from sklearn.datasets import load_iris
from synthcity.plugins import Plugins

X, y = load_iris(as_frame = True, return_X_y = True)
X["target"] = y

plugin = Plugins().get("decaf", n_iter = 100)
plugin.fit(X)

plugin.generate(50)

我不断获得

ValueError: Shape of passed values is (150, 1), indices imply (150, 3)

无论我在做什么。我对这个错误的发生感到有点惊讶,因为它实际上是作者的一个案例研究。

有人能够解释或更重要的是解决这个错误吗?

machine-learning error-handling generative-adversarial-network
1个回答
0
投票

在这种情况下,错误消息表明传递的值的形状为 (150, 1),但索引暗示形状为 (150, 3)。这表明 DECAF 生成器期望数据中有 3 个特征(列),但您只传递了 1 个。 问题在于您加载 iris 数据集的方式。

默认情况下,load_iris 返回具有 4 个特征(萼片长度、萼片宽度、花瓣长度和花瓣宽度)的数据集。但是,您将目标变量 y 转换为具有单列(“目标”)的数据帧,然后将其与 X 连接,这会产生具有单个特征(目标变量)的数据帧。

要修复此错误,您需要将 iris 数据集的所有特征传递给 DECAF 生成器。不要将 y 转换为数据帧并将其与 X 连接,而是尝试将 X 直接传递给 plugin.fit 方法,如下所示:

from sklearn.datasets import load_iris
from synthcity.plugins import Plugins

X, y = load_iris(as_frame=True, return_X_y=True)

plugin = Plugins().get("decaf", n_iter=100)
plugin.fit(X)  # Pass X directly, without concatenating with y

plugin.generate(50)

通过执行此操作,您将把虹膜数据集的所有 4 个特征传递给 DECAF 生成器,该生成器应与预期尺寸匹配并解决形状不匹配错误。

© www.soinside.com 2019 - 2024. All rights reserved.