Fully Sharded Data Parallel¶
- class fairscale.nn.FullyShardedDataParallel(module: torch.nn.modules.module.Module, process_group: Optional[ProcessGroup] = None, process_group_reduce_scatter: Any = ProcessGroupName.reduce_scatter, reshard_after_forward: bool = True, disable_reshard_on_root: bool = True, mixed_precision: bool = False, fp32_reduce_scatter: bool = False, flatten_parameters: bool = True, move_params_to_cpu: bool = False, compute_dtype: Optional[torch.dtype] = None, buffer_dtype: Optional[torch.dtype] = None, move_grads_to_cpu: Optional[bool] = None, bucket_cap_mb: int = 25, compute_device: Optional[torch.device] = None, no_broadcast_optim_state: Optional[bool] = False, state_dict_device: Optional[torch.device] = None, clear_autocast_cache: bool = False, force_input_to_fp32: bool = False, verbose: bool = False, cpu_offload: bool = False, state_dict_on_rank_0_only: bool = False, gradient_predivide_factor: Optional[float] = None, allow_reset_parameters: bool = False)[source]¶
A wrapper for sharding Module parameters across data parallel workers. This is inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. FullyShardedDataParallel is commonly shorten to FSDP.
Pseudo-code usage:
import torch from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP torch.cuda.set_device(device_id) sharded_module = FSDP(my_module) optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) x = sharded_module(x, y=3, z=torch.Tensor([1])) loss = x.sum() loss.backward() optim.step()
It is also possible to shard individual layers separately and have an outer wrapper handle any leftover parameters. This can be helpful to further reduce GPU memory usage, reduce system memory usage when initializing large models and to improve training speed by overlapping the all-gather step across the forward pass. For example:
import torch from fairscale.nn.wrap import wrap, enable_wrap, auto_wrap from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.utils.testing import dist_init, teardown, rmf result = dist_init(0, 1, "/tmp/t1", "/tmp/t2") assert result fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True) with enable_wrap(**fsdp_params): l1 = wrap(torch.nn.Linear(5, 5)) assert isinstance(l1, FSDP) # Wraps layer in FSDP by default if within context # Separately Wraps children modules with more than 1e8 params large_tfmr = torch.nn.Transformer(d_model=2048, num_encoder_layers=12, num_decoder_layers=12) l2 = auto_wrap(large_tfmr) assert isinstance(l2.encoder, FSDP) assert isinstance(l2.decoder, FSDP) print(l2) # You can print the model to examine FSDP wrapping. teardown() rmf("/tmp/t1") rmf("/tmp/t2")
Warning
The optimizer must be initialized after the module has been wrapped, since FSDP will shard parameters in-place and this will break any previously initialized optimizers.
Warning
If you wrap every parameter inside a nested FSDP and leaving the outer FSDP empty without any parameter, checkpointing activation may trigger an assert on the backward pass. The solution is to leave some parameters to the outer FSDP.
Warning
If activation checkpointing is used with FSDP, it is strongly encouraged to use
checkpoint_wrapper
function from FairScale instead of thecheckpoint
function from PyTorch.- Parameters
module (nn.Module) – module to be wrapped with FSDP.
process_group (Optional) – process group for sharding
process_group_reduce_scatter (Optional) – process group for reduce scatter it defaults to ProcessGroupName.reduce_scatter. A seperate process group is initialized and assigned to the reduce_scatter operation. And the reduce_scatter operation overlaps with other operations in the backward propagation If it is a specific ProcessGroup, the reduce_scatter operates on this ProcessGroup, and the overlap still happens. To disable the overlap feature, set the process group to ProcessGroupName.default. In this case, the reduce_scatter operation uses the same process group with the default group. If reduce scatter process group size is differnt with the default process group size, the reduce_scatter operation rolls back to use the same process group with the default process group.
reshard_after_forward (bool, Optional) – if
True
, reshard parameters after the forward pass. This saves memory but slows training. This is only relevant when resharding individual layers.disable_reshard_on_root (bool, Optional) – If
True
,reshard_after_forward
will be set toFalse
if the module is a FSDP root module to improve performance. For some cases, we do not reshard the full parameters of an FSDP root module since those parameters are needed immediately for the backward pass. IfFalse
, the performance will be lower, but it is needed because it helps to save memory. Consider a case that an FSDP root module is a submodule of a model. Backward pass may not start immediate after the FSDP root module finishes its forward. So, reshard the parameters for the FSDP root modules can help to save memory in this case. In certain cases, the performance is not even slower, because the cached full param state may be stale due to load_local_state_dict() calls anyway. Default: True.mixed_precision (bool, Optional) – if
True
, inputs, activations and gradients will be kept in FP16; computation and communication will occur in FP16; and a (sharded) master copy of the model weights will be maintained in FP32.fp32_reduce_scatter (bool, Optional) – if
True
, then reduce-scatter gradients in FP32. This is only relevant when ``mixed_precision`` isTrue
.flatten_parameters (bool, Optional) – if
True
, flatten parameters into a single contiguous tensor, which improves training speed.move_params_to_cpu (bool, Optional) – if
True
, offload params to CPU. Default: Falsecompute_dtype (torch.dtype, Optional) – dtype for full parameters for computation. This defaults to
torch.float32
unless ``mixed_precision`` is set, in which case it defaults totorch.float16
.buffer_dtype (torch.dtype, Optional) – dtype for buffers for computation. This defaults to
compute_dtype
.move_grads_to_cpu (bool, Optional) – move gradient shard to CPU after reduction. This is useful when combined with CPU-based optimizers. It defaults to the value of ``move_params_to_cpu``.
bucket_cap_mb (int, Optional) – FSDP will bucket parameters so that gradient reduction can be more efficient for small parameters.
bucket_cap_mb
controls the bucket size in MegaBytes (MB). Buckets are sub-divided based on world_size, so the max shard size is roughlybucket_cap_mb / world_size
. There is one bucketer (with potentially multiplebucket_cap_mb
sized buffers shared by all FSDP instances. Large gradient tensors are directly reduced without using the buffers. The buffers are there to reduce communication overhead for small tensors. Overlapping with computation happens due to use of a different CUDA stream than the computation CUDA stream. The total memory overhead per buffer is aroundbucket_cap_mb / world_size * (world_size + 1)
. The buffers are allocated during the backward pass and freed at the end of the backward pass to save more memory for other phases of the training process. Note, the memory vs. speed tradeoff of bucket size is very different from that of the DDP engine. In DDP, the buffer size1MB + n*cap_mb
, until n is big enough to cover the entire model size. The order of which buffer is ready there is more rigid and DDP requires all gradients to be computed in the backward. In FSDP, the buffer size does not change with model size (it changes based on number of <dtype, device, process_group> tuples) and gradient ready order matters little since FSDP has a final flush call that ensures everything is reduced and not all gradients need to be upfront known. Overlapping with compute is done differently too. Values <= 0 disable bucketing. Default: 25.compute_device (torch.device, Optional) – device for computation. If not given and module params are on a CUDA device, the param’s device will be used. If not given and module params are on CPU, then the current CUDA device (as indicated by
torch.cuda.current_device()
will be used.no_broadcast_optim_state – (bool, Optional) do not broadcast this modules optimizer state when
gather_full_optim_state_dict
is called. If you set this true, you are expected to overwrite the relevant state entries of the returned optimizer state dict with the proper state at each rank. This is useful for situations, like Mixture Of Experts, where all but a few parameters can fit on one node. Default: Falsestate_dict_device (torch.device, Optional) – device for parameters returned by
state_dict()
. If not given, this will default tocompute_device
. Note that only the device type will be respected (e.g., “cuda:0” and “cuda:1” are the same).clear_autocast_cache (bool) – When using mixed precision training with torch.amp.autocast, if the model weights are in FP32, autocast maintains a cache for downcasted weights. The cache can cause GPU OOM during the forward pass. Setting this flag to true will help clearing this cache as inner FSDP instances finish part of the forward pass to save GPU memory. Default: False
force_input_to_fp32 (bool) – Set to
True
to force input floating point tensors to be FP32 (if they are FP16) when the FSDP instance is in full precision mode. This helps avoid issues of running SyncBatchNorm with AMP and checkpoint_wrapper. Default: Falseverbose (bool) – Set this to
True
to turn on verbose output for model’s string representation. Default: Falsecpu_offload (bool, Optional) – if
True
, offload params to CPU. Note: This arg will be deprecated in favor of ``move_params_to_cpu`` in an upcoming release.state_dict_on_rank_0_only (bool) – When set to
True
,model.state_dict()
will only returns full state dict on rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM. Default: Falsegradient_predivide_factor (float, optional) – If supplied, pre-divide the gradients before scatter-reduce. Default: None
allow_reset_parameters (bool) – If True, allow
reset_parameters
API to be proxied to the wrapped module. Default: False
- set_gradient_divide_factors(pre: float, post: float, recursive: bool) None [source]¶
Allowing user to override the pre and post divide factors.
- property module: fairscale.nn.misc.flatten_params_wrapper.FlattenParamsWrapper¶
make model.module accessible, just like DDP.
Add a param that’s already owned by another FSDP wrapper.
Warning
This is experimental!
This only works with all sharing FSDP modules are un-flattened.
p must to be already sharded by the owning module.
Check the corresponding unit tests to see how is it used and tested. In particular, the sharing FSDP wrappers are “siblings” not “parent” and “child” of each other in the nested module structure.
- Parameters
p (Parameter) – The shared parameter.
Return the list of non-shared parameters.
- apply(fn: Callable[[torch.nn.modules.module.Module], None]) fairscale.nn.data_parallel.fully_sharded_data_parallel.FullyShardedDataParallel [source]¶
Applies
fn
recursively to every submodule (as returned by.children()
) as well as self. Typical use includes initializing the parameters of a model.Compared to
torch.nn.Module.apply
, this version additionally gathers the full parameters before applyingfn
. It should not be called from within anothersummon_full_params
context.- Parameters
fn (nn.Module) – function to be applied to each submodule
- Returns
Module – self
- property params_with_grad: List[torch.nn.parameter.Parameter]¶
[p for p in self.parameters() if p.grad is not None]
- clip_grad_norm_(max_norm: Union[float, int], norm_type: Union[float, int] = 2.0) torch.Tensor [source]¶
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.
- Parameters
- Returns
Total norm of the parameters (viewed as a single vector).
Note
This is analogous to torch.nn.utils.clip_grad_norm_ but handles the partitioning and multiple devices per rank under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads in the model, so calling it in the OSS context would lead to different scaling being applied per subset of model parameters.
Warning
This needs to be called on all ranks, since synchronization primitives will be used.
- __getstate__() Dict[str, str] [source]¶
Serialize the state of the current FSDP instance.
Some properties are not serializable (e.g., process groups, streams), so we remove them and try to reconstruct them in
__setstate__()
.
- __setstate__(state: Dict[str, Any]) None [source]¶
Intercept state setting and perform needed changes on params.
- parameters(recurse: bool = True) Iterator[torch.nn.parameter.Parameter] [source]¶
Returns an iterator over the module parameters, yielding all the parameters part of the model.
- named_parameters(*args: Any, **kwargs: Any) Iterator[Tuple[str, torch.nn.parameter.Parameter]] [source]¶
Returns an iterator over the module parameters, yielding both the name of the parameter as well as the parameter.
With FSDP, the named_parameters function implemented in nn.Module will not be able to return the name and param when we use flattened parameters unless we call this function under a summon_full_params context.
If you want the full param to be returned, you should call this function under a summon_full_params context when using flattened or original params.
Warning
This overloaded method will not be called in the case of a parent module containing a FSDP-wrapped child module. Calling parent.named_parameters() will return original unclean key strings, i.e. _fsdp_wrapped_module and _fpw_module are included the key string.
- state_dict(destination: Mapping[str, torch.Tensor], prefix: str = ..., keep_vars: bool = ...) Mapping[str, torch.Tensor] [source]¶
- state_dict(prefix: str = ..., keep_vars: bool = ...) OrderedDict[str, torch.Tensor]
Returns the whole (unsharded) state of the module. Parameters are not sharded, so the resulting state_dict can be loaded directly by the wrapped Module without any sharding-specific logic. Returned tensors will be full precision (e.g., FP32).
Warning
This needs to be called on all ranks, since synchronization primitives will be used.
- local_state_dict(destination: Mapping[str, torch.Tensor], prefix: str = ..., keep_vars: bool = ...) Mapping[str, torch.Tensor] [source]¶
- local_state_dict(prefix: str = ..., keep_vars: bool = ...) OrderedDict[str, torch.Tensor]
Returns the local (sharded) state of the module. Parameters are sharded, so the resulting state_dict can only be loaded after the Module has been wrapped with FSDP.
- load_state_dict(state_dict: Union[Dict[str, torch.Tensor], OrderedDict[str, torch.Tensor]], strict: bool = True) NamedTuple [source]¶
- load_local_state_dict(state_dict: Union[Dict[str, torch.Tensor], OrderedDict[str, torch.Tensor]], strict: bool = True) NamedTuple [source]¶
Load a local (sharded) state_dict.
- no_sync() Generator [source]¶
A context manager to disable gradient synchronizations across FSDP processes. Within this context, gradients will be accumulated on module variables, which will later be synchronized in the first forward-backward pass after exiting the context.
Note
This likely results in higher memory usage because FSDP will accumulate the full model gradients (instead of gradient shards) until the eventual sync.
Note
Gradient accumulation can be done without this context, avoiding the extra GPU memory overhead, but with the extra networking overhead. I.e. the trainer loop should just do multiple fwd/bwd without step() without the no_sync context.
- reset_parameters() None [source]¶
Special reset_parameters API handling.
We don’t by default allow this API because it has at least 2 issues:
calling it after wrapping can crash due to unexpected tensor size and dimensions due to flattening and sharding. So summon_full_params context might be required.
calling it after wrapping can result in incorrect init values due to flattening.
See this gist for an example of the init issue when parameters are flatten.
https://gist.github.com/407bb158f0d0612e157c2cbcf5c8b76a
Or, like in 1, init function can silently init the weight differently because of the dimensions.
Finally, be advised that init on CPU vs. on GPU can have different values. If models are originally on CPU and after wrapping it is moved to GPU, calling this will again be problematic.
- summon_full_params(recurse: bool = True, volatile: bool = False) Generator [source]¶
A context manager to expose full params for the current FSDP instance. Can be useful after forward/backward for a model to get the params for additional processing or checking. Parameters will be gathered in full precision (e.g., FP32).
Note
This can be used on inner FSDPs.
Note
This can not be used within a forward or backward pass. Nor can forward and backward be started from within this context.
Note
The full parameters will be freed after the context manager exits; it is up to the caller to clone them if needed.
Note
The full parameters can be modified, but only the portion corresponding to the local param shard will persist after the context manager exits (unless
volatile=True
, in which case there are no guarantees about persistence).
- forward(*args: Any, **kwargs: Any) torch.Tensor [source]¶
- local_metadata_dict() Dict[str, Any] [source]¶
Get the information needed to reconstruct the model from shards offline.
See the consolidate_shard_weights method below.
- static consolidate_shard_weights(shard_weights: List[Dict[str, torch.Tensor]], shard_metadata: List[Dict[str, Any]], with_module_buffers: bool = True, strict: bool = True) Dict[str, torch.Tensor] [source]¶
Given a list of weights and meta data associated to N shards, reconstruct the weights of an equivalent consolidated (non-sharded) state dict.
Module parameters are consolidated using the shard metadata.
Module buffers are taken from shard 0: this assumes that module buffers are either synchronized or that the shard 0 value is valid for all shards. If this behavior is not correct for your module (for instance if buffers needs to be all-reduced instead), you can disable it with with_module_buffers=False.
This method is used to re-assemble checkpoints of shards without having to instantiate FSDP wrappers with the world size (i.e. large number of GPUs) originally used to save the shards.
- Parameters
shard_weights (List[Dict[str, torch.Tensor]]) – List of dictionaries that contains sharded weights from each rank.
shard_metadata (List[Dict[str, Any]]) – List of dictionaries that contains metadata from each shard. See local_metadata_dict above.
with_module_buffers (bool) – If shard 0’s buffer should be returned in the consolidated weight dict. Default: True.
strict (bool) – allow incomplete shard weights. if True, every key in the metadata must be present in the weights.
- assert_state(state: Union[fairscale.nn.data_parallel.fully_sharded_data_parallel.TrainingState, List[fairscale.nn.data_parallel.fully_sharded_data_parallel.TrainingState]]) None [source]¶
Assert we are in the given state.
- gather_full_optim_state_dict(optim: torch.optim.optimizer.Optimizer, **ignored: Dict) Optional[Dict[str, Any]] [source]¶
Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the sharded properties are not exposed. Multiple parameter groups are not yet supported.
This should be called only on the root FSDP instance. Nested FSDP instances are supported as long as they have the same world_size as the parent or world_size=1.
- Parameters
optim (Optimizer) – an optimizer instance for this FSDP rank. Its state_dict is used in the consolidation. However, its state is not modified.
- Returns
- A dict with four entries (On rank zero, other workers return
None
) state - a dict holding gathered optimization state, 1 entry per unflat parameter
param_groups - a dict containing the 1 parameter group
param_id_map - global (unflat) to local (flat) id mapping
uncollected_local_ids - keys in the state dict that were not broadcast
- A dict with four entries (On rank zero, other workers return
- get_shard_from_optim_state_dict(full_optim_state_dict: Dict[str, Any]) Dict[str, Any] [source]¶
Get the portion of the optimizer state dict associated with the shard
This can be used to get the right sharded optimizer state to be loaded into the sharded optimizer for this FSDP rank.
- ..warning:: The input state dict is modified in-place assuming the original
full state isn’t going to be used anymore. This is done so that we don’t need to copy extra state in it. It is caller’s responsibility to make copies if it doesn’t want the original state dict modified.
- Parameters
full_optim_state_dict (dict) – consolidated optimizer state returned by
gather_full_optim_state
, or loaded from a checkpoint.- Returns
(dict) – a shard of the optimizer state.