AdaScale SGD

class fairscale.optim.AdaScale(optimizer: torch.optim.optimizer.Optimizer, world_size: Optional[int] = None, scale: Optional[float] = None, smoothing: Optional[float] = None, num_gradients_to_accumulate: int = 1, debias_ewma: bool = True)[source]

Implements the AdaScale algorithm for scaling the learning rate for distributed and large batch size training. Can be used in combination with torch.nn.parallel.DistributedDataParallel and torch.optim.SGD.

This class subclasses Optimizer so that torch.optim.lr_scheduler can work with it. In other words, AdaScale is intended to be a complete wrapper of an torch Optimizer.

Note that, AdaScale does not help increase per-GPU batch size.

There are several ways to integrate AdaScale with your training loop. We show two examples below.

Example 1: using PyTorch’s lr_scheduler classes.

optim = AdaScale(SGD(model.parameters(), lr=0.001))
model = DistributedDataParallel(model)
scheduler = LambdaLR(optim, lr_lambda=...)

last_epoch = 0
done = False
step = 0
while not done:
    for batch in dataset:
        logits = model()
        loss = criterion(logits, ...)
        step += optim.gain()
        epoch = step // len(dataset)
        if epoch > last_epoch:
            last_epoch = epoch
        if epoch >= MAX_EPOCHS:
            done = True

Example 2: using a custom update_lr() function that update the learning rate based on the current step count per epoch.

optim = AdaScale(SGD(model.parameters(), lr=0.001))
model = DistributedDataParallel(model)

step = 0
while step < max_steps:
    for batch in ...:
        logits = model()
        loss = criterion()
        step += optim.gain()
  • optimizer (torch.optim.Optimizer) – Optimizer to apply AdaScale to.

  • world_size (int) – Number of world_size for distributed training. If None, defaults to dist.get_world_size().

  • scale (float) – Scaling factor of the batch size from scale equals 1, e.g. using a 10x larger batch size (summed across all ranks with gradient accumulation) means a scale of 10. If None, defaults to world_size * num_gradients_to_accumulate.

  • smoothing (float) – Smoothing factor for moving average. If None, it defaults to max(1 - (world_size * num_gradients_to_accumulate)/1000, 0). Note, for very high scale training, higher smoothing value might be needed, esp at the begining of the training. Therefore, if your scale is close to or larger than 1000, try experimenting with smoothing value > 0 if the final accuracy is poor.

  • num_gradients_to_accumulate (int) – Number of passes that we accumulate gradients locally between each optimizer step. This can be changed during training as long as the train loop changes gradient accumulation accordingly. Default to 1, which does not accumulate gradients.

  • debias_ewma (bool) – (experimental) Use debias exponential moving average for smoothing and mu and sigma variables. False will use the method in the paper’s Appendix B.3. Default: True, which is what have been validated so far.

__del__() None[source]

Unhook in case caller forgets to call unhook.

This however may not “work” since there would be circular reference between the hook objects and this objects. In that case, neither will get GC’ed. Calling unhook explicitly if you really want to delete AdaScale from memory.

unhook() None[source]

Unregister hook handles.

This is public because caller may need to call this to ensure all GPU memory are released. Otherwise, the hook may prevent parameters from being released from the GPU memory pool.

Internally, we use this to support add_param_group() API.

property scale: float

The scaling factor of the current batch size, relative to the baseline batch size, which could be a DDP training. For example, if the baseline batch size is 32 on 2 GPUs, but using a scaled-up batch size of 80 on 4 GPUs, then then the scaling factor is 80 * 4 / 32 / 2 = 5.

This is exposed API mainly for logging purpose. Note, this is different from self.gain().


(float) – The current scaling factor.

property smoothing: float

The smoothing constant used in exponentially-weighted moving average tracking the gradient norm mean and variance within AdaScale.

This is exposed API since the value is computed and caller may want to obtain this value and log it.


(float) – The current smoothing value.

set_scale(scale: float, update_estimate: bool = True) None[source]

Set the scaling factor of the current batch size. It is up to the application to invoke this function to make sure that AdaScale’s scaling factor matches the actual batch size used during training.

  • scale (float) – New scaling factor to be applied to AdaScale.

  • update_estimate (bool) – Whether to update the scale-depenent estimate of gradient variance; this is highly recommended. (default: True)

gain(pg_idx: Optional[int] = None) float[source]

Current estimate of the AdaScale gain ratio (r_t in the paper).


pg_idx (int) – Optional index of a parameter group. Default None: returns “averaged” gain for all groups.


(float) – Estimate of gain ratio.

step(*args: Any, **kwargs: Any) Optional[float][source]

Run one optimizer step using Adascale. Essentially just invokes optimizer.step(*args, **kwargs) with a scaled learning rate.


It is possible that this function becames a performance bottleneck if you have frequent updates. To avoid that, making bigger steps and reducing update frequency is generally better for performance.

  • args (Any) – Positional arguments passed to optimizer.step.

  • kwargs (Any) – Keyword arguments passed to optimizer.step.


(Tensor) – The loss tensor if a closure if used to re-evaluate the model.

add_param_group(pg: Dict) None[source]

Support adding parameter groups

We need to re-size some of the state and re-register the backward hooks.

zero_grad() None[source]

Proxy function to optimizer, because some training loops need this.

state_dict() Dict[source]

Proxy function to optimizer, checkpointing needs this.


Do NOT checkpoint in the middle of gradient accumulation since associated AdaScale internal states are not saved in the checkpoint.

load_state_dict(data: Dict) None[source]

Proxy function to optimizer, checkpointing needs this.


Do NOT checkpoint in the middle of gradient accumulation since associated AdaScale internal states are not saved in the checkpoint.

set_num_gradients_to_accumulate(num_gradients_to_accumulate: int, update_smoothing: bool = True) None[source]

Set the number of gradients to accumulate to a new value.

This is experimental. This could be called while training so that we can gradually increasing the steps between updates. Almost always, set_scale needs to be called to update the scale as well.

TODO (min): need a way of determine how much to increase the step size?

TODO (min): have both set_scale and set_num_gradients_to_accumulate is hard to use and easy to make mistake. I think it is better to specific a specify a base_scale. But more discussion is needed here.

  • num_gradients_to_accumulate (int) – Number of gradients to accumulate (calls to backward) between each optimizer step

  • update_smoothing (bool) – Whether to update smoothing factor or not. Default: True.

__getattr__(name: str) Any[source]

Forward missing attributes to wrapped optimizer.

Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.