Fundamental operations and helper functions in use throughout the library
Helper Functions¶
Testing¶
test_near(torch.ones(2,2), torch.zeros(2,2))
test_near(torch.ones(2,2), torch.ones(2,2))
test_near_zero(torch.zeros(2,2).mean())
Data Loading¶
def get_mnist():
path = datasets.download_data(MNIST_URL, ext='.gz')
with gzip.open(path, 'rb') as f:
((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
xt,yt,xv,yv = map(tensor, (x_train, y_train, x_valid, y_valid))
return normalize(xt).float(), yt.float(), normalize(xv, xt.mean(), xt.std()).float(), yv.float()
xt,yt,xv,yv = get_mnist()
get_stats(xt)
Visualization¶
show_im(x_train)
Matmul¶
def matmul(a, b):
ar, ac = a.shape
br, bc = b.shape
assert (ac == br)
c = torch.zeros(ar,bc)
for ar_in in range(ar):
c[ar_in] += (a[ar_in].unsqueeze(-1) * b).sum(dim=0)
return c
%time result = matmul(ims,test_weight)
test_near(result,ground_truth)