Fundamental operations and helper functions in use throughout the library

Helper Functions

Testing

is_equal[source]

is_equal(a, b)

Test for equality between a and b

near[source]

near(a, b)

Test if tensors a and b are the same within a small tolerance

test_near[source]

test_near(a, b)

Test if tensors a and b are near within a small tolerance

test_near_zero[source]

test_near_zero(data, tol=0.001)

Tests if tensor values are near zero under given tol

test_near(torch.ones(2,2), torch.zeros(2,2))
not near
test_near(torch.ones(2,2), torch.ones(2,2))
good
test_near_zero(torch.zeros(2,2).mean())
Near zero: 0.0

Data Loading

normalize[source]

normalize(datasets, mean=None, std=None)

Normalizes according to given 'mean' and 'std' or mean of std of datasets if none given

get_mnist[source]

get_mnist()

Helper function to load normalized train and validation MNIST datasets

def get_mnist():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')

    xt,yt,xv,yv = map(tensor, (x_train, y_train, x_valid, y_valid))
    return normalize(xt).float(), yt.float(), normalize(xv, xt.mean(), xt.std()).float(), yv.float()
xt,yt,xv,yv = get_mnist()

get_stats[source]

get_stats(data)

Print mean and standard deviation of given data

get_stats(xt)
Mean: 0.00012300178059376776
Std: 1.0

Visualization

show_im[source]

show_im(image, size=28)

Displays 'image' or random 'image' from set if multiple given of given 'size'

show_im(x_train)

Matmul

matmul[source]

matmul(a, b)

Perform matrix multiplication on a and b

def matmul(a, b):
    ar, ac = a.shape
    br, bc = b.shape
    assert (ac == br)
    c = torch.zeros(ar,bc)
    for ar_in in range(ar):
        c[ar_in] += (a[ar_in].unsqueeze(-1) * b).sum(dim=0)

    return c
%time result = matmul(ims,test_weight)
CPU times: user 506 µs, sys: 337 µs, total: 843 µs
Wall time: 487 µs
test_near(result,ground_truth)
good