Optimizer State Sharding¶
- class fairscale.optim.OSS(params: Any, optim: Type[torch.optim.optimizer.Optimizer] = <class 'torch.optim.sgd.SGD'>, group: Optional[Any] = None, broadcast_buffer_size: int = -1, broadcast_fp16: bool = False, force_broadcast_object: bool = False, **default: Any)¶
opt = OSS(params, optim=torch.optim.Adam, lr=0.01)
We use a greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among rank.
After each rank completed their parameter update, they broadcast the new version of the parameters to all other ranks to synchronize the parameters for next round forward/backward computation.
params (list of tensors) – parameters to be optimized
- Keyword Arguments
optim (torch.nn.Optimizer) – optimizer to shard (default: SGD)
group (group) – torch.distributed group (default: group.WORLD)
broadcast_buffer_size (int) – (deprecated) used to cap the size of the broadcast buffers, not being used anymore.
broadcast_fp16 (bool) – Compress the model shards in fp16 before sharing them in between ranks. This is safe to use when PyTorch AMP is activated. Without torch AMP this will lead to a slight degradation in terms of accuracy.
force_broadcast_object (bool) – If True, ‘_broadcast_object’ will be used for rebuilding the sharded optimizer. If False, whether to use ‘_broadcast_object’ or ‘dist.broadcast_object_list’ will be determined by GPU capabilities. This feature is needed since some newer GPUs still get some memory issues when applying dist.broadcast_object_list.
- optim: torch.optim.optimizer.Optimizer¶
The optimizer used for a given shard
- partition_parameters() List[List[dict]] ¶
Partitions parameters across distributed data parallel ranks.
Returns a list of param_groups (which is a list of dict) where each element of the list contains the param_groups for a rank. Element 0 corresponds to rank 0, etc. We need all the ranks for the broadcast inside step().
- step(closure: Optional[Callable[, float]] = None, **kwargs: Any) Optional[float] ¶
Performs a single optimization step (parameter update).
closure (callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.
- clip_grad_norm(max_norm: Union[float, int], norm_type: Union[float, int] = 2.0, filter_params_fn: Optional[Callable[[Any], Any]] = None) torch.Tensor ¶
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.
- consolidate_state_dict(recipient_rank: int = 0) None ¶
Update the consolidated state_dict list, one per rank.
recipient_rank (int) – on which rank to materialize the full state dict.
value (-1 is a special) –
state (which means that all ranks should have the) –
- state_dict(all_ranks: bool = False) Dict[str, Any] ¶
Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the sharded properties are not exposed.
all_ranks (bool) – materialize the state on all ranks. In that case, .state_dict() needs to be called on
ranks (all) –
- a dict with two entries
- state - a dict holding current optimization state. Its content
differs between optimizer classes.
param_groups - a dict containing all parameter groups
- load_state_dict(state_dict: Dict[str, Any]) None ¶
Restore the global parameter groups as well as the shard.
- refresh_trainable() None ¶
Updates the partitioning and communication patterns if the trainability (requires_grad) of some parameters changed.
- add_param_group(param_group: dict) None ¶
Add a param group to the
This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the
Optimizeras training progresses.
param_group (dict) – Specifies what Tensors should be optimized along with group
options (specific optimization) –