Implementing GRU, a modified LSTM cell with 1 less gate for increased training efficiency.

class GRU[source]

GRU(input_sz, hidden_sz) :: Module

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

class GRU(nn.Module):
    def __init__(self, input_sz, hidden_sz):
        "GRU network"
        super().__init__()
        self.s_rescale_weight = nn.Parameter(get_weight(hidden_sz, hidden_sz))
        self.inp_rescale_weight = nn.Parameter(get_weight(input_sz, hidden_sz))
        self.rescale_bias = nn.Parameter(torch.zeros(hidden_sz))

        self.s_update_weight = nn.Parameter(get_weight(hidden_sz, hidden_sz))
        self.inp_update_weight = nn.Parameter(get_weight(input_sz, hidden_sz))
        self.update_bias = nn.Parameter(torch.zeros(hidden_sz))

        self.s_add_weight = nn.Parameter(get_weight(hidden_sz, hidden_sz))
        self.inp_add_weight = nn.Parameter(get_weight(input_sz, hidden_sz))
        self.add_bias = nn.Parameter(torch.zeros(hidden_sz))

        self.tanh = nn.Tanh()
        self.sig = nn.Sigmoid()
        self.hidden_sz = hidden_sz

    def forward(self, input, state=None):
        hidden, state = torch.zeros(self.hidden_sz) if state is None else state[0], torch.zeros(self.hidden_sz) if state is None else state[1]
        bs, fs, _ = input.shape
        hiddens = []

        for feat in range(fs):
            inp = input[:,feat,:]

            reset_scale = self.sig(state @ self.s_rescale_weight + inp @ self.inp_rescale_weight + self.rescale_bias)
            reset = state * reset_scale

            update = self.sig(state @ self.s_update_weight + inp @ self.inp_update_weight + self.update_bias)

            add = self.tanh(reset @ self.s_add_weight + inp @ self.inp_add_weight + self.add_bias) * update

            state = state * (-1*update) + add

            hiddens.append(state.unsqueeze(1))

        hiddens = torch.cat(hiddens, dim=1)

        return hiddens, state
gru = GRU(16, 10)
gru(torch.randn(16,8,16))[0].shape
torch.Size([16, 8, 10])
!python notebook2script.py GRU.ipynb
Converted GRU.ipynb to ModernArchitecturesFromPyTorch/nb_GRU.py