PyTorch实现线性回归详细过程
更新时间:2022年03月09日 17:13:45 作者:心️升明月
本文介绍PyTorch实现线性回归,线性关系是一种非常简单的变量之间的关系,因变量和自变量在线性关系的情况下,可以使用线性回归算法对一个或多个因变量和自变量间的线性关系进行建模,该模型的系数可以用最小二乘法进行求解,需要的朋友可以参考一下
一、实现步骤
1、准备数据
x_data = torch.tensor([[1.0],[2.0],[3.0]]) y_data = torch.tensor([[2.0],[4.0],[6.0]])
2、设计模型
class LinearModel(torch.nn.Module): def __init__(self): super(LinearModel,self).__init__() self.linear = torch.nn.Linear(1,1) def forward(self, x): y_pred = self.linear(x) return y_pred model = LinearModel()
3、构造损失函数和优化器
criterion = torch.nn.MSELoss(reduction='sum') optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
4、训练过程
epoch_list = [] loss_list = [] w_list = [] b_list = [] for epoch in range(1000): y_pred = model(x_data) # 计算预测值 loss = criterion(y_pred, y_data) # 计算损失 print(epoch,loss) epoch_list.append(epoch) loss_list.append(loss.data.item()) w_list.append(model.linear.weight.item()) b_list.append(model.linear.bias.item()) optimizer.zero_grad() # 梯度归零 loss.backward() # 反向传播 optimizer.step() # 更新
5、结果展示
展示最终的权重和偏置:
# 输出权重和偏置 print('w = ',model.linear.weight.item()) print('b = ',model.linear.bias.item())
结果为:
w = 1.9998501539230347
b = 0.0003405189490877092
模型测试:
# 测试模型 x_test = torch.tensor([[4.0]]) y_test = model(x_test) print('y_pred = ',y_test.data) y_pred = tensor([[7.9997]])
分别绘制损失值随迭代次数变化的二维曲线图和其随权重与偏置变化的三维散点图:
# 二维曲线图 plt.plot(epoch_list,loss_list,'b') plt.xlabel('epoch') plt.ylabel('loss') plt.show() # 三维散点图 fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.scatter(w_list,b_list,loss_list,c='r') #设置坐标轴 ax.set_xlabel('weight') ax.set_ylabel('bias') ax.set_zlabel('loss') plt.show()
结果如下图所示:
到此这篇关于PyTorch实现线性回归详细过程的文章就介绍到这了,更多相关PyTorch线性回归内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
二、参考文献
相关文章
Python 比较文本相似性的方法(difflib,Levenshtein)
今天小编就为大家分享一篇Python 比较文本相似性的方法(difflib,Levenshtein),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2018-10-10Python数据处理中pd.concat与pd.merge的区别及说明
这篇文章主要介绍了Python数据处理中pd.concat与pd.merge的区别及说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教2024-02-02Python 中给请求设置用户代理 User-Agent的方法
本文介绍 HTTP 标头用户代理主题以及如何使用 Python 中的请求设置用户代理,您将了解 HTTP 标头及其在理解用户代理、获取用户代理以及学习使用 Python 中的请求设置用户代理的多种方法方面的重要性,感兴趣的朋友跟随小编一起看看吧2023-06-06Numpy中扁平化函数ravel()和flatten()的区别详解
本文主要介绍了Numpy中扁平化函数ravel()和flatten()的区别详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧2023-02-02
最新评论