Implementing forwards and backwards passes for convolution and pooling layers as well as support for padding and stride

Tensor Shaping Modules

class Reshape[source]

Reshape(channels, size1, size2) :: Module

Module to reshape input tensor into tensor of (bs, channels, size1, size2)

class Reshape(Module):
    "Module to reshape input tensor into tensor of (bs, `channels`, `size1`, `size2`)"
    def __init__(self, channels, size1, size2):
        super().__init__()
        self.size1 = size1
        self.size2 = size2
        self.channels = channels

    def forward(self, xb): return xb.view(-1, self.channels, self.size1, self.size2)

    def bwd(self, out, inp): 
        inp.g = out.g.reshape(-1, self.channels*self.size1*self.size2)

    def __repr__(self): return f'Reshape({self.channels}, {self.size1}, {self.size2})'

class Flatten[source]

Flatten() :: Module

Module to flatten tensor input into shape (bs, rest)

class Flatten(Module):
    "Module to flatten tensor input into shape (bs, rest)"
    def __init__(self):
        super().__init__()

    def forward(self, xb): 
        self.size1 = xb.shape[2]
        self.size2 = xb.shape[3]
        self.channels = xb.shape[1]
        return xb.view(xb.shape[0],-1)

    def bwd(self, out, inp): inp.g = out.g.view(-1, self.channels, self.size1, self.size2)

    def __repr__(self): return f'Flatten()'

Initialization

get_fan[source]

get_fan(dim1, dim2, dim3, dim4, fan_out)

Get the appropriate fan value based on the receptive field size and number of activations in the previous layer

get_gain[source]

get_gain(leak)

Get proper initialization gain factor based on the leak amount of the next layer. Leak of 1 if no ReLU after

get_weight[source]

get_weight(in_d, out_d, relu_after)

Returns weight matrix of size in_d x out_d initialized using Kaiming initialization

Please see: https://arxiv.org/abs/1502.01852 for more details and explanation on Kaiming initialisation. The idea is to regularize the model by keeping the mean and standard deviation close to 0 and 1 respectively.

import torch.nn as nn
from torch.nn import init
test_conv = nn.Conv2d(1, 8, 5)
xt, _, _, _ = get_mnist()
xt = xt.view(-1, 1, 28, 28)
test_conv.weight.shape
torch.Size([8, 1, 5, 5])

Without proper initialization

get_stats(test_conv(xt))
Mean: -0.010123089887201786
Std: 0.6414140462875366

Using my own regularization (mean is half due to the ReLU activation)

test_conv.weight = nn.Parameter(get_weight(8, 1, 5, 5, 0, True)) 
get_stats(relu(test_conv(xt)))
Mean: 0.5077283382415771
Std: 0.7732385993003845

Pytorch regularization

init.kaiming_normal_(test_conv.weight, a=0.)
get_stats(relu((test_conv(xt))))
Mean: 0.5470803380012512
Std: 1.1031794548034668

Padding

class Padding[source]

Padding(size=1, mode='constant', value=0)

Adds padding around an image size pixels wide.

Can input any of the given modes:
Constant: Adds a constant pixel of value value around the image
Reflection: Repeats the outer most pixel value of the actual image

Forward

convolve[source]

convolve(weight, filts, filts_bias, stride=1, padding=None)

