API

class adadamp.AdaDamp(*args, approx_loss=False, **kwargs)
damping() → int

Adaptively damp the noise depending on the current loss with

\[B_k = \left\lceil B_0\frac{F(x_0) - F^\star}{F(x_k) - F^\star}\right\rceil\]

Warning

This batch size is expensive to compute. It requires evaluating the entire loss function \(F\). Use of PadaDamp is recommended.

class adadamp.BaseDamper(model: torch.nn.modules.module.Module, dataset: torch.utils.data.dataset.Dataset, opt: torch.optim.optimizer.Optimizer, loss: Callable = <function nll_loss>, initial_batch_size: int = 1, device: str = 'cpu', max_batch_size: Optional[int] = None, best_train_loss: Optional[float] = None, random_state: Optional[int] = None, dwell: int = 20, **kwargs)

Damp the noise in the gradient estimate.

Parameters
  • model (nn.Module) – The model to train

  • dataset (torch.Dataset) – Dataset to use for training

  • opt (torch.optim.Optimizer) – The optimizer to use

  • loss (callable (function), default=torch.nn.F.nll_loss) – The loss function to use. Must support the reduction keyword. Signature: loss(output, target, reduction="sum").

  • initial_batch_size (int, default=1) – Initial batch size

  • device (str, default="cpu") – The device to use.

  • max_batch_size (int, float, None, default=None) – The maximum batch size. If the batch size is larger than this value, the learning rate is decayed by an appropriate amount. If None, will automatically be set to be the size of the dataset. Setting to NaN will result in no maximum batch size.

  • dwell (int, default=20) – How many model updates should the batch size be held constant? This is similar to the “relaxation time” parameter in simulated annealing. Setting dwell=1 will mean the batch size will be evaluated for every model update.

  • random_state (int, optional) – The random state the samples are selected in.

Notes

By default, this class does not perform any damping (but it’s children do). If a function needs an instance of BaseDamper, this class can wrap any optimizer.

damping() → int

Determines how strongly noise in stochastic gradient estimate is damped.

Notes

This is the main function for subclasses to overwrite. By default, this wraps an optimizer with a static self.initial_batch_size. Here’s a brief example usage:

>>> dataset = datasets.MNIST(...)
>>> model = Net()
>>> opt = optim.AdaGrad(model.parameters())
>>> opt = BaseDamper(model, dataset, opt, initial_batch_size=32)
>>> opt.damping()
32
get_params() → Dict[str, Any]

Get parameters for this optimzer.

property meta

Get meta information about this optimizer, including number of model updates and number of examples processed.

step(**kwargs)

Perform an optimization step

Parameters

kwargs (Dict[str, Any], optional) – Arguments to pass to PyTorch’s opt.step (e.g., torch.optim.AdaGrad)

class adadamp.CntsDampLR(*args, dampingfactor=0.02, **kwargs)
damping() → int

Decay the learning rate by \(1/k\) after \(k\) model updates.

class adadamp.GeoDamp(*args, dampingdelay=5, dampingfactor=2, **kwargs)
damping() → int

Set the batch size to increase by dampingfactor every dampingdelay epochs.

class adadamp.GeoDampLR(*args, **kwargs)
damping() → int

Set the learning rate to decrease by dampingfactor every dampingdelay epochs.

class adadamp.GradientDescent(*args, **kwargs)

This class performs full gradient descent.

damping() → int
class adadamp.PadaDamp(*args, batch_growth_rate=None, **kwargs)
Parameters
  • args (list) – Passed to BaseDamper

  • batch_growth_rate (float) –

    The rate to increase the damping by. That is, set the batch size to be

    \[B_k = B_0 \lceil \textrm{rate}\cdot k \rceil\]

    after the model is updated \(k\) times.

  • kwargs (dict) – Passed to BaseDamper

Notes

The number of epochs is

\[uB_0 + \sum_{i=1}^u \lceil \textrm{rate} \cdot k\rceil\]

for \(u\) model updates.

Note

This class is only appropriate for non-convex and convex loss functions. It is not appropriate for strongly convex loss or PL functions.

damping() → int

Approximate AdaDamp with less computation via

\[B_k = B_0 + \lceil \textrm{rate}\cdot k\rceil\]

where k is the number of model updates.