Shortcuts

Offload Model

class fairscale.experimental.nn.OffloadModel(model: Any, device: torch.device, offload_device: torch.device = device(type='cpu'), num_slices: int = 3, checkpoint_activation: bool = False, num_microbatches: int = 1)[source]

Wraps an arbitrary nn.Sequential module to train by offloading majority of the model parameters to the CPU. OffloadModel is heavily inspired by the _L2L algorithm and _Zero-Offload.

model = get_model()
offload_model = OffloadModel(model, device,
                            offload_device=torch.device(“cpu”),
                            num_slices=3,
                            checkpoint_activation=True,
                            num_microbatches=5)

At each step, a layer(or series of layers) are loaded onto the GPU for the forward and backward pass with intermediate activations being copied onto the GPU as required. Once the forward or backward pass is completed for a given shard, it is moved back to the CPU again.

OffloadModel supports activation checkpointing which reduces the memory footprint. You can also increase the number of microbatches which translates to more computation cycles for every shard load. This helps offset the cost of moving the shard from the CPU to GPU and vice versa.

Note: OffloadModel currently only supports nn.Sequential models.

Parameters
  • module (Sequential) – Module to be offloaded.

  • device (torch.device) – Device where the active model should reside.

  • offload_device (torch.device) – Device where the inactive model should reside.

  • num_slices (int) – Number of slices into which the model should be chunked.

  • checkpoint_activation (bool) – Boolean to indicate if we want to checkpoint intermediate activation states on the CPU. Default value is False.

  • num_microbatches (int) – Number of microbatches which should be run per model shard on device.

forward(*inputs: Any, **_: Any) Any[source]
training: bool
Read the Docs v: stable
Versions
latest
stable
docs
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.