我完全按照 PyKan 文档上的代码进行操作,但是它一直给我这个错误:
description: 0%| | 0/20 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[44], line 9
6 def test_acc():
7 return torch.mean((torch.argmax(model(dataset['test_input']), dim=1) == dataset['test_label']).float())
----> 9 results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc), loss_fn=torch.nn.CrossEntropyLoss())
File ~\anaconda3\Lib\site-packages\kan\KAN.py:898, in KAN.train(self, dataset, opt, steps, log, lamb, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff, update_grid, grid_update_num, loss_fn, lr, stop_grid_update_step, batch, small_mag_threshold, small_reg_factor, metrics, sglr_avoid, save_fig, in_vars, out_vars, beta, save_fig_freq, img_folder, device)
895 test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
897 if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
--> 898 self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
900 if opt == "LBFGS":
901 optimizer.step(closure)
File ~\anaconda3\Lib\site-packages\kan\KAN.py:243, in KAN.update_grid_from_samples(self, x)
220 '''
221 update grid from samples
222
(...)
240 tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809])
241 '''
242 for l in range(self.depth):
--> 243 self.forward(x)
244 self.act_fun[l].update_grid_from_samples(self.acts[l])
File ~\anaconda3\Lib\site-packages\kan\KAN.py:311, in KAN.forward(self, x)
307 self.acts.append(x) # acts shape: (batch, width[l])
309 for l in range(self.depth):
--> 311 x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
313 if self.symbolic_enabled == True:
314 x_symbolic, postacts_symbolic = self.symbolic_fun[l](x)
File ~\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~\anaconda3\Lib\site-packages\kan\KANLayer.py:173, in KANLayer.forward(self, x)
171 preacts = x.permute(1, 0).clone().reshape(batch, self.out_dim, self.in_dim)
172 base = self.base_fun(x).permute(1, 0) # shape (batch, size)
--> 173 y = coef2curve(x_eval=x, grid=self.grid[self.weight_sharing], coef=self.coef[self.weight_sharing], k=self.k, device=self.device) # shape (size, batch)
174 y = y.permute(1, 0) # shape (batch, size)
175 postspline = y.clone().reshape(batch, self.out_dim, self.in_dim)
File ~\anaconda3\Lib\site-packages\kan\spline.py:100, in coef2curve(x_eval, grid, coef, k, device)
65 '''
66 converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis).
67
(...)
96 torch.Size([5, 100])
97 '''
98 # x_eval: (size, batch), grid: (size, grid), coef: (size, coef)
99 # coef: (size, coef), B_batch: (size, coef, batch), summer over coef
--> 100 y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k, device=device))
101 return y_eval
File ~\anaconda3\Lib\site-packages\torch\functional.py:385, in einsum(*args)
380 return einsum(equation, *_operands)
382 if len(operands) <= 2 or not opt_einsum.enabled:
383 # the path for contracting 0 or 1 time(s) is already optimized
384 # or the user has disabled using opt_einsum
--> 385 return _VF.einsum(equation, operands) # type: ignore[attr-defined]
387 path = None
388 if opt_einsum.is_available():
RuntimeError: expected scalar type Double but found Float
我该如何解决这个问题?
我将代码修改为这样:
from kan import KAN
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
import torch
import numpy as np
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dataset = {}
train_input, train_label = make_moons(
n_samples=1000, shuffle=True, noise=0.1, random_state=None)
test_input, test_label = make_moons(
n_samples=1000, shuffle=True, noise=0.1, random_state=None)
dataset['train_input'] = torch.from_numpy(train_input)
dataset['test_input'] = torch.from_numpy(test_input)
dataset['train_label'] = torch.from_numpy(train_label)
dataset['test_label'] = torch.from_numpy(test_label)
dataset['train_input'] = dataset['train_input'].to(device)
dataset['test_input'] = dataset['train_input'].to(device)
dataset['train_label'] = dataset['train_input'].to(device)
dataset['test_label'] = dataset['train_input'].to(device)
X = dataset['train_input']
y = dataset['train_label']
model = KAN(width=[2, 2], grid=3, k=3, device=device).double()
print(model.device)
print(dataset['train_input'].device)
print(dataset['test_input'].device)
print(dataset['train_label'].device)
print(dataset['test_label'].device)
def train_acc():
return torch.mean((torch.argmax(model(dataset['train_input']), dim=1) == dataset['train_label']))
def test_acc():
return torch.mean((torch.argmax(model(dataset['test_input']), dim=1) == dataset['test_label']))
results = model.train(dataset, opt="LBFGS", steps=20, metrics=(
train_acc, test_acc), loss_fn=torch.nn.CrossEntropyLoss())
但是,我不断收到此错误,我什至尝试通过手动设置设备(如上面的代码所示)来修复它,但没有任何效果。
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
description: 0%| | 0/20 [00:00<?, ?it/s]
Traceback (most recent call last):
File "c:\Users\kshit\OneDrive\Documents\IRIS\KANforIRIS.py", line 45, in <module>
results = model.train(dataset, opt="LBFGS", steps=20, metrics=(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\kshit\AppData\Local\Programs\Python\Python311\Lib\site-packages\kan\KAN.py", line 898, in train
self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
File "C:\Users\kshit\AppData\Local\Programs\Python\Python311\Lib\site-packages\kan\KAN.py", line 243, in update_grid_from_samples
self.forward(x)
File "C:\Users\kshit\AppData\Local\Programs\Python\Python311\Lib\site-packages\kan\KAN.py", line 311, in forward
x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
^^^^^^^^^^^^^^^^^^
File "C:\Users\kshit\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\kshit\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\kshit\AppData\Local\Programs\Python\Python311\Lib\site-packages\kan\KANLayer.py", line 170, in forward
x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, device=self.device)).reshape(batch, self.size).permute(1, 0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\kshit\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\functional.py", line 380, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
编辑:对于最初的简短错误消息表示抱歉。为了方便起见,我现在添加了完整的错误消息,并修复了原始代码的一些小问题,因此唯一的错误是我在这篇文章中给出的错误。
我也有类似的问题。这里的仓库提供了解决方案:https://github.com/KindXiaoming/pykan/blob/master/kan/KAN.py(第761行)。您需要将“设备”传递到方法签名中。
致以诚挚的问候