Performs a convolution on weight using the given filts and bias filts_bias. Can specify a stride for the convolution orpadding` for the main image.

def convolve(weight, filts, filts_bias, stride=1,padding=None):
    n_filt, depth_f, f_w, f_h = filts.shape
    bs, depth_im, im_w, im_h = weight.shape

    if padding is not None: 
        weight = padding(weight)
        p_s = padding.size
    else: p_s = 0

    _,_,p_w, p_h = weight.shape

    assert depth_f == depth_im

    final = torch.zeros(bs, n_filt, int((im_w + 2*p_s - f_w)/stride)+1, int((im_h + 2*p_s - f_h)/stride)+1)
    for j in range(0, p_w-f_h+1, stride): #vertical passes
        for k in range(0, p_h-f_w+1, stride): #horizontal passes
            final[:,:,j//stride,k//stride] = (weight[:,:,j:j+f_h,k:k+f_w].unsqueeze(1)*filts).sum(dim=-1).sum(dim=-1).sum(dim=-1) + filts_bias.unsqueeze(0)

    return final

Testing against PyTorch's convolution:

weight = torch.randn(10, 16, 28, 28)
bs = weight.shape[0]
im_h = im_w = weight.shape[2]
w_dim = weight.shape[1]
weight.shape
torch.Size([10, 16, 28, 28])
stride = 3
pad_amount = 2
f_dim = 4
f_w = 6
f_h = f_w
test_conv = nn.Conv2d(w_dim, f_dim, f_w, padding=pad_amount, stride=stride)
f = test_conv.weight
test_conv.weight.shape
torch.Size([4, 16, 6, 6])
b = test_conv.bias
b.shape
torch.Size([4])
pt_res = test_conv(weight)
pt_res.shape
torch.Size([10, 4, 9, 9])
pad = Padding(size=pad_amount)
pad
Padding (Mode: constant, Size: 2, Value: 0)
res = convolve(weight, f, b, stride, pad)
test_near(res, pt_res)
good

Backward

conv_back[source]

conv_back(out, inp, weight, bias, stride=1, padding=None)

Performs a backward pass to get the gradient of the output, out, with respect to the inp, weight of filters and bias.

def conv_back(out,inp,weight,bias,stride=1,padding=None):
    "Performs a backward pass to get the gradient of the output, `out`, with respect to the `inp`, `weight` of filters and `bias`."
    dZ = out.g

    (A_prev, W, b, stride) = inp, weight.d, bias.d, stride

    # Retrieve dimensions from A_prev's shape
    (m, n_C_prev, n_W_prev, n_H_prev) = A_prev.shape

    # Retrieve dimensions from W's shape
    (n_C, n_C_prev, f, f) = W.shape

    # Retrieve dimensions from dZ's shape
    (m, n_C, n_W, n_H) = dZ.shape

    # Initialize dA_prev, dW, db with the correct shapes
    dA_prev = torch.zeros((m, n_C_prev, n_W_prev, n_H_prev))                           
    dW = torch.zeros((n_C, n_C_prev, f, f))
    db = torch.zeros((n_C, 1, 1, 1))

    # Pad A_prev and dA_prev
    if padding is not None: 
        A_prev = padding(A_prev)
        dA_prev = padding(dA_prev)

    for h in range(n_H): # loop over vertical axis of the output volume
        for w in range(n_W):               # loop over horizontal axis of the output volume

                # Find the corners of the current "slice"
                vert_start = h*stride
                vert_end = vert_start + f
                horiz_start = w*stride
                horiz_end = horiz_start + f

                # Use the corners to define the slice from a_prev_pad
                a_slice = A_prev[:, :, horiz_start:horiz_end, vert_start:vert_end]

                ezdz = dZ[:, :, w, h].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

                # Update gradients for the filter, bias and input
                dA_prev[:, :, horiz_start:horiz_end, vert_start:vert_end] += (W * ezdz).sum(dim=1)
                dW += (a_slice.unsqueeze(1)*ezdz).sum(dim=0)
                db += dZ[:, :, w, h].sum(dim=0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

    if padding is not None: dA_prev = dA_prev[:, :, padding.size:-padding.size, padding.size:-padding.size]

    weight.update(dW)
    bias.update(db.view(-1))
    inp.g = dA_prev

Training

class Conv[source]

Conv(n_in, n_out, kernel_size=3, stride=1, leak=1, padding=None) :: Module

Module to perform convolutions. Can specify kernel size, stride and padding

def get_basic_conv_model(lr):
    "Helper function to get a basic conv model up and running"
    pad1 = Padding(1)
    pad2 = Padding(1)
    model = SequentialModel(Reshape(1, 28, 28), 
                            Conv(1, 16, 5, stride=3, leak=0, padding=pad1), 
                            ReLU(), 
                            Conv(16, 8, 5, stride=2, leak=1, padding=pad2), 
                            Flatten(), 
                            Linear(128, 10, False))
    loss_func = CrossEntropy()
    optimizer = Optimizer(model.parameters(), lr)
    return model, optimizer, loss_func

Testing vs Pytorch

m, o, l = get_basic_conv_model(0.1)
loss_func = l
xt, yt, _, _ = get_mnist()
sx, sy = xt[:100], yt[:100]
m
(Layer1): Reshape(1, 28, 28)
(Layer2): Conv(1, 16, ks = 5, stride = 3)
(Layer3): ReLU()
(Layer4): Conv(16, 8, ks = 5, stride = 2)
(Layer5): Flatten()
(Layer6): Linear(128, 10)
loss = loss_func(m(sx), sy)
loss_func.backward()
m.backward()
sxg = sx.g.clone()
cw1g = m.layers[1].filters.grad.clone()
cb1g = m.layers[1].bias.grad.clone()
cw2g = m.layers[3].filters.grad.clone()
cb2g = m.layers[3].bias.grad.clone()
lw = m.layers[5].w.grad.clone()
lb = m.layers[5].b.grad.clone()
sx2 = sx.clone().requires_grad_(True)
m.layers[1].filters.d.requires_grad_(True)
m.layers[1].bias.d.requires_grad_(True)
m.layers[3].filters.d.requires_grad_(True)
m.layers[3].bias.d.requires_grad_(True)
m.layers[5].w.d.requires_grad_(True)
m.layers[5].b.d.requires_grad_(True)
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)
loss = loss_func(m(sx2), sy)
loss.backward()
test_near(lb, m.layers[5].b.d.grad)
test_near(lw, m.layers[5].w.d.grad)
test_near(cb2g, m.layers[3].bias.d.grad)
test_near(cw2g, m.layers[3].filters.d.grad)
test_near(cb1g, m.layers[1].bias.d.grad)
test_near(cw1g, m.layers[1].filters.d.grad)
good
good
good
good
good
good

get_linear_model[source]

get_linear_model(lr)

get_model[source]

get_model(lr)

Easy helper function to get basic fully connected network with optimizer and loss function, takes learning rate, lr, as a parameter

m = [Reshape(1, 28, 28), 
     Conv(1, 4, 5, stride=1, leak=0), 
     ReLU(), 
     Conv(4, 1, 5, stride=1, leak=1), 
     Flatten(), 
     Linear(20*20, 10, False)]

m, o, l = get_model(0.1, m)

get_small_datasets[source]

get_small_datasets()

Helper function to get smaller versions of MNIST datasets

train,valid = get_small_datasets()
fit(2, m, o, l, train, valid)
Epoch 1, Accuracy: 0.0963541716337204, Loss: nan
Epoch 2, Accuracy: 0.0963541716337204, Loss: nan

Pooling Layers

Forward

max_pool[source]

max_pool(inp)

Applies a max pooling operation on inp

avg_pool[source]

avg_pool(inp)

Applies an average pooling operation on inp

pool[source]

pool(inp, ks, stride, padding=None, operation='max_pool')

Runs a pooling operation on inp of type operation with given stride, ks and padding

def pool(inp, ks, stride, padding=None, operation=max_pool):
    if padding is not None: 
        if operation == max_pool: padding.value = inp.min() - 1
        inpp = padding(inp)
    else: inpp = inp

    bs, nc, h, w = inp.shape
    nw, nh = int((int(w - ks) / stride)+1), int((int(h - ks) / stride)+1)

    out = torch.zeros(bs, nc, nw, nh)

    for i in range(nh):
        for j in range(nw):
            window = inpp[:,:,j*stride:j*stride+ks, i*stride:i*stride+ks]
            out[:,:,j,i] = operation(window)

    return out

Testing Pooling Forward Passes

test_t = torch.randn(16, 3, 28, 28)
ks = 3
stride = 1
pad_size = 0
py_max = nn.MaxPool2d(ks, stride=stride, padding=pad_size)
py_avg = nn.AvgPool2d(ks, stride=stride, padding=pad_size)
py_max_result = py_max(test_t)
py_avg_result = py_avg(test_t)
pad = Padding(size=pad_size) if pad_size > 0 else None
my_avg = pool(test_t, ks, stride, pad, avg_pool)
my_max = pool(test_t, ks, stride, pad, max_pool)
test_near(py_max_result, my_max)
test_near(py_avg_result, my_avg)
good
good

Backward

max_back[source]

max_back(window)

Gradient for window of max pooling

average_back[source]

average_back(window, shape)

Gradient for window of average pooling

pool_back[source]

pool_back(out, inp, ks, stride, padding=None, operation='max_pool')

Function to pass gradient through pooling layers

def pool_back(out, inp, ks, stride, padding=None, operation=max_pool):
    dZ = out.g

    if padding is not None: 
        if operation == max_pool: padding.value = inp.min() - 1
        inp = padding(inp)

    bs, nc, nh, nw = dZ.shape

    dA_prev = torch.zeros(inp.shape)

    shape = (ks,ks)

    for i in range(nh):
        for j in range(nw):

            if operation == max_pool:
                window = inp[:,:,j*stride:j*stride+ks, i*stride:i*stride+ks]

                mask = max_back(window)

                dA_prev[:,:,j*stride:j*stride+ks, i*stride:i*stride+ks] += mask*dZ[:,:,j,i].unsqueeze(-1).unsqueeze(-1)

            elif operation == avg_pool:
                dz = dZ[:,:,j, i]
                dA_prev[:,:,j*stride:j*stride+ks, i*stride:i*stride+ks] += average_back(dz, shape)

    inp.g = dA_prev

Module

class Pool[source]

Pool(operation, ks=2, stride=2, padding=None) :: Module

Module for defining a pooling layer in a model

layers = [Reshape(1, 28, 28),
          Conv(1, 4),
          Pool(avg_pool, ks=2, stride=1),
          Conv(4, 1),
          Flatten(),
          Linear(529, 10, False)]
m, o, l = get_model(0.1, layers)
train,valid = get_small_datasets()
m
(Layer1): Reshape(1, 28, 28)
(Layer2): Conv(1, 4, ks = 3, stride = 1)
(Layer3): AveragePool(ks: 2, stride: 1)
(Layer4): Conv(4, 1, ks = 3, stride = 1)
(Layer5): Flatten()
(Layer6): Linear(529, 10)
loss = loss_func(m(sx), sy)
loss_func.backward()
m.backward()
sxg = sx.g.clone()
cw = m.layers[1].filters.grad.clone()
cb = m.layers[1].bias.grad.clone()
lw = m.layers[5].w.grad.clone()
lb = m.layers[5].b.grad.clone()
sx2 = sx.clone().requires_grad_(True)
m.layers[1].filters.d.requires_grad_(True)
m.layers[1].bias.d.requires_grad_(True)
m.layers[5].w.d.requires_grad_(True)
m.layers[5].b.d.requires_grad_(True)
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)