Implementing backward and forward passes to train a simple fully connected network
Please see: https://arxiv.org/abs/1502.01852 for more details and explanation on Kaiming initialisation.
get_weight(4,5, True)
eps = 1e-9
#hide
def cross_entropy(xb, targ):
"Cross Entropy Loss"
return -( (xb + eps).log()[range(targ.shape[0]), targ.long()].mean() )
def mse_grad(inp, targ):
"Grad for mean squared error"
inp.g = 2. * (inp.squeeze(-1) - targ).unsqueeze(-1) / inp.shape[0]
def rel_grad(inp, out):
"Grad for ReLU layer"
inp.g = out.g * (inp > 0).float()
def lin_grad(inp, out, w, b):
"Grad for linear layer"
inp.g = out.g @ w.t()
w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
b.g = out.g.sum(0)
def softmax_cross_grad(inp, targ):
"Grad for softmax and cross entropy loss"
targ = torch.nn.functional.one_hot(targ.to(torch.int64), 10)
inp_s = softmax(inp)
inp.g = ( inp_s - targ ) / targ.shape[0]
Putting it all together...
def full_pass(xb, targ):
l1 = linear(xb, w1, b1)
l1_r = relu(l1)
l2 = linear(l1_r, w2, b2)
soft = softmax(l2)
loss = cross_entropy(soft, targ)
softmax_cross_grad(l2, targ)
lin_grad(l1_r, l2, w2, b2)
rel_grad(l1, l1_r)
lin_grad(xb, l1, w1, b1)
return loss
loss = full_pass(xt, yt)
loss
layers = [Linear(784,50,True), ReLU(), Linear(50,10, False)]
loss_func = CrossSoft()
model = Model(layers)
loss = loss_func(model(xt),yt)
loss_func.backward()
model.backward()
w1g = model.layers[0].w.g.clone()
w2g = model.layers[2].w.g.clone()
b1g = model.layers[0].b.g.clone()
b2g = model.layers[2].b.g.clone()
ig = xt.g.clone()
xt = xt.clone().requires_grad_(True)
model.layers[0].w = model.layers[0].w.clone().requires_grad_(True)
model.layers[0].b = model.layers[0].b.clone().requires_grad_(True)
model.layers[2].w = model.layers[2].w.clone().requires_grad_(True)
model.layers[2].b = model.layers[2].b.clone().requires_grad_(True)
%time loss = loss_func(model(xt), yt)
%time loss.backward()
test_near(w2g, model.layers[2].w.grad)
test_near(b2g, model.layers[2].b.grad)
test_near(w1g, model.layers[0].w.grad)
test_near(b1g, model.layers[0].b.grad)
test_near(ig, xt.grad)