编辑代码

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms, models

# 初始化分布式环境
def setup(rank, world_size):
    # 设置分布式通信的地址和端口
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '8889'
    # 初始化进程组,使用 Gloo 后端
    dist.init_process_group(
        backend='gloo',
        init_method='tcp://localhost:8889',
        rank=rank,
        world_size=world_size
    )
    # 设置当前进程使用的 GPU
    torch.cuda.set_device(rank)

# 清理分布式环境
def cleanup():
    dist.destroy_process_group()

# 训练函数
def train(rank, world_size):
    setup(rank, world_size)

    # 数据预处理和加载
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化处理
    ])
    dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
    dataloader = DataLoader(dataset, batch_size=64, sampler=sampler, num_workers=4)  # 使用多线程加载数据

    # 模型定义
    model = models.resnet18(num_classes=10).to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    # 损失函数和优化器
    criterion = torch.nn.CrossEntropyLoss().to(rank)
    optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.01, momentum=0.9)

    # 训练过程
    for epoch in range(10):  # 训练 10 个 epoch
        ddp_model.train()
        sampler.set_epoch(epoch)
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(rank), target.to(rank)

            optimizer.zero_grad()
            output = ddp_model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            # 每 10 个 batch 打印一次损失值,仅在主进程打印
            if batch_idx % 10 == 0 and rank == 0:
                print(f"Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {loss.item():.4f}")

    cleanup()

if __name__ == "__main__":
    # 获取 GPU 数量
    world_size = torch.cuda.device_count()
    if world_size == 0:
        print("No GPU devices available. Please ensure that your machine has CUDA-capable GPUs and PyTorch is installed with CUDA support.")
    else:
        # 启动多进程训练
        mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)