机器学习一百天-day3多元线性回归及虚拟变量陷阱分析 一,数据预处理 二,在训练集上训练多元线性回归模型 三,在测试集上预测结果 四,可视化
机器学习一百天-day3多元线性回归及虚拟变量陷阱分析
导入数据集
import pandas as pd import numpy as np dataset = pd.read_csv('D:\100Daysdatasets\50_Startups.csv') X = dataset.iloc[ : , :-1].values Y = dataset.iloc[ : , 4 ].values
数据集(head(5)):
R&D Spend Administration Marketing Spend State Profit
0 165349.20 136897.80 471784.10 New York 192261.83
1 162597.70 151377.59 443898.53 California 191792.06
2 153441.51 101145.55 407934.54 Florida 191050.39
3 144372.41 118671.85 383199.62 New York 182901.99
4 142107.34 91391.77 366168.42 Florida 166187.94
将类别数据数字化
from sklearn.preprocessing import LabelEncoder,OneHotEncoder labelencoder = LabelEncoder() onehotencoder = OneHotEncoder(categorical_features=[3]) X[ : , 3] = labelencoder.fit_transform(X[: , 3 ]) X = onehotencoder.fit_transform(X).toarray() #print(X)
先用labelencoder将字符串编码为0,1,2,再用onrhotencoder编码
此时的X :
[[0.0000000e+00 0.0000000e+00 1.0000000e+00 1.6534920e+05 1.3689780e+05 4.7178410e+05] [1.0000000e+00 0.0000000e+00 0.0000000e+00 1.6259770e+05 1.5137759e+05 4.4389853e+05] [0.0000000e+00 1.0000000e+00 0.0000000e+00 1.5344151e+05 1.0114555e+05 4.0793454e+05] [0.0000000e+00 0.0000000e+00 1.0000000e+00 1.4437241e+05 1.1867185e+05 3.8319962e+05] [0.0000000e+00 1.0000000e+00 0.0000000e+00 1.4210734e+05 9.1391770e+04 3.6616842e+05] [0.0000000e+00 0.0000000e+00 1.0000000e+00 1.3187690e+05 9.9814710e+04 3.6286136e+05] [1.0000000e+00 0.0000000e+00 0.0000000e+00 1.3461546e+05 1.4719887e+05 1.2771682e+05] [0.0000000e+00 1.0000000e+00 0.0000000e+00 1.3029813e+05 1.4553006e+05 3.2387668e+05] [0.0000000e+00 0.0000000e+00 1.0000000e+00 1.2054252e+05 1.4871895e+05 3.1161329e+05] [1.0000000e+00 0.0000000e+00 0.0000000e+00 1.2333488e+05 1.0867917e+05 3.0498162e+05] [0.0000000e+00 1.0000000e+00 0.0000000e+00 1.0191308e+05 1.1059411e+05 2.2916095e+05] [1.0000000e+00 0.0000000e+00 0.0000000e+00 1.0067196e+05 9.1790610e+04 2.4974455e+05] [0.0000000e+00 1.0000000e+00 0.0000000e+00 9.3863750e+04 1.2732038e+05 2.4983944e+05] [1.0000000e+00 0.0000000e+00 0.0000000e+00 9.1992390e+04 1.3549507e+05 2.5266493e+05] [0.0000000e+00 1.0000000e+00 0.0000000e+00 1.1994324e+05 1.5654742e+05 2.5651292e+05] [0.0000000e+00 0.0000000e+00 1.0000000e+00 1.1452361e+05 1.2261684e+05 2.6177623e+05] [1.0000000e+00 0.0000000e+00 0.0000000e+00 7.8013110e+04 1.2159755e+05 2.6434606e+05] [0.0000000e+00 0.0000000e+00 1.0000000e+00 9.4657160e+04 1.4507758e+05 2.8257431e+05] [0.0000000e+00 1.0000000e+00 0.0000000e+00 9.1749160e+04 1.1417579e+05 2.9491957e+05] [0.0000000e+00 0.0000000e+00 1.0000000e+00 8.6419700e+04 1.5351411e+05 0.0000000e+00] [1.0000000e+00 0.0000000e+00 0.0000000e+00 7.6253860e+04 1.1386730e+05 2.9866447e+05] [0.0000000e+00 0.0000000e+00 1.0000000e+00 7.8389470e+04 1.5377343e+05 2.9973729e+05] [0.0000000e+00 1.0000000e+00 0.0000000e+00 7.3994560e+04 1.2278275e+05 3.0331926e+05] [0.0000000e+00 1.0000000e+00 0.0000000e+00 6.7532530e+04 1.0575103e+05 3.0476873e+05] [0.0000000e+00 0.0000000e+00 1.0000000e+00 7.7044010e+04 9.9281340e+04 1.4057481e+05] [1.0000000e+00 0.0000000e+00 0.0000000e+00 6.4664710e+04 1.3955316e+05 1.3796262e+05] [0.0000000e+00 1.0000000e+00 0.0000000e+00 7.5328870e+04 1.4413598e+05 1.3405007e+05] [0.0000000e+00 0.0000000e+00 1.0000000e+00 7.2107600e+04 1.2786455e+05 3.5318381e+05] [0.0000000e+00 1.0000000e+00 0.0000000e+00 6.6051520e+04 1.8264556e+05 1.1814820e+05] [0.0000000e+00 0.0000000e+00 1.0000000e+00 6.5605480e+04 1.5303206e+05 1.0713838e+05] [0.0000000e+00 1.0000000e+00 0.0000000e+00 6.1994480e+04 1.1564128e+05 9.1131240e+04] [0.0000000e+00 0.0000000e+00 1.0000000e+00 6.1136380e+04 1.5270192e+05 8.8218230e+04] [1.0000000e+00 0.0000000e+00 0.0000000e+00 6.3408860e+04 1.2921961e+05 4.6085250e+04] [0.0000000e+00 1.0000000e+00 0.0000000e+00 5.5493950e+04 1.0305749e+05 2.1463481e+05] [1.0000000e+00 0.0000000e+00 0.0000000e+00 4.6426070e+04 1.5769392e+05 2.1079767e+05] [0.0000000e+00 0.0000000e+00 1.0000000e+00 4.6014020e+04 8.5047440e+04 2.0551764e+05] [0.0000000e+00 1.0000000e+00 0.0000000e+00 2.8663760e+04 1.2705621e+05 2.0112682e+05] [1.0000000e+00 0.0000000e+00 0.0000000e+00 4.4069950e+04 5.1283140e+04 1.9702942e+05] [0.0000000e+00 0.0000000e+00 1.0000000e+00 2.0229590e+04 6.5947930e+04 1.8526510e+05] [1.0000000e+00 0.0000000e+00 0.0000000e+00 3.8558510e+04 8.2982090e+04 1.7499930e+05] [1.0000000e+00 0.0000000e+00 0.0000000e+00 2.8754330e+04 1.1854605e+05 1.7279567e+05] [0.0000000e+00 1.0000000e+00 0.0000000e+00 2.7892920e+04 8.4710770e+04 1.6447071e+05] [1.0000000e+00 0.0000000e+00 0.0000000e+00 2.3640930e+04 9.6189630e+04 1.4800111e+05] [0.0000000e+00 0.0000000e+00 1.0000000e+00 1.5505730e+04 1.2738230e+05 3.5534170e+04] [1.0000000e+00 0.0000000e+00 0.0000000e+00 2.2177740e+04 1.5480614e+05 2.8334720e+04] [0.0000000e+00 0.0000000e+00 1.0000000e+00 1.0002300e+03 1.2415304e+05 1.9039300e+03] [0.0000000e+00 1.0000000e+00 0.0000000e+00 1.3154600e+03 1.1581621e+05 2.9711446e+05] [1.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 1.3542692e+05 0.0000000e+00] [0.0000000e+00 0.0000000e+00 1.0000000e+00 5.4205000e+02 5.1743150e+04 0.0000000e+00] [1.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 1.1698380e+05 4.5173060e+04]]
存在所谓的虚拟变量陷阱。意思就是:其实state只有3种取值,理论上2位二进制就可以表示,而这里用100,010,001三种表示。其实若把第一位统一去掉,变为00,10,01也是可以区分的。所以这里需要做一个处理:
躲避虚拟变量陷阱,把第一列去掉了
X = X[ : ,1:] print(X)
[[0.0000000e+00 1.0000000e+00 1.6534920e+05 1.3689780e+05 4.7178410e+05] [0.0000000e+00 0.0000000e+00 1.6259770e+05 1.5137759e+05 4.4389853e+05] [1.0000000e+00 0.0000000e+00 1.5344151e+05 1.0114555e+05 4.0793454e+05] [0.0000000e+00 1.0000000e+00 1.4437241e+05 1.1867185e+05 3.8319962e+05] [1.0000000e+00 0.0000000e+00 1.4210734e+05 9.1391770e+04 3.6616842e+05] [0.0000000e+00 1.0000000e+00 1.3187690e+05 9.9814710e+04 3.6286136e+05] [0.0000000e+00 0.0000000e+00 1.3461546e+05 1.4719887e+05 1.2771682e+05] [1.0000000e+00 0.0000000e+00 1.3029813e+05 1.4553006e+05 3.2387668e+05] [0.0000000e+00 1.0000000e+00 1.2054252e+05 1.4871895e+05 3.1161329e+05] [0.0000000e+00 0.0000000e+00 1.2333488e+05 1.0867917e+05 3.0498162e+05] [1.0000000e+00 0.0000000e+00 1.0191308e+05 1.1059411e+05 2.2916095e+05] [0.0000000e+00 0.0000000e+00 1.0067196e+05 9.1790610e+04 2.4974455e+05] [1.0000000e+00 0.0000000e+00 9.3863750e+04 1.2732038e+05 2.4983944e+05] [0.0000000e+00 0.0000000e+00 9.1992390e+04 1.3549507e+05 2.5266493e+05] [1.0000000e+00 0.0000000e+00 1.1994324e+05 1.5654742e+05 2.5651292e+05] [0.0000000e+00 1.0000000e+00 1.1452361e+05 1.2261684e+05 2.6177623e+05] [0.0000000e+00 0.0000000e+00 7.8013110e+04 1.2159755e+05 2.6434606e+05] [0.0000000e+00 1.0000000e+00 9.4657160e+04 1.4507758e+05 2.8257431e+05] [1.0000000e+00 0.0000000e+00 9.1749160e+04 1.1417579e+05 2.9491957e+05] [0.0000000e+00 1.0000000e+00 8.6419700e+04 1.5351411e+05 0.0000000e+00] [0.0000000e+00 0.0000000e+00 7.6253860e+04 1.1386730e+05 2.9866447e+05] [0.0000000e+00 1.0000000e+00 7.8389470e+04 1.5377343e+05 2.9973729e+05] [1.0000000e+00 0.0000000e+00 7.3994560e+04 1.2278275e+05 3.0331926e+05] [1.0000000e+00 0.0000000e+00 6.7532530e+04 1.0575103e+05 3.0476873e+05] [0.0000000e+00 1.0000000e+00 7.7044010e+04 9.9281340e+04 1.4057481e+05] [0.0000000e+00 0.0000000e+00 6.4664710e+04 1.3955316e+05 1.3796262e+05] [1.0000000e+00 0.0000000e+00 7.5328870e+04 1.4413598e+05 1.3405007e+05] [0.0000000e+00 1.0000000e+00 7.2107600e+04 1.2786455e+05 3.5318381e+05] [1.0000000e+00 0.0000000e+00 6.6051520e+04 1.8264556e+05 1.1814820e+05] [0.0000000e+00 1.0000000e+00 6.5605480e+04 1.5303206e+05 1.0713838e+05] [1.0000000e+00 0.0000000e+00 6.1994480e+04 1.1564128e+05 9.1131240e+04] [0.0000000e+00 1.0000000e+00 6.1136380e+04 1.5270192e+05 8.8218230e+04] [0.0000000e+00 0.0000000e+00 6.3408860e+04 1.2921961e+05 4.6085250e+04] [1.0000000e+00 0.0000000e+00 5.5493950e+04 1.0305749e+05 2.1463481e+05] [0.0000000e+00 0.0000000e+00 4.6426070e+04 1.5769392e+05 2.1079767e+05] [0.0000000e+00 1.0000000e+00 4.6014020e+04 8.5047440e+04 2.0551764e+05] [1.0000000e+00 0.0000000e+00 2.8663760e+04 1.2705621e+05 2.0112682e+05] [0.0000000e+00 0.0000000e+00 4.4069950e+04 5.1283140e+04 1.9702942e+05] [0.0000000e+00 1.0000000e+00 2.0229590e+04 6.5947930e+04 1.8526510e+05] [0.0000000e+00 0.0000000e+00 3.8558510e+04 8.2982090e+04 1.7499930e+05] [0.0000000e+00 0.0000000e+00 2.8754330e+04 1.1854605e+05 1.7279567e+05] [1.0000000e+00 0.0000000e+00 2.7892920e+04 8.4710770e+04 1.6447071e+05] [0.0000000e+00 0.0000000e+00 2.3640930e+04 9.6189630e+04 1.4800111e+05] [0.0000000e+00 1.0000000e+00 1.5505730e+04 1.2738230e+05 3.5534170e+04] [0.0000000e+00 0.0000000e+00 2.2177740e+04 1.5480614e+05 2.8334720e+04] [0.0000000e+00 1.0000000e+00 1.0002300e+03 1.2415304e+05 1.9039300e+03] [1.0000000e+00 0.0000000e+00 1.3154600e+03 1.1581621e+05 2.9711446e+05] [0.0000000e+00 0.0000000e+00 0.0000000e+00 1.3542692e+05 0.0000000e+00] [0.0000000e+00 1.0000000e+00 5.4205000e+02 5.1743150e+04 0.0000000e+00] [0.0000000e+00 0.0000000e+00 0.0000000e+00 1.1698380e+05 4.5173060e+04]]
拆分数据集
from sklearn.model_selection import train_test_split X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.2, random_state = 0)
二,在训练集上训练多元线性回归模型
from sklearn.linear_model import LinearRegression regressor = LinearRegression() regressor.fit(X_train,Y_train) #训练
三,在测试集上预测结果
#预测 Y_pred = regressor.predict(X_test)
也就是说其实多元线性回归与简单线性回归是基本一致的,区别就在于多元需要去掉虚拟变量陷阱,即去掉独热编码后的第一列,X = X[ : , 1: ]
四,可视化
#误差可视化 import pylab as plt plt.plot((Y_test.min(),Y_test.max()),(Y_test.min(),Y_test.max()),color = 'blue') plt.scatter(Y_test,Y_pred,color = 'red') plt.xlabel("Y_test");plt.ylabel("y_pred") plt.style.use('ggplot') plt.show()
参考链接中解释了虚拟变量陷阱的原因,遇到的时候再回来看
链接:https://blog.csdn.net/ssswill/article/details/86151933