Keras 类权重和 KeyError 问题

问题描述 投票:0回答:1

我预计我已经看到了这个问题:Keras class_weight error字典键/值指的是同样的问题,但解决方案似乎对我没有帮助。

使用这段代码,我只是添加了一些类权重,因为二元分类的数据非常不平衡:

from sklearn.utils.class_weight import compute_class_weight

# Split the data
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Scale the data
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)

# Calculate class weights
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
class_weights_dict = dict(zip(np.unique(y_train),class_weights))

print(f"Class Weights: {class_weights_dict}")

def create_model(input_shape):
    model = Sequential()
    
    # Increase number of neurons and add more layers
    model.add(Dense(128, activation='relu', input_shape=(input_shape,), kernel_regularizer=l2(0.01)))
    model.add(Dropout(0.4))
    
    model.add(Dense(64, activation='relu', kernel_regularizer=l2(0.01)))
    model.add(Dropout(0.4))
    
    model.add(Dense(32, activation='relu', kernel_regularizer=l2(0.01)))
    model.add(Dropout(0.4))
    
    # Output layer
    model.add(Dense(1, activation='sigmoid'))  # Binary classification
    
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    return model

# Create model and fit it
input_shape = X_train_scaled.shape[1]
model = create_model(input_shape)

# Train the model
    history = model.fit(
    X_train_scaled, y_train, 
    epochs=20, 
    batch_size=32, 
    validation_data=(X_val_scaled, y_val), 
    class_weight= 
    )
Class Weights: {0: 0.5029020103669545, 1: 86.64717674574005}

我收到此错误:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[69], line 7
      4 model = create_model(input_shape)
      6 # Train the model
----> 7 history = model.fit(
      8     X_train_scaled, y_train, 
      9     epochs=20, 
     10     batch_size=32, 
     11     validation_data=(X_val_scaled, y_val), 
     12     class_weight= class_weights
     13 )
     15 # Save model
     16 with open('model.pkl', 'wb') as file:

File ~/.local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File ~/.local/lib/python3.10/site-packages/pandas/core/series.py:1111, in Series.__getitem__(self, key)
   1108     return self._values[key]
   1110 elif key_is_scalar:
-> 1111     return self._get_value(key)
   1113 # Convert generator to list before going through hashable part
   1114 # (We will iterate through the generator there to check for slices)
   1115 if is_iterator(key):

File ~/.local/lib/python3.10/site-packages/pandas/core/series.py:1227, in Series._get_value(self, label, takeable)
   1224     return self._values[label]
   1226 # Similar to Index.get_value, but we do not fall back to positional
-> 1227 loc = self.index.get_loc(label)
   1229 if is_integer(loc):
   1230     return self._values[loc]

File ~/.local/lib/python3.10/site-packages/pandas/core/indexes/base.py:3809, in Index.get_loc(self, key)
   3804     if isinstance(casted_key, slice) or (
   3805         isinstance(casted_key, abc.Iterable)
   3806         and any(isinstance(x, slice) for x in casted_key)
   3807     ):
   3808         raise InvalidIndexError(key)
-> 3809     raise KeyError(key) from err
   3810 except TypeError:
   3811     # If we have a listlike key, _check_indexing_error will raise
   3812     #  InvalidIndexError. Otherwise we fall through and re-raise
   3813     #  the TypeError.
   3814     self._check_indexing_error(key)

KeyError: 2
python keras scikit-learn neural-network imbalanced-data
1个回答
0
投票

我在这里找到了解决方案1,它说转换 y_train 将解决问题

© www.soinside.com 2019 - 2024. All rights reserved.