Shortcuts

Optimizer, Gradient and Model Sharding

Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications in the case of OSS, but it is possible and makes OSS a drop in solution in your existing torch distributed code. Let’s suppose that your trainer looks like

import torch
from torch.nn.parallel import DistributedDataParallel as DDP


def train(
    rank: int,
    world_size: int,
    epochs: int):

    # process group init
    dist_init(rank, world_size)

    # Problem statement
    model = myAwesomeModel().to(rank)
    model = DDP(model, device_ids=[rank])
    dataloader = mySuperFastDataloader()
    loss_ln = myVeryRelevantLoss()

    # optimizer specific arguments e.g. LR, momentum, etc...
    base_optimizer_arguments = { "lr": 1e-4}
    optimizer = torch.optim.SGD(
        params=model.parameters(),
        **base_optimizer_arguments)

    # Any relevant training loop, nothing specific to OSS. For example:
    model.train()
    for e in range(epochs):
        for (data, target) in dataloader:
            data, target = data.to(rank), target.to(rank)
            # Train
            model.zero_grad()
            outputs = model(data)
            loss = loss_fn(outputs, target)
            loss.backward()
            optimizer.step()

Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows. DDP can be used in place of ShardedDDP in the example below, but the memory savings will be reduced (the gradients are not as efficiently sharded).

import torch
from fairscale.optim.oss import OSS
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP


def train(
    rank: int,
    world_size: int,
    epochs: int):

    # process group init
    dist_init(rank, world_size)

    # Problem statement
    model = myAwesomeModel().to(rank)
    dataloader = mySuperFastDataloader()
    loss_ln = myVeryRelevantLoss()

    # optimizer specific arguments e.g. LR, momentum, etc...
    base_optimizer_arguments = { "lr": 1e-4}

    # Wrap a base optimizer into OSS
    base_optimizer = torch.optim.SGD  # any pytorch compliant optimizer
    optimizer = OSS(
        params=model.parameters(),
        optim=base_optimizer,
        **base_optimizer_arguments)

    # Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
    model = ShardedDDP(model, optimizer)

    # Any relevant training loop, nothing specific to OSS. For example:
    model.train()
    for e in range(epochs):
        for (data, target) in dataloader:
            data, target = data.to(rank), target.to(rank)
            # Train
            model.zero_grad()
            outputs = model(data)
            loss = loss_fn(outputs, target)
            loss.backward()
            optimizer.step()

The above train function can then be run via a multiprocessing.spawn call. Note that any launcher can be used, the only assumption being that each of the ranks lives in its own python process.

mp.spawn(
        train,
        args=(WORLD_SIZE, EPOCHS),
        nprocs=WORLD_SIZE,
        join=True
    )

Using PyTorch Automatic Mixed Precision is possible, and its actual usage will depend on whether OSS is used with DDP or with ShardedDDP. If OSS is used with DDP, then the normal PyTorch GradScaler can be used, nothing needs to be changed. If OSS is used with ShardedDDP (to get the gradient sharding), then a very similar flow can be used, but it requires a shard-aware GradScaler, which is available in fairscale.optim.grad_scaler. In both cases Autocast can be used as is, and the loss will be scaled and handled in the same way. See [the original documentation] (https://pytorch.org/docs/stable/notes/amp_examples.html?highlight=automatic%20mixed%20precision) for more information.

from fairscale.optim.grad_scaler import ShardedGradScaler


# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

# Creates a ShardedGradScaler once at the beginning of training.
scaler = ShardedGradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # Runs the forward pass with autocasting.
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        scaler.scale(loss).backward()

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

Parameters can be sharded using the FullyShardedDataParallel (FSDP) API. It involves wrapping your model similar to the SDP API above.

import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP


def train(
    rank: int,
    world_size: int,
    epochs: int):

    # process group init
    dist_init(rank, world_size)

    # Problem statement
    model = myAwesomeModel().to(rank)
    dataloader = mySuperFastDataloader()
    loss_ln = myVeryRelevantLoss()

    # optimizer specific arguments e.g. LR, momentum, etc...
    base_optimizer_arguments = { "lr": 1e-4}

    # Wrap a base optimizer into OSS
    base_optimizer = torch.optim.SGD  # any pytorch compliant optimizer

    # Wrap the model into FSDP, which will reduce parameters to the proper ranks
    model = FSDP(model)

    # Any relevant training loop. For example:
    model.train()
    for e in range(epochs):
        for (data, target) in dataloader:
            data, target = data.to(rank), target.to(rank)
            # Train
            model.zero_grad()
            outputs = model(data)
            loss = loss_fn(outputs, target)
            loss.backward()
            optimizer.step()

Auto wrapping sub-modules with FSDP is a convenient way to improve training speed by overlapping the allgather step across the forward passes of different submodules. It also improves memory efficiency by freeing gathered parameters after each layer finishes executing. For example:

import torch
from fairscale.nn.wrap import auto_wrap, enable_wrap, wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import DummyProcessGroup


tfmr = torch.nn.Transformer(num_encoder_layers=2, num_decoder_layers=2)

group = DummyProcessGroup(rank=0, size=1)
fsdp_params = dict(mixed_precision=True, flatten_parameters=True)
with enable_wrap(wrapper_cls=FSDP, process_group=group, **fsdp_params):

    # Wraps layer in FSDP by default if within context
    l1 = wrap(torch.nn.Linear(5, 5))
    assert isinstance(l1, FSDP)
    assert l1.mixed_precision and l1.flatten_parameters
    # Separately Wraps children modules with more than 1e8 params
    tfmr_auto_wrapped = auto_wrap(tfmr, min_num_params=1e6)
    assert isinstance(l2, nn.Transformer)
    for l in l2.encoder.layers:
        assert isinstance(l, FSDP)
        assert l.mixed_precision and l.flatten_parameters
        assert isinstance(l.linear1, FSDP)
        assert isinstance(l.linear2, FSDP)
        assert not isinstance(l.self_attn, FSDP) # self attention is not auto-wrapped
Read the Docs v: stable
Versions
latest
stable
docs
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.