摘要:一元线性回归代码与多元线性回归代码
从之前的两个代码来看,机器学习的大致流程已经明了,也就是:
(1)Look at the big picture
(2)Get the data
(3)Discover and visualize the data to gain insights
(4)Prepared the data for machine learning algorithms
(5)select a model and train it
(6)Fine-tune ur model
(7)Present ur solution
(8)Launch, monitor and maintain ur system
步骤没有严格地程序性,例如对于可视化来说可以穿插在任何一个步骤中。这几步在之前的理解中不断地被强调,一味地重复代码也不会在加深理解。总之,目前来看机器学习难点如下:
(1)如何获取数据(目前来说不是难点);(2)数据的预处理(这一块设计到的问题有点多,且需要对pandas和numpy非常熟悉);(3)模型选择与优化(机器学习的魅力所在,之前都是一味的调用sklearn中的方法,比较死板,后续会把所有模型全部手动实现,理解算法后才能更好的做选择以及优化模型);(4)模型的评估(每一类模型有每一类模型不同的评估方式);(5)数据可视化(matplotlib的熟练运用)
预处理与模型评估会单独列出来写理解,模型的选择优化与可视化会贯穿整个机器学习学习的过程。
下面是一元线性模型与多元线性模型的代码:
数据集:
,TV,radio,newspaper,sales1,230.1,37.8,69.2,22.1
2,44.5,39.3,45.1,10.4
3,17.2,45.9,69.3,9.3
4,151.5,41.3,58.5,18.5
5,180.8,10.8,58.4,12.9
6,8.7,48.9,75,7.2
7,57.5,32.8,23.5,11.8
8,120.2,19.6,11.6,13.2
9,8.6,2.1,1,4.8
10,199.8,2.6,21.2,10.6
11,66.1,5.8,24.2,8.6
12,214.7,24,4,17.4
13,23.8,35.1,65.9,9.2
14,97.5,7.6,7.2,9.7
15,204.1,32.9,46,19
16,195.4,47.7,52.9,22.4
17,67.8,36.6,114,12.5
18,281.4,39.6,55.8,24.4
19,69.2,20.5,18.3,11.3
20,147.3,23.9,19.1,14.6
21,218.4,27.7,53.4,18
22,237.4,5.1,23.5,12.5
23,13.2,15.9,49.6,5.6
24,228.3,16.9,26.2,15.5
25,62.3,12.6,18.3,9.7
26,262.9,3.5,19.5,12
27,142.9,29.3,12.6,15
28,240.1,16.7,22.9,15.9
29,248.8,27.1,22.9,18.9
30,70.6,16,40.8,10.5
31,292.9,28.3,43.2,21.4
32,112.9,17.4,38.6,11.9
33,97.2,1.5,30,9.6
34,265.6,20,0.3,17.4
35,95.7,1.4,7.4,9.5
36,290.7,4.1,8.5,12.8
37,266.9,43.8,5,25.4
38,74.7,49.4,45.7,14.7
39,43.1,26.7,35.1,10.1
40,228,37.7,32,21.5
41,202.5,22.3,31.6,16.6
42,177,33.4,38.7,17.1
43,293.6,27.7,1.8,20.7
44,206.9,8.4,26.4,12.9
45,25.1,25.7,43.3,8.5
46,175.1,22.5,31.5,14.9
47,89.7,9.9,35.7,10.6
48,239.9,41.5,18.5,23.2
49,227.2,15.8,49.9,14.8
50,66.9,11.7,36.8,9.7
51,199.8,3.1,34.6,11.4
52,100.4,9.6,3.6,10.7
53,216.4,41.7,39.6,22.6
54,182.6,46.2,58.7,21.2
55,262.7,28.8,15.9,20.2
56,198.9,49.4,60,23.7
57,7.3,28.1,41.4,5.5
58,136.2,19.2,16.6,13.2
59,210.8,49.6,37.7,23.8
60,210.7,29.5,9.3,18.4
61,53.5,2,21.4,8.1
62,261.3,42.7,54.7,24.2
63,239.3,15.5,27.3,15.7
64,102.7,29.6,8.4,14
65,131.1,42.8,28.9,18
66,69,9.3,0.9,9.3
67,31.5,24.6,2.2,9.5
68,139.3,14.5,10.2,13.4
69,237.4,27.5,11,18.9
70,216.8,43.9,27.2,22.3
71,199.1,30.6,38.7,18.3
72,109.8,14.3,31.7,12.4
73,26.8,33,19.3,8.8
74,129.4,5.7,31.3,11
75,213.4,24.6,13.1,17
76,16.9,43.7,89.4,8.7
77,27.5,1.6,20.7,6.9
78,120.5,28.5,14.2,14.2
79,5.4,29.9,9.4,5.3
80,116,7.7,23.1,11
81,76.4,26.7,22.3,11.8
82,239.8,4.1,36.9,12.3
83,75.3,20.3,32.5,11.3
84,68.4,44.5,35.6,13.6
85,213.5,43,33.8,21.7
86,193.2,18.4,65.7,15.2
87,76.3,27.5,16,12
88,110.7,40.6,63.2,16
89,88.3,25.5,73.4,12.9
90,109.8,47.8,51.4,16.7
91,134.3,4.9,9.3,11.2
92,28.6,1.5,33,7.3
93,217.7,33.5,59,19.4
94,250.9,36.5,72.3,22.2
95,107.4,14,10.9,11.5
96,163.3,31.6,52.9,16.9
97,197.6,3.5,5.9,11.7
98,184.9,21,22,15.5
99,289.7,42.3,51.2,25.4
100,135.2,41.7,45.9,17.2
101,222.4,4.3,49.8,11.7
102,296.4,36.3,100.9,23.8
103,280.2,10.1,21.4,14.8
104,187.9,17.2,17.9,14.7
105,238.2,34.3,5.3,20.7
106,137.9,46.4,59,19.2
107,25,11,29.7,7.2
108,90.4,0.3,23.2,8.7
109,13.1,0.4,25.6,5.3
110,255.4,26.9,5.5,19.8
111,225.8,8.2,56.5,13.4
112,241.7,38,23.2,21.8
113,175.7,15.4,2.4,14.1
114,209.6,20.6,10.7,15.9
115,78.2,46.8,34.5,14.6
116,75.1,35,52.7,12.6
117,139.2,14.3,25.6,12.2
118,76.4,0.8,14.8,9.4
119,125.7,36.9,79.2,15.9
120,19.4,16,22.3,6.6
121,141.3,26.8,46.2,15.5
122,18.8,21.7,50.4,7
123,224,2.4,15.6,11.6
124,123.1,34.6,12.4,15.2
125,229.5,32.3,74.2,19.7
126,87.2,11.8,25.9,10.6
127,7.8,38.9,50.6,6.6
128,80.2,0,9.2,8.8
129,220.3,49,3.2,24.7
130,59.6,12,43.1,9.7
131,0.7,39.6,8.7,1.6
132,265.2,2.9,43,12.7
133,8.4,27.2,2.1,5.7
134,219.8,33.5,45.1,19.6
135,36.9,38.6,65.6,10.8
136,48.3,47,8.5,11.6
137,25.6,39,9.3,9.5
138,273.7,28.9,59.7,20.8
139,43,25.9,20.5,9.6
140,184.9,43.9,1.7,20.7
141,73.4,17,12.9,10.9
142,193.7,35.4,75.6,19.2
143,220.5,33.2,37.9,20.1
144,104.6,5.7,34.4,10.4
145,96.2,14.8,38.9,11.4
146,140.3,1.9,9,10.3
147,240.1,7.3,8.7,13.2
148,243.2,49,44.3,25.4
149,38,40.3,11.9,10.9
150,44.7,25.8,20.6,10.1
151,280.7,13.9,37,16.1
152,121,8.4,48.7,11.6
153,197.6,23.3,14.2,16.6
154,171.3,39.7,37.7,19
155,187.8,21.1,9.5,15.6
156,4.1,11.6,5.7,3.2
157,93.9,43.5,50.5,15.3
158,149.8,1.3,24.3,10.1
159,11.7,36.9,45.2,7.3
160,131.7,18.4,34.6,12.9
161,172.5,18.1,30.7,14.4
162,85.7,35.8,49.3,13.3
163,188.4,18.1,25.6,14.9
164,163.5,36.8,7.4,18
165,117.2,14.7,5.4,11.9
166,234.5,3.4,84.8,11.9
167,17.9,37.6,21.6,8
168,206.8,5.2,19.4,12.2
169,215.4,23.6,57.6,17.1
170,284.3,10.6,6.4,15
171,50,11.6,18.4,8.4
172,164.5,20.9,47.4,14.5
173,19.6,20.1,17,7.6
174,168.4,7.1,12.8,11.7
175,222.4,3.4,13.1,11.5
176,276.9,48.9,41.8,27
177,248.4,30.2,20.3,20.2
178,170.2,7.8,35.2,11.7
179,276.7,2.3,23.7,11.8
180,165.6,10,17.6,12.6
181,156.6,2.6,8.3,10.5
182,218.5,5.4,27.4,12.2
183,56.2,5.7,29.7,8.7
184,287.6,43,71.8,26.2
185,253.8,21.3,30,17.6
186,205,45.1,19.6,22.6
187,139.5,2.1,26.6,10.3
188,191.1,28.7,18.2,17.3
189,286,13.9,3.7,15.9
190,18.7,12.1,23.4,6.7
191,39.5,41.1,5.8,10.8
192,75.5,10.8,6,9.9
193,17.2,4.1,31.6,5.9
194,166.8,42,3.6,19.6
195,149.7,35.6,6,17.3
196,38.2,3.7,13.8,7.6
197,94.2,4.9,8.1,9.7
198,177,9.3,6.4,12.8
199,283.6,42,66.2,25.5
200,232.1,8.6,8.7,13.4
(不能传文件,只好全部复制粘贴过来了)
代码:
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import numpy as np
import pandas as pd
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
# In[3]:
data = pd.read_csv('Advertising.csv', index_col=0)
data
# In[4]:
data = data.dropna(how='any')
data = data.drop_duplicates() # 去除重复值
data
# In[5]:
data.corr()['sales']
# In[6]:
# 绘图函数
def figure(title:str, *datalist:tuple):
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(figsize=(20, 16), facecolor='gray')
for v in datalist:
plt.plot(v[0], '-', label=v[1], linewidth=2)
plt.plot(v[0], 'o')
plt.title(title, fontsize=20)
plt.legend(fontsize=16)
plt.grid()
plt.show()
# In[8]:
# 一元回归
# 从上面结果看出第一列与结果的线性相关性最大,选择第一列做一元线性回归
x = np.array(data.iloc[:, :1])
y = np.array(data.iloc[:, -1:])
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=1)
lr = LinearRegression()
lr.fit(x_train, y_train)
y_train_pred = lr.predict(x_train)
y_test_pred = lr.predict(x_test)
print("在训练集上的均方误差为:{}".format(mean_squared_error(y_train, y_train_pred)))
print("在测试集上的均方误差为:{}".format(mean_squared_error(y_test, y_test_pred)))
print("在训练集上的决定系数为:{}".format(lr.score(x_train, y_train)))
print("在测试集上的决定系数为:{}".format(lr.score(x_test, y_test)))
figure("预测值与真实值图模型的$R^2={:.4f}$".format(lr.score(x_test, y_test)), (y_test, "真实值"), (y_test_pred, "预测值"))
print("线性回归模型的系数为:\nw = {};\nb = {}".format(lr.coef_, lr.intercept_))
# In[29]:
# 一元线性回归的可视化
import matplotlib.pyplot as plt
plt.figure(figsize=(16, 8))
plt.scatter(x_train, y_train, label="训练集")
plt.plot(x_train, y_train_pred, '-', label = "预测训练集", linewidth=2, color='yellow')
plt.legend(fontsize=20)
plt.figure(figsize=(16, 8))
plt.scatter(x_test, y_test, label="测试集")
plt.plot(x_test, y_test_pred, '-', label = "预测测试集", linewidth=2, color='yellow')
plt.legend(fontsize=20)
# In[30]:
# 多元回归
x = np.array(data.iloc[:, :-1])
y = np.array(data.iloc[:, -1:])
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=1)
lr = LinearRegression()
lr.fit(x_train, y_train)
y_train_pred = lr.predict(x_train)
y_test_pred = lr.predict(x_test)
print("在训练集上的均方误差为:{}".format(mean_squared_error(y_train, y_train_pred)))
print("在测试集上的均方误差为:{}".format(mean_squared_error(y_test, y_test_pred)))
print("在训练集上的决定系数为:{}".format(lr.score(x_train, y_train)))
print("在测试集上的决定系数为:{}".format(lr.score(x_test, y_test)))
figure("预测值与真实值图模型的$R^2={:.4f}$".format(lr.score(x_test, y_test)), (y_test, "真实值"), (y_test_pred, "预测值"))
print("线性回归模型的系数为:\nw = {};\nb = {}".format(lr.coef_, lr.intercept_))