Shortcuts

SlowMo Distributed Data Parallel

class fairscale.experimental.nn.data_parallel.SlowMoDistributedDataParallel(module: torch.nn.modules.module.Module, nprocs_per_node: int, broadcast_buffers: bool = True, slowmo_base_algorithm: fairscale.experimental.nn.data_parallel.gossip.distributed.SlowMoBaseAlgorithm = SlowMoBaseAlgorithm.LOCALSGD, slowmo_momentum: float = 0.5, slowmo_memory_efficient: bool = True, slowmo_frequency: int = 48, slowmo_lr: float = 1.0, slowmo_num_shards: int = 32, localsgd_frequency: int = 3, graph: Optional[fairscale.experimental.nn.data_parallel.gossip.graph_manager.GraphManager] = None, mixing: Optional[fairscale.experimental.nn.data_parallel.gossip.mixing_manager.MixingManager] = None, push_sum: bool = True, overlap: bool = False, synch_freq: int = 0, use_streams: bool = True, slowmo_sgp_average_params: bool = False, verbose: bool = False, profile_mode: bool = False, process_rank: Optional[int] = None, process_world_size: Optional[int] = None, global_group: Optional[torch.distributed.distributed_c10d.ProcessGroup] = None, master_group: Optional[torch.distributed.distributed_c10d.ProcessGroup] = None, local_node_group: Optional[torch.distributed.distributed_c10d.ProcessGroup] = None, comm_device: Optional[torch.device] = None)[source]

Wraps an arbitrary nn.Module module and allows it to be run on multiple GPUs (distributed) in a data parallel setting.

