Pytorch 多块GPU的使用详解

 更新时间:2019年12月31日 09:36:44   作者:R.X.Geng  
今天小编就为大家分享一篇Pytorch 多块GPU的使用详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

注:本文针对单个服务器上多块GPU的使用,不是多服务器多GPU的使用。

在一些实验中,由于Batch_size的限制或者希望提高训练速度等原因,我们需要使用多块GPU。本文针对Pytorch中多块GPU的使用进行说明。

1. 设置需要使用的GPU编号

import os
 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"
ids = [0,1]

比如我们需要使用第0和第4块GPU,只用上述三行代码即可。

其中第二行指程序只能看到第1块和第4块GPU;

第三行的0即为第二行中编号为0的GPU;1即为编号为4的GPU。

2.更改网络,可以理解为将网络放入GPU

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()
    self.conv1 = nn.Sequential(
    ......
    )
    
    ......
    
    self.out = nn.Linear(Liner_input,2)
 
  ......
    
  def forward(self,x):
    x = self.conv1(x)
    ......
    output = self.out(x)
    return output,x
  
cnn = CNN()
 
# 更改,.cuda()表示将本存储到CPU的网络及其参数存储到GPU!
cnn.cuda()

3. 更改输出数据(如向量/矩阵/张量):

for epoch in range(EPOCH):
  epoch_loss = 0.
  for i, data in enumerate(train_loader2):
    image = data['image'] # data是字典,我们需要改的是其中的image
 
    #############更改!!!##################
    image = Variable(image).float().cuda()
    ############################################
 
    label = inputs['label']
    #############更改!!!##################
    label = Variable(label).type(torch.LongTensor).cuda()
    ############################################
    label = label.resize(BATCH_SIZE)
    output = cnn(image)[0]
    loss = loss_func(output, label)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step() 
    ... ...

4. 更改其他CPU与GPU冲突的地方

有些函数必要在GPU上完成,例如将Tensor转换为Numpy,就要使用data.cpu().numpy(),其中data是GPU上的Tensor。

若直接使用data.numpy()则会报错。除此之外,plot等也需要在CPU中完成。如果不是很清楚哪里要改的话可以先不改,等到程序报错了,再哪里错了改哪里,效率会更高。例如:

  ... ...
    #################################################
    pred_y = torch.max(test_train_output, 1)[1].data.cpu().numpy()
    
    accuracy = float((pred_y == label.cpu().numpy()).astype(int).sum()) / float(len(label.cpu().numpy()))

假如不加.cpu()便会报错,此时再改即可。

5. 更改前向传播函数,从而使用多块GPU

以VGG为例:

class VGG(nn.Module):
 
  def __init__(self, features, num_classes=2, init_weights=True):
    super(VGG, self).__init__()
... ...
 
  def forward(self, x):
    #x = self.features(x)
    #################Multi GPUS#############################
    x = nn.parallel.data_parallel(self.features,x,ids)
    x = x.view(x.size(0), -1)
    # x = self.classifier(x)
    x = nn.parallel.data_parallel(self.classifier,x,ids)
    return x
... ...

然后就可以看运行结果啦,nvidia-smi查看GPU使用情况:

可以看到0和4都被使用啦

以上这篇Pytorch 多块GPU的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 使用pandas中的DataFrame.rolling方法查看时间序列中的异常值

    使用pandas中的DataFrame.rolling方法查看时间序列中的异常值

    Pandas是Python中最受欢迎的数据分析和处理库之一,提供了许多强大且灵活的数据操作工具,在Pandas中,DataFrame.rolling方法是一个强大的工具,在本文中,我们将深入探讨DataFrame.rolling方法的各种参数和示例,以帮助您更好地理解和应用这个功能
    2023-12-12
  • Spyder中如何设置默认python解释器

    Spyder中如何设置默认python解释器

    Spyder作为一款流行的Python IDE,支持用户自定义Python解释器,包括虚拟环境的设置,通过打开Spyder,选择“Tools”->“Preferences”,在弹出窗口中选择“Use the following Python interpreter”后,浏览并选择相应的解释器或虚拟环境路径
    2024-09-09
  • python常用运维脚本实例小结

    python常用运维脚本实例小结

    这篇文章主要介绍了python常用运维脚本实例小结,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-02-02
  • Python中容易被忽视的核心功能总结

    Python中容易被忽视的核心功能总结

    Python是一门富有魅力的编程语言,拥有丰富的功能和库,以及强大的社区支持,然而,有一些核心功能经常被忽视,而它们实际上可以极大地提高代码的质量、可读性和性能,本文将给大家详细的介绍一下这些容易被忽视的功能,需要的朋友可以参考下
    2023-11-11
  • matplotlib 对坐标的控制,加图例注释的操作

    matplotlib 对坐标的控制,加图例注释的操作

    这篇文章主要介绍了matplotlib 对坐标的控制,加图例注释的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04
  • Python中pip安装非PyPI官网第三方库的方法

    Python中pip安装非PyPI官网第三方库的方法

    这篇文章主要介绍了Python中pip安装非PyPI官网第三方库的方法,pip最新的版本(1.5以上的版本), 出于安全的考 虑,pip不允许安装非PyPI的URL,本文就给出两种解决方法,需要的朋友可以参考下
    2015-06-06
  • Django windows使用Apache实现部署流程解析

    Django windows使用Apache实现部署流程解析

    这篇文章主要介绍了Django windows使用Apache实现部署流程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-10-10
  • Python实现将视频按照时间维度剪切

    Python实现将视频按照时间维度剪切

    这篇文章主要为大家详细介绍了如何利用Python实现将视频按照时间维度进行剪切,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起了解一下
    2022-12-12
  • Python超细致探究面向对象

    Python超细致探究面向对象

    面向对象编程是一种编程方式,此编程方式的落地需要使用“类”和 “对象”来实现,所以,面向对象编程其实就是对 “类”和“对象” 的使用,今天给大家介绍下python 面向对象开发及基本特征,感兴趣的朋友一起看看吧
    2022-06-06
  • Python使用迭代器捕获Generator返回值的方法

    Python使用迭代器捕获Generator返回值的方法

    这篇文章主要介绍了Python使用迭代器捕获Generator返回值的方法,结合具体实例形式分析了Python迭代器获取生成器返回值的相关操作技巧,需要的朋友可以参考下
    2017-04-04

最新评论