Pipe¶
- class fairscale.nn.Pipe(module: torch.nn.modules.container.Sequential, balance: Optional[Iterable[int]] = None, *, devices: Optional[Union[Iterable[Union[torch.device, int, str]], List[Union[torch.device, int, str]]]] = None, chunks: int = 1, checkpoint: str = 'except_last', deferred_batch_norm: bool = False)[source]¶
Wraps an arbitrary
nn.Sequential
module to train on Pipe. If the module requires lots of memory, Pipe will be very efficient.model = nn.Sequential(a, b, c, d) model = Pipe(model, balance=[1, 1, 1, 1], chunks=8) output = model(input)
Pipe combines pipeline parallelism with checkpointing to reduce peak memory required to train while minimizing device under-utilization.
You should determine the balance when defining a
Pipe
module, as balancing will not be done automatically. The module will be partitioned into multiple devices according to the given balance. You may rely on heuristics to find your own optimal configuration.- Parameters
module (torch.nn.Sequential) – sequential module to be parallelized
balance (ints) – list of number of layers in each partition
- Keyword Arguments
devices (iterable of devices) – devices to use (default: all CUDA devices)
chunks (int) – number of micro-batches (default:
1
)checkpoint (str) – when to enable checkpointing, one of
'always'
,'except_last'
, or'never'
(default:'except_last'
)deferred_batch_norm (bool) – whether to use deferred BatchNorm moving statistics (default:
False
, seeDeferred Batch Normalization
for more details)
- Raises
TypeError – the module is not a
nn.Sequential
.ValueError – invalid arguments, or wrong balance
IndexError – the number of devices is fewer than the number of partitions.
- checkpoint: str = 'except_last'¶
The checkpoint mode to determine when to enable checkpointing. It is one of
'always'
,'except_last'
, or'never'
.
- devices: List[torch.device] = []¶
The devices mapped to each partition.
devices[-1]
refers to the device of the last partition, which means it is the output device. Probably, you need to use it to transfer the target to calculate the loss without a device mismatchRuntimeError
. For example:out_device = pipe.devices[-1] for input, target in loader: target = target.to(out_device, non_blocking=True) output = pipe(input) loss = F.cross_entropy(output, target)
- __getitem__(index: int) torch.nn.modules.module.Module [source]¶
Gets a layer in the underlying sequential module.
- __iter__() Iterable[torch.nn.modules.module.Module] [source]¶
Iterates over children of the underlying sequential module.
- cuda(device: Optional[Union[torch.device, int, str]] = None) fairscale.nn.pipe.pipe.Pipe [source]¶
- to(*args: Any, **kwargs: Any) fairscale.nn.pipe.pipe.Pipe [source]¶
Deny these usages: - to(device[, dtype, non_blocking]) - to(tensor[, non_blocking])
But allow this: - to(dtype[, non_blocking])
- forward(input: Union[torch.Tensor, Tuple[torch.Tensor, ...]]) Union[torch.Tensor, Tuple[torch.Tensor, ...]] [source]¶
Pipe
is a fairly transparent module wrapper. It doesn’t modify the input and output signature of the underlying module. But there’s type restriction. Input and output have to be aTensor
or a tuple of tensors. This restriction is applied at partition boundaries too.- Parameters
input (torch.Tensor or tensors) – input mini-batch
- Returns
tensor or tensors – output mini-batch
- Raises
TypeError – input is not a tensor or tensors.