AdaScale SGD¶
- class fairscale.optim.AdaScale(optimizer: torch.optim.optimizer.Optimizer, world_size: Optional[int] = None, scale: Optional[float] = None, smoothing: Optional[float] = None, num_gradients_to_accumulate: int = 1, is_scaled_loss: bool = True, debias_ewma: bool = True)[source]¶
Implements the AdaScale algorithm for scaling the learning rate for distributed and large batch size training. Can be used in combination with
torch.nn.parallel.DistributedDataParallel
andtorch.optim.SGD
.This class subclasses Optimizer so that torch.optim.lr_scheduler can work with it. In other words, AdaScale is intended to be a complete wrapper of an torch Optimizer.
Note that, AdaScale does not help increase per-GPU batch size.
There are several ways to integrate AdaScale with your training loop. We show two examples below.
Example 1: using PyTorch’s lr_scheduler classes.
optim = AdaScale(SGD(model.parameters(), lr=0.001)) model = DistributedDataParallel(model) scheduler = LambdaLR(optim, lr_lambda=...) last_epoch = 0 done = False step = 0 while not done: for batch in dataset: optim.zero_grad() logits = model() loss = criterion(logits, ...) loss.backward() step += optim.gain() optim.step() epoch = step // len(dataset) if epoch > last_epoch: scheduler.step() last_epoch = epoch if epoch >= MAX_EPOCHS: done = True
Example 2: using a custom update_lr() function that update the learning rate based on the current step count per epoch.
optim = AdaScale(SGD(model.parameters(), lr=0.001)) model = DistributedDataParallel(model) step = 0 while step < max_steps: for batch in ...: optim.zero_grad() logits = model() loss = criterion() loss.backward() step += optim.gain() optim.step() update_lr(step)
- Parameters
optimizer (torch.optim.Optimizer) – Optimizer to apply AdaScale to.
world_size (int) – Number of world_size for distributed training. If None, defaults to
dist.get_world_size()
.scale (float) – Scaling factor of the batch size from scale equals 1, e.g. using a 10x larger batch size (summed across all ranks with gradient accumulation) means a scale of 10. If None, defaults to
world_size * num_gradients_to_accumulate
.smoothing (float) – Smoothing factor for moving average. If None, it defaults to
max(1 - (world_size * num_gradients_to_accumulate)/1000, 0)
. Note, for very high scale training, higher smoothing value might be needed, esp at the begining of the training. Therefore, if your scale is close to or larger than 1000, try experimenting with smoothing value > 0 if the final accuracy is poor.num_gradients_to_accumulate (int) – Number of passes that we accumulate gradients locally between each optimizer step. This can be changed during training as long as the train loop changes gradient accumulation accordingly. The loss in each pass can be either scaled or unscaled. See is_scaled_loss below. Default to 1, which does not accumulate gradients.
is_scaled_loss (bool) – If True, assume that the loss is scaled by num_gradients_to_accumulate. If False, the loss is not scaled. Default: True.
debias_ewma (bool) – (experimental) Use debias exponential moving average for smoothing and mu and sigma variables. False will use the method in the paper’s Appendix B.3. Default: True, which is what have been validated so far.
- __del__() None [source]¶
Unhook in case caller forgets to call unhook.
This however may not “work” since there would be circular reference between the hook objects and this objects. In that case, neither will get GC’ed. Calling unhook explicitly if you really want to delete AdaScale from memory.
- unhook() None [source]¶
Unregister hook handles.
This is public because caller may need to call this to ensure all GPU memory are released. Otherwise, the hook may prevent parameters from being released from the GPU memory pool.
Internally, we use this to support
add_param_group()
API.
- property scale: float¶
The scaling factor of the current batch size, relative to the baseline batch size, which could be a DDP training. For example, if the baseline batch size is 32 on 2 GPUs, but using a scaled-up batch size of 80 on 4 GPUs, then then the scaling factor is 80 * 4 / 32 / 2 = 5.
This is exposed API mainly for logging purpose. Note, this is different from
self.gain()
.- Returns
(float) – The current scaling factor.
- property smoothing: float¶
The smoothing constant used in exponentially-weighted moving average tracking the gradient norm mean and variance within AdaScale.
This is exposed API since the value is computed and caller may want to obtain this value and log it.
- Returns
(float) – The current smoothing value.
- set_scale(scale: float, update_estimate: bool = True) None [source]¶
Set the scaling factor of the current batch size. It is up to the application to invoke this function to make sure that AdaScale’s scaling factor matches the actual batch size used during training.
- gain(pg_idx: Optional[int] = None) float [source]¶
Current estimate of the AdaScale gain ratio (r_t in the paper).
- Parameters
pg_idx (int) – Optional index of a parameter group. Default None: returns “averaged” gain for all groups.
- Returns
(float) – Estimate of gain ratio.
- step(*args: Any, **kwargs: Any) Optional[float] [source]¶
Run one optimizer step using Adascale. Essentially just invokes
optimizer.step(*args, **kwargs)
with a scaled learning rate.Note
It is possible that this function becames a performance bottleneck if you have frequent updates. To avoid that, making bigger steps and reducing update frequency is generally better for performance.
- Parameters
args (Any) – Positional arguments passed to
optimizer.step
.kwargs (Any) – Keyword arguments passed to
optimizer.step
.
- Returns
(Tensor) – The loss tensor if a closure if used to re-evaluate the model.
- add_param_group(pg: Dict) None [source]¶
Support adding parameter groups
We need to re-size some of the state and re-register the backward hooks.
- state_dict() Dict [source]¶
Proxy function to optimizer, checkpointing needs this.
Note
Do NOT checkpoint in the middle of gradient accumulation since associated AdaScale internal states are not saved in the checkpoint.
- load_state_dict(data: Dict) None [source]¶
Proxy function to optimizer, checkpointing needs this.
Note
Do NOT checkpoint in the middle of gradient accumulation since associated AdaScale internal states are not saved in the checkpoint.
- set_num_gradients_to_accumulate(num_gradients_to_accumulate: int, update_smoothing: bool = True) None [source]¶
Set the number of gradients to accumulate to a new value.
This is experimental. This could be called while training so that we can gradually increasing the steps between updates. Almost always, set_scale needs to be called to update the scale as well.
TODO (min): need a way of determine how much to increase the step size?
TODO (min): have both set_scale and set_num_gradients_to_accumulate is hard to use and easy to make mistake. I think it is better to specific a specify a base_scale. But more discussion is needed here.