Shortcuts

Efficient memory usage using Activation Checkpointing

Adapted from torch.utils.checkpoint, this is a friendlier wrapper for performing activation checkpointing.

Compared to the PyTorch version, this version wraps a nn.Module and allows for all subsequent calls to be checkpointed.

from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper


class CheckpointModel(nn.Module):

    def __init__(self, **kwargs):
        super().__init__()
        torch.manual_seed(0)  # make sure weights are deterministic.
        self.ffn_module = nn.Sequential(
            nn.Linear(32, 128),
            nn.Dropout(p=0.5),
            nn.Linear(128, 32),
        )

        self.ffn_module = checkpoint_wrapper(self.ffn_module, **kwargs)
        self.last_linear = nn.Linear(32, 1)

    def forward(self, input):
        output = self.ffn_module(input)
        return self.last_linear(output)
Read the Docs v: stable
Versions
latest
stable
docs
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.