如何使用PyTorch优化一个边缘检测器

 更新时间:2024年09月23日 09:42:13   作者:GarryLau  
这篇文章主要给大家介绍了关于如何使用PyTorch优化一个边缘检测器的相关资料,文中通过代码介绍的非常详细,对大家的学习或者工作具有一定的参考借鉴价值,需要的朋友可以参考下

import torch
import torch.nn as nn

X = torch.tensor([[10,10,10,0,0,0],[10,10,10,0,0,0],[10,10,10,0,0,0],[10,10,10,0,0,0],[10,10,10,0,0,0],[10,10,10,0,0,0]], dtype=float)
Y = torch.tensor([[0,30,30,0],[0,30,30,0],[0,30,30,0],[0,30,30,0]], dtype=float)

conv2d = nn.Conv2d(1,1,kernel_size=(3,3), bias=False, dtype=float)

X = X.reshape((1,1,6,6))
Y = Y.reshape((1,1,4,4))
lr = 0.0005

optim = torch.optim.RMSprop(conv2d.parameters(), lr=lr)
loss_fn = torch.nn.MSELoss()
for i in range(4000):
    Y_pred = conv2d(X)
    loss = loss_fn(Y_pred, Y)
    # 更新参数
    if 0: # 手动更新
        conv2d.zero_grad()
        loss.backward()
        conv2d.weight.data[:] -= lr * conv2d.weight.grad
    if 10: # 使用优化器更新
        optim.zero_grad()
        loss.backward()
        optim.step()
    if(i + 1) % 100 == 0:
        print(f'epoch {i+1}, loss {loss.sum():.4f}')

# 打印训练的参数
print(conv2d.weight.data.reshape(3,3))

输出:

epoch 100, loss 331.4604
epoch 200, loss 284.8803
epoch 300, loss 248.8032
epoch 400, loss 218.8007
epoch 500, loss 193.1186
epoch 600, loss 170.4061
epoch 700, loss 149.4530
epoch 800, loss 129.7580
epoch 900, loss 111.4134
epoch 1000, loss 94.5393
epoch 1100, loss 79.1782
epoch 1200, loss 65.3312
epoch 1300, loss 52.9822
epoch 1400, loss 42.1062
epoch 1500, loss 32.6718
epoch 1600, loss 24.6388
epoch 1700, loss 17.9555
epoch 1800, loss 12.5522
epoch 1900, loss 8.3332
epoch 2000, loss 5.1700
epoch 2100, loss 2.9096
epoch 2200, loss 1.4077
epoch 2300, loss 0.5341
epoch 2400, loss 0.1348
epoch 2500, loss 0.0166
epoch 2600, loss 0.0006
epoch 2700, loss 0.0000
epoch 2800, loss 0.0001
epoch 2900, loss 0.0001
epoch 3000, loss 0.0001
epoch 3100, loss 0.0001
epoch 3200, loss 0.0002
epoch 3300, loss 0.0002
epoch 3400, loss 0.0002
epoch 3500, loss 0.0002
epoch 3600, loss 0.0002
epoch 3700, loss 0.0002
epoch 3800, loss 0.0002
epoch 3900, loss 0.0002
epoch 4000, loss 0.0002
tensor([[ 1.3123, -0.0050, -1.0276],
        [ 0.8334,  0.0677, -0.8868],
        [ 0.8551, -0.0619, -1.0849]], dtype=torch.float64)

由训练出的结果可以看出卷积核参数与实际的卷积核挺接近了。

到此这篇关于如何使用PyTorch优化一个边缘检测器的文章就介绍到这了,更多相关PyTorch优化边缘检测器内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Qt6中重大改变的QtMultimedia多媒体模块实现

    Qt6中重大改变的QtMultimedia多媒体模块实现

    本文主要介绍了Qt6中重大改变的QtMultimedia多媒体模块实现,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-09-09
  • 儿童编程python入门

    儿童编程python入门

    很多家长都想让孩子学习编程,今天我们给大家分享一下关于儿童python的入门以及简单的代码,有兴趣的朋友阅读下吧。
    2018-05-05
  • Python采用Django制作简易的知乎日报API

    Python采用Django制作简易的知乎日报API

    这篇文章主要为大家详细介绍了Python采用Django制作简易的知乎日报API,感兴趣的小伙伴们可以参考一下
    2016-08-08
  • tkinter自定义下拉多选框问题

    tkinter自定义下拉多选框问题

    这篇文章主要介绍了tkinter自定义下拉多选框问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-01-01
  • Python、 Pycharm、Django安装详细教程(图文)

    Python、 Pycharm、Django安装详细教程(图文)

    这篇文章主要介绍了Python、 Pycharm、Django安装详细教程,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-04-04
  • numpy linalg模块的具体使用方法

    numpy linalg模块的具体使用方法

    这篇文章主要介绍了numpy linalg模块的具体使用方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-05-05
  • Python生成rsa密钥对操作示例

    Python生成rsa密钥对操作示例

    这篇文章主要介绍了Python生成rsa密钥对操作,涉及Python rsa加密与密钥生成相关操作技巧,需要的朋友可以参考下
    2019-04-04
  • Python面向对象中的封装详情

    Python面向对象中的封装详情

    这篇文章主要介绍了Python面向对象中的封装详情,在python中也有对对象的封装操作,使其对外只提供固定的访问模式,不能访问其内部的私有属性和私有方法。下文详细内容,需要的小伙伴可以参考一下
    2022-03-03
  • Python图像处理之图像金字塔详解

    Python图像处理之图像金字塔详解

    这篇文章主要介绍了图像处理中的图像金字塔,包括图像向上取样和向下取样。文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编学习一下
    2022-02-02
  • 一文搞懂python 中的迭代器和生成器

    一文搞懂python 中的迭代器和生成器

    这篇文章主要介绍了python 中的迭代器和生成器简单介绍,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-03-03

最新评论