Sharded Data Parallel¶
- class fairscale.nn.ShardedDataParallel(module: torch.nn.modules.module.Module, sharded_optimizer: Union[fairscale.optim.oss.OSS, List[fairscale.optim.oss.OSS]], process_group: Optional[Any] = None, broadcast_buffers: bool = True, sync_models_at_startup: bool = True, reduce_buffer_size: int = 8388608, auto_refresh_trainable: bool = True, reduce_fp16: bool = False, warn_on_trainable_params_changed: bool = True)¶
Wrap the model, and reduce the gradients to the right rank during the backward pass.
the partition is given by the sharded optimizer
wrap the base model with a model which knows where to reduce each gradient
add an autograd function which calls the model grad dispatch on the way back
- module (nn.Module):
model to be wrapped
- sharded_optimizer (OSS, or list of OSS):
the sharded optimizer(s) which will decide the gradient partitioning
- Keyword Arguments
process_group (group) – torch.distributed group (default: group.WORLD)
broadcast_buffers (bool) – Whether to additionally broadcast model buffers in between ranks at the beginning of each forward pass. Same setting as in Pytorch DDP, this is in addition to the broadcast and reduction of the model parameters.
sync_models_at_startup (bool) – Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed, or the training restarts from a saved state
reduce_buffer_size (int) – The max size of the buffer used to batch the small parameter tensors, in number of elements (default 0 - unused). this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded. Set to 0 to remove all bucketing, 1M to 8M is usually reasonable.
auto_refresh_trainable (bool) – (default: True) Check whether the parameters trainability (requires_grad) has changed and update both ShardedDDP and OSS automatically if this is the case. If set to False, refresh_trainable() needs to be called anytime a parameter is frozen or unfrozen.
reduce_fp16 (bool) – cast the grads to fp16 before reducing. Not needed if the model is already fp16, but will probably improve performance for multi node jobs using PyTorch AMP. The effect is similar to DDP’s fp16_compress_hook and will also save some memory.
warn_on_trainable_params_changed (bool) – When set to False no warning will be logged whenever a parameter trainability change has been detected. Default is True.
- forward(*inputs: Any, **kwargs: Any) → Any¶
Module forward pass, handles any DDP-specific work in the background. Primes the backward pass for gradient reduction to the proper ranks.
- to(device: Optional[torch.device], dtype: Optional[torch.dtype] = None, non_blocking: bool = False) → fairscale.nn.data_parallel.sharded_ddp.ShardedDataParallel¶
Moves and/or casts the parameters and buffers.
Its signature is similar to
torch.Tensor.to(), but only accepts floating point desired
dtypes. In addition, this method will only cast the floating point parameters and buffers to
dtype(if given). The integral parameters and buffers will be moved
device, if that is given, but with dtypes unchanged. When
non_blockingis set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.
This method modifies the module in-place.
torch.device) – the desired device of the parameters and buffers in this module.
torch.dtype) – the desired floating point type of the floating point parameters and buffers.
non_blocking (bool) – make it an asynchronous call.
Module – self.
- reduce() → None¶
This does not need to be called, the gradient reduction is done automatically during the BW pass. Use this method to reduce the gradients manually
- sync_buffers(blocking: bool = False) → None¶
Sync all the param buffers in between ranks (including for instance batch norm statistics).
blocking (bool) – wait for the operation to conclude.
- zero_grad(set_to_none: bool = False) → None¶
Sets gradients of all model parameters to zero. See similar function under
torch.optim.Optimizerfor more context.
- no_sync() → Generator¶
A context manager to disable gradient synchronization.