Basic Usage¶
Let’s make small modifications to PyTorch’s MNIST example. This example usage will be trimmed for brevity. The complete script can be found on GitHub at test_mnist.py.
First, let’s start by creating the model.
import torch
import torch.nn as nn
from torchvision import datasets
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
# etc; unchanged from example
def forward(self, x):
# etc; unchanged from example
model = Net()
This is a standard PyTorch model definition. Now, let’s create the dataset:
train_set = datasets.MNIST(...) # internals untouched
test_set = datasets.MNIST(...) # internals untouched
Notice a DataLoader
object is not created (as with
PyTorch), only the dataset that’s passed to
DataLoader
.
Now, let’s create the loss function to optimize, alongside the optimizer:
import torch.optim as optim
import torch.nn.functional as F
_optimizer = optim.SGD(model.parameters(), lr=args.lr)
loss = F.nll_loss
Again, the optimizer is a pretty standard definition. Use of SGD is recommended and backed by mathematics. Use of momentum and Nesterov acceleration is likely beneficial even though there’s no backing mathematics.
Let’s use PadaDamp
, which will grow the batch size and is an
approximation to the firmly grounded but impractical AdaDamp
:
from adadamp import PadaDamp
optimizer = PadaDamp(
model=model,
dataset=train_set,
opt=_optimizer,
loss=loss,
device="cpu",
batch_growth_rate=0.01,
initial_batch_size=32,
max_batch_size=1024,
)
This optimizer
is a drop in replacement for any of PyTorch’s optimizers
like torch.optim.SGD
or torch.optim.Adagrad
. This means that
we can use it in our (custom) training functions by calling
optimizer.step()
.
However, it might be easier to use the built-in train/test functions:
from adadamp.experiment import train, test
for epoch in range(1, args.epochs + 1):
train(model=model, opt=optimizer)
data = test(model=model, loss=loss, dataset=test_set)
print(data)
These train
and test
functions are small modifications from the
functions in PyTorch’s MNIST example.