Implementing an advanced training loop to take advantage of callbacks. Defines new wrapper classes such as databunch and learner.
db = get_mnist_databunch()
db
model,_,_ = get_model(0.1, [Linear(784, 50, True), ReLU(), Linear(50, 10, False)])
learn = Learner(model, 0.1, Optimizer, db)
learn
learn.model.learner
def do_one_batch(self, xb, yb):
"Applies forward and backward passes of model to one batch"
self.xb, self.yb = xb, yb
self.pred = self.learner.model(xb)
self.loss = self.learner.loss_func(self.pred, yb)
if self.check_callbacks('after_loss') or not self.learner.model.training: return
self.learner.loss_func.backward()
if self.check_callbacks('after_loss_back'): return
self.learner.model.backward()
if self.check_callbacks('after_model_back'): return
self.opt.step()
if self.check_callbacks('after_opt'): return
self.opt.zero_grad()
if self.check_callbacks('after_zero_grad'): return
def do_all_batches(self, dl):
"Runs every batch of a dataloader through `do_one_batch`"
self.iters, self.iters_done = len(dl), 0
for xb, yb in dl:
if self.stop: break
if self.check_callbacks('before_batch'): return
self.do_one_batch(xb,yb)
if self.check_callbacks('after_batch'): return
self.iters = 0
self.stop = False
def fit(self, epochs, lr=0.1):
"Method to fit the model `epoch` times using learning rate `lr`"
self.lr, self.epochs = lr, epochs
if self.check_callbacks('before_fit'): return
for epoch in range(epochs):
self.epoch = epoch
if self.check_callbacks('before_epoch'): return
if not self.check_callbacks('before_train'): self.do_all_batches(self.learner.db.train)
if not self.check_callbacks('before_valid'): self.do_all_batches(self.learner.db.valid)
if self.check_callbacks('after_epoch'): break
if self.check_callbacks('after_fit'): return
def check_callbacks(self, state):
"Helper functions to run through each callback, calling it's state method if applicable"
for cb in sorted(self.cbs, key=lambda x: x._order):
f = getattr(cb, state, None)
if f and f(): return True
return False
Testing the Stats
callback:
run = Runner(learn, [Stats([accuracy])])
run.fit(5, 0.1)
def annealer(f):
"Allows easy implementation of different scheduling functions through inner functions"
def _inner(start, end): return partial(f, start, end)
return _inner
@annealer
def linear_scheduler(start, end, pos):
"Schedule linearly from start to end"
return pos*(end-start) + start
@annealer
def cos_scheduler(start, end, pos):
"Schedule using a cosine function"
return start + (1 + math.cos(math.pi*(1-pos))) * (end-start) / 2
Testing HyperRecorder
and Scheduler
start_lr = 0.01
model, _, loss_func= get_linear_model(start_lr)
db = get_mnist_databunch()
learn = Learner(model, loss_func, Optimizer, db, start_lr)
schedule = combine_scheds([0.4, 0.6], [cos_scheduler(0.01,0.1), cos_scheduler(0.1,0.01)])
run = Runner(learn, [Scheduler('lr', schedule), HyperRecorder(['lr'])])
run.fit(2)
run.cbs[2].plot_loss()
run.cbs[2].plot_param('lr')
Testing out the progress bar:
run = get_runner([Stats([accuracy]), ProgressCallback()])
run.fit(3, 0.2)