Implementing batchnorm regularization
Testing CheckGrad on linear model
run = Runner(get_learner(), [CheckGrad()])
run.model
run.fit(1, 0.1)
#export
class Batchnorm(Module):
"Module for applying batch normalization"
def __init__(self, nf, mom=0.1, eps=1e-6):
super().__init__()
self.nf = nf
self.mom, self.eps = mom, eps
self.multiplier = Parameter(torch.ones(1,nf, 1, 1))
self.adder = Parameter(torch.zeros(1,nf,1,1))
self.means = torch.zeros(1,nf,1,1)
self.vars = torch.ones(1,nf,1,1)
def update(self, xb):
#Get the mean and standard deviation of the batch, update running average
mean = xb.mean(dim=(0,2,3), keepdim=True)
var = xb.std(dim=(0,2,3), keepdim=True)
self.mean = self.mom * self.means + (1-self.mom) * mean
self.vars = self.mom * self.vars + (1-self.mom) * var
return mean, var
def forward(self, xb):
if not self.learner.model.training:
normed = (xb-self.mean) / (self.vars+self.eps).sqrt()
return normed * self.multiplier.d + self.adder.d
mean, var = self.update(xb)
self.after_stats = (xb - mean) / (var + self.eps).sqrt()
self.after_scaling = self.after_stats * self.multiplier.d + self.adder.d
return self.after_scaling
def bwd(self, out, inp):
bs = out.g.shape[0]
self.multiplier.update((out.g * self.after_stats).sum(dim=(0,2,3), keepdim=True))
self.adder.update(out.g.sum(dim=(0,2,3), keepdim=True))
var_factor = 1./(self.vars+self.eps).sqrt()
mean_factor = inp - self.means
delta_norm = out.g * self.multiplier.d
delta_var = delta_norm * mean_factor * -0.5 * (self.vars + self.eps)**(-3/2)
delta_mean = delta_norm * -var_factor + delta_var * 1 / bs * -2 * mean_factor
inp.g = (delta_norm * var_factor) + (delta_mean / bs) + (delta_var * 2 / bs * mean_factor)
def __repr__(self): return f'Batchnorm({self.nf})'
learn = get_conv_learner()
learn.model.learner
run = get_conv_runner([CheckGrad()])
run.model