Implementing GRU, a modified LSTM cell with 1 less gate for increased training efficiency.
class GRU(nn.Module):
def __init__(self, input_sz, hidden_sz):
"GRU network"
super().__init__()
self.s_rescale_weight = nn.Parameter(get_weight(hidden_sz, hidden_sz))
self.inp_rescale_weight = nn.Parameter(get_weight(input_sz, hidden_sz))
self.rescale_bias = nn.Parameter(torch.zeros(hidden_sz))
self.s_update_weight = nn.Parameter(get_weight(hidden_sz, hidden_sz))
self.inp_update_weight = nn.Parameter(get_weight(input_sz, hidden_sz))
self.update_bias = nn.Parameter(torch.zeros(hidden_sz))
self.s_add_weight = nn.Parameter(get_weight(hidden_sz, hidden_sz))
self.inp_add_weight = nn.Parameter(get_weight(input_sz, hidden_sz))
self.add_bias = nn.Parameter(torch.zeros(hidden_sz))
self.tanh = nn.Tanh()
self.sig = nn.Sigmoid()
self.hidden_sz = hidden_sz
def forward(self, input, state=None):
hidden, state = torch.zeros(self.hidden_sz) if state is None else state[0], torch.zeros(self.hidden_sz) if state is None else state[1]
bs, fs, _ = input.shape
hiddens = []
for feat in range(fs):
inp = input[:,feat,:]
reset_scale = self.sig(state @ self.s_rescale_weight + inp @ self.inp_rescale_weight + self.rescale_bias)
reset = state * reset_scale
update = self.sig(state @ self.s_update_weight + inp @ self.inp_update_weight + self.update_bias)
add = self.tanh(reset @ self.s_add_weight + inp @ self.inp_add_weight + self.add_bias) * update
state = state * (-1*update) + add
hiddens.append(state.unsqueeze(1))
hiddens = torch.cat(hiddens, dim=1)
return hiddens, state
gru = GRU(16, 10)
gru(torch.randn(16,8,16))[0].shape
!python notebook2script.py GRU.ipynb