Implementing an LSTM. An extension of RNNs that have better long term dependency handling through the introduction of a cell state controlled by forget, update, input and output states.

class LSTM[source]

LSTM(input_sz, hidden_sz) :: Module

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

class LSTM(nn.Module):
    def __init__(self, input_sz, hidden_sz):
        "LSTM Module"
        super().__init__()
        self.x_forget_weight = nn.Parameter(get_weight(input_sz, hidden_sz))
        self.h_forget_weight = nn.Parameter(get_weight(hidden_sz, hidden_sz))
        self.forget_bias = nn.Parameter(torch.zeros(hidden_sz))

        self.x_scale_weight = nn.Parameter(get_weight(input_sz, hidden_sz))
        self.h_scale_weight = nn.Parameter(get_weight(hidden_sz, hidden_sz))
        self.scale_bias = nn.Parameter(torch.zeros(hidden_sz))

        self.x_add_weight = nn.Parameter(get_weight(input_sz, hidden_sz))
        self.h_add_weight = nn.Parameter(get_weight(hidden_sz, hidden_sz))
        self.add_bias = nn.Parameter(torch.zeros(hidden_sz))

        self.x_hidden_weight = nn.Parameter(get_weight(input_sz, hidden_sz))
        self.h_hidden_weight = nn.Parameter(get_weight(hidden_sz, hidden_sz))
        self.hidden_bias = nn.Parameter(torch.zeros(hidden_sz))

        self.tanh = nn.Tanh()
        self.sig = nn.Sigmoid()
        self.hidden_sz = hidden_sz

    def forward(self, inputs, states=None):
        hidden, cs = torch.zeros(self.hidden_sz) if states is None else states, torch.zeros(self.hidden_sz) if states is None else states
        bs, fs, _ = inputs.shape
        hiddens = []

        for feat in range(fs):
            inp = inputs[:,feat,:]

            forget = self.sig(inp @ self.x_forget_weight + hidden @ self.h_forget_weight + self.forget_bias)
            addition_scale = self.sig(inp @ self.x_scale_weight + hidden @ self.h_scale_weight + self.scale_bias)
            addition_base = self.tanh(inp @ self.x_add_weight + hidden @ self.h_add_weight + self.add_bias)

            cs = cs * forget + addition_base * addition_scale
            hidden = self.sig(inp @ self.x_hidden_weight + hidden @ self.h_hidden_weight + self.hidden_bias) + self.tanh(cs)

            hiddens.append(hidden.unsqueeze(1))

        hiddens = torch.cat(hiddens, dim=1)

        return hiddens
lstm = LSTM(16, 10)
lstm(torch.randn(8,32,16)).shape
torch.Size([8, 32, 10])
!python notebook2script.py LSTM.ipynb
Converted LSTM.ipynb to ModernArchitecturesFromPyTorch/nb_LSTM.py