我正在尝试并且未能将参数传递给scikit中的自定义估算器。我想在gridsearch期间更改参数lr
。问题是lr
参数没有变化......
代码示例从here复制和更新
(原始代码对我不起作用)
GridSearchCV
与自定义估算器的任何完整工作示例,具有更改参数将不胜感激。
我使用ubuntu
0.20.2在scikit-learn
18.10
from sklearn.model_selection import GridSearchCV
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np
class MyClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, lr=0.1):
# Some code
print('lr:', lr)
return self
def fit(self, X, y):
# Some code
return self
def predict(self, X):
# Some code
return X % 3
params = {
'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)
x = np.arange(30)
y = np.concatenate((np.zeros(10), np.ones(10), np.ones(10) * 2))
gs.fit(x, y)
问候,马克
由于您在构造函数中打印,因此无法看到lr
值的更改。
如果我们在.fit()
函数内打印,我们可以看到lr
值的变化。它发生的原因是way创建了不同的估算副本。请参阅here以了解创建多个副本的过程。
from sklearn.model_selection import GridSearchCV
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np
class MyClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, lr=0):
# Some code
print('lr:', lr)
self.lr = lr
def fit(self, X, y):
# Some code
print('lr:', self.lr)
return self
def predict(self, X):
# Some code
return X % 3
params = {
'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)
x = np.arange(30)
y = np.concatenate((np.zeros(10), np.ones(10), np.ones(10) * 2))
gs.fit(x, y)
gs.predict(x)
输出:
lr: 0
lr: 0
lr: 0
lr: 0.1
lr: 0
lr: 0.1
lr: 0
lr: 0.1
lr: 0
lr: 0.1
lr: 0
lr: 0.5
lr: 0
lr: 0.5
lr: 0
lr: 0.5
lr: 0
lr: 0.5
lr: 0
lr: 0.7
lr: 0
lr: 0.7
lr: 0
lr: 0.7
lr: 0
lr: 0.7
lr: 0
lr: 0.1