Fully implemented ResNet architecture from scratch: https://arxiv.org/pdf/1512.03385.pdf

Helper

get_runner[source]

get_runner(callbacks=None)

Nested Modules

We first need to make new classes that allow architectures that aren't straight forward passes through a defined set of layers. This is normally handled in the forward passes of pytorch with autograd. We need to be a bit more clever due to the fact that we need to define our gradients in each module.

class NestedModel[source]

NestedModel() :: Module

NestModel that allows for a sequential model to be called withing an outer model

class TestMixingGrads[source]

TestMixingGrads() :: NestedModel

Test module to see if nested SequentialModels will work

Testing the gradients and the outputs:

m = SequentialModel(TestMixingGrads(), Linear(25,10, False))
db = get_mnist_databunch()
lf = CrossEntropy()
optimizer = adam_opt()
m
(Layer1): 
SubModel( 
(Layer1): Linear(784, 50)
(Layer2): ReLU()
(Layer3): Linear(50, 25)
)
(Layer2): Linear(25, 10)
learn = Learner(m, CrossEntropy(), Optimizer, db)
run = Runner(learn, [CheckGrad()])
run.fit(1,0.1)
good
good
good
good
good
good

Refactored Conv Layers

Before we can start making ResNets, we first define a few helper modules that abstract some of the layers:

class AutoConv[source]

AutoConv(n_in, n_out, kernel_size=3, stride=1) :: Conv

Automatic resizing of padding based on kernel size to ensure constant dimensions of input to output

class ConvBatch[source]

ConvBatch(n_in, n_out, kernel_size=3, stride=1, **kwargs) :: NestedModel

Performs conv then batchnorm

class Identity[source]

Identity() :: Module

Module to perform the identity connection (what goes in, comes out)

ResBlocks

Final built up ResNet blocks that implement the skip connecton layers characteristic of a ResNet

class BasicRes[source]

BasicRes(n_in, n_out, expansion=1, stride=1, Activation='ReLU', *args, **kwargs) :: Module

Basic block to implement the two different ResBlocks presented in the paper

class BasicResBlock[source]

BasicResBlock(n_in, n_out, *args, **kwargs) :: BasicRes

Basic block to implement the two different ResBlocks presented in the paper

class BottleneckBlock[source]

BottleneckBlock(n_in, n_out, *args, **kwargs) :: BasicRes

Basic block to implement the two different ResBlocks presented in the paper

class ResBlock[source]

ResBlock(n_in, n_out, block='BasicResBlock', stride=1, kernel_size=3, Activation='ReLU', **kwargs) :: NestedModel

Adds the final activation after the skip connection addition

class ResLayer[source]

ResLayer(block, n, n_in, n_out, *args, **kwargs) :: NestedModel

Sequential ResBlock layers as outlined in the paper

class ResLayer(NestedModel):
    "Sequential res layers"
    def __init__(self, block, n, n_in, n_out, *args, **kwargs):
        layers = []
        self.block, self.n, self.n_in, self.n_out = block, n, n_in, n_out

        downsampling = 2 if n_in != n_out else 1

        layers = [ResBlock(n_in, n_out, block, stride=downsampling),
        *[ResBlock(n_out * block.expansion, n_out, block, stride=1) for i in range(n-1)]]

        self.layers = SequentialModel(*layers)

    def __repr__(self): return f'ResLayer(\n{self.layers}\n)'

ResNet

class ResNet[source]

ResNet(block, layer_sizes=[64, 128, 256, 512], depths=[2, 2, 2, 2], c_in=3, c_out=1000, im_size=(28, 28), activation='ReLU', *args, **kwargs) :: NestedModel

Class to create ResNet architectures of dynamic sizing

class ResNet(NestedModel):
    "Class to create ResNet architectures of dynamic sizing"
    def __init__(self, block, layer_sizes=[64, 128, 256, 512], depths=[2,2,2,2], c_in=3, 
               c_out=1000, im_size=(28,28), activation=ReLU, *args, **kwargs):

        self.layer_sizes = layer_sizes

        gate = [
            Reshape(c_in, im_size[0], im_size[1]),
            ConvBatch(c_in, self.layer_sizes[0], stride=2, kernel_size=7),
            activation(),
            Pool(max_pool, ks=3, stride=2, padding=Padding(1))
        ]

        self.conv_sizes = list(zip(self.layer_sizes, self.layer_sizes[1:]))
        body = [
            ResLayer(block, depths[0], self.layer_sizes[0], self.layer_sizes[0], Activation=activation, *args, **kwargs),
            *[ResLayer(block, n, n_in * block.expansion, n_out, Activation=activation)
             for (n_in,n_out), n in zip(self.conv_sizes, depths[1:])]
        ]

        tail = [
            Pool(avg_pool, ks=1, stride=1, padding=None),
            Flatten(),
            Linear(self.layer_sizes[-1]*block.expansion, c_out, relu_after=False)
        ]

        self.layers = SequentialModel(
            *[layer for layer in gate],
            *[layer for layer in body],
            *[layer for layer in tail]
        )

    def __repr__(self): return f'ResNet: \n{self.layers}'
