Sharded Data Parallel

class fairscale.nn.ShardedDataParallel(module: torch.nn.modules.module.Module, sharded_optimizer: Union[fairscale.optim.oss.OSS, List[fairscale.optim.oss.OSS]], process_group: Optional[Any] = None, broadcast_buffers: bool = True, sync_models_at_startup: bool = True, reduce_buffer_size: int = 8388608, auto_refresh_trainable: bool = True, reduce_fp16: bool = False, warn_on_trainable_params_changed: bool = True)[source]

Wrap the model, and reduce the gradients to the right rank during the backward pass.

  • the partition is given by the sharded optimizer

  • wrap the base model with a model which knows where to reduce each gradient

  • add an autograd function which calls the model grad dispatch on the way back

module (nn.Module):

model to be wrapped

sharded_optimizer (OSS, or list of OSS):

the sharded optimizer(s) which will decide the gradient partitioning

Keyword Arguments
  • process_group (group) – torch.distributed group (default: group.WORLD)

  • broadcast_buffers (bool) – Whether to additionally broadcast model buffers in between ranks at the beginning of each forward pass. Same setting as in Pytorch DDP, this is in addition to the broadcast and reduction of the model parameters.

  • sync_models_at_startup (bool) – Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed, or the training restarts from a saved state

  • reduce_buffer_size (int) – The max size of the buffer used to batch the small parameter tensors, in number of elements (default 0 - unused). this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded. Set to 0 to remove all bucketing, 1M to 8M is usually reasonable.

  • auto_refresh_trainable (bool) – (default: True) Check whether the parameters trainability (requires_grad) has changed and update both ShardedDDP and OSS automatically if this is the case. If set to False, refresh_trainable() needs to be called anytime a parameter is frozen or unfrozen.

  • reduce_fp16 (bool) – cast the grads to fp16 before reducing. Not needed if the model is already fp16, but will probably improve performance for multi node jobs using PyTorch AMP. The effect is similar to DDP’s fp16_compress_hook and will also save some memory.

  • warn_on_trainable_params_changed (bool) – When set to False no warning will be logged whenever a parameter trainability change has been detected. Default is True.

forward(*inputs: Any, **kwargs: Any) Any[source]

Module forward pass, handles any DDP-specific work in the background. Primes the backward pass for gradient reduction to the proper ranks.

to(device: Optional[torch.device], dtype: Optional[torch.dtype] = None, non_blocking: bool = False) fairscale.nn.data_parallel.sharded_ddp.ShardedDataParallel[source]

Moves and/or casts the parameters and buffers.

Its signature is similar to, but only accepts floating point desired dtype s. In addition, this method will only cast the floating point parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.


This method modifies the module in-place.

  • device (torch.device) – the desired device of the parameters and buffers in this module.

  • dtype (torch.dtype) – the desired floating point type of the floating point parameters and buffers.

  • non_blocking (bool) – make it an asynchronous call.


Module – self.

refresh_trainable() None[source]

If the module trainability has changed, update all the assumptions

reduce() None[source]

This does not need to be called, the gradient reduction is done automatically during the BW pass. Use this method to reduce the gradients manually

sync_buffers(blocking: bool = False) None[source]

Sync all the param buffers in between ranks (including for instance batch norm statistics).


blocking (bool) – wait for the operation to conclude.

zero_grad(set_to_none: bool = False) None[source]

Sets gradients of all model parameters to zero. See similar function under torch.optim.Optimizer for more context.


set_to_none (bool) – instead of setting to zero, set the grads to None. See torch.optim.Optimizer.zero_grad() for details.

__getattr__(name: str) Any[source]

Forward missing attributes to wrapped module.

no_sync() Generator[source]

A context manager to disable gradient synchronization.

training: bool
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.