Pytorch+PyG实现GraphSAGE过程示例详解

 更新时间:2023年04月21日 09:54:31   作者:实力  
这篇文章主要为大家介绍了Pytorch+PyG实现GraphSAGE过程示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

GraphSAGE简介

GraphSAGE(Graph Sampling and Aggregation)是一种常见的图神经网络模型,主要用于结点级别的表征学习。该模型基于采样和聚合策略,将一个结点及其邻居节点信息融合在一起,得到其表征表示,并通过多轮迭代更新来提高表征的精度。

实现步骤

数据准备

在本次实现中,我们仍然使用Cora数据集作为示例进行测试,由于GraphSage主要聚焦于单一节点特征的更新,因此这里不需要对数据集做特别处理,只需要将数据转化成PyG格式即可。

import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import from_networkx, to_networkx
# 加载cora数据集
dataset = Planetoid(root='./cora', name='Cora')
data = dataset[0]
# 将nx.Graph形式的图转换成PyG需要的格式
graph = to_networkx(data)
data = from_networkx(graph)
# 获取节点数量和特征向量维度
num_nodes = data.num_nodes
num_features = dataset.num_features
num_classes = dataset.num_classes
# 建立需要训练的节点分割数据集
data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.val_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.test_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.train_mask[:num_nodes - 1000] = True
data.test_mask[-1000:] = True
data.val_mask[num_nodes - 2000: num_nodes - 1000] = True

实现模型

接下来,我们需要定义GraphSAGE模型。与传统的GCN中只需要一层卷积操作不同,GraphSAGE包含两层卷积和采样(也称“聚合”)操作。

from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers):
        super(GraphSAGE, self).__init__()
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            in_channels = hidden_channels if i != 0 else num_features
            out_channels = num_classes if i == num_layers - 1 else hidden_channels
            self.convs.append(SAGEConv(in_channels, out_channels))
    def forward(self, x, edge_index):
        for _, conv in enumerate(self.convs[:-1]):
            x = F.relu(conv(x, edge_index))
        # 最后一层不用激活函数
        x = self.convs[-1](x, edge_index)
        return F.log_softmax(x, dim=-1)

在上述代码中,我们实现了多层GraphSAGE卷积和相应的聚合函数,并使用ReLU和softmax函数来进行特征提取和分类分数的输出。

模型训练

定义好模型之后,就可以开始针对Cora数据集进行模型训练。首先还是需要先指定优化器和损失函数,并设定一些参数用于记录训练过程中的信息,如Epochs、Batch size、学习率等。

# 初始化GraphSage并指定参数
num_layers = 2
hidden_channels = 256
model = GraphSAGE(hidden_channels, num_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_func = nn.CrossEntropyLoss()
# 训练过程
for epoch in range(500):
    model.train()
    optimizer.zero_grad()
    out = model(data.x.to(device), data.edge_index.to(device))
    loss = loss_func(out[data.train_mask], data.y.to(device)[data.train_mask])
    loss.backward()
    optimizer.step()
    # 在各个测试阶段检测一下准确率
    if epoch % 10 == 0:
        with torch.no_grad():
            _, pred = model(data.x.to(device), data.edge_index.to(device)).max(dim=1)
            correct = float(pred[data.test_mask].eq(data.y.to(device)[data.test_mask]).sum().item())
            acc = correct / data.test_mask.sum().item()
            print("Epoch {:03d}, Train Loss {:.4f}, Test Acc {:.4f}".format(
                epoch, loss.item(), acc))

在上述代码中,我们使用有标记的训练数据拟合GraphSAGE模型,在各个验证阶段测试准确率,并通过梯度下降法优化损失函数。

以上就是Pytorch+PyG实现GraphSAGE过程示例详解的详细内容,更多关于Pytorch PyG实现GraphSAGE的资料请关注脚本之家其它相关文章!

相关文章

  • python清除字符串里非数字字符的方法

    python清除字符串里非数字字符的方法

    这篇文章主要介绍了python清除字符串里非数字字符的方法,涉及Python使用re模块正则替换操作字符串的技巧,需要的朋友可以参考下
    2015-07-07
  • python模块itsdangerous简单介绍

    python模块itsdangerous简单介绍

    这篇文章主要介绍了python模块itsdangerous简单介绍,本文通过案例分析给大家详细讲解,对python模块itsdangerous相关知识感兴趣的朋友一起看看吧
    2022-11-11
  • Python如何为图片添加水印

    Python如何为图片添加水印

    这篇文章主要介绍了Python如何使用Python-Pillow库给图片添加水印的方法,非常的简单实用,有需要的小伙伴可以参考下
    2016-11-11
  • pycharm修改界面主题颜色的方法

    pycharm修改界面主题颜色的方法

    今天小编就为大家分享一篇pycharm修改界面主题颜色的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01
  • 使用python模拟高斯分布例子

    使用python模拟高斯分布例子

    今天小编就为大家分享一篇使用python模拟高斯分布例子,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • 使用django和vue进行数据交互的方法步骤

    使用django和vue进行数据交互的方法步骤

    这篇文章主要介绍了使用django和vue进行数据交互的方法步骤,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-11-11
  • python爬取淘宝商品详情页数据

    python爬取淘宝商品详情页数据

    这篇文章主要为大家详细介绍了python爬取淘宝商品详情页数据的相关资料,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-02-02
  • Python Pytest装饰器@pytest.mark.parametrize详解

    Python Pytest装饰器@pytest.mark.parametrize详解

    本文主要介绍了Python Pytest装饰器@pytest.mark.parametrize详解,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-08-08
  • python实现俄罗斯方块游戏

    python实现俄罗斯方块游戏

    这篇文章主要为大家介绍了python实现俄罗斯方块游戏的详细代码,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-06-06
  • 浅析python中SQLAlchemy排序的一个坑

    浅析python中SQLAlchemy排序的一个坑

    这篇文章主要介绍了关于python中SQLAlchemy排序的一个坑,文中给出了详细的示例代码,需要的朋友可以参考借鉴,感兴趣的朋友们下面来一起学习学习吧。
    2017-02-02

最新评论