Implementing forwards and backwards passes for convolution and pooling layers as well as support for padding and stride
Tensor Shaping Modules¶
class Reshape(Module):
"Module to reshape input tensor into tensor of (bs, `channels`, `size1`, `size2`)"
def __init__(self, channels, size1, size2):
super().__init__()
self.size1 = size1
self.size2 = size2
self.channels = channels
def forward(self, xb): return xb.view(-1, self.channels, self.size1, self.size2)
def bwd(self, out, inp):
inp.g = out.g.reshape(-1, self.channels*self.size1*self.size2)
def __repr__(self): return f'Reshape({self.channels}, {self.size1}, {self.size2})'
class Flatten(Module):
"Module to flatten tensor input into shape (bs, rest)"
def __init__(self):
super().__init__()
def forward(self, xb):
self.size1 = xb.shape[2]
self.size2 = xb.shape[3]
self.channels = xb.shape[1]
return xb.view(xb.shape[0],-1)
def bwd(self, out, inp): inp.g = out.g.view(-1, self.channels, self.size1, self.size2)
def __repr__(self): return f'Flatten()'
Initialization¶
Please see: https://arxiv.org/abs/1502.01852 for more details and explanation on Kaiming initialisation. The idea is to regularize the model by keeping the mean and standard deviation close to 0 and 1 respectively.
import torch.nn as nn
from torch.nn import init
test_conv = nn.Conv2d(1, 8, 5)
xt, _, _, _ = get_mnist()
xt = xt.view(-1, 1, 28, 28)
test_conv.weight.shape
Without proper initialization
get_stats(test_conv(xt))
Using my own regularization (mean is half due to the ReLU activation)
test_conv.weight = nn.Parameter(get_weight(8, 1, 5, 5, 0, True))
get_stats(relu(test_conv(xt)))
Pytorch regularization
init.kaiming_normal_(test_conv.weight, a=0.)
get_stats(relu((test_conv(xt))))
Padding¶
Can input any of the given modes:
Constant
: Adds a constant pixel of value value
around the image
Reflection
: Repeats the outer most pixel value of the actual image
Forward¶
def convolve(weight, filts, filts_bias, stride=1,padding=None):
n_filt, depth_f, f_w, f_h = filts.shape
bs, depth_im, im_w, im_h = weight.shape
if padding is not None:
weight = padding(weight)
p_s = padding.size
else: p_s = 0
_,_,p_w, p_h = weight.shape
assert depth_f == depth_im
final = torch.zeros(bs, n_filt, int((im_w + 2*p_s - f_w)/stride)+1, int((im_h + 2*p_s - f_h)/stride)+1)
for j in range(0, p_w-f_h+1, stride): #vertical passes
for k in range(0, p_h-f_w+1, stride): #horizontal passes
final[:,:,j//stride,k//stride] = (weight[:,:,j:j+f_h,k:k+f_w].unsqueeze(1)*filts).sum(dim=-1).sum(dim=-1).sum(dim=-1) + filts_bias.unsqueeze(0)
return final
Testing against PyTorch's convolution:
weight = torch.randn(10, 16, 28, 28)
bs = weight.shape[0]
im_h = im_w = weight.shape[2]
w_dim = weight.shape[1]
weight.shape
stride = 3
pad_amount = 2
f_dim = 4
f_w = 6
f_h = f_w
test_conv = nn.Conv2d(w_dim, f_dim, f_w, padding=pad_amount, stride=stride)
f = test_conv.weight
test_conv.weight.shape
b = test_conv.bias
b.shape
pt_res = test_conv(weight)
pt_res.shape
pad = Padding(size=pad_amount)
pad
res = convolve(weight, f, b, stride, pad)
test_near(res, pt_res)
Backward¶
def conv_back(out,inp,weight,bias,stride=1,padding=None):
"Performs a backward pass to get the gradient of the output, `out`, with respect to the `inp`, `weight` of filters and `bias`."
dZ = out.g
(A_prev, W, b, stride) = inp, weight.d, bias.d, stride
# Retrieve dimensions from A_prev's shape
(m, n_C_prev, n_W_prev, n_H_prev) = A_prev.shape
# Retrieve dimensions from W's shape
(n_C, n_C_prev, f, f) = W.shape
# Retrieve dimensions from dZ's shape
(m, n_C, n_W, n_H) = dZ.shape
# Initialize dA_prev, dW, db with the correct shapes
dA_prev = torch.zeros((m, n_C_prev, n_W_prev, n_H_prev))
dW = torch.zeros((n_C, n_C_prev, f, f))
db = torch.zeros((n_C, 1, 1, 1))
# Pad A_prev and dA_prev
if padding is not None:
A_prev = padding(A_prev)
dA_prev = padding(dA_prev)
for h in range(n_H): # loop over vertical axis of the output volume
for w in range(n_W): # loop over horizontal axis of the output volume
# Find the corners of the current "slice"
vert_start = h*stride
vert_end = vert_start + f
horiz_start = w*stride
horiz_end = horiz_start + f
# Use the corners to define the slice from a_prev_pad
a_slice = A_prev[:, :, horiz_start:horiz_end, vert_start:vert_end]
ezdz = dZ[:, :, w, h].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
# Update gradients for the filter, bias and input
dA_prev[:, :, horiz_start:horiz_end, vert_start:vert_end] += (W * ezdz).sum(dim=1)
dW += (a_slice.unsqueeze(1)*ezdz).sum(dim=0)
db += dZ[:, :, w, h].sum(dim=0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
if padding is not None: dA_prev = dA_prev[:, :, padding.size:-padding.size, padding.size:-padding.size]
weight.update(dW)
bias.update(db.view(-1))
inp.g = dA_prev
Training¶
def get_basic_conv_model(lr):
"Helper function to get a basic conv model up and running"
pad1 = Padding(1)
pad2 = Padding(1)
model = SequentialModel(Reshape(1, 28, 28),
Conv(1, 16, 5, stride=3, leak=0, padding=pad1),
ReLU(),
Conv(16, 8, 5, stride=2, leak=1, padding=pad2),
Flatten(),
Linear(128, 10, False))
loss_func = CrossEntropy()
optimizer = Optimizer(model.parameters(), lr)
return model, optimizer, loss_func
Testing vs Pytorch
m, o, l = get_basic_conv_model(0.1)
loss_func = l
xt, yt, _, _ = get_mnist()
sx, sy = xt[:100], yt[:100]
m
loss = loss_func(m(sx), sy)
loss_func.backward()
m.backward()
sxg = sx.g.clone()
cw1g = m.layers[1].filters.grad.clone()
cb1g = m.layers[1].bias.grad.clone()
cw2g = m.layers[3].filters.grad.clone()
cb2g = m.layers[3].bias.grad.clone()
lw = m.layers[5].w.grad.clone()
lb = m.layers[5].b.grad.clone()
sx2 = sx.clone().requires_grad_(True)
m.layers[1].filters.d.requires_grad_(True)
m.layers[1].bias.d.requires_grad_(True)
m.layers[3].filters.d.requires_grad_(True)
m.layers[3].bias.d.requires_grad_(True)
m.layers[5].w.d.requires_grad_(True)
m.layers[5].b.d.requires_grad_(True)
loss = loss_func(m(sx2), sy)
loss.backward()
test_near(lb, m.layers[5].b.d.grad)
test_near(lw, m.layers[5].w.d.grad)
test_near(cb2g, m.layers[3].bias.d.grad)
test_near(cw2g, m.layers[3].filters.d.grad)
test_near(cb1g, m.layers[1].bias.d.grad)
test_near(cw1g, m.layers[1].filters.d.grad)
m = [Reshape(1, 28, 28),
Conv(1, 4, 5, stride=1, leak=0),
ReLU(),
Conv(4, 1, 5, stride=1, leak=1),
Flatten(),
Linear(20*20, 10, False)]
m, o, l = get_model(0.1, m)
train,valid = get_small_datasets()
fit(2, m, o, l, train, valid)
Pooling Layers¶
Forward¶
def pool(inp, ks, stride, padding=None, operation=max_pool):
if padding is not None:
if operation == max_pool: padding.value = inp.min() - 1
inpp = padding(inp)
else: inpp = inp
bs, nc, h, w = inp.shape
nw, nh = int((int(w - ks) / stride)+1), int((int(h - ks) / stride)+1)
out = torch.zeros(bs, nc, nw, nh)
for i in range(nh):
for j in range(nw):
window = inpp[:,:,j*stride:j*stride+ks, i*stride:i*stride+ks]
out[:,:,j,i] = operation(window)
return out
Testing Pooling Forward Passes
test_t = torch.randn(16, 3, 28, 28)
ks = 3
stride = 1
pad_size = 0
py_max = nn.MaxPool2d(ks, stride=stride, padding=pad_size)
py_avg = nn.AvgPool2d(ks, stride=stride, padding=pad_size)
py_max_result = py_max(test_t)
py_avg_result = py_avg(test_t)
pad = Padding(size=pad_size) if pad_size > 0 else None
my_avg = pool(test_t, ks, stride, pad, avg_pool)
my_max = pool(test_t, ks, stride, pad, max_pool)
test_near(py_max_result, my_max)
test_near(py_avg_result, my_avg)
Backward¶
def pool_back(out, inp, ks, stride, padding=None, operation=max_pool):
dZ = out.g
if padding is not None:
if operation == max_pool: padding.value = inp.min() - 1
inp = padding(inp)
bs, nc, nh, nw = dZ.shape
dA_prev = torch.zeros(inp.shape)
shape = (ks,ks)
for i in range(nh):
for j in range(nw):
if operation == max_pool:
window = inp[:,:,j*stride:j*stride+ks, i*stride:i*stride+ks]
mask = max_back(window)
dA_prev[:,:,j*stride:j*stride+ks, i*stride:i*stride+ks] += mask*dZ[:,:,j,i].unsqueeze(-1).unsqueeze(-1)
elif operation == avg_pool:
dz = dZ[:,:,j, i]
dA_prev[:,:,j*stride:j*stride+ks, i*stride:i*stride+ks] += average_back(dz, shape)
inp.g = dA_prev
Module¶
layers = [Reshape(1, 28, 28),
Conv(1, 4),
Pool(avg_pool, ks=2, stride=1),
Conv(4, 1),
Flatten(),
Linear(529, 10, False)]
m, o, l = get_model(0.1, layers)
train,valid = get_small_datasets()
m
loss = loss_func(m(sx), sy)
loss_func.backward()
m.backward()
sxg = sx.g.clone()
cw = m.layers[1].filters.grad.clone()
cb = m.layers[1].bias.grad.clone()
lw = m.layers[5].w.grad.clone()
lb = m.layers[5].b.grad.clone()
sx2 = sx.clone().requires_grad_(True)
m.layers[1].filters.d.requires_grad_(True)
m.layers[1].bias.d.requires_grad_(True)
m.layers[5].w.d.requires_grad_(True)
m.layers[5].b.d.requires_grad_(True)