使用Pytorch实现Swish激活函数的示例详解

 更新时间:2023年11月21日 11:41:42   作者:鲸落_  
激活函数是人工神经网络的基本组成部分,他们将非线性引入模型,使其能够学习数据中的复杂关系,Swish 激活函数就是此类激活函数之一,在本文中,我们将深入研究 Swish 激活函数,提供数学公式,探索其相对于 ReLU 的优势,并使用 PyTorch 演示其实现

前言

激活函数是人工神经网络的基本组成部分。他们将非线性引入模型,使其能够学习数据中的复杂关系。Swish 激活函数就是此类激活函数之一,因其独特的属性和相对于广泛使用的整流线性单元 (ReLU) 激活的潜在优势而受到关注。在本文中,我们将深入研究 Swish 激活函数,提供数学公式,探索其相对于 ReLU 的优势,并使用 PyTorch 演示其实现。

Swish 激活功能

Swish 激活函数由 Google 研究人员于 2017 年推出,其数学定义如下:

Swish(x) = x * sigmoid(x)

Where:

  • x:激活函数的输入值。
  • sigmoid(x):sigmoid 函数,将任何实数值映射到范围 [0, 1]。随着 x 的增加,它从 0 平滑过渡到 1。

Swish 激活将线性分量(输入 x)与非线性分量(sigmoid函数)相结合,产生平滑且可微的激活函数。

在哪里使用 Swish 激活?

Swish 可用于各种神经网络架构,包括前馈神经网络、卷积神经网络 (CNN) 和循环神经网络 (RNN)。它的优势在深度网络中变得尤为明显,它可以帮助缓解梯度消失问题。

Swish 激活函数相对于 ReLU 的优点

现在,我们来探讨一下 Swish 激活函数与流行的 ReLU 激活函数相比的优势。

平滑度和可微分性

由于 sigmoid 分量的存在,Swish 是一个平滑且可微的函数。此属性使其非常适合基于梯度的优化技术,例如随机梯度下降 (SGD) 和反向传播。相比之下,ReLU 在零处不可微(ReLU 的导数在 x=0 时未定义),这可能会带来优化挑战。

改进深度网络的学习

在深度神经网络中,与 ReLU 相比,Swish 可以实现更好的学习和收敛。Swish 的平滑性有助于梯度在网络中更平滑地流动,减少训练期间梯度消失的可能性。这在非常深的网络中尤其有用。

类似的计算成本

Swish 激活的计算效率很高,类似于 ReLU。这两个函数都涉及基本的算术运算,不会显着增加训练或推理过程中的计算负担。

使用 PyTorch 实现

现在,我们来看看如何使用 PyTorch 实现 Swish 激活函数。我们将创建一个自定义 Swish 模块并将其集成到一个简单的神经网络中。

让我们从导入必要的库开始。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

一旦我们完成了库的导入,我们就可以定义自定义激活——Swish。

以下代码定义了一个继承 PyTorch 基类的类。类内部有一个forward方法。该方法定义模块如何处理输入数据。它将输入张量作为参数,并在应用 Swish 激活后返回输出张量。

# Swish功能
class Swish(nn.Module):
	def forward(self, x):
		return x * torch.sigmoid(x)

定义 Swish 类后,我们继续定义神经网络模型。

在下面的代码片段中,我们使用 PyTorch 定义了一个专为图像分类任务设计的神经网络模型。

  • 输入层有28×28像素。

  • 隐藏层

    • 第一个隐藏层由 256 个神经元组成。它采用扁平输入并应用线性变换来产生输出。
    • 第二个隐藏层由 128 个神经元组成,从前一层获取 256 维输出并产生 128 维输出。
    • Swish 激活函数应用于两个隐藏层,以向网络引入非线性。
    • 输出层由 10 个神经元组成,用于执行 10 个类别的分类。
# 定义神经网络模型
class Net(nn.Module):
	def __init__(self):
		super(Net, self).__init__()
		self.fc1 = nn.Linear(28 * 28, 256)
		self.fc2 = nn.Linear(256, 128)
		self.fc3 = nn.Linear(128, 10)
		self.swish = Swish()

	def forward(self, x):
		x = x.view(-1, 28 * 28)
		x = self.fc1(x)
		x = self.swish(x)
		x = self.fc2(x)
		x = self.swish(x)
		x = self.fc3(x)
		return x

为了设置用于训练的神经网络,我们创建模型的实例,定义损失函数、优化器和数据转换。

# 创建模型的实例
model = Net()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 定义数据转换
transform = transforms.Compose([
	transforms.ToTensor(),
])

完成此步骤后,我们可以继续在数据集上训练和评估模型。让我们使用以下代码加载 MNIST 数据并创建用于训练的数据加载器。

