Optimizer State Sharding

class fairscale.optim.OSS(params: Any, optim: Type[torch.optim.optimizer.Optimizer] = <class 'torch.optim.sgd.SGD'>, group: Optional[Any] = None, broadcast_buffer_size: int = -1, broadcast_fp16: bool = False, force_broadcast_object: bool = False, **default: Any)[source]

Wraps an arbitrary optim.Optimizer optimizer and shards its state as described by ZeRO.

opt = OSS(params, optim=torch.optim.Adam, lr=0.01)

We use a greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among rank.

After each rank completed their parameter update, they broadcast the new version of the parameters to all other ranks to synchronize the parameters for next round forward/backward computation.


params (list of tensors) – parameters to be optimized

Keyword Arguments
  • optim (torch.nn.Optimizer) – optimizer to shard (default: SGD)

  • group (group) – torch.distributed group (default: group.WORLD)

  • broadcast_buffer_size (int) – (deprecated) used to cap the size of the broadcast buffers, not being used anymore.

  • broadcast_fp16 (bool) – Compress the model shards in fp16 before sharing them in between ranks. This is safe to use when PyTorch AMP is activated. Without torch AMP this will lead to a slight degradation in terms of accuracy.

  • force_broadcast_object (bool) – If True, ‘_broadcast_object’ will be used for rebuilding the sharded optimizer. If False, whether to use ‘_broadcast_object’ or ‘dist.broadcast_object_list’ will be determined by GPU capabilities. This feature is needed since some newer GPUs still get some memory issues when applying dist.broadcast_object_list.

optim: torch.optim.optimizer.Optimizer

The optimizer used for a given shard

in_super_constructor: bool
partition_parameters() List[List[dict]][source]

Partitions parameters across distributed data parallel ranks.

Returns a list of param_groups (which is a list of dict) where each element of the list contains the param_groups for a rank. Element 0 corresponds to rank 0, etc. We need all the ranks for the broadcast inside step().

step(closure: Optional[Callable[[], float]] = None, **kwargs: Any) Optional[float][source]

Performs a single optimization step (parameter update).


closure (callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.

clip_grad_norm(max_norm: Union[float, int], norm_type: Union[float, int] = 2.0, filter_params_fn: Optional[Callable[[Any], Any]] = None) torch.Tensor[source]

Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.

  • max_norm (float or int) – max norm of the gradients

  • norm_type (float or int) – type of the used p-norm. Can be 'inf' for infinity norm.


Total norm of the parameters (viewed as a single vector).

consolidate_state_dict(recipient_rank: int = 0) None[source]

Update the consolidated state_dict list, one per rank.

  • recipient_rank (int) – on which rank to materialize the full state dict.

  • value (-1 is a special) –

  • state (which means that all ranks should have the) –

state_dict(all_ranks: bool = False) Dict[str, Any][source]

Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the sharded properties are not exposed.

  • all_ranks (bool) – materialize the state on all ranks. In that case, .state_dict() needs to be called on

  • ranks (all) –


a dict with two entries
  • state - a dict holding current optimization state. Its content

    differs between optimizer classes.

  • param_groups - a dict containing all parameter groups

load_state_dict(state_dict: Dict[str, Any]) None[source]

Restore the global parameter groups as well as the shard.


state_dict (dict) – optimizer state. Should be an object returned from a call to state_dict()

refresh_trainable() None[source]

Updates the partitioning and communication patterns if the trainability (requires_grad) of some parameters changed.

add_param_group(param_group: dict) None[source]

Add a param group to the Optimizer s param_groups.

This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the Optimizer as training progresses.

  • param_group (dict) – Specifies what Tensors should be optimized along with group

  • options (specific optimization) –

Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.