Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

8.5 线性回归的实现

在Python中,你可以使用scikit-learn库或者statsmodels库来实现一元线性回归。

1. sklearn机器学习库linear_model模块

sklearn机器学习库linear_model模块下的LinearRegression类如下:

class sklearn.linear_model.LinearRegression(*, fit_intercept=True, copy_X=True, n_jobs=None, positive=False)[source]¶
类别名称含义
参数fit_intercept默认为True,是否计算模型的截距,如果设为False,是指数据是以原点为中心的,不会计算截距。
属性coef_线性回归问题的估计系数。
intercept_线性模型中的独立项,也就是截距
方法fit(X, y)训练模型(或称估计器、学习器)
predict(X)在训练后,使用模型预测
score(X, y)用来计算模型的精度
get_params()获得模型的参数

1.1 模型输入X和y

上述回归模型中使用XXyy的数据结构如下,XX包含多个样本,以及每个样本的属性,也就是自变量,和X的每个样本对应的就是我们的预测目标yy,也就是因变量。

在实际编程中,一般使用pd.DataFrame来表示XXyy

X&y

1.2 实现流程

针对于多元回归分析,其一般化的流程如下:

regression_steps

1)创建学习器,也就是初始化线性回归模型

from sklearn.linear_model import LinearRegression
model = LinearRegression()

也可以这样表述:

from sklearn import linear_model
model = linear_model.LinearRegression()

2)训练模型

model.fit(X, y)

3)生成预测结果

predicted_y = model.predict(X)

4)计算模型预测精度和拟合优度R2R^2

precision = model.score(X, y)
from sklearn.metrics import r2_score
r2 = r2_score(predicted_y, y)

5)生成汇总信息(summary)

print("系数(beta1): %s" %model.coef_)
print("截距(beta0): %.4f" %model.intercept_)
print("样本内(IS)训练集精度:%.2f" %precision)
print("拟合优度R-squared: %.2f" % r2)

2. 使用statsmodels库来实现一元线性回归

使用statsmodels库来实现一元线性回归也很简单。以下是一个示例代码:

import statsmodels.api as sm
import numpy as np

# 构造样本数据
x = np.array([1, 2, 3, 4, 5])
y = np.array([2, 3, 4, 5, 6])

# 添加常数项(截距)
x = sm.add_constant(x)

# 创建模型对象
model = sm.OLS(y, x)

# 拟合模型
result = model.fit()

# 打印模型摘要
print(result.summary())
                            OLS Regression Results                            
==============================================================================
Dep. Variable:                      y   R-squared:                       1.000
Model:                            OLS   Adj. R-squared:                  1.000
Method:                 Least Squares   F-statistic:                 3.042e+31
Date:                Sun, 05 May 2024   Prob (F-statistic):           1.31e-47
Time:                        21:00:55   Log-Likelihood:                 169.66
No. Observations:                   5   AIC:                            -335.3
Df Residuals:                       3   BIC:                            -336.1
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
const          1.0000   6.01e-16   1.66e+15      0.000       1.000       1.000
x1             1.0000   1.81e-16   5.52e+15      0.000       1.000       1.000
==============================================================================
Omnibus:                          nan   Durbin-Watson:                   0.400
Prob(Omnibus):                    nan   Jarque-Bera (JB):                0.770
Skew:                          -0.844   Prob(JB):                        0.680
Kurtosis:                       2.078   Cond. No.                         8.37
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
/Users/xhc/miniconda3/envs/d2l/lib/python3.11/site-packages/statsmodels/stats/stattools.py:74: ValueWarning: omni_normtest is not valid with less than 8 observations; 5 samples were given.
  warn("omni_normtest is not valid with less than 8 observations; %i "

参考

  1. https://www.jmp.com/en_us/statistics-knowledge-portal/what-is-regression.html

  2. 詹姆斯*斯托克,马克*沃森《计量经济学》第三版

  3. sklearn官网:链接