Fundamental operations and helper functions in use throughout the library
   
    
    
    
    
   
Helper Functions¶
Testing¶
test_near(torch.ones(2,2), torch.zeros(2,2))
test_near(torch.ones(2,2), torch.ones(2,2))
test_near_zero(torch.zeros(2,2).mean())
Data Loading¶
def get_mnist():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    xt,yt,xv,yv = map(tensor, (x_train, y_train, x_valid, y_valid))
    return normalize(xt).float(), yt.float(), normalize(xv, xt.mean(), xt.std()).float(), yv.float()
xt,yt,xv,yv = get_mnist()
get_stats(xt)
Visualization¶
show_im(x_train)
Matmul¶
def matmul(a, b):
    ar, ac = a.shape
    br, bc = b.shape
    assert (ac == br)
    c = torch.zeros(ar,bc)
    for ar_in in range(ar):
        c[ar_in] += (a[ar_in].unsqueeze(-1) * b).sum(dim=0)
    return c
%time result = matmul(ims,test_weight)
test_near(result,ground_truth)