使用sklearn LogisticRegression和RandomForest模型的Predict()总是预测少数类(1)

问题描述 投票:0回答:1

我正在构建一个 Logistic 回归模型,以使用仅包含 150 个观察值的数据集来预测交易是否有效 (1) 或无效 (0)。我的数据在两个类之间的分布如下:

  • 106 个观察结果为 0(无效)
  • 44 个观察结果为 1(有效)

我使用两个预测变量(均为数值)。尽管数据大部分为 0,但我的分类器仅对测试集中的每笔交易预测为 1,即使其中大部分应该为 0。分类器绝不会针对任何观察输出 0。

这是我的整个代码:

# Logistic Regression
import numpy as np
import pandas as pd
from pandas import Series, DataFrame

import scipy
from scipy.stats import spearmanr
from pylab import rcParams
import seaborn as sb
import matplotlib.pyplot as plt
import sklearn
from sklearn.preprocessing import scale
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn import preprocessing

address = "dummy_csv-150.csv"
trades = pd.read_csv(address)
trades.columns=['location','app','el','rp','rule1','rule2','rule3','validity','transactions']
trades.head()

trade_data = trades.ix[:,(1,8)].values
trade_data_names = ['app','transactions']

# set dependent/response variable
y = trades.ix[:,7].values

# center around the data mean
X= scale(trade_data)

LogReg = LogisticRegression()

LogReg.fit(X,y)
print(LogReg.score(X,y))

y_pred = LogReg.predict(X)

from sklearn.metrics import classification_report

print(classification_report(y,y_pred)) 

log_prediction = LogReg.predict_log_proba(
    [
       [2, 14],[3,1], [1, 503],[1, 122],[1, 101],[1, 610],[1, 2120],[3, 85],[3, 91],[2, 167],[2, 553],[2, 144]
    ])
prediction = LogReg.predict([[2, 14],[3,1], [1, 503],[1, 122],[1, 101],[1, 610],[1, 2120],[3, 85],[3, 91],[2, 167],[2, 553],[2, 144]])

我的模型定义为:

LogReg = LogisticRegression()  
LogReg.fit(X,y)

X 看起来像这样:

X = array([[1, 345],
       [1, 222],
       [1, 500],
       [2, 120]]....)

并且 Y 对于每个观测值只是 0 或 1。

标准化传递给模型的X是这样的:

[[-1.67177659  0.14396503]
 [-1.67177659 -0.14538932]
 [-1.67177659  0.50859856]
 [-1.67177659 -0.3853417 ]
 [-1.67177659 -0.43239119]
 [-1.67177659  0.743846  ]
 [-1.67177659  4.32195953]
 [ 0.95657805 -0.46062089]
 [ 0.95657805 -0.45591594]
 [ 0.95657805 -0.37828428]
 [ 0.95657805 -0.52884264]
 [ 0.95657805 -0.20420118]
 [ 0.95657805 -0.63705646]
 [ 0.95657805 -0.65587626]
 [ 0.95657805 -0.66763863]
 [-0.35759927 -0.25125067]
 [-0.35759927  0.60975496]
 [-0.35759927 -0.33358727]
 [-0.35759927 -0.20420118]
 [-0.35759927  1.37195666]
 [-0.35759927  0.27805607]
 [-0.35759927  0.09456307]
 [-0.35759927  0.03810368]
 [-0.35759927 -0.41121892]
 [-0.35759927 -0.64411389]
 [-0.35759927 -0.69586832]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.53825254]
 [ 0.95657805 -0.53354759]
 [ 0.95657805 -0.52413769]
 [ 0.95657805 -0.57589213]
 [ 0.95657805  0.03810368]
 [ 0.95657805 -0.66293368]
 [ 0.95657805  2.86107294]
 [-1.67177659  0.14396503]
 [-1.67177659 -0.14538932]
 [-1.67177659  0.50859856]
 [-1.67177659 -0.3853417 ]
 [-1.67177659 -0.43239119]
 [-1.67177659  0.743846  ]
 [-1.67177659  4.32195953]
 [ 0.95657805 -0.46062089]
 [ 0.95657805 -0.45591594]
 [ 0.95657805 -0.37828428]
 [ 0.95657805 -0.52884264]
 [ 0.95657805 -0.20420118]
 [ 0.95657805 -0.63705646]
 [ 0.95657805 -0.65587626]
 [ 0.95657805 -0.66763863]
 [-0.35759927 -0.25125067]
 [-0.35759927  0.60975496]
 [-0.35759927 -0.33358727]
 [-0.35759927 -0.20420118]
 [-0.35759927  1.37195666]
 [-0.35759927  0.27805607]
 [-0.35759927  0.09456307]
 [-0.35759927  0.03810368]
 [-0.35759927 -0.41121892]
 [-0.35759927 -0.64411389]
 [-0.35759927 -0.69586832]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.53825254]
 [ 0.95657805 -0.53354759]
 [ 0.95657805 -0.52413769]
 [ 0.95657805 -0.57589213]
 [ 0.95657805  0.03810368]
 [ 0.95657805 -0.66293368]
 [ 0.95657805  2.86107294]
 [-1.67177659  0.14396503]
 [-1.67177659 -0.14538932]
 [-1.67177659  0.50859856]
 [-1.67177659 -0.3853417 ]
 [-1.67177659 -0.43239119]
 [-1.67177659  0.743846  ]
 [-1.67177659  4.32195953]
 [ 0.95657805 -0.46062089]
 [ 0.95657805 -0.45591594]
 [ 0.95657805 -0.37828428]
 [ 0.95657805 -0.52884264]
 [ 0.95657805 -0.20420118]
 [ 0.95657805 -0.63705646]
 [ 0.95657805 -0.65587626]
 [ 0.95657805 -0.66763863]
 [-0.35759927 -0.25125067]
 [-0.35759927  0.60975496]
 [-0.35759927 -0.33358727]
 [-0.35759927 -0.20420118]
 [-0.35759927  1.37195666]
 [-0.35759927  0.27805607]
 [-0.35759927  0.09456307]
 [-0.35759927  0.03810368]
 [-0.35759927 -0.41121892]
 [-0.35759927 -0.64411389]
 [-0.35759927 -0.69586832]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.53825254]
 [ 0.95657805 -0.53354759]
 [ 0.95657805 -0.52413769]
 [ 0.95657805 -0.57589213]
 [ 0.95657805  0.03810368]
 [ 0.95657805 -0.66293368]
 [ 0.95657805  2.86107294]
 [-1.67177659  0.14396503]
 [-1.67177659 -0.14538932]
 [-1.67177659  0.50859856]
 [-1.67177659 -0.3853417 ]
 [-1.67177659 -0.43239119]
 [-1.67177659  0.743846  ]
 [-1.67177659  4.32195953]
 [ 0.95657805 -0.46062089]
 [ 0.95657805 -0.45591594]
 [ 0.95657805 -0.37828428]
 [ 0.95657805 -0.52884264]
 [ 0.95657805 -0.20420118]
 [ 0.95657805 -0.63705646]
 [ 0.95657805 -0.65587626]
 [ 0.95657805 -0.66763863]
 [-0.35759927 -0.25125067]
 [-0.35759927  0.60975496]
 [-0.35759927 -0.33358727]
 [-0.35759927 -0.20420118]
 [-0.35759927  1.37195666]
 [-0.35759927  0.27805607]
 [-0.35759927  0.09456307]
 [-0.35759927  0.03810368]
 [-0.35759927 -0.41121892]
 [-0.35759927 -0.64411389]
 [-0.35759927 -0.69586832]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.53825254]
 [ 0.95657805 -0.53354759]
 [ 0.95657805 -0.52413769]
 [ 0.95657805 -0.57589213]
 [ 0.95657805  0.03810368]
 [ 0.95657805 -0.66293368]
 [ 0.95657805  2.86107294]
 [-0.35759927  0.60975496]
 [-0.35759927 -0.33358727]
 [-0.35759927 -0.20420118]
 [-0.35759927  1.37195666]
 [-0.35759927  0.27805607]
 [-0.35759927  0.09456307]
 [-0.35759927  0.03810368]]

