import pytorch_lightning as pl import torch import torch.nn as nn class LSTMModel(pl.LightningModule): def __init__(self, **config): super(LSTMModel, self).__init__() self.save_hyperparameters(config) self.lstm = nn.LSTM(input_size=21,hidden_size=512,num_layers=3,proj_size=21,batch_first=True) self.linear = nn.Linear(in_features=21, out_features=7) def forward(self, x): outputs = [] hidden, cell = None, None for i in range(20): if i == 0: output, (hidden, cell) = self.lstm(x) else: output, (hidden, cell) = self.lstm(output[:, -1, :].unsqueeze(1), (hidden, cell)) outputs.append(self.linear(output[:, -1, :])) return torch.stack(outputs, dim=1)