res = ResNet(BasicResBlock)
res
ResNet: 
(Layer1): Reshape(3, 28, 28)
(Layer2): Conv(3, 64, ks = 7, stride = 2), Batchnorm
(Layer3): ReLU()
(Layer4): MaxPool(ks: 3, stride: 2)
(Layer5): ResLayer(
(Layer1): ResBlock(64, 64, kernel_size=3, stride=1)
(Layer2): ResBlock(64, 64, kernel_size=3, stride=1)
)
(Layer6): ResLayer(
(Layer1): ResBlock(64, 128, kernel_size=3, stride=2)
(Layer2): ResBlock(128, 128, kernel_size=3, stride=1)
)
(Layer7): ResLayer(
(Layer1): ResBlock(128, 256, kernel_size=3, stride=2)
(Layer2): ResBlock(256, 256, kernel_size=3, stride=1)
)
(Layer8): ResLayer(
(Layer1): ResBlock(256, 512, kernel_size=3, stride=2)
(Layer2): ResBlock(512, 512, kernel_size=3, stride=1)
)
(Layer9): AveragePool(ks: 1, stride: 1)
(Layer10): Flatten()
(Layer11): Linear(512, 1000)

GetResnet[source]

GetResnet(size, c_in=3, c_out=10, *args, **kwargs)

Helper function to get ResNet architectures of different sizes

Testing out the ResNet Architectures:

