Python之torch.no_grad()函数使用和示例

 更新时间:2024年03月26日 16:41:30   作者:木彳  
这篇文章主要介绍了Python之torch.no_grad()函数使用和示例,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

torch.no_grad()函数使用和示例

torch.no_grad() 是 PyTorch 中的一个上下文管理器,用于在进入该上下文时禁用梯度计算。

这在你只关心评估模型,而不是训练模型时非常有用,因为它可以显著减少内存使用并加速计算。

当你在 torch.no_grad() 上下文管理器中执行张量操作时,PyTorch 不会为这些操作计算梯度。

这意味着不会在 .grad 属性中累积梯度,并且操作会更快地执行。

使用torch.no_grad()

import torch

# 创建一个需要梯度的张量
x = torch.tensor([1.0], requires_grad=True)

# 使用 no_grad() 上下文管理器
with torch.no_grad():
    y = x * 2

    
y.backward()

print(x.grad)

输出:

RuntimeError                              Traceback (most recent call last)
Cell In[52], line 11
      7 with torch.no_grad():
      8     y = x * 2
---> 11 y.backward()
     13 print(x.grad)

File E:\anaconda\lib\site-packages\torch\_tensor.py:396, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    387 if has_torch_function_unary(self):
    388     return handle_torch_function(
    389         Tensor.backward,
    390         (self,),
   (...)
    394         create_graph=create_graph,
    395         inputs=inputs)
--> 396 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File E:\anaconda\lib\site-packages\torch\autograd\__init__.py:173, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    168     retain_graph = create_graph
    170 # The reason we repeat same the comment below is that
    171 # some Python versions print out the first line of a multi-line function
    172 # calls in the traceback and some print out the last line
--> 173 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    175     allow_unreachable=True, accumulate_grad=True)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

输出错误,因为使用了with torch.no_grad():。

不使用torch.no_grad()

import torch

# 创建一个需要梯度的张量
x = torch.tensor([1.0], requires_grad=True)

# 使用 no_grad() 上下文管理器
y = x * 2
y.backward()
print(x.grad)

输出:

tensor([2.])

@torch.no_grad()

with torch.no_grad()或者@torch.no_grad()中的数据不需要计算梯度,也不会进行反向传播

model.eval()                               
with torch.no_grad():
   ...

等价于

@torch.no_grad()
def eval():
    ...

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Tensorflow 定义变量,函数,数值计算等名字的更新方式

    Tensorflow 定义变量,函数,数值计算等名字的更新方式

    今天小编就为大家分享一篇Tensorflow 定义变量,函数,数值计算等名字的更新方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-02-02
  • python掌握字符串只需这一篇就够了

    python掌握字符串只需这一篇就够了

    字符串是 Python 中最常用的数据类型。我们可以使用引号('或")来创建字符串。创建字符串很简单,只要为变量分配一个值即可
    2021-11-11
  • Python如何对齐字符串

    Python如何对齐字符串

    这篇文章主要介绍了Python如何对齐字符串,文中讲解非常细致,代码帮助大家更好的理解和学习,感兴趣的朋友可以了解下
    2020-07-07
  • 详解python字符串驻留技术

    详解python字符串驻留技术

    在本文中,我们将深入研究 Python 的内部实现,并了解 Python 如何使用一种名为字符串驻留(String Interning)的技术,实现解释器的高性能。
    2021-05-05
  • pytest配置文件pytest.ini的具体使用

    pytest配置文件pytest.ini的具体使用

    本文主要介绍了pytest配置文件pytest.ini的具体使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-07-07
  • python下载文件时显示下载进度的方法

    python下载文件时显示下载进度的方法

    这篇文章主要介绍了python下载文件时显示下载进度的方法,涉及Python文件操作的技巧,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-04-04
  • 利用TensorFlow训练简单的二分类神经网络模型的方法

    利用TensorFlow训练简单的二分类神经网络模型的方法

    本篇文章主要介绍了利用TensorFlow训练简单的二分类神经网络模型的方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-03-03
  • Python趣味挑战之给幼儿园弟弟生成1000道算术题

    Python趣味挑战之给幼儿园弟弟生成1000道算术题

    为了让弟弟以后好好学习,我特地用Python给他生成了1000道算术题让他做,他以后一定会感谢我的!文中有非常详细的代码示例,需要的朋友可以参考下
    2021-05-05
  • pytorch无坑安装CPU版小白教程(配gpu版链接、conda命令教程)

    pytorch无坑安装CPU版小白教程(配gpu版链接、conda命令教程)

    pip安装无论是cpu还是gpu的pytorch安装,其实官方给了很好的安装流程,本文主要介绍了pytorch无坑安装CPU版小白教程,具有一定的参考价值,感兴趣的可以了解一下
    2024-03-03
  • pandas 如何分割字符的实现方法

    pandas 如何分割字符的实现方法

    这篇文章主要介绍了pandas 如何分割字符的实现方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07

最新评论