如何配置 SciKit-Learn 中的函数 r2_score 来识别我已经安装了使用 GPU 的 PyTorch?

问题描述 投票:0回答:1

我是

PyTorch
的新手,我刚刚安装了它并运行了一个 Hello-World 示例。

首先,我安装了

PyTorch
,如下。

$ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

我还安装了

SciKit-Learn
,因为它在 Hello-World 示例中使用。

$ pip3 install scikit-learn

然后,我运行以下程序。

import torch
from torch import nn
from sklearn.metrics import r2_score


class MyMachine(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(2,5),
            nn.ReLU(),
            nn.Linear(5,1)
        )

    def forward(self, x):
        x = self.fc(x)
        return x


def get_dataset():
        X = torch.rand((1000,2))
        x1 = X[:,0]
        x2 = X[:,1]
        y = x1 * x2
        return X, y


def train():
    model = MyMachine()
    model.train()
    X, y = get_dataset()
    NUM_EPOCHS = 1000
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-5)
    criterion = torch.nn.MSELoss(reduction='mean')

    for epoch in range(NUM_EPOCHS):
        optimizer.zero_grad()
        y_pred = model(X)
        y_pred = y_pred.reshape(1000)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        print(f'Epoch:{epoch}, Loss:{loss.item()}')
    torch.save(model.state_dict(), 'model.h5')


def test():
    model = MyMachine()
    model.load_state_dict(torch.load("model.h5"))
    model.eval()
    X, y = get_dataset()

    with torch.no_grad():
        y_pred = model(X)
        print(r2_score(y, y_pred))


train()
test()

之后,我收到以下错误。

/home/mylab/Workspace/PyTorchScripts/helloworld.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  model.load_state_dict(torch.load("model.h5"))
Traceback (most recent call last):
  File "/home/mylab/Workspace/PyTorchScripts/helloworld.py", line 59, in <module>
    test()
  File "/home/mylab/Workspace/PyTorchScripts/helloworld.py", line 55, in test
    print(r2_score(y, y_pred))
  File "/home/mylab/.local/lib/python3.10/site-packages/sklearn/utils/_param_validation.py", line 216, in wrapper
    return func(*args, **kwargs)
  File "/home/mylab/.local/lib/python3.10/site-packages/sklearn/metrics/_regression.py", line 1281, in r2_score
    return _assemble_r2_explained_variance(
  File "/home/mylab/.local/lib/python3.10/site-packages/sklearn/metrics/_regression.py", line 935, in _assemble_r2_explained_variance
    output_scores = xp.ones([n_outputs], device=device, dtype=dtype)
  File "/home/mylab/.local/lib/python3.10/site-packages/sklearn/utils/_array_api.py", line 316, in wrapped_func
    _check_device_cpu(kwargs.pop("device", None))
  File "/home/mylab/.local/lib/python3.10/site-packages/sklearn/utils/_array_api.py", line 310, in _check_device_cpu
    raise ValueError(f"Unsupported device for NumPy: {device!r}")

根据

PyTorch
网站,我的安装应该会导致
PyTorch
在我的 GPU 上运行,但
SciKit-Learn
似乎指的是 CPU。

我该如何解决这个问题?

scikit-learn pytorch
1个回答
0
投票

Scikit-learn 仅处理 numpy 或 pandas 数组,这些数组与 CPU 绑定,不支持 GPU 操作。因此,当您将 PyTorch 张量传递给

r2_score()
函数时,scikit-learn 将它们视为 numpy 数组,如果它们仍在 GPU 上,则会导致错误。

您仍然可以将 PyTorch 模型和张量保留在 GPU 上以进行模型推理,但您只需使用

.cpu()
将它们移动到 CPU,并使用
.numpy()
将它们转换为 NumPy 数组(
r2_score()
)。

print(r2_score(y.cpu().numpy(), y_pred.cpu().numpy()))
© www.soinside.com 2019 - 2024. All rights reserved.