Implementing backward and forward passes to train a simple fully connected network
   
    
    
    
    
   
Please see: https://arxiv.org/abs/1502.01852 for more details and explanation on Kaiming initialisation.
get_weight(4,5, True)
eps = 1e-9
#hide
def cross_entropy(xb, targ): 
    "Cross Entropy Loss"
    return -( (xb + eps).log()[range(targ.shape[0]), targ.long()].mean() )
def mse_grad(inp, targ):
    "Grad for mean squared error"
    inp.g = 2. * (inp.squeeze(-1) - targ).unsqueeze(-1) / inp.shape[0]
def rel_grad(inp, out):
    "Grad for ReLU layer"
    inp.g = out.g * (inp > 0).float()
def lin_grad(inp, out, w, b):
    "Grad for linear layer"
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)
def softmax_cross_grad(inp, targ):
    "Grad for softmax and cross entropy loss"
    targ = torch.nn.functional.one_hot(targ.to(torch.int64), 10)
    inp_s = softmax(inp)
    inp.g = ( inp_s - targ ) / targ.shape[0]
Putting it all together...
def full_pass(xb, targ):
    l1 = linear(xb, w1, b1)
    l1_r = relu(l1)
    l2 = linear(l1_r, w2, b2)
    
    soft = softmax(l2)
    
    loss = cross_entropy(soft, targ)
    
    softmax_cross_grad(l2, targ)
    lin_grad(l1_r, l2, w2, b2)
    rel_grad(l1, l1_r)
    lin_grad(xb, l1, w1, b1)
    
    return loss
loss = full_pass(xt, yt)
loss
layers = [Linear(784,50,True), ReLU(), Linear(50,10, False)]
loss_func = CrossSoft()
model = Model(layers)
loss = loss_func(model(xt),yt)
loss_func.backward()
model.backward()
w1g = model.layers[0].w.g.clone()
w2g = model.layers[2].w.g.clone()
b1g = model.layers[0].b.g.clone()
b2g = model.layers[2].b.g.clone()
ig  = xt.g.clone()
xt = xt.clone().requires_grad_(True)
model.layers[0].w = model.layers[0].w.clone().requires_grad_(True)
model.layers[0].b = model.layers[0].b.clone().requires_grad_(True)
model.layers[2].w = model.layers[2].w.clone().requires_grad_(True)
model.layers[2].b = model.layers[2].b.clone().requires_grad_(True)
%time loss = loss_func(model(xt), yt)
%time loss.backward()
test_near(w2g, model.layers[2].w.grad)
test_near(b2g, model.layers[2].b.grad)
test_near(w1g, model.layers[0].w.grad)
test_near(b1g, model.layers[0].b.grad)
test_near(ig, xt.grad)