Fully Sharded Data Parallel¶
- class fairscale.nn.FullyShardedDataParallel(module: torch.nn.modules.module.Module, process_group: Optional[ProcessGroup] = None, process_group_reduce_scatter: Any = ProcessGroupName.reduce_scatter, reshard_after_forward: bool = True, disable_reshard_on_root: bool = True, mixed_precision: bool = False, fp32_reduce_scatter: bool = False, flatten_parameters: bool = True, move_params_to_cpu: bool = False, compute_dtype: Optional[torch.dtype] = None, buffer_dtype: Optional[torch.dtype] = None, move_grads_to_cpu: Optional[bool] = None, bucket_cap_mb: int = 25, compute_device: Optional[torch.device] = None, no_broadcast_optim_state: Optional[bool] = False, state_dict_device: Optional[torch.device] = None, clear_autocast_cache: bool = False, force_input_to_fp32: bool = False, verbose: bool = False, cpu_offload: bool = False, state_dict_on_rank_0_only: bool = False, gradient_predivide_factor: Optional[float] = None, allow_reset_parameters: bool = False)[source]¶
A wrapper for sharding Module parameters across data parallel workers. This is inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. FullyShardedDataParallel is commonly shorten to FSDP.
import torch from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP torch.cuda.set_device(device_id) sharded_module = FSDP(my_module) optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) x = sharded_module(x, y=3, z=torch.Tensor()) loss = x.sum() loss.backward() optim.step()
It is also possible to shard individual layers separately and have an outer wrapper handle any leftover parameters. This can be helpful to further reduce GPU memory usage, reduce system memory usage when initializing large models and to improve training speed by overlapping the all-gather step across the forward pass. For example:
import torch from fairscale.nn.wrap import wrap, enable_wrap, auto_wrap from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.utils.testing import dist_init, teardown, rmf result = dist_init(0, 1, "/tmp/t1", "/tmp/t2") assert result fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True) with enable_wrap(**fsdp_params): l1 = wrap(torch.nn.Linear(5, 5)) assert isinstance(l1, FSDP) # Wraps layer in FSDP by default if within context # Separately Wraps children modules with more than 1e8 params large_tfmr = torch.nn.Transformer(d_model=2048, num_encoder_layers=12, num_decoder_layers=12) l2 = auto_wrap(large_tfmr) assert isinstance(l2.encoder, FSDP) assert isinstance(l2.decoder, FSDP) print(l2) # You can print the model to examine FSDP wrapping. teardown() rmf("/tmp/t1") rmf("/tmp/t2")
The optimizer must be initialized after the module has been wrapped, since FSDP will shard parameters in-place and this will break any previously initialized optimizers.
If you wrap every parameter inside a nested FSDP and leaving the outer FSDP empty without any parameter, checkpointing activation may trigger an assert on the backward pass. The solution is to leave some parameters to the outer FSDP.
If activation checkpointing is used with FSDP, it is strongly encouraged to use
checkpoint_wrapperfunction from FairScale instead of the
checkpointfunction from PyTorch.
module (nn.Module) – module to be wrapped with FSDP.
process_group (Optional) – process group for sharding
process_group_reduce_scatter (Optional) – process group for reduce scatter it defaults to ProcessGroupName.reduce_scatter. A seperate process group is initialized and assigned to the reduce_scatter operation. And the reduce_scatter operation overlaps with other operations in the backward propagation If it is a specific ProcessGroup, the reduce_scatter operates on this ProcessGroup, and the overlap still happens. To disable the overlap feature, set the process group to ProcessGroupName.default. In this case, the reduce_scatter operation uses the same process group with the default group. If reduce scatter process group size is differnt with the default process group size, the reduce_scatter operation rolls back to use the same process group with the default process group.
reshard_after_forward (bool, Optional) – if
True, reshard parameters after the forward pass. This saves memory but slows training. This is only relevant when resharding individual layers.
disable_reshard_on_root (bool, Optional) – If
reshard_after_forwardwill be set to
Falseif the module is a FSDP root module to improve performance. For some cases, we do not reshard the full parameters of an FSDP root module since those parameters are needed immediately for the backward pass. If
False, the performance will be lower, but it is needed because it helps to save memory. Consider a case that an FSDP root module is a submodule of a model. Backward pass may not start immediate after the FSDP root module finishes its forward. So, reshard the parameters for the FSDP root modules can help to save memory in this case. In certain cases, the performance is not even slower, because the cached full param state may be stale due to load_local_state_dict() calls anyway. Default: True.
mixed_precision (bool, Optional) – if
True, inputs, activations and gradients will be kept in FP16; computation and communication will occur in FP16; and a (sharded) master copy of the model weights will be maintained in FP32.
fp32_reduce_scatter (bool, Optional) – if
True, then reduce-scatter gradients in FP32. This is only relevant when ``mixed_precision`` is
flatten_parameters (bool, Optional) – if
True, flatten parameters into a single contiguous tensor, which improves training speed.
move_params_to_cpu (bool, Optional) – if
True, offload params to CPU. Default: False
compute_dtype (torch.dtype, Optional) – dtype for full parameters for computation. This defaults to
torch.float32unless ``mixed_precision`` is set, in which case it defaults to
buffer_dtype (torch.dtype, Optional) – dtype for buffers for computation. This defaults to
move_grads_to_cpu (bool, Optional) – move gradient shard to CPU after reduction. This is useful when combined with CPU-based optimizers. It defaults to the value of ``move_params_to_cpu``.
bucket_cap_mb (int, Optional) – FSDP will bucket parameters so that gradient reduction can be more efficient for small parameters.
bucket_cap_mbcontrols the bucket size in MegaBytes (MB). Buckets are sub-divided based on world_size, so the max shard size is roughly
bucket_cap_mb / world_size. There is one bucketer (with potentially multiple
bucket_cap_mbsized buffers shared by all FSDP instances. Large gradient tensors are directly reduced without using the buffers. The buffers are there to reduce communication overhead for small tensors. Overlapping with computation happens due to use of a different CUDA stream than the computation CUDA stream. The total memory overhead per buffer is around
bucket_cap_mb / world_size * (world_size + 1). The buffers are allocated during the backward pass and freed at the end of the backward pass to save more memory for other phases of the training process. Note, the memory vs. speed tradeoff of bucket size is very different from that of the DDP engine. In DDP, the buffer size
1MB + n*cap_mb, until n is big enough to cover the entire model size. The order of which buffer is ready there is more rigid and DDP requires all gradients to be computed in the backward. In FSDP, the buffer size does not change with model size (it changes based on number of <dtype, device, process_group> tuples) and gradient ready order matters little since FSDP has a final flush call that ensures everything is reduced and not all gradients need to be upfront known. Overlapping with compute is done differently too. Values <= 0 disable bucketing. Default: 25.
compute_device (torch.device, Optional) – device for computation. If not given and module params are on a CUDA device, the param’s device will be used. If not given and module params are on CPU, then the current CUDA device (as indicated by
torch.cuda.current_device()will be used.
no_broadcast_optim_state – (bool, Optional) do not broadcast this modules optimizer state when
gather_full_optim_state_dictis called. If you set this true, you are expected to overwrite the relevant state entries of the returned optimizer state dict with the proper state at each rank. This is useful for situations, like Mixture Of Experts, where all but a few parameters can fit on one node. Default: False
state_dict_device (torch.device, Optional) – device for parameters returned by
state_dict(). If not given, this will default to
compute_device. Note that only the device type will be respected (e.g., “cuda:0” and “cuda:1” are the same).
clear_autocast_cache (bool) – When using mixed precision training with torch.amp.autocast, if the model weights are in FP32, autocast maintains a cache for downcasted weights. The cache can cause GPU OOM during the forward pass. Setting this flag to true will help clearing this cache as inner FSDP instances finish part of the forward pass to save GPU memory. Default: False
force_input_to_fp32 (bool) – Set to
Trueto force input floating point tensors to be FP32 (if they are FP16) when the FSDP instance is in full precision mode. This helps avoid issues of running SyncBatchNorm with AMP and checkpoint_wrapper. Default: False
verbose (bool) – Set this to
Trueto turn on verbose output for model’s string representation. Default: False
cpu_offload (bool, Optional) – if
True, offload params to CPU. Note: This arg will be deprecated in favor of ``move_params_to_cpu`` in an upcoming release.
state_dict_on_rank_0_only (bool) – When set to
model.state_dict()will only returns full state dict on rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM. Default: False
gradient_predivide_factor (float, optional) – If supplied, pre-divide the gradients before scatter-reduce. Default: None
allow_reset_parameters (bool) – If True, allow
reset_parametersAPI to be proxied to the wrapped module. Default: False
- set_gradient_divide_factors(pre: float, post: float, recursive: bool) None [source]¶
Allowing user to override the pre and post divide factors.
- property module: fairscale.nn.misc.flatten_params_wrapper.FlattenParamsWrapper¶
make model.module accessible, just like DDP.
Add a param that’s already owned by another FSDP wrapper.
This is experimental!
This only works with all sharing FSDP modules are un-flattened.
p must to be already sharded by the owning module.
Check the corresponding unit tests to see how is it used and tested. In particular, the sharing FSDP wrappers are “siblings” not “parent” and “child” of each other in the nested module structure.
p (Parameter) – The shared parameter.
Return the list of non-shared parameters.
- apply(fn: Callable[[torch.nn.modules.module.Module], None]) fairscale.nn.data_parallel.fully_sharded_data_parallel.FullyShardedDataParallel [source]¶
fnrecursively to every submodule (as returned by
.children()) as well as self. Typical use includes initializing the parameters of a model.
torch.nn.Module.apply, this version additionally gathers the full parameters before applying
fn. It should not be called from within another
fn (nn.Module) – function to be applied to each submodule
Module – self
- property params_with_grad: List[torch.nn.parameter.Parameter]¶
[p for p in self.parameters() if p.grad is not None]
- clip_grad_norm_(max_norm: Union[float, int], norm_type: Union[float, int] = 2.0) torch.Tensor [source]¶
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.
Total norm of the parameters (viewed as a single vector).
This is analogous to torch.nn.utils.clip_grad_norm_ but handles the partitioning and multiple devices per rank under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads in the model, so calling it in the OSS context would lead to different scaling being applied per subset of model parameters.
This needs to be called on all ranks, since synchronization primitives will be used.
- __getstate__() Dict[str, str] [source]¶
Serialize the state of the current FSDP instance.
Some properties are not serializable (e.g., process groups, streams), so we remove them and try to reconstruct them in
- __setstate__(state: Dict[str, Any]) None [source]¶
Intercept state setting and perform needed changes on params.
- parameters(recurse: bool = True) Iterator[torch.nn.parameter.Parameter] [source]¶
Returns an iterator over the module parameters, yielding all the parameters part of the model.
- named_parameters(*args: Any, **kwargs: Any) Iterator[Tuple[str, torch.nn.parameter.Parameter]] [source]¶
Returns an iterator over the module parameters, yielding both the name of the parameter as well as the parameter.
With FSDP, the named_parameters function implemented in nn.Module will not be able to return the name and param when we use flattened parameters unless we call this function under a summon_full_params context.
If you want the full param to be returned, you should call this function under a summon_full_params context when using flattened or original params.
This overloaded method will not be called in the case of a parent module containing a FSDP-wrapped child module. Calling parent.named_parameters() will return original unclean key strings, i.e. _fsdp_wrapped_module and _fpw_module are included the key string.
- state_dict(destination: Mapping[str, torch.Tensor], prefix: str = ..., keep_vars: bool = ...) Mapping[str, torch.Tensor] [source]¶
- state_dict(prefix: str = ..., keep_vars: bool = ...) OrderedDict[str, torch.Tensor]
Returns the whole (unsharded) state of the module. Parameters are not sharded, so the resulting state_dict can be loaded directly by the wrapped Module without any sharding-specific logic. Returned tensors will be full precision (e.g., FP32).
This needs to be called on all ranks, since synchronization primitives will be used.
- local_state_dict(destination: Mapping[str, torch.Tensor], prefix: str = ..., keep_vars: bool = ...) Mapping[str, torch.Tensor] [source]¶
- local_state_dict(prefix: str = ..., keep_vars: bool = ...) OrderedDict[str, torch.Tensor]
Returns the local (sharded) state of the module. Parameters are sharded, so the resulting state_dict can only be loaded after the Module has been wrapped with FSDP.
- load_state_dict(state_dict: Union[Dict[str, torch.Tensor], OrderedDict[str, torch.Tensor]], strict: bool = True) NamedTuple [source]¶
- load_local_state_dict(state_dict: Union[Dict[str, torch.Tensor], OrderedDict[str, torch.Tensor]], strict: bool = True) NamedTuple [source]¶
Load a local (sharded) state_dict.
- no_sync() Generator [source]¶
A context manager to disable gradient synchronizations across FSDP processes. Within this context, gradients will be accumulated on module variables, which will later be synchronized in the first forward-backward pass after exiting the context.
This likely results in higher memory usage because FSDP will accumulate the full model gradients (instead of gradient shards) until the eventual sync.
Gradient accumulation can be done without this context, avoiding the extra GPU memory overhead, but with the extra networking overhead. I.e. the trainer loop should just do multiple fwd/bwd without step() without the no_sync context.
- reset_parameters() None [source]¶
Special reset_parameters API handling.
We don’t by default allow this API because it has at least 2 issues:
calling it after wrapping can crash due to unexpected tensor size and dimensions due to flattening and sharding. So summon_full_params context might be required.
calling it after wrapping can result in incorrect init values due to flattening.
See this gist for an example of the init issue when parameters are flatten.
Or, like in 1, init function can silently init the weight differently because of the dimensions.
Finally, be advised that init on CPU vs. on GPU can have different values. If models are originally on CPU and after wrapping it is moved to GPU, calling this will again be problematic.
- summon_full_params(recurse: bool = True, volatile: bool = False) Generator [source]¶
A context manager to expose full params for the current FSDP instance. Can be useful after forward/backward for a model to get the params for additional processing or checking. Parameters will be gathered in full precision (e.g., FP32).
This can be used on inner FSDPs.
This can not be used within a forward or backward pass. Nor can forward and backward be started from within this context.
The full parameters will be freed after the context manager exits; it is up to the caller to clone them if needed.
The full parameters can be modified, but only the portion corresponding to the local param shard will persist after the context manager exits (unless
volatile=True, in which case there are no guarantees about persistence).
- forward(*args: Any, **kwargs: Any) torch.Tensor [source]¶
- local_metadata_dict() Dict[str, Any] [source]¶
Get the information needed to reconstruct the model from shards offline.
See the consolidate_shard_weights method below.
- static consolidate_shard_weights(shard_weights: List[Dict[str, torch.Tensor]], shard_metadata: List[Dict[str, Any]], with_module_buffers: bool = True, strict: bool = True) Dict[str, torch.Tensor] [source]¶
Given a list of weights and meta data associated to N shards, reconstruct the weights of an equivalent consolidated (non-sharded) state dict.
Module parameters are consolidated using the shard metadata.
Module buffers are taken from shard 0: this assumes that module buffers are either synchronized or that the shard 0 value is valid for all shards. If this behavior is not correct for your module (for instance if buffers needs to be all-reduced instead), you can disable it with with_module_buffers=False.
This method is used to re-assemble checkpoints of shards without having to instantiate FSDP wrappers with the world size (i.e. large number of GPUs) originally used to save the shards.
shard_weights (List[Dict[str, torch.Tensor]]) – List of dictionaries that contains sharded weights from each rank.
shard_metadata (List[Dict[str, Any]]) – List of dictionaries that contains metadata from each shard. See local_metadata_dict above.
with_module_buffers (bool) – If shard 0’s buffer should be returned in the consolidated weight dict. Default: True.
strict (bool) – allow incomplete shard weights. if True, every key in the metadata must be present in the weights.
- assert_state(state: Union[fairscale.nn.data_parallel.fully_sharded_data_parallel.TrainingState, List[fairscale.nn.data_parallel.fully_sharded_data_parallel.TrainingState]]) None [source]¶
Assert we are in the given state.
- gather_full_optim_state_dict(optim: torch.optim.optimizer.Optimizer, **ignored: Dict) Optional[Dict[str, Any]] [source]¶
Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the sharded properties are not exposed. Multiple parameter groups are not yet supported.
This should be called only on the root FSDP instance. Nested FSDP instances are supported as long as they have the same world_size as the parent or world_size=1.
optim (Optimizer) – an optimizer instance for this FSDP rank. Its state_dict is used in the consolidation. However, its state is not modified.
- A dict with four entries (On rank zero, other workers return
state - a dict holding gathered optimization state, 1 entry per unflat parameter
param_groups - a dict containing the 1 parameter group
param_id_map - global (unflat) to local (flat) id mapping
uncollected_local_ids - keys in the state dict that were not broadcast
- A dict with four entries (On rank zero, other workers return
- get_shard_from_optim_state_dict(full_optim_state_dict: Dict[str, Any]) Dict[str, Any] [source]¶
Get the portion of the optimizer state dict associated with the shard
This can be used to get the right sharded optimizer state to be loaded into the sharded optimizer for this FSDP rank.
- ..warning:: The input state dict is modified in-place assuming the original
full state isn’t going to be used anymore. This is done so that we don’t need to copy extra state in it. It is caller’s responsibility to make copies if it doesn’t want the original state dict modified.
full_optim_state_dict (dict) – consolidated optimizer state returned by
gather_full_optim_state, or loaded from a checkpoint.
(dict) – a shard of the optimizer state.