import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP def example(rank, world_size): # create default process group dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) # create local model model = nn.Linear(10, 10).to(rank) # construct DDP model ddp_model = DDP(model, device_ids=[rank]) # define loss function and optimizer loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) # forward pass outputs = ddp_model(torch.randn(20, 10).to(rank)) labels = torch.randn(20, 10).to(rank) # backward pass loss_fn(outputs, labels).backward() # update parameters optimizer.step() def main(): world_size = 2 mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)
Preview:
downloadDownload PNG
downloadDownload JPEG
downloadDownload SVG
Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!
Click to optimize width for Twitter