在将值传递给KERAS Classifier
时捕获属性错误from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from sklearn.model_selection import cross_val_score
from scikeras.wrappers import KerasClassifier
def create_model():
model = Sequential([
Dense(32,input_dim=16,kernel_initializer='normal',activation='relu'),
Dense(16,kernel_initializer='normal',activation='relu'),
Dense(1,kernel_initializer='normal',activation='sigmoid')
])
model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
return model
estimator = KerasClassifier(model=create_model,epochs=100,verbose=0)
cv_scores = cross_val_score(estimator, all_features, all_classes, cv=10)
print("Mean cross-validation accuracy:", cv_scores.mean())
将值超过cross_val_score(估算器,all_features,all_classes,cv = 10)。不确定输入字段是否有任何更改
i我在这里获取属性错误。
程序错误消息:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last) Cell In[38], line 18
13 return model
16 estimator = KerasClassifier(model=create_model,epochs=100,verbose=0)
---> 18 cv_scores = cross_val_score(estimator, all_features, all_classes, cv=10)
19 print("Mean cross-validation accuracy:", cv_scores.mean())
File ~\anaconda3\Lib\site-packages\sklearn\utils\_param_validation.py:216, in validate_params.<locals>.decorator.<locals>.wrapper(*args,
**kwargs)
210 try:
211 with config_context(
212 skip_parameter_validation=(
213 prefer_skip_nested_validation or global_skip_validation
214 )
215 ):
--> 216 return func(*args, **kwargs)
217 except InvalidParameterError as e:
218 # When the function is just a wrapper around an estimator, we allow
219 # the function to delegate validation to the estimator, but we replace
220 # the name of the estimator by the name of the function in the error
221 # message to avoid confusion.
222 msg = re.sub(
223 r"parameter of \w+ must be",
224 f"parameter of {func.__qualname__} must be",
225 str(e),
226 )
File ~\anaconda3\Lib\site-packages\sklearn\model_selection\_validation.py:684, in cross_val_score(estimator, X, y, groups, scoring, cv, n_jobs, verbose, params, pre_dispatch, error_score)
681 # To ensure multimetric format is not supported
682 scorer = check_scoring(estimator, scoring=scoring)
--> 684 cv_results = cross_validate(
685 estimator=estimator,
686 X=X,
687 y=y,
688 groups=groups,
689 scoring={"score": scorer},
690 cv=cv,
691 n_jobs=n_jobs,
692 verbose=verbose,
693 params=params,
694 pre_dispatch=pre_dispatch,
695 error_score=error_score,
696 )
697 return cv_results["test_score"]
File ~\anaconda3\Lib\site-packages\sklearn\utils\_param_validation.py:216, in validate_params.<locals>.decorator.<locals>.wrapper(*args,
**kwargs)
210 try:
211 with config_context(
212 skip_parameter_validation=(
213 prefer_skip_nested_validation or global_skip_validation
214 )
215 ):
--> 216 return func(*args, **kwargs)
217 except InvalidParameterError as e:
218 # When the function is just a wrapper around an estimator, we allow
219 # the function to delegate validation to the estimator, but we replace
220 # the name of the estimator by the name of the function in the error
221 # message to avoid confusion.
222 msg = re.sub(
223 r"parameter of \w+ must be",
224 f"parameter of {func.__qualname__} must be",
225 str(e),
226 )
File ~\anaconda3\Lib\site-packages\sklearn\model_selection\_validation.py:347, in cross_validate(estimator, X, y, groups, scoring, cv, n_jobs, verbose, params, pre_dispatch, return_train_score, return_estimator, return_indices, error_score)
345 X, y = indexable(X, y)
346 params = {} if params is None else params
--> 347 cv = check_cv(cv, y, classifier=is_classifier(estimator))
349 scorers = check_scoring(
350 estimator, scoring=scoring, raise_exc=(error_score == "raise")
351 )
353 if _routing_enabled():
354 # For estimators, a MetadataRouter is created in get_metadata_routing
355 # methods. For these router methods, we create the router to use
356 # `process_routing` on it.
File ~\anaconda3\Lib\site-packages\sklearn\base.py:1237, in is_classifier(estimator) 1230 warnings.warn( 1231 f"passing a class to {print(inspect.stack()[0][3])} is deprecated and " 1232 "will be removed in 1.8. Use an instance of the class instead.", 1233 FutureWarning, 1234 ) 1235 return getattr(estimator, "_estimator_type", None) == "classifier"
-> 1237 return get_tags(estimator).estimator_type == "classifier"
File ~\anaconda3\Lib\site-packages\sklearn\utils\_tags.py:430, in get_tags(estimator)
428 for klass in reversed(type(estimator).mro()):
429 if "__sklearn_tags__" in vars(klass):
--> 430 sklearn_tags_provider[klass] = klass.__sklearn_tags__(estimator) # type: ignore[attr-defined]
431 class_order.append(klass)
432 elif "_more_tags" in vars(klass):
File ~\anaconda3\Lib\site-packages\sklearn\base.py:540, in ClassifierMixin.__sklearn_tags__(self)
539 def __sklearn_tags__(self):
--> 540 tags = super().__sklearn_tags__()
541 tags.estimator_type = "classifier"
542 tags.classifier_tags = ClassifierTags()