Implementing batchnorm regularization
   
    
    
    
    
   
Testing CheckGrad on linear model
run = Runner(get_learner(), [CheckGrad()])
run.model
run.fit(1, 0.1)
#export
class Batchnorm(Module):
    "Module for applying batch normalization"
    def __init__(self, nf, mom=0.1, eps=1e-6):
        super().__init__()
        self.nf = nf
        self.mom, self.eps = mom, eps
        self.multiplier = Parameter(torch.ones(1,nf, 1, 1))
        self.adder = Parameter(torch.zeros(1,nf,1,1))
        self.means = torch.zeros(1,nf,1,1)
        self.vars = torch.ones(1,nf,1,1)
    def update(self, xb):
        #Get the mean and standard deviation of the batch, update running average
        mean = xb.mean(dim=(0,2,3), keepdim=True)
        var = xb.std(dim=(0,2,3), keepdim=True)
        self.mean = self.mom * self.means + (1-self.mom) * mean
        self.vars = self.mom * self.vars + (1-self.mom) * var
        return mean, var
    def forward(self, xb): 
        if not self.learner.model.training:
            normed = (xb-self.mean) / (self.vars+self.eps).sqrt()
            return normed * self.multiplier.d + self.adder.d
        mean, var = self.update(xb)
        self.after_stats = (xb - mean) / (var + self.eps).sqrt()
        self.after_scaling = self.after_stats * self.multiplier.d + self.adder.d
        return self.after_scaling
    def bwd(self, out, inp):
        bs = out.g.shape[0]
        self.multiplier.update((out.g * self.after_stats).sum(dim=(0,2,3), keepdim=True))
        self.adder.update(out.g.sum(dim=(0,2,3), keepdim=True))
        var_factor = 1./(self.vars+self.eps).sqrt()
        mean_factor = inp - self.means
        delta_norm = out.g * self.multiplier.d
        delta_var = delta_norm * mean_factor * -0.5 * (self.vars + self.eps)**(-3/2)
        delta_mean = delta_norm * -var_factor + delta_var * 1 / bs * -2 * mean_factor
        inp.g = (delta_norm * var_factor) + (delta_mean / bs) + (delta_var * 2 / bs * mean_factor)
    def __repr__(self): return f'Batchnorm({self.nf})'
learn = get_conv_learner()
learn.model.learner
run = get_conv_runner([CheckGrad()])
run.model