GetResnet(18, c_in=1, c_out=10)
ResNet: 
(Layer1): Reshape(1, 28, 28)
(Layer2): Conv(1, 64, ks = 7, stride = 2), Batchnorm
(Layer3): ReLU()
(Layer4): MaxPool(ks: 3, stride: 2)
(Layer5): ResLayer(
(Layer1): ResBlock(64, 64, kernel_size=3, stride=1)
(Layer2): ResBlock(64, 64, kernel_size=3, stride=1)
)
(Layer6): ResLayer(
(Layer1): ResBlock(64, 128, kernel_size=3, stride=2)
(Layer2): ResBlock(128, 128, kernel_size=3, stride=1)
)
(Layer7): ResLayer(
(Layer1): ResBlock(128, 256, kernel_size=3, stride=2)
(Layer2): ResBlock(256, 256, kernel_size=3, stride=1)
)
(Layer8): ResLayer(
(Layer1): ResBlock(256, 512, kernel_size=3, stride=2)
(Layer2): ResBlock(512, 512, kernel_size=3, stride=1)
)
(Layer9): AveragePool(ks: 1, stride: 1)
(Layer10): Flatten()
(Layer11): Linear(512, 10)
GetResnet(34, c_in=1, c_out=10)
ResNet: 
(Layer1): Reshape(1, 28, 28)
(Layer2): Conv(1, 64, ks = 7, stride = 2), Batchnorm
(Layer3): ReLU()
(Layer4): MaxPool(ks: 3, stride: 2)
(Layer5): ResLayer(
(Layer1): ResBlock(64, 64, kernel_size=3, stride=1)
(Layer2): ResBlock(64, 64, kernel_size=3, stride=1)
(Layer3): ResBlock(64, 64, kernel_size=3, stride=1)
)
(Layer6): ResLayer(
(Layer1): ResBlock(64, 128, kernel_size=3, stride=2)
(Layer2): ResBlock(128, 128, kernel_size=3, stride=1)
(Layer3): ResBlock(128, 128, kernel_size=3, stride=1)
(Layer4): ResBlock(128, 128, kernel_size=3, stride=1)
)
(Layer7): ResLayer(
(Layer1): ResBlock(128, 256, kernel_size=3, stride=2)
(Layer2): ResBlock(256, 256, kernel_size=3, stride=1)
(Layer3): ResBlock(256, 256, kernel_size=3, stride=1)
(Layer4): ResBlock(256, 256, kernel_size=3, stride=1)
(Layer5): ResBlock(256, 256, kernel_size=3, stride=1)
(Layer6): ResBlock(256, 256, kernel_size=3, stride=1)
)
(Layer8): ResLayer(
(Layer1): ResBlock(256, 512, kernel_size=3, stride=2)
(Layer2): ResBlock(512, 512, kernel_size=3, stride=1)
(Layer3): ResBlock(512, 512, kernel_size=3, stride=1)
)
(Layer9): AveragePool(ks: 1, stride: 1)
(Layer10): Flatten()
(Layer11): Linear(512, 10)
GetResnet(50, c_in=1, c_out=10)
ResNet: 
(Layer1): Reshape(1, 28, 28)
(Layer2): Conv(1, 64, ks = 7, stride = 2), Batchnorm
(Layer3): ReLU()
(Layer4): MaxPool(ks: 3, stride: 2)
(Layer5): ResLayer(
(Layer1): ResBlock(64, 256, kernel_size=3, stride=1)
(Layer2): ResBlock(256, 256, kernel_size=3, stride=1)
(Layer3): ResBlock(256, 256, kernel_size=3, stride=1)
)
(Layer6): ResLayer(
(Layer1): ResBlock(256, 512, kernel_size=3, stride=2)
(Layer2): ResBlock(512, 512, kernel_size=3, stride=1)
(Layer3): ResBlock(512, 512, kernel_size=3, stride=1)
(Layer4): ResBlock(512, 512, kernel_size=3, stride=1)
)
(Layer7): ResLayer(
(Layer1): ResBlock(512, 1024, kernel_size=3, stride=2)
(Layer2): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer3): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer4): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer5): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer6): ResBlock(1024, 1024, kernel_size=3, stride=1)
)
(Layer8): ResLayer(
(Layer1): ResBlock(1024, 2048, kernel_size=3, stride=2)
(Layer2): ResBlock(2048, 2048, kernel_size=3, stride=1)
(Layer3): ResBlock(2048, 2048, kernel_size=3, stride=1)
)
(Layer9): AveragePool(ks: 1, stride: 1)
(Layer10): Flatten()
(Layer11): Linear(2048, 10)
GetResnet(150, c_in=1, c_out=10)
ResNet: 
(Layer1): Reshape(1, 28, 28)
(Layer2): Conv(1, 64, ks = 7, stride = 2), Batchnorm
(Layer3): ReLU()
(Layer4): MaxPool(ks: 3, stride: 2)
(Layer5): ResLayer(
(Layer1): ResBlock(64, 256, kernel_size=3, stride=1)
(Layer2): ResBlock(256, 256, kernel_size=3, stride=1)
(Layer3): ResBlock(256, 256, kernel_size=3, stride=1)
)
(Layer6): ResLayer(
(Layer1): ResBlock(256, 512, kernel_size=3, stride=2)
(Layer2): ResBlock(512, 512, kernel_size=3, stride=1)
(Layer3): ResBlock(512, 512, kernel_size=3, stride=1)
(Layer4): ResBlock(512, 512, kernel_size=3, stride=1)
)
(Layer7): ResLayer(
(Layer1): ResBlock(512, 1024, kernel_size=3, stride=2)
(Layer2): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer3): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer4): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer5): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer6): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer7): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer8): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer9): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer10): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer11): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer12): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer13): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer14): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer15): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer16): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer17): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer18): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer19): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer20): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer21): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer22): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer23): ResBlock(1024, 1024, kernel_size=3, stride=1)
)
(Layer8): ResLayer(
(Layer1): ResBlock(1024, 2048, kernel_size=3, stride=2)
(Layer2): ResBlock(2048, 2048, kernel_size=3, stride=1)
(Layer3): ResBlock(2048, 2048, kernel_size=3, stride=1)
)
(Layer9): AveragePool(ks: 1, stride: 1)
(Layer10): Flatten()
(Layer11): Linear(2048, 10)
GetResnet(152, c_in=1, c_out=10)
ResNet: 
(Layer1): Reshape(1, 28, 28)
(Layer2): Conv(1, 64, ks = 7, stride = 2), Batchnorm
(Layer3): ReLU()
(Layer4): MaxPool(ks: 3, stride: 2)
(Layer5): ResLayer(
(Layer1): ResBlock(64, 256, kernel_size=3, stride=1)
(Layer2): ResBlock(256, 256, kernel_size=3, stride=1)
(Layer3): ResBlock(256, 256, kernel_size=3, stride=1)
)
(Layer6): ResLayer(
(Layer1): ResBlock(256, 512, kernel_size=3, stride=2)
(Layer2): ResBlock(512, 512, kernel_size=3, stride=1)
(Layer3): ResBlock(512, 512, kernel_size=3, stride=1)
(Layer4): ResBlock(512, 512, kernel_size=3, stride=1)
(Layer5): ResBlock(512, 512, kernel_size=3, stride=1)
(Layer6): ResBlock(512, 512, kernel_size=3, stride=1)
(Layer7): ResBlock(512, 512, kernel_size=3, stride=1)
(Layer8): ResBlock(512, 512, kernel_size=3, stride=1)
)
(Layer7): ResLayer(
(Layer1): ResBlock(512, 1024, kernel_size=3, stride=2)
(Layer2): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer3): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer4): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer5): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer6): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer7): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer8): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer9): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer10): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer11): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer12): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer13): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer14): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer15): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer16): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer17): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer18): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer19): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer20): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer21): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer22): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer23): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer24): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer25): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer26): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer27): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer28): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer29): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer30): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer31): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer32): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer33): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer34): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer35): ResBlock(1024, 1024, kernel_size=3, stride=1)
(Layer36): ResBlock(1024, 1024, kernel_size=3, stride=1)
)
(Layer8): ResLayer(
(Layer1): ResBlock(1024, 2048, kernel_size=3, stride=2)
(Layer2): ResBlock(2048, 2048, kernel_size=3, stride=1)
(Layer3): ResBlock(2048, 2048, kernel_size=3, stride=1)
)
(Layer9): AveragePool(ks: 1, stride: 1)
(Layer10): Flatten()
(Layer11): Linear(2048, 10)
run = get_runner(model=GetResnet(18,c_in=1, c_out=10))