Efficient memory usage using Activation Checkpointing¶
Adapted from torch.utils.checkpoint, this is a friendlier wrapper for performing activation checkpointing.
Compared to the PyTorch version, this version wraps a nn.Module and allows for all subsequent calls to be checkpointed.
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
class CheckpointModel(nn.Module):
def __init__(self, **kwargs):
super().__init__()
torch.manual_seed(0) # make sure weights are deterministic.
self.ffn_module = nn.Sequential(
nn.Linear(32, 128),
nn.Dropout(p=0.5),
nn.Linear(128, 32),
)
self.ffn_module = checkpoint_wrapper(self.ffn_module, **kwargs)
self.last_linear = nn.Linear(32, 1)
def forward(self, input):
output = self.ffn_module(input)
return self.last_linear(output)