# 加载MNIST数据集
train_dataset = datasets.MNIST('', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('', train=False, download=True, transform=transform)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

有了这些数据加载器,我们就可以继续训练循环来迭代批量的训练和测试数据。

在下面的代码中,我们执行了神经网络的训练循环。该循环将重复 5 个时期,在此期间更新模型的权重,以最大限度地减少损失并提高其在训练数据上的性能。

# 训练循环
num_epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)

for epoch in range(num_epochs):
	model.train()
	total_loss = 0.0
	for batch_idx, (data, target) in enumerate(train_loader):
		data, target = data.to(device), target.to(device)
		optimizer.zero_grad()
		output = model(data)
		loss = criterion(output, target)
		loss.backward()
		optimizer.step()
		total_loss += loss.item()
	
	print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

输出:

Epoch 1/5, Loss: 1.6938323568503062
Epoch 2/5, Loss: 0.4569567457397779
Epoch 3/5, Loss: 0.3522500048557917
Epoch 4/5, Loss: 0.31695075702369213
Epoch 5/5, Loss: 0.2961081813474496

最后一步是模型评估步骤。

# 评估循环
model.eval()
correct = 0
total = 0
with torch.no_grad():
	for data, target in test_loader:
		data, target = data.to(device), target.to(device)
		outputs = model(data)
		_, predicted = torch.max(outputs.data, 1)
		total += target.size(0)
		correct += (predicted == target).sum().item()

print(f"Accuracy on test set: {100 * correct / total}%")

输出:

Accuracy on test set: 92.02%

结论

Swish 激活函数为 ReLU 等传统激活函数提供了一种有前景的替代方案。它的平滑性、可微性和改善深度网络学习的潜力使其成为现代神经网络架构的宝贵工具。通过在 PyTorch 中实施 Swish,您可以利用其优势并探索其在各种机器学习任务中的有效性。

以上就是使用Pytorch实现Swish激活函数的示例详解的详细内容,更多关于Pytorch Swish激活函数的资料请关注脚本之家其它相关文章!

相关文章

  • numpy 数组拷贝地址所引起的同步替换问题

    numpy 数组拷贝地址所引起的同步替换问题

    本文主要介绍了numpy 数组拷贝地址所引起的同步替换问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02
  • Python利用Turtle绘制哆啦A梦和小猪佩奇

    Python利用Turtle绘制哆啦A梦和小猪佩奇

    turtle库是python的基础绘图库,经常被用来介绍编程知识的方法库,是标准库之一,利用turtle可以制作很多复杂的绘图。本文将为大家介绍通过turtle库绘制制哆啦A梦和小猪佩奇,感兴趣的小伙伴可以学习一下
    2022-04-04
  • python dict.get()和dict[''key'']的区别详解

    python dict.get()和dict[''key'']的区别详解

    下面小编就为大家带来一篇python dict.get()和dict['key']的区别详解。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2016-06-06
  • Flask实现跨域请求的处理方法

    Flask实现跨域请求的处理方法

    这篇文章主要介绍了Flask实现跨域请求的处理方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-09-09
  • 详解Python的Django框架中的模版继承

    详解Python的Django框架中的模版继承

    这篇文章主要介绍了详解Python的Django框架中的模版继承,就像Python中面对对象的方法继承道理类似,需要的朋友可以参考下
    2015-07-07
  • Python代码库之Tuple如何append添加元素问题

    Python代码库之Tuple如何append添加元素问题

    这篇文章主要介绍了Python代码库之Tuple如何append添加元素问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-01-01
  • Python 类方法和实例方法(@classmethod),静态方法(@staticmethod)原理与用法分析

    Python 类方法和实例方法(@classmethod),静态方法(@staticmethod)原理与用法分析

    这篇文章主要介绍了Python 类方法和实例方法(@classmethod),静态方法(@staticmethod),结合实例形式分析了Python 类方法和实例方法及静态方法相关原理、用法及相关操作注意事项,需要的朋友可以参考下
    2019-09-09
  • Python实现控制台输入密码的方法

    Python实现控制台输入密码的方法

    这篇文章主要介绍了Python实现控制台输入密码的方法,实例对比分析了几种输入密码的方法,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-05-05
  • Python pyecharts实现绘制中国地图的实例详解

    Python pyecharts实现绘制中国地图的实例详解

    pyecharts是一个用于生成 Echarts 图表的类库。Echarts 是百度开源的一个数据可视化 JS 库。用 Echarts 生成的图可视化效果非常棒。本文将通过pyecharts绘制中国地图,需要的可以学习一下
    2022-01-01
  • pycharm设置虚拟环境与更换镜像教程

    pycharm设置虚拟环境与更换镜像教程

    这篇文章主要介绍了pycharm设置虚拟环境与更换镜像教程,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-09-09

最新评论