Implementing an advanced training loop to take advantage of callbacks. Defines new wrapper classes such as databunch and learner.

DataLoader

class Dataset[source]

Dataset(x, y)

Container class to store and get input and target values from a dataset

class DataLoader[source]

DataLoader(ds, batcher, collate_fcn)

Refactored DataLoader to include a batcher, also collates the output of batcher into single tensor for model

class Databunch[source]

Databunch(train_dl, valid_dl)

Wrapper to combine training and validation datasets

get_databunch[source]

get_databunch(xt, yt, xv, yv, bs=64)

Helper function to get a databunch of given bs

get_mnist_databunch[source]

get_mnist_databunch()

Grabs MNIST databuunch usuing get_mnist and get_databunch

db = get_mnist_databunch()
db
Databunch(
Train: Data: X: torch.Size([50000, 784]), Y: torch.Size([50000]), bs = 64, 
ValidData: X: torch.Size([50000, 784]), Y: torch.Size([50000]), bs = 128
)

Learner

class Learner[source]

Learner(model, loss_func, optimizer, db, lr=0.5)

get_learner[source]

get_learner()

Helper function to get learner

model,_,_ = get_model(0.1, [Linear(784, 50, True), ReLU(), Linear(50, 10, False)])
learn = Learner(model, 0.1, Optimizer, db)
learn
> <ipython-input-19-51751b17fe8e>(7)__init__()
-> model.set_learner(self)
(Pdb) self
Data: 
 Databunch(
Train: Data: X: torch.Size([50000, 784]), Y: torch.Size([50000]), bs = 64, 
ValidData: X: torch.Size([50000, 784]), Y: torch.Size([50000]), bs = 128
) 
 Model: 
 (Layer1): Linear(784, 50)
(Layer2): ReLU()
(Layer3): Linear(50, 10)
(Pdb) model.set_learner(self)
(Pdb) model.learner
(Pdb) model.learner = 2
(Pdb) model.layers[0].learner
Data: 
 Databunch(
Train: Data: X: torch.Size([50000, 784]), Y: torch.Size([50000]), bs = 64, 
ValidData: X: torch.Size([50000, 784]), Y: torch.Size([50000]), bs = 128
) 
 Model: 
 (Layer1): Linear(784, 50)
(Layer2): ReLU()
(Layer3): Linear(50, 10)
(Pdb) exit
---------------------------------------------------------------------------
BdbQuit                                   Traceback (most recent call last)
<ipython-input-22-3d0b2ce1be4a> in <module>()
      1 model,_,_ = get_model(0.1, [Linear(784, 50, True), ReLU(), Linear(50, 10, False)])
----> 2 learn = Learner(model, 0.1, Optimizer, db)
      3 learn

<ipython-input-19-51751b17fe8e> in __init__(self, model, loss_func, optimizer, db, lr)
      5         self.model, self.loss_func, self.optimizer, self.db = model, loss_func, optimizer(model.parameters(), lr), db
      6         import pdb; pdb.set_trace()
----> 7         model.set_learner(self)
      8 
      9     def __repr__(self): return f'Data: \n {self.db} \n Model: \n {self.model}'

<ipython-input-19-51751b17fe8e> in __init__(self, model, loss_func, optimizer, db, lr)
      5         self.model, self.loss_func, self.optimizer, self.db = model, loss_func, optimizer(model.parameters(), lr), db
      6         import pdb; pdb.set_trace()
----> 7         model.set_learner(self)
      8 
      9     def __repr__(self): return f'Data: \n {self.db} \n Model: \n {self.model}'

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/bdb.py in trace_dispatch(self, frame, event, arg)
     86             return # None
     87         if event == 'line':
---> 88             return self.dispatch_line(frame)
     89         if event == 'call':
     90             return self.dispatch_call(frame, arg)

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/bdb.py in dispatch_line(self, frame)
    111         if self.stop_here(frame) or self.break_here(frame):
    112             self.user_line(frame)
--> 113             if self.quitting: raise BdbQuit
    114         return self.trace_dispatch
    115 

BdbQuit: 
learn.model.learner

Runner

class Runner[source]

Runner(learner, cbs=None)

All encompossing class to train a model with specific callbacks

def do_one_batch(self, xb, yb):
        "Applies forward and backward passes of model to one batch"
        self.xb, self.yb = xb, yb

        self.pred = self.learner.model(xb)
        self.loss = self.learner.loss_func(self.pred, yb)
        if self.check_callbacks('after_loss') or not self.learner.model.training: return

        self.learner.loss_func.backward()
        if self.check_callbacks('after_loss_back'): return

        self.learner.model.backward()
        if self.check_callbacks('after_model_back'): return

        self.opt.step()
        if self.check_callbacks('after_opt'): return

        self.opt.zero_grad()
        if self.check_callbacks('after_zero_grad'): return

    def do_all_batches(self, dl):
        "Runs every batch of a dataloader through `do_one_batch`"
        self.iters, self.iters_done = len(dl), 0
        for xb, yb in dl:
            if self.stop: break
            if self.check_callbacks('before_batch'): return
            self.do_one_batch(xb,yb)
            if self.check_callbacks('after_batch'): return
        self.iters = 0

        self.stop = False

    def fit(self, epochs, lr=0.1):
        "Method to fit the model `epoch` times using learning rate `lr`"
        self.lr, self.epochs = lr, epochs
        if self.check_callbacks('before_fit'): return

        for epoch in range(epochs):
            self.epoch = epoch
            if self.check_callbacks('before_epoch'): return
            if not self.check_callbacks('before_train'): self.do_all_batches(self.learner.db.train)
            if not self.check_callbacks('before_valid'): self.do_all_batches(self.learner.db.valid)
            if self.check_callbacks('after_epoch'): break

        if self.check_callbacks('after_fit'): return

    def check_callbacks(self, state):
        "Helper functions to run through each callback, calling it's state method if applicable"
        for cb in sorted(self.cbs, key=lambda x: x._order):
            f = getattr(cb, state, None)
            if f and f(): return True
        return False

