Activation Checkpoint

class fairscale.nn.checkpoint.checkpoint_wrapper(module: torch.nn.modules.module.Module, offload_to_cpu: bool = False)[source]

A friendlier wrapper for performing activation checkpointing.

Compared to the PyTorch version, this version:

  • wraps an nn.Module, so that all subsequent calls will use checkpointing

  • handles keyword arguments in the forward

  • handles non-Tensor outputs from the forward

  • supports offloading activations to CPU


checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True)
a, b = checkpointed_module(x, y=3, z=torch.Tensor([1]))

To understand the benefits of checkpointing and the offload_to_cpu flag, let’s divide activations into 2 types: inner activations and outer activations w.r.t. the checkpointed modules. The inner ones are saved by activation checkpointing, the outer ones are saved by offload_to_cpu.

In terms of GPU memory savings:

  • When inner ones are large in size and outer ones are small, checkpointing helps a lot, offload_to_cpu may help a little.

  • When inner ones are small and outer ones are large, checkpointing helps little, offload_to_cpu helps a lot.

  • When both inner and outer are large, both help and the benefit is additive.


The first and last layers are not likely to benefit from the `offload_to_cpu` flag
because (1) there are typically other references to the first layer's input, so
the GPU memory won't be freed; (2) the input to the last layer is immediately
used by the backward pass and won't result in memory savings.
  • module (nn.Module) – The module to be wrapped

  • offload_to_cpu (bool) – Whether to offload activations to CPU.


(nn.Module) – Wrapped module

Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.