# Basic Usage¶

Let’s make small modifications to PyTorch’s MNIST example. This example usage will be trimmed for brevity. The complete script can be found on GitHub at test_mnist.py.

First, let’s start by creating the model.

import torch
import torch.nn as nn
from torchvision import datasets

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
# etc; unchanged from example

def forward(self, x):
# etc; unchanged from example

model = Net()


This is a standard PyTorch model definition. Now, let’s create the dataset:

train_set = datasets.MNIST(...)  # internals untouched
test_set = datasets.MNIST(...)  # internals untouched


Notice a DataLoader object is not created (as with PyTorch), only the dataset that’s passed to DataLoader.

Now, let’s create the loss function to optimize, alongside the optimizer:

import torch.optim as optim
import torch.nn.functional as F

_optimizer = optim.SGD(model.parameters(), lr=args.lr)
loss = F.nll_loss


Again, the optimizer is a pretty standard definition. Use of SGD is recommended and backed by mathematics. Use of momentum and Nesterov acceleration is likely beneficial even though there’s no backing mathematics.

Let’s use PadaDamp, which will grow the batch size and is an approximation to the firmly grounded but impractical AdaDamp:

from adadamp import PadaDamp

model=model,
dataset=train_set,
opt=_optimizer,
loss=loss,
device="cpu",
batch_growth_rate=0.01,
initial_batch_size=32,
max_batch_size=1024,
)


This optimizer is a drop in replacement for any of PyTorch’s optimizers like torch.optim.SGD or torch.optim.Adagrad. This means that we can use it in our (custom) training functions by calling optimizer.step().

However, it might be easier to use the built-in train/test functions:

from adadamp.experiment import train, test

for epoch in range(1, args.epochs + 1):
train(model=model, opt=optimizer)
data = test(model=model, loss=loss, dataset=test_set)
print(data)


These train and test functions are small modifications from the functions in PyTorch’s MNIST example.