- class fairscale.nn.checkpoint.checkpoint_wrapper(module: torch.nn.modules.module.Module, offload_to_cpu: bool = False)¶
A friendlier wrapper for performing activation checkpointing.
Compared to the PyTorch version, this version:
wraps an nn.Module, so that all subsequent calls will use checkpointing
handles keyword arguments in the forward
handles non-Tensor outputs from the forward
supports offloading activations to CPU
checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True) a, b = checkpointed_module(x, y=3, z=torch.Tensor())
To understand the benefits of checkpointing and the offload_to_cpu flag, let’s divide activations into 2 types: inner activations and outer activations w.r.t. the checkpointed modules. The inner ones are saved by activation checkpointing, the outer ones are saved by offload_to_cpu.
In terms of GPU memory savings:
When inner ones are large in size and outer ones are small, checkpointing helps a lot, offload_to_cpu may help a little.
When inner ones are small and outer ones are large, checkpointing helps little, offload_to_cpu helps a lot.
When both inner and outer are large, both help and the benefit is additive.
The first and last layers are not likely to benefit from the `offload_to_cpu` flag because (1) there are typically other references to the first layer's input, so the GPU memory won't be freed; (2) the input to the last layer is immediately used by the backward pass and won't result in memory savings.
module (nn.Module) – The module to be wrapped
offload_to_cpu (bool) – Whether to offload activations to CPU.
(nn.Module) – Wrapped module