This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension. The module is replicated on each machine and each device, and each such replica handles a portion of the input. After the optimizer update, it synchronizes the parameters on the different nodes using SlowMo (https://arxiv.org/abs/1910.00643).

Please make sure to read the documentation for slowmo_memory_efficient parameter as it contains a non-trivial trick in order to optimize our implementation.

Please refer to the documentation of torch.nn.parallel.DistributedDataParallel for other useful tips for using this container.

Parameters
  • module (Module) – module to be parallelized

  • nprocs_per_node (int) – Number of processes per node (one per GPU). This needs to be specified for optimal accuracy and speed. Syncing across GPUs in a node is extremely fast, which we utilize for performance optimization

  • broadcast_buffers (bool) – Flag that enables syncing (broadcasting) buffers (example - batchnorm buffers) of the module at beginning of the forward function. Setting it to False would result in better performance due to less communication on the network but might result in a reduced accuracy (default: True)

  • slowmo_base_algorithm (SlowMoBaseAlgorithm) – The base algorithm to be used for approximately averaging the different parameters across nodes. The base algorithm is responsible for increasing the efficiency of this module. The base algorithm, combined with SlowMo, results in significant speedups without accuracy loss. Either Stochastic Gradient Push (SlowMoBaseAlgorithm.SGP) (https://arxiv.org/abs/1811.10792) or LocalSGD (SlowMoBaseAlgorithm.LOCALSGD) (https://arxiv.org/abs/1808.07217) can be used here (default: SlowMoBaseAlgorithm.LOCALSGD)

SlowMo Parameters
  • slowmo_momentum (float) – This specifies the value of slowmo momentum to be used (read https://arxiv.org/abs/1910.00643 for more details). This parameter might need to be tuned and the optimal value varies according to the use case and the number of nodes being run on. The optimal value typically increases with the number of nodes. On training transfomers on the WMT 16 En-De dataset, we have found the optimal values to be 0 for less than 4 nodes, 0.2 for 4 nodes, 0.5 for 8 nodes and 0.6 for 16 nodes (default: 0.5)

  • slowmo_memory_efficient (bool) – If enabled, use a memory efficient implementation of SlowMo. The basic implementation of SlowMo occupies extra memory equal to double the memory occupied by the model parameters. The memory efficient implementation shards that memory across a certain number of shards which is specified as a parameter below. In addition, slowmo_memory_efficient leads to extra communication with throughput equivalent to an allreduce, and performs an allreduce as a side-effect. In order to optimize the implementation, we skip the typical allreduce when slowmo_base_algorithm is localsgd and the localsgd step and slowmo step occur on the same iteration. Also, we skip the gossip step when slowmo_base_algorithm is sgp. We can skip these because the memory-efficient slowmo step does an allreduce as a side effect. Due to this skipping, when slowmo_base_algorithm is localsgd, we recommend setting slowmo_frequency to be a multiple of localsgd_frequency. We recommend setting this parameter to True when slowmo_base_algorithm is localsgd. In case of sgp, there is a tradeoff between extra memory usage which is double the memory occupied by the parameters, and extra time spent which is half the time taken up by an allreduce every slowmo_frequency iterations and we suggest setting it to False (default: True)

  • slowmo_frequency (int) – This specifies how often (number of iterations) slow momentum is to be performed. We recommend keeping slowmo_frequency as a multiple of localsgd_frequency. Please look at the documentation of slowmo_memory_efficient for the reasoning (default: 48)

  • slowmo_lr (float) – This specifies the value of slowmo learning rate to be used (read https://arxiv.org/abs/1910.00643 for more details). We do not recommend changing this (default: 1.0)

  • slowmo_num_shards (int) – The number of shards between which slow momentum parameters are distributed. This is only used when memory_efficient is set to True. The number of shards should scale with the number of parameters in the model. Increasing the number of shards decreases the memory used per node for storing the slow momentum parameters. However, if the shard size per node is too small, it results in a communication overhead (default: 32)

LocalSGD Parameters

localsgd_frequency (int) – LocalSGD typically averages the parameters once every few iterations. This parameter specifices the frequency of averaging. We recommend keeping slowmo_frequency as a multiple of localsgd_frequency. Please look at the documentation of slowmo_memory_efficient for the reasoning (default: 3)

SGP Parameters
  • graph (Optional[GraphManager) – Graph to be used for gossip communication. This is used to specify the interaction graph between the different nodes (default: None)

  • mixing (Optional[MixingManager]) – Mixing manager to be used for gossip communication. This is used to specify weights given to outgoing and incoming messages (default: None)

  • push_sum (bool) – Whether to use PushSum or PushPull gossip (default: True)

  • overlap (bool) – Whether to use the overlap form of SGP. This feature is currently disabled until further testing is done for its use (default: False)

  • synch_freq (int) – How often (number of iterations) to synchronize for overlap SGP. A value of 0 means to synchronize overlap SGP every iteration (default: 0)

  • use_streams (bool) – Whether to use CUDA streams to speed up SGP overlap (default: True)

  • slowmo_sgp_average_params (bool) – Whether to completely average the parameters when slowmo is done instead of a partial averaging that happens every iteration (default: False)

Debugging Parameters
  • verbose (bool) – Prints various logs which are useful for debugging (default: False)

  • profile_mode (bool) – Prints the time taken by different parts of the code, which can help in finding bottlenecks (default: False)

Parameters for Advanced Users
  • process_rank (Optional[int]) – Rank of the current process in the process group (default: None)

  • process_world_size (Optional[int]) – Size of the process group (default: None)

  • global_group (Optional[torch.distributed.ProcessGroup]) – Global process group initialized by init_process_group (default: None)

  • master_group (Optional[torch.distributed.ProcessGroup]) – Process group which only contains the master GPUs of each node (default: None)

  • local_node_group (Optional[torch.distributed.ProcessGroup]) – Process group which only contains the GPUs local to the current node (default: None)

  • comm_device – (Optional[torch.device]): The torch.device on which torch tensors are to be placed before communication (default: None)

Example

>>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
>>> net = fairscale.data_parallel.SlowMoDistributedDataParallel(model, nprocs_per_node=8)
>>> loss = criterion(net(inputs), targets)
>>> loss.backward()
>>> optimizer.step()
>>> net.perform_slowmo(optimizer)
perform_slowmo(optimizer: torch.optim.optimizer.Optimizer, fp32_params: Optional[torch.Tensor] = None) None[source]

This is to be called after optimizer.step(). It performs the approximate averaging using the base algorithm (SGP/ LocalSGD) and the slow momentum step. Since LocalSGD and the slow momentum step are not performed every iteration, it only performs those when needed.

It is recommended to call model.zero_grad(set_to_none=True) just before calling this function. This is because model.zero_grad(set_to_none=True) frees up the memory occupied by the gradients, some of which may be reused by this function.

Parameters
  • optimizer (torch.optim.Optimizer) – The optimizer being used for training the model

  • fp32_params (Optional[torch.Tensor]) – To be used when performing fp16 training. Needs to be set to the fp16 copy of the parameters (default: None)

Read the Docs v: stable
Versions
latest
stable
docs
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.