get_runner[source]

get_runner(callbacks=None)

Runner.__init__[source]

Runner.__init__(learner, cbs=None)

Initialize self. See help(type(self)) for accurate signature.

Runner.do_one_batch[source]

Runner.do_one_batch(xb, yb)

Applies forward and backward passes of model to one batch

Runner.do_all_batches[source]

Runner.do_all_batches(dl)

Runs every batch of a dataloader through do_one_batch

Runner.fit[source]

Runner.fit(epochs, lr=0.1)

Method to fit the model epoch times using learning rate lr

Runner.check_callbacks[source]

Runner.check_callbacks(state)

Helper functions to run through each callback, calling it's state method if applicable

Callbacks

Base Class

class Callback[source]

Callback()

Base class for callbacks, defines order of execution and allows abstraction of self to runner class

Built in Callback

class TrainEvalCallback[source]

TrainEvalCallback() :: Callback

Keeps track of training/eval mode of model and progress through training

Metric Tracking

class Stat[source]

Stat(calc)

Defines a metric to keep track of through training, metric calculated using calc

class StatTracker[source]

StatTracker(metrics, in_train)

Class to implement thet Stats callback using metrics of class Stat

class Stats[source]

Stats(metrics) :: Callback

Callback to keep track of metrics

Testing the Stats callback:

run = Runner(learn, [Stats([accuracy])])
run.fit(5, 0.1)
Epoch: 1
Train: Loss: 2.3827933546272106e-06, Accuracy: 0.9668400287628174
Valid: Loss: 0.00014949109754525125, Accuracy: 0.9701799750328064
Epoch: 2
Train: Loss: 2.326479034309159e-06, Accuracy: 0.9716399908065796
Valid: Loss: 7.430403638863936e-05, Accuracy: 0.9783200025558472
Epoch: 3
Train: Loss: 4.621696461981628e-06, Accuracy: 0.9759799838066101
Valid: Loss: 0.00013563691754825413, Accuracy: 0.9827600121498108
Epoch: 4
Train: Loss: 2.2784424800192937e-06, Accuracy: 0.9796000123023987
Valid: Loss: 8.738860196899623e-05, Accuracy: 0.985040009021759
Epoch: 5
Train: Loss: 7.541702507296577e-05, Accuracy: 0.9826599955558777
Valid: Loss: 0.0008652143878862262, Accuracy: 0.94514000415802

Parameter Scheduling

class Optimizer[source]

Optimizer(params, lr)

class Scheduler[source]

Scheduler(param, scheduler) :: Callback

Class to schedule the hyperparameters of the model

annealer[source]

annealer(f)

Allows easy implementation of different scheduling functions through inner functions

annealer.._inner[source]

annealer.._inner(start, end)

annealer.._inner[source]

annealer.._inner(start, end)

def annealer(f):
    "Allows easy implementation of different scheduling functions through inner functions"
    def _inner(start, end): return partial(f, start, end)
    return _inner

@annealer
def linear_scheduler(start, end, pos): 
    "Schedule linearly from start to end"
    return pos*(end-start) + start

@annealer
def cos_scheduler(start, end, pos): 
    "Schedule using a cosine function"
    return start + (1 + math.cos(math.pi*(1-pos))) * (end-start) / 2

combine_scheds[source]

combine_scheds(pcts, scheds)

Combine multiple different schedules in each epoch

class HyperRecorder[source]

HyperRecorder(params) :: Callback

Callback to keep track of and visualize hyperparameter values and losses in the network

class Learner[source]

Learner(model, loss_func, optimizer, db, lr=0.5)

Testing HyperRecorder and Scheduler

start_lr = 0.01
model, _, loss_func= get_linear_model(start_lr)
db = get_mnist_databunch()
learn = Learner(model, loss_func, Optimizer, db, start_lr)
schedule = combine_scheds([0.4, 0.6], [cos_scheduler(0.01,0.1), cos_scheduler(0.1,0.01)])
run = Runner(learn, [Scheduler('lr', schedule), HyperRecorder(['lr'])])
run.fit(2)
run.cbs[2].plot_loss()
run.cbs[2].plot_param('lr')

Progress Bar

Reformated Stats

accuracy[source]

accuracy(preds, targ)

Compute accuracy of preds with respect to targ

loss[source]

loss(preds, targ, loss, **kwargs)

Stat for loss

class Stat[source]

Stat(calc)

Defines a metric to keep track of through training, metric calculated using calc

class StatTracker[source]

StatTracker(metrics, in_train)

Class to implement thet Stats callback using metrics of class Stat

class Stats[source]

Stats(metrics) :: Callback

Callback to keep track of metrics

class ProgressCallback[source]

ProgressCallback() :: Callback

Callback to make a nice progress bar with metrics for training. Slightly modified version of: https://github.com/fastai/course-v3/blob/master/nbs/dl2/09c_add_progress_bar.ipynb

Testing out the progress bar:

run = get_runner([Stats([accuracy]), ProgressCallback()])
run.fit(3, 0.2)
epoch train_loss train_accuracy valid_loss valid_accuracy time
0 0.065741 0.979400 0.072409 0.977340 00:01
1 0.056172 0.982340 0.056556 0.981700 00:01
2 0.048253 0.984220 0.083364 0.973420 00:01