Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

案例:股票收益率预测—RNN

接上一节内容,本节使用长短期记忆模型(RNN)来回归股票收益率序列。

import pandas as pd
X_train = pd.read_csv('datasets/X_train.csv', index_col=0)
X_test = pd.read_csv('datasets/X_test.csv', index_col=0, date_format='%Y%m%d')
y_train = pd.read_csv('datasets/y_train.csv', index_col=0)
y_test = pd.read_csv('datasets/y_test.csv', index_col=0)
X_train.index = pd.to_datetime(X_train.index)
X_test.index = pd.to_datetime(X_test.index)
y_train.index = pd.to_datetime(y_train.index)
y_test.index = pd.to_datetime(y_test.index)

7. 深度学习模型——RNN

现在让我们为 RNN 模型准备数据集。我们需要所有输入变量和输出变量的数组形式的数据。

RNN 背后的逻辑是,数据取自前 nn 天(当天相关资产的所有其他特征数据和 MSFT 的滞后变量),我们尝试预测第 n+1n+1 天。然后,我们将窗口移动一天,再次预测第二天。我们对整个数据集进行这样的重复(当然是分批进行)。

初始化模型的训练集和测试集:

import numpy as np
time_steps = 10                    # 窗口长度
feature_n = X_train.shape[1]
# 单个样本的大小为 seq_len * feature_n
X_train_RNN = np.zeros((X_train.shape[0], time_steps, feature_n))  
y_train_RNN = np.zeros((X_train.shape[0], feature_n))  
X_test_RNN = np.zeros((X_test.shape[0], time_steps, feature_n))  
y_test_RNN = np.zeros((X_test.shape[0], feature_n))  
X_train_RNN.shape, y_train_RNN.shape
((540, 10, 11), (540, 11))
for i in range(X_train.shape[0]-time_steps):
    X_train_RNN[i, :, :] = X_train.iloc[i:i+time_steps,:]
    y_train_RNN[i] = y_train.iloc[i+time_steps]
for i in range(X_test.shape[0]-time_steps):
    X_test_RNN[i, :, :] = X_test.iloc[i:i+time_steps,:]
    y_test_RNN[i] = y_test.iloc[i+time_steps]

调用keras库的RNN模型

from keras.models import Sequential 
from keras.layers import SimpleRNN, Dense
from keras.optimizers.legacy import SGD, Adam 
input_shape=(time_steps, feature_n)
output_units = feature_n      
model = Sequential()
model.add(SimpleRNN(50, input_shape=(time_steps, feature_n)))
model.add(Dense(feature_n))
model.compile(loss='mean_squared_error', optimizer='adam')

下面开始训练模型,

model_fit = model.fit(X_train_RNN, y_train_RNN,
                      validation_data=(X_test_RNN,y_test_RNN),
                      epochs=100,
                      batch_size=50,
                      verbose=1,
                      shuffle=False)
