(4) 使用多GPUs时需注意的事项
- 如果该设备上已存在model.cuda(),那么它不会完成任何操作。
- 始终输入到设备列表中的第一个设备上。
- 跨设备传输数据非常昂贵,不到万不得已不要这样做。
- 优化器和梯度将存储在GPU 0上。因此,GPU 0使用的内存很可能比其他处理器大得多。
9. 多节点GPU训练
每台机器上的各GPU都可获取一份模型的副本。每台机器分得一部分数据,并仅针对该部分数据进行训练。各机器彼此同步梯度。
做到了这一步,就可以在几分钟内训练Imagenet数据集了! 这没有想象中那么难,但需要更多有关计算集群的知识。这些指令假定你正在集群上使用SLURM。
Pytorch在各个GPU上跨节点复制模型并同步梯度,从而实现多节点训练。因此,每个模型都是在各GPU上独立初始化的,本质上是在数据的一个分区上独立训练的,只是它们都接收来自所有模型的梯度更新。
高级阶段:
- 在各GPU上初始化一个模型的副本(确保设置好种子,使每个模型初始化到相同的权值,否则操作会失效。)
- 将数据集分成子集。每个GPU只在自己的子集上训练。
- On .backward() 所有副本都会接收各模型梯度的副本。只有此时,模型之间才会相互通信。
Pytorch有一个很好的抽象概念,叫做分布式数据并行处理,它可以为你完成这一操作。要使用DDP(分布式数据并行处理),需要做4件事:
- def tng_dataloader():
-
- d = MNIST()
-
- # 4: Add distributed sampler
- # sampler sends a portion of tng data to each machine
- dist_sampler = DistributedSampler(dataset)
- dataloader = DataLoader(d, shuffle=False, sampler=dist_sampler)
-
- def main_process_entrypoint(gpu_nb):
- # 2: set up connections between all gpus across all machines
- # all gpus connect to a single GPU "root"
- # the default uses env://
-
- world = nb_gpus * nb_nodes
- dist.init_process_group("nccl", rank=gpu_nb, worldworld_size=world)
-
- # 3: wrap model in DPP
- torch.cuda.set_device(gpu_nb)
- model.cuda(gpu_nb)
- model = DistributedDataParallel(model, device_ids=[gpu_nb])
-
- # train your model now...
-
- if __name__ == '__main__':
- # 1: spawn number of processes
- # your cluster will call main for each machine
- mp.spawn(main_process_entrypoint, nprocs=8)
Pytorch团队对此有一份详细的实用教程
(https://github.com/pytorch/examples/blob/master/imagenet/main.py?source=post_page---------------------------)。
然而,在Lightning中,这是一个自带功能。只需设定节点数标志,其余的交给Lightning处理就好。
- # train on 1024 gpus across 128 nodes
- trainer = Trainer(nb_gpu_nodes=128, gpus=[0, 1, 2, 3, 4, 5, 6, 7])
Lightning还附带了一个SlurmCluster管理器,可助你简单地提交SLURM任务的正确细节(示例:
https://github.com/williamFalcon/pytorch-lightning/blob/master/examples/new_project_templates/multi_node_cluster_template.py?source=post_page---------------------------#L103-L134)
10. 福利!更快的多GPU单节点训练
事实证明,分布式数据并行处理要比数据并行快得多,因为其唯一的通信是梯度同步。因此,最好用分布式数据并行处理替换数据并行,即使只是在做单机训练。
(编辑:ASP站长网)
|