Python深度学习pyTorch权重衰减与L2范数正则化解析

 更新时间:2021年09月30日 12:14:06   作者:算法菜鸟飞高高  
这篇文章主要介绍了Python深度学习中的pyTorch权重衰减与L2范数正则化的详细解析,文中附含详细示例代码,有需要的朋友可以借鉴参考下

在这里插入图片描述

下面进行一个高维线性实验

假设我们的真实方程是:

在这里插入图片描述

假设feature数200,训练样本和测试样本各20个

模拟数据集

num_train,num_test = 10,10
num_features = 200
true_w = torch.ones((num_features,1),dtype=torch.float32) * 0.01
true_b = torch.tensor(0.5)
samples = torch.normal(0,1,(num_train+num_test,num_features))
noise = torch.normal(0,0.01,(num_train+num_test,1))
labels = samples.matmul(true_w) + true_b + noise
train_samples, train_labels= samples[:num_train],labels[:num_train]
test_samples, test_labels = samples[num_train:],labels[num_train:]

定义带正则项的loss function

def loss_function(predict,label,w,lambd):
    loss = (predict - label) ** 2
    loss = loss.mean() + lambd * (w**2).mean()
    return loss

画图的方法

def semilogy(x_val,y_val,x_label,y_label,x2_val,y2_val,legend):
    plt.figure(figsize=(3,3))
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.semilogy(x_val,y_val)
    if x2_val and y2_val:
        plt.semilogy(x2_val,y2_val)
        plt.legend(legend)
    plt.show()

拟合和画图

def fit_and_plot(train_samples,train_labels,test_samples,test_labels,num_epoch,lambd):
    w = torch.normal(0,1,(train_samples.shape[-1],1),requires_grad=True)
    b = torch.tensor(0.,requires_grad=True)
    optimizer = torch.optim.Adam([w,b],lr=0.05)
    train_loss = []
    test_loss = []
    for epoch in range(num_epoch):
        predict = train_samples.matmul(w) + b
        epoch_train_loss = loss_function(predict,train_labels,w,lambd)
        optimizer.zero_grad()
        epoch_train_loss.backward()
        optimizer.step()
        test_predict = test_sapmles.matmul(w) + b
        epoch_test_loss = loss_function(test_predict,test_labels,w,lambd)
        train_loss.append(epoch_train_loss.item())
        test_loss.append(epoch_test_loss.item())
    semilogy(range(1,num_epoch+1),train_loss,'epoch','loss',range(1,num_epoch+1),test_loss,['train','test'])

在这里插入图片描述
可以发现加了正则项的模型,在测试集上的loss确实下降了

以上就是Python深度学习pyTorch权重衰减与L2范数正则化解析的详细内容,更多关于Python pyTorch权重与L2范数正则化的资料请关注脚本之家其它相关文章!

相关文章

  • python中的下划线多种用法总结

    python中的下划线多种用法总结

    在 Python 中,下划线(underscore)有多种用法,它在不同的上下文中可以扮演不同的角色,本文将介绍python中的下划线用法总结,感兴趣的朋友一起看看吧
    2024-05-05
  • Python 列表(List)操作方法详解

    Python 列表(List)操作方法详解

    这篇文章主要介绍了Python中列表(List)的详解操作方法,包含创建、访问、更新、删除、其它操作等,需要的朋友可以参考下
    2014-03-03
  • Python实现将罗马数字转换成普通阿拉伯数字的方法

    Python实现将罗马数字转换成普通阿拉伯数字的方法

    这篇文章主要介绍了Python实现将罗马数字转换成普通阿拉伯数字的方法,简单分析了罗马数字的构成并结合实例形式给出了Python转换罗马数字为阿拉伯数字的实现方法,需要的朋友可以参考下
    2017-04-04
  • python-xpath获取html文档的部分内容

    python-xpath获取html文档的部分内容

    这篇文章主要介绍了python-xpath获取html文档的部分内容,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-03-03
  • Python matplotlib的spines模块实例详解

    Python matplotlib的spines模块实例详解

    作为程序员,经常需要进行绘图,下面这篇文章主要给大家介绍了关于Python matplotlib的spines模块的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-08-08
  • Python+random模块实现随机抽样

    Python+random模块实现随机抽样

    python的random库,提供了很多随机抽样方法。本文将通过几个示例为大家详细讲讲random模块实现随机抽样的方法,需要的可以参考一下
    2022-09-09
  • pytorch模型的保存和加载、checkpoint操作

    pytorch模型的保存和加载、checkpoint操作

    这篇文章主要介绍了pytorch模型的保存和加载、checkpoint操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-06-06
  • Python 命令行 prompt_toolkit 库详解

    Python 命令行 prompt_toolkit 库详解

    prompt_toolkit 是一个用于构建强大交互式命令行的 Python 工具库。接下来通过本文给大家介绍Python 命令行 prompt_toolkit 库的相关知识,感兴趣的朋友一起看看吧
    2022-01-01
  • python3之模块psutil系统性能信息使用

    python3之模块psutil系统性能信息使用

    psutil是个跨平台库,能够轻松实现获取系统运行的进程和系统利用率,这篇文章主要介绍了python3之模块psutil系统性能信息使用,感兴趣的小伙伴们可以参考一下
    2018-05-05
  • Python使用Bokeh库实现炫目的交互可视化

    Python使用Bokeh库实现炫目的交互可视化

    Bokeh是一个用于创建交互式可视化图形的强大Python库,它不仅易于使用,而且功能强大,适用于各种数据可视化需求,本文将介绍Bokeh库的绘图可视化基础入门,需要的可以了解下
    2024-03-03

最新评论