Epoch 1/100
11/11 [==============================] - 0s 9ms/step - loss: 0.0123 - val_loss: 0.0065
Epoch 2/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0053 - val_loss: 0.0042
Epoch 3/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0035 - val_loss: 0.0032
Epoch 4/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0027 - val_loss: 0.0027
Epoch 5/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0022 - val_loss: 0.0024
Epoch 6/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0019 - val_loss: 0.0022
Epoch 7/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0018 - val_loss: 0.0020
Epoch 8/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0017 - val_loss: 0.0019
Epoch 9/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0016 - val_loss: 0.0018
Epoch 10/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0015 - val_loss: 0.0018
Epoch 11/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0015 - val_loss: 0.0017
Epoch 12/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0014 - val_loss: 0.0017
Epoch 13/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0014 - val_loss: 0.0016
Epoch 14/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0013 - val_loss: 0.0016
Epoch 15/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0013 - val_loss: 0.0016
Epoch 16/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0013 - val_loss: 0.0015
Epoch 17/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0013 - val_loss: 0.0015
Epoch 18/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0013 - val_loss: 0.0015
Epoch 19/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0013 - val_loss: 0.0015
Epoch 20/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0012 - val_loss: 0.0015
Epoch 21/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0012 - val_loss: 0.0015
Epoch 22/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0012 - val_loss: 0.0015
Epoch 23/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0012 - val_loss: 0.0015
Epoch 24/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0012 - val_loss: 0.0014
Epoch 25/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0012 - val_loss: 0.0014
Epoch 26/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0012 - val_loss: 0.0014
Epoch 27/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0012 - val_loss: 0.0014
Epoch 28/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0012 - val_loss: 0.0014
Epoch 29/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0012 - val_loss: 0.0014
Epoch 30/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0012 - val_loss: 0.0014
Epoch 31/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0012 - val_loss: 0.0014
Epoch 32/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0012 - val_loss: 0.0014
Epoch 33/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 34/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 35/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 36/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 37/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 38/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 39/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 40/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 41/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 42/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 43/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 44/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 45/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 46/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 47/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 48/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 49/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 50/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 51/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 52/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 53/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 54/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 55/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 56/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 57/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 58/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 59/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 60/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 61/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 62/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 63/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 64/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 65/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 66/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 67/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 68/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 69/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 70/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 71/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0014
Epoch 72/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 73/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 74/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 75/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 76/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 77/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 78/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 79/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 80/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 81/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 82/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 83/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 84/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 85/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 86/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 87/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 88/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 89/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 90/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0011 - val_loss: 0.0015
Epoch 91/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0010 - val_loss: 0.0015
Epoch 92/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0010 - val_loss: 0.0015
Epoch 93/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0010 - val_loss: 0.0015
Epoch 94/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0010 - val_loss: 0.0015
Epoch 95/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0010 - val_loss: 0.0015
Epoch 96/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0010 - val_loss: 0.0015
Epoch 97/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0010 - val_loss: 0.0015
Epoch 98/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0010 - val_loss: 0.0015
Epoch 99/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0010 - val_loss: 0.0015
Epoch 100/100
11/11 [==============================] - 0s 2ms/step - loss: 0.0010 - val_loss: 0.0015

现在,我们用数据拟合 RNN 模型,同时在训练集和测试集中观察模型性能指标随时间的变化:

import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'Songti SC'  # 设置字体为宋体
plt.rcParams['axes.unicode_minus'] = False  # 解决负号“-”显示异常
from matplotlib_inline import backend_inline
backend_inline.set_matplotlib_formats('svg') 

plt.figure(figsize=(12,4))
plt.plot(model_fit.history['loss'], label='train loss', )
plt.plot(model_fit.history['val_loss'], '--',label='test loss',)
plt.title('RNN model performance')
plt.legend(loc='lower left')
plt.show()
Loading...

随着训练的轮数增加,训练集和测试集的误差在同步下降。

8. 模型对比

train_results, test_results, names = [], [], []
# LR
train_results.append(0.001146)
test_results.append(0.00129)

# ARIMA
train_results.append(0.001139)
test_results.append(0.001246)

# RNN
train_results.append(0.0010)
test_results.append(0.0014)

names.append("LR")
names.append("ARIMA")
names.append("RNN")
ind = np.arange(len(names))
width = 0.35
fig = plt.figure(figsize=(8,4))
ax = fig.add_subplot(111)
plt.bar(ind-width/2, train_results, width=width, label='Train Error')
plt.bar(ind+width/2, test_results, width=width, label='Test Error')
plt.xticks(ind)
ax.set_xticklabels(names)
plt.legend()
plt.show()
Loading...

结论

本案例研究将利用3种(LR、ARIMA、RNN)模型,结合股票的相关资产及其自身历史数据,来预测其股票价格。本案例研究,提供了股票预测建模的一般机器学习方法,涵盖从数据收集和清理,到构建和调整不同模型的整个过程。研究发现,RNN代表的非线性模型在误差,相比于线性模型(LR、ARIMA)有较好的表现。