Pytorch distributed 多卡并行载入模型操作

 更新时间:2021年06月05日 10:09:20   作者:orientliu96  
这篇文章主要介绍了Pytorch distributed 多卡并行载入模型操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

一、Pytorch distributed 多卡并行载入模型

这次来介绍下如何载入模型。

目前没有找到官方的distribute 载入模型的方式,所以采用如下方式。

大部分情况下,我们在测试时不需要多卡并行计算。

所以,我在测试时只使用单卡。

from collections import OrderedDict
device = torch.device("cuda")
model = DGCNN(args).to(device)  #自己的模型
state_dict = torch.load(args.model_path)    #存放模型的位置

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    # load params
model.load_state_dict (new_state_dict)

二、pytorch DistributedParallel进行单机多卡训练

One_导入库:

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

Two_进程初始化:

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
# 添加必要参数
# local_rank:系统自动赋予的进程编号,可以利用该编号控制打印输出以及设置device

torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile',
rank=local_rank, world_size=world_size)

# world_size:所创建的进程数,也就是所使用的GPU数量
# (初始化设置详见参考文档)

Three_数据分发:

dataset = datasets.ImageFolder(dataPath)
data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size)
# 使用DistributedSampler来为各个进程分发数据,其中num_replicas与world_size保持一致,用于将数据集等分成不重叠的数个子集

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1,drop_last=True, pin_memory=True, sampler=data_sampler)
# 在Dataloader中指定sampler时,其中的shuffle必须为False,而DistributedSampler中的shuffle项默认为True,因此训练过程默认执行shuffle

Four_网络模型:

torch.cuda.set_device(local_rank)
device = torch.device('cuda:'+f'{local_rank}')
# 设置每个进程对应的GPU设备

D = Model()
D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device)
# 由于在训练过程中各卡的前向后向传播均独立进行,因此无法进行统一的批归一化,如果想要将各卡的输出统一进行批归一化,需要将模型中的BN转换成SyncBN
   
D = torch.nn.parallel.DistributedDataParallel(
D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)
# 如果有forward的返回值如果不在计算loss的计算图里,那么需要find_unused_parameters=True,即返回值不进入backward去算grad,也不需要在不同进程之间进行通信。

Five_迭代:

data_sampler.set_epoch(epoch)
# 每个epoch需要为sampler设置当前epoch

Six_加载:

dist.barrier()
D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu'))
dist.barrier()
# 加载模型前后用dist.barrier()来同步不同进程间的快慢

Seven_启动:

CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2
# 用-m torch.distributed.launch启动,nproc_per_node为所使用的卡数,batchsize设置为每张卡各自的批大小

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • pycharm 终端部启用虚拟环境详情

    pycharm 终端部启用虚拟环境详情

    这篇文章主要介绍了pycharm 终端部启用虚拟环境详情,文章围绕pycharm 终端部启用虚拟环境商务相关资料展开全文章的详细内容,需要的小伙伴可以参考一下
    2021-12-12
  • 教你如何在Pygame 中移动你的游戏角色

    教你如何在Pygame 中移动你的游戏角色

    Pygame是一组跨平台的 Python 模块,专为编写视频游戏而设计,您可以使用 pygame 创建不同类型的游戏,包括街机游戏、平台游戏等等,今天通过本文给大家介绍Pygame移动游戏角色的实现过程,一起看看吧
    2021-09-09
  • 微信小程序跳一跳游戏 python脚本跳一跳刷高分技巧

    微信小程序跳一跳游戏 python脚本跳一跳刷高分技巧

    这篇文章主要为大家详细介绍了微信小程序跳一跳游戏,python脚本跳一跳刷高分技巧,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-01-01
  • jupyter运行时左边一直出现*号问题及解决

    jupyter运行时左边一直出现*号问题及解决

    这篇文章主要介绍了jupyter运行时左边一直出现*号问题及解决方案,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-09-09
  • 如何利用Pandas删除某列指定值所在的行

    如何利用Pandas删除某列指定值所在的行

    工作中通常会遇到大量的数据集需要处理,其中的一项就是将含有某些数据的行删除掉,下面这篇文章主要给大家介绍了关于如何利用Pandas删除某列指定值所在的行的相关资料,需要的朋友可以参考下
    2022-04-04
  • python 如何对logging日志封装

    python 如何对logging日志封装

    这篇文章主要介绍了python 如何对logging日志封装,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2020-12-12
  • django+celery+RabbitMQ自定义多个消息队列的实现

    django+celery+RabbitMQ自定义多个消息队列的实现

    本文主要介绍了django+celery+RabbitMQ自定义多个消息队列的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02
  • 安装ElasticSearch搜索工具并配置Python驱动的方法

    安装ElasticSearch搜索工具并配置Python驱动的方法

    这篇文章主要介绍了安装ElasticSearch搜索工具并配置Python驱动的方法,文中还介绍了其与Kibana数据显示客户端的配合使用,需要的朋友可以参考下
    2015-12-12
  • opencv+python识别七段数码显示器的数字(数字识别)

    opencv+python识别七段数码显示器的数字(数字识别)

    本文主要介绍了opencv+python识别七段数码显示器的数字(数字识别),文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-01-01
  • 基于Python实现语音识别和语音转文字

    基于Python实现语音识别和语音转文字

    这篇文章主要为大家详细介绍了如何利用Python实现语音识别和语音转文字功能,文中的示例代码讲解详细,感兴趣的小伙伴可以了解一下
    2022-09-09

最新评论