Y 是:

[0 0 0 0 0 0 1 1 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 1 0 0
 0 0 0 0 1 1 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0
 0 0 1 1 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0
 1 1 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 1 0 0 0 1 0 0 0]

模型指标是:

             precision    recall  f1-score   support

          0       0.78      1.00      0.88        98
          1       1.00      0.43      0.60        49

avg / total       0.85      0.81      0.78       147

得分为 0.80

当我运行 model.predict_log_proba(test_data) 时,我得到的概率区间如下所示:

array([[ -1.10164032e+01,  -1.64301095e-05],
       [ -2.06326947e+00,  -1.35863187e-01],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00]])

我的测试集是,除了 2 之外的所有测试集都应该是 0,但它们都被分类为 1。每个测试集都会发生这种情况,即使是那些具有模型训练值的测试集。

[2, 14],[3,1], [1, 503],[1, 122],[1, 101],[1, 610],[1, 2120],[3, 85],[3, 91],[2, 167],[2, 553],[2, 144]

我在这里发现了一个类似的问题:https://stats.stackexchange.com/questions/168929/logistic-regression-is-predicting-all-1-and-no-0但是在这个问题中,问题似乎是因为数据大部分都是 1,所以模型输出 1 是有道理的。我的情况正好相反,因为训练数据大部分都是 0,但出于某种原因,我的模型总是输出 1,尽管 1 相对较少。我还尝试了随机森林分类器来查看模型是否错误,但同样的事情发生了。也许这是我的数据,但我不知道它有什么问题,因为它满足所有假设。

可能出了什么问题?数据满足逻辑模型的所有假设(两个预测变量都是独立的,输出是二进制的,没有丢失数据点)。

python machine-learning scikit-learn logistic-regression
1个回答
1
投票

您没有缩放您的

test
数据。 当您这样做时,缩放训练数据是正确的:

X= scale(trade_data)

训练模型后,您无需对测试数据执行相同的操作:

log_prediction = LogReg.predict_log_proba(
[
   [2, 14],[3,1], [1, 503],[1, 122],[1, 101],[1, 610],[1, 2120],[3, 85],[3, 91],[2, 167],[2, 553],[2, 144]
])

模型的系数是根据标准化输入构建的。 您的测试数据未标准化。 模型的任何正系数都将乘以一个巨大的数字,因为您的数据未缩放,可能导致您的预测值全部为 1。

一般规则是,您在训练集上所做的任何转换也应该在测试集上进行。 您还应该对训练集和测试集应用相同的转换。 而不是:

X = scale(trade_data)

您应该根据训练数据创建一个缩放器,如下所示:

scaler = StandardScaler().fit(trade_date)
X = scaler.transform(trade_data)

然后将该缩放器应用到您的

test
数据:

scaled_test = scaler.transform(test_x)
© www.soinside.com 2019 - 2024. All rights reserved.