我正在开发一个使用人工神经网络的入场预测系统。我在训练模型时损失很大。以下是与数据集和模型相关的必要信息。
数据集:
for col in data.columns:
print(col,":",len(data[col].unique()), 'labels')
Label_encoding 和 Standardscaler:
columns_to_encode = ['college_name', 'seat_type', 'branch', 'City', 'gender', 'family_income_less_than_8_lakh', 'category']
label_encoder = LabelEncoder()
for column in columns_to_encode:
data[column] = label_encoder.fit_transform(data[column])
scaler = StandardScaler()
percentile_column = data['percentile'].values.reshape(-1, 1)
data['percentile_scaled'] = scaler.fit_transform(percentile_column)
训练_测试_分割:
X = data[['percentile_scaled','City']].values
y = data[['college_name','branch']].values
X_train, X_test, Y_train, Y_test = train_test_split(X, y, test_size=0.2, random_state=42)
型号:
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(2,), name='percentile_and_city'),
tf.keras.layers.Dense(64, activation='softmax'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(64, activation='softmax'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(2, activation='softmax', name='college_and_branch')
])
from keras.optimizers import Adam
learning_rate = 0.001
optimizer = Adam(learning_rate=learning_rate)
model.compile(optimizer=optimizer,
loss='categorical_crossentropy',
metrics=['accuracy'])
高损耗:
model.fit(X_train, Y_train, epochs=50, batch_size=32, validation_data=(X_test, Y_test))
Epoch 1/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5021 - loss: 128.4439 - val_accuracy: 0.9035 - val_loss: 129.5756
Epoch 2/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4981 - loss: 127.9602 - val_accuracy: 0.9035 - val_loss: 129.5739
Epoch 3/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5042 - loss: 128.0659 - val_accuracy: 0.9035 - val_loss: 129.0870
Epoch 4/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5022 - loss: 128.5717 - val_accuracy: 0.9035 - val_loss: 129.0783
Epoch 5/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.5011 - loss: 127.1210 - val_accuracy: 0.9035 - val_loss: 129.2309
Epoch 6/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4959 - loss: 126.6648 - val_accuracy: 0.9035 - val_loss: 129.1433
Epoch 7/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.4871 - loss: 127.7158 - val_accuracy: 0.9035 - val_loss: 129.4971
Epoch 8/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5018 - loss: 127.8886 - val_accuracy: 0.9035 - val_loss: 129.3699
Epoch 9/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.4968 - loss: 127.7575 - val_accuracy: 0.9035 - val_loss: 129.2882
Epoch 10/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5049 - loss: 127.2928 - val_accuracy: 0.9035 - val_loss: 129.1008
Epoch 11/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.4917 - loss: 128.2043 - val_accuracy: 0.9035 - val_loss: 129.0593
Epoch 12/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5068 - loss: 127.3596 - val_accuracy: 0.9035 - val_loss: 128.9710
Epoch 13/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5069 - loss: 127.2081 - val_accuracy: 0.9035 - val_loss: 129.0541
Epoch 14/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5026 - loss: 127.4382 - val_accuracy: 0.9035 - val_loss: 128.9952
Epoch 15/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4946 - loss: 128.3105 - val_accuracy: 0.9035 - val_loss: 129.4832
Epoch 16/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.4993 - loss: 128.5704 - val_accuracy: 0.9035 - val_loss: 129.8361
Epoch 17/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4984 - loss: 127.7996 - val_accuracy: 0.9035 - val_loss: 129.5564
Epoch 18/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4979 - loss: 127.7063 - val_accuracy: 0.9035 - val_loss: 129.2761
Epoch 19/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4953 - loss: 127.7529 - val_accuracy: 0.9035 - val_loss: 129.3995
Epoch 20/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.5041 - loss: 128.7030 - val_accuracy: 0.9035 - val_loss: 129.4727
Epoch 21/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.4971 - loss: 128.3986 - val_accuracy: 0.9035 - val_loss: 129.5111
Epoch 22/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4959 - loss: 128.6659 - val_accuracy: 0.9035 - val_loss: 129.5424
Epoch 23/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4985 - loss: 128.2918 - val_accuracy: 0.9035 - val_loss: 129.4988
Epoch 24/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.4999 - loss: 128.2642 - val_accuracy: 0.9035 - val_loss: 129.4657
Epoch 25/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5004 - loss: 127.7300 - val_accuracy: 0.9035 - val_loss: 129.5337
Epoch 26/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5014 - loss: 127.9406 - val_accuracy: 0.9035 - val_loss: 129.4309
Epoch 27/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4973 - loss: 128.0004 - val_accuracy: 0.9035 - val_loss: 129.7136
Epoch 28/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5021 - loss: 127.9145 - val_accuracy: 0.9035 - val_loss: 129.4570
Epoch 29/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4950 - loss: 127.6928 - val_accuracy: 0.9035 - val_loss: 129.5459
Epoch 30/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5025 - loss: 128.1482 - val_accuracy: 0.9035 - val_loss: 129.4936
Epoch 31/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4998 - loss: 127.3472 - val_accuracy: 0.9035 - val_loss: 129.4570
Epoch 32/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5053 - loss: 128.3870 - val_accuracy: 0.9035 - val_loss: 129.3508
Epoch 33/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4983 - loss: 127.1599 - val_accuracy: 0.9035 - val_loss: 129.3961
Epoch 34/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5103 - loss: 127.2617 - val_accuracy: 0.9035 - val_loss: 129.3125
Epoch 35/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4935 - loss: 127.6443 - val_accuracy: 0.9035 - val_loss: 129.2361
Epoch 36/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5016 - loss: 127.0884 - val_accuracy: 0.9035 - val_loss: 129.4727
Epoch 37/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4986 - loss: 128.3033 - val_accuracy: 0.9035 - val_loss: 129.1841
Epoch 38/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5092 - loss: 127.1143 - val_accuracy: 0.9035 - val_loss: 129.0852
Epoch 39/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5021 - loss: 127.5894 - val_accuracy: 0.9035 - val_loss: 129.5180
Epoch 40/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4977 - loss: 128.3102 - val_accuracy: 0.9035 - val_loss: 129.8432
Epoch 41/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5008 - loss: 128.1937 - val_accuracy: 0.9035 - val_loss: 129.7153
Epoch 42/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5034 - loss: 128.4547 - val_accuracy: 0.9035 - val_loss: 129.8221
Epoch 43/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4951 - loss: 128.1209 - val_accuracy: 0.9035 - val_loss: 130.0643
Epoch 44/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5059 - loss: 127.9155 - val_accuracy: 0.9035 - val_loss: 129.8502
Epoch 45/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4916 - loss: 128.6513 - val_accuracy: 0.9035 - val_loss: 129.9238
Epoch 46/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4958 - loss: 127.4559 - val_accuracy: 0.9035 - val_loss: 130.0151
Epoch 47/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4900 - loss: 128.3895 - val_accuracy: 0.9035 - val_loss: 129.9308
Epoch 48/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5023 - loss: 128.1325 - val_accuracy: 0.9035 - val_loss: 129.9027
Epoch 49/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4982 - loss: 128.1566 - val_accuracy: 0.9035 - val_loss: 130.2086
Epoch 50/50
671/671 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.5010 - loss: 129.0888 - val_accuracy: 0.9035 - val_loss: 130.0749
我需要良好的准确性,但我没有得到,而且损失也非常高。如何改进我的模型?
我检查了您的代码并进行了一些更改,准确率达到了 94%。我 将最初的 softmax 层替换为
ReLU
,因为通常使用 softmax
用于最后一层的多类分类。另外,我切换到
binary cross-entropy
用于损失函数而不是分类交叉熵,
更适合二分类任务,并添加了额外的层
到网络。
请参考这个要点