pytorch使用nn.Moudle实现逻辑回归

 更新时间:2022年07月30日 15:42:35   作者:ALEN.Z  
这篇文章主要为大家详细介绍了pytorch使用nn.Moudle实现逻辑回归,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

本文实例为大家分享了pytorch使用nn.Moudle实现逻辑回归的具体代码,供大家参考,具体内容如下

内容

pytorch使用nn.Moudle实现逻辑回归

问题

loss下降不明显

解决方法

#源代码 out的数据接收方式
     if torch.cuda.is_available():
         x_data=Variable(x).cuda()
         y_data=Variable(y).cuda()
     else:
         x_data=Variable(x)
         y_data=Variable(y)
    
    out=logistic_model(x_data)  #根据逻辑回归模型拟合出的y值
    loss=criterion(out.squeeze(),y_data)  #计算损失函数
#源代码 out的数据有拼装数据直接输入
#     if torch.cuda.is_available():
#         x_data=Variable(x).cuda()
#         y_data=Variable(y).cuda()
#     else:
#         x_data=Variable(x)
#         y_data=Variable(y)
    
    out=logistic_model(x_data)  #根据逻辑回归模型拟合出的y值
    loss=criterion(out.squeeze(),y_data)  #计算损失函数
    print_loss=loss.data.item()  #得出损失函数值

源代码

import torch
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np

#生成数据
sample_nums = 100
mean_value = 1.7
bias = 1
n_data = torch.ones(sample_nums, 2)
x0 = torch.normal(mean_value * n_data, 1) + bias      # 类别0 数据 shape=(100, 2)
y0 = torch.zeros(sample_nums)                         # 类别0 标签 shape=(100, 1)
x1 = torch.normal(-mean_value * n_data, 1) + bias     # 类别1 数据 shape=(100, 2)
y1 = torch.ones(sample_nums)                          # 类别1 标签 shape=(100, 1)
x_data = torch.cat((x0, x1), 0)  #按维数0行拼接
y_data = torch.cat((y0, y1), 0)

#画图
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn')
plt.show()

# 利用torch.nn实现逻辑回归
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.lr = nn.Linear(2, 1)
        self.sm = nn.Sigmoid()

    def forward(self, x):
        x = self.lr(x)
        x = self.sm(x)
        return x
    
logistic_model = LogisticRegression()
# if torch.cuda.is_available():
#     logistic_model.cuda()

#loss函数和优化
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(logistic_model.parameters(), lr=0.01, momentum=0.9)
#开始训练
#训练10000次
for epoch in range(10000):
#     if torch.cuda.is_available():
#         x_data=Variable(x).cuda()
#         y_data=Variable(y).cuda()
#     else:
#         x_data=Variable(x)
#         y_data=Variable(y)
    
    out=logistic_model(x_data)  #根据逻辑回归模型拟合出的y值
    loss=criterion(out.squeeze(),y_data)  #计算损失函数
    print_loss=loss.data.item()  #得出损失函数值
    #反向传播
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    mask=out.ge(0.5).float()  #以0.5为阈值进行分类
    correct=(mask==y_data).sum().squeeze()  #计算正确预测的样本个数
    acc=correct.item()/x_data.size(0)  #计算精度
    #每隔20轮打印一下当前的误差和精度
    if (epoch+1)%100==0:
        print('*'*10)
        print('epoch {}'.format(epoch+1))  #误差
        print('loss is {:.4f}'.format(print_loss))
        print('acc is {:.4f}'.format(acc))  #精度
        
        
w0, w1 = logistic_model.lr.weight[0]
w0 = float(w0.item())
w1 = float(w1.item())
b = float(logistic_model.lr.bias.item())
plot_x = np.arange(-7, 7, 0.1)
plot_y = (-w0 * plot_x - b) / w1
plt.xlim(-5, 7)
plt.ylim(-7, 7)
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=logistic_model(x_data)[:,0].cpu().data.numpy(), s=100, lw=0, cmap='RdYlGn')
plt.plot(plot_x, plot_y)
plt.show()

输出结果

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • Python 爬虫之超链接 url中含有中文出错及解决办法

    Python 爬虫之超链接 url中含有中文出错及解决办法

    这篇文章主要介绍了Python 爬虫之超链接 url中含有中文出错及解决办法的相关资料,出现UnicodeEncodeError: 'ascii' codec can't encode characters,的错误解决办法,需要的朋友可以参考下
    2017-08-08
  • Python 解释器的站点配置和模块搜索路径详解

    Python 解释器的站点配置和模块搜索路径详解

    Python 解释器的站点配置是指一组配置和路径设置,用于支持特定于站点的定制和扩展,这些配置和路径信息由 Python 的内置 site 模块提供,这篇文章主要介绍了Python 解释器的站点配置和模块搜索路径详解,需要的朋友可以参考下
    2022-01-01
  • Python面向对象之类和对象

    Python面向对象之类和对象

    这篇文章主要为大家介绍了Python类和对象,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
    2021-12-12
  • pytorch-autograde-计算图的特点说明

    pytorch-autograde-计算图的特点说明

    这篇文章主要介绍了pytorch-autograde-计算图的特点,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-05-05
  • Python中多个数组行合并及列合并的方法总结

    Python中多个数组行合并及列合并的方法总结

    下面小编就为大家分享一篇Python中多个数组行合并及列合并的方法总结,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • 一行Python命令实现批量加水印

    一行Python命令实现批量加水印

    工作的时候,尤其是自媒体工作者,必备水印添加工具以保护知识产权。本文为大家提供了一个快速加水印的方法:一行Python命令就能实现,快来了解一下吧
    2022-04-04
  • python查看列的唯一值方法

    python查看列的唯一值方法

    今天小编就为大家分享一篇python查看列的唯一值方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • 关于Python Error标准异常的总结

    关于Python Error标准异常的总结

    这篇文章主要介绍了关于Python Error标准异常的总结,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-09-09
  • Python嵌入C/C++进行开发详解

    Python嵌入C/C++进行开发详解

    在本篇文章里小编给大家分享了关于Python嵌入C/C++进行开发的相关知识点内容,有兴趣的朋友们可以参考下。
    2020-06-06
  • python 实现简单的吃豆人游戏

    python 实现简单的吃豆人游戏

    这篇文章主要介绍了python 如何实现简单的吃豆人游戏,帮助大家更好的理解和学习使用python制作游戏,感兴趣的朋友可以了解下
    2021-04-04

最新评论