import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms, models
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '8889'
dist.init_process_group(
backend='gloo',
init_method='tcp://localhost:8889',
rank=rank,
world_size=world_size
)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def train(rank, world_size):
setup(rank, world_size)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
dataloader = DataLoader(dataset, batch_size=64, sampler=sampler, num_workers=4)
model = models.resnet18(num_classes=10).to(rank)
ddp_model = DDP(model, device_ids=[rank])
criterion = torch.nn.CrossEntropyLoss().to(rank)
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.01, momentum=0.9)
for epoch in range(10):
ddp_model.train()
sampler.set_epoch(epoch)
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = ddp_model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 10 == 0 and rank == 0:
print(f"Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {loss.item():.4f}")
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
if world_size == 0:
print("No GPU devices available. Please ensure that your machine has CUDA-capable GPUs and PyTorch is installed with CUDA support.")
else:
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)