7.2 CNN卷积神经网络
CNN计算机视觉领域占据着重要地位,而CNN同样可以用在时间序列上。区别在于应用在图像上的卷积核是二维的,而应用在时间序列上的卷积核是一维的,也就是一维卷积神经网络,1D CNN。
相比于基于RNN的LSTM等模型,1D CNN的优势是训练快,可以并行计算,并且在某些场景下可以获得不输给LSTM的模型效果。
下面就来学习如何用1D CNN训练时间序列数据。
# 使用和上一节中LSTM准备好的相同数据样本,不再重复
# 构建一个简单的1D CNN模型
class CNNnetwork(nn.Module):
def __init__(self):
super().__init__()
self.conv1d = nn.Conv1d(1,64,kernel_size=2)
self.relu = nn.ReLU(inplace=True)
self.fc1 = nn.Linear(64*11,50)
self.fc2 = nn.Linear(50,1)
def forward(self,x):
# 该模型的网络结构为 一维卷积层 -> Relu层 -> Flatten -> 全连接层1 -> 全连接层2
x = self.conv1d(x)
x = self.relu(x)
x = x.view(-1)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
torch.manual_seed(101)
model =CNNnetwork()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model
# CNN用到的模型参数更少
def count_parameters(model):
params = [p.numel() for p in model.parameters() if p.requires_grad]
for item in params:
print(f'{item:>6}')
print(f'______\n{sum(params):>6}')
count_parameters(model)
128
64
35200
50
50
1
______
35493
epochs = 100
model.train()
start_time = time.time()
for epoch in range(epochs):
for seq, y_train in train_data:
# 每次更新参数前都梯度归零和初始化
optimizer.zero_grad()
# 注意这里要对样本进行reshape,转换成conv1d的input size(batch size, channel, series length)
y_pred = model(seq.reshape(1,1,-1))
loss = criterion(y_pred, y_train)
loss.backward()
optimizer.step()
print(f'Epoch: {epoch+1:2} Loss: {loss.item():10.8f}')
print(f'\nDuration: {time.time() - start_time:.0f} seconds')
future = 12
# 选取序列最后12个值开始预测
preds = train_norm[-window_size:].tolist()
# 设置成eval模式
model.eval()
# 循环的每一步表示向时间序列向后滑动一格
for i in range(future):
seq = torch.FloatTensor(preds[-window_size:])
with torch.no_grad():
preds.append(model(seq.reshape(1,1,-1)).item())
# 逆归一化还原真实值
true_predictions = scaler.inverse_transform(np.array(preds[window_size:]).reshape(-1, 1))
# 对比真实值和预测值
plt.figure(figsize=(12,4))
plt.grid(True)
plt.plot(df['S4248SM144NCEN'])
x = np.arange('2018-02-01', '2019-02-01', dtype='datetime64[M]').astype('datetime64[D]')
plt.plot(x,true_predictions)
plt.show()
# 放大看
fig = plt.figure(figsize=(12,4))
plt.grid(True)
fig.autofmt_xdate()
plt.plot(df['S4248SM144NCEN']['2017-01-01':])
plt.plot(x,true_predictions)
plt.show()
Last updated