Optimizer, Gradient and Model Sharding¶
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications in the case of OSS, but it is possible and makes OSS a drop in solution in your existing torch distributed code. Let’s suppose that your trainer looks like
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
def train(
rank: int,
world_size: int,
epochs: int):
# process group init
dist_init(rank, world_size)
# Problem statement
model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
# optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = { "lr": 1e-4}
optimizer = torch.optim.SGD(
params=model.parameters(),
**base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
# Train
model.zero_grad()
outputs = model(data)
loss = loss_fn(outputs, target)
loss.backward()
optimizer.step()
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows. DDP can be used in place of ShardedDDP in the example below, but the memory savings will be reduced (the gradients are not as efficiently sharded).
import torch
from fairscale.optim.oss import OSS
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
def train(
rank: int,
world_size: int,
epochs: int):
# process group init
dist_init(rank, world_size)
# Problem statement
model = myAwesomeModel().to(rank)
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
# optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = { "lr": 1e-4}
# Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
optimizer = OSS(
params=model.parameters(),
optim=base_optimizer,
**base_optimizer_arguments)
# Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
model = ShardedDDP(model, optimizer)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
# Train
model.zero_grad()
outputs = model(data)
loss = loss_fn(outputs, target)
loss.backward()
optimizer.step()
The above train function can then be run via a multiprocessing.spawn call. Note that any launcher can be used, the only assumption being that each of the ranks lives in its own python process.
mp.spawn(
train,
args=(WORLD_SIZE, EPOCHS),
nprocs=WORLD_SIZE,
join=True
)
Using PyTorch Automatic Mixed Precision is possible, and its actual usage will depend on whether OSS is used with DDP or with ShardedDDP. If OSS is used with DDP, then the normal PyTorch GradScaler can be used, nothing needs to be changed. If OSS is used with ShardedDDP (to get the gradient sharding), then a very similar flow can be used, but it requires a shard-aware GradScaler, which is available in fairscale.optim.grad_scaler. In both cases Autocast can be used as is, and the loss will be scaled and handled in the same way. See [the original documentation] (https://pytorch.org/docs/stable/notes/amp_examples.html?highlight=automatic%20mixed%20precision) for more information.
from fairscale.optim.grad_scaler import ShardedGradScaler
# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# Creates a ShardedGradScaler once at the beginning of training.
scaler = ShardedGradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# Runs the forward pass with autocasting.
with autocast():
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
Parameters can be sharded using the FullyShardedDataParallel (FSDP) API. It involves wrapping your model similar to the SDP API above.
import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
def train(
rank: int,
world_size: int,
epochs: int):
# process group init
dist_init(rank, world_size)
# Problem statement
model = myAwesomeModel().to(rank)
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
# optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = { "lr": 1e-4}
# Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
# Wrap the model into FSDP, which will reduce parameters to the proper ranks
model = FSDP(model)
# Any relevant training loop. For example:
model.train()
for e in range(epochs):
for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
# Train
model.zero_grad()
outputs = model(data)
loss = loss_fn(outputs, target)
loss.backward()
optimizer.step()
Auto wrapping sub-modules with FSDP is a convenient way to improve training speed by overlapping the allgather step across the forward passes of different submodules. It also improves memory efficiency by freeing gathered parameters after each layer finishes executing. For example:
import torch
from fairscale.nn.wrap import auto_wrap, enable_wrap, wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import DummyProcessGroup
tfmr = torch.nn.Transformer(num_encoder_layers=2, num_decoder_layers=2)
group = DummyProcessGroup(rank=0, size=1)
fsdp_params = dict(mixed_precision=True, flatten_parameters=True)
with enable_wrap(wrapper_cls=FSDP, process_group=group, **fsdp_params):
# Wraps layer in FSDP by default if within context
l1 = wrap(torch.nn.Linear(5, 5))
assert isinstance(l1, FSDP)
assert l1.mixed_precision and l1.flatten_parameters
# Separately Wraps children modules with more than 1e8 params
tfmr_auto_wrapped = auto_wrap(tfmr, min_num_params=1e6)
assert isinstance(l2, nn.Transformer)
for l in l2.encoder.layers:
assert isinstance(l, FSDP)
assert l.mixed_precision and l.flatten_parameters
assert isinstance(l.linear1, FSDP)
assert isinstance(l.linear2, FSDP)
assert not isinstance(l.self_attn, FSDP) # self attention is not auto-wrapped