浅谈tensorflow与pytorch的相互转换

 更新时间:2022年06月28日 08:40:28   作者:wendy_ya  
本文主要介绍了简单介绍一下tensorflow与pytorch的相互转换,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

本文以一段代码为例,简单介绍一下tensorflow与pytorch的相互转换(主要是tensorflow转pytorch),可能介绍的没有那么详细,仅供参考。

由于本人只熟悉pytorch,而对tensorflow一知半解,而代码经常遇到tensorflow,而我希望使用pytorch,因此简单介绍一下tensorflow转pytorch,可能存在诸多错误,希望轻喷~

1.变量预定义

在TensorFlow的世界里,变量的定义和初始化是分开的。
tensorflow中一般都是在开头预定义变量,声明其数据类型、形状等,在执行的时候再赋具体的值,如下图所示,而pytorch用到时才会定义,定义和变量初始化是合在一起的。

在这里插入图片描述

2.创建变量并初始化

tensorflow中利用tf.Variable创建变量并进行初始化,而pytorch中使用torch.tensor创建变量并进行初始化,如下图所示。

在这里插入图片描述

3.语句执行

在TensorFlow的世界里,变量的定义和初始化是分开的,所有关于图变量的赋值和计算都要通过tf.Session的run来进行。

sess.run([G_solver, G_loss_temp, MSE_loss],
             feed_dict = {X: X_mb, M: M_mb, H: H_mb})

而在pytorch中,并不需要通过run进行,赋值完了直接计算即可。

4.tensor

pytorch运算时要创建完的numpy数组转为tensor,如下:

if use_gpu is True:
	X_mb = torch.tensor(X_mb, device="cuda")
	M_mb = torch.tensor(M_mb, device="cuda")
	H_mb = torch.tensor(H_mb, device="cuda")
else:
	X_mb = torch.tensor(X_mb)
	M_mb = torch.tensor(M_mb)
	H_mb = torch.tensor(H_mb)

最后运行完还要将tensor数据类型转换回numpy数组:

if use_gpu is True:
	imputed_data=imputed_data.cpu().detach().numpy()
else:
	imputed_data=imputed_data.detach().numpy()

而tensorflow中不需要这种操作。

5.其他函数

在tensorflow中包含诸多函数是pytorch中没有的,但是都可以在其他库中找到类似,具体如下表所示。

tensorflow中函数pytorch中代替(所在库)参数区别
tf.sqrtnp.sqrt(numpy)完全相同
tf.random_normalnp.random.normal(numpy)tf.random_normal(shape = size, stddev = xavier_stddev)
np.random.normal(size = size, scale = xavier_stddev)
tf.concattorch.cat(torch)inputs = tf.concat(values = [x, m], axis = 1)
inputs = torch.cat(dim=1, tensors=[x, m])
tf.nn.reluF.relu(torch.nn.functional)完全相同
tf.nn.sigmoidtorch.sigmoid(torch)完全相同
tf.matmultorch.matmul(torch)完全相同
tf.reduce_meantorch.mean(torch)完全相同
tf.logtorch.log(torch)完全相同
tf.zerosnp.zeros完全相同
tf.train.AdamOptimizertorch.optim.Adam(torch)optimizer_D = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
optimizer_D = torch.optim.Adam(params=theta_D)

到此这篇关于浅谈tensorflow与pytorch的相互转换的文章就介绍到这了,更多相关tensorflow与pytorch的相互转换内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python使用mysql数据库示例代码

    python使用mysql数据库示例代码

    本篇文章主要介绍了python使用mysql数据库示例代码,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-05-05
  • Python对多个sheet表进行整合实例讲解

    Python对多个sheet表进行整合实例讲解

    在本篇文章里小编给大家整理的是一篇关于Python对多个sheet表进行整合实例讲解内容,有兴趣的朋友们可以学习下。
    2021-04-04
  • Pandas实现列(column)排序的几种方法

    Pandas实现列(column)排序的几种方法

    Pandas是一种高效的数据处理库,在数据处理过程中,咱们经常需要将列按照一定的要求进行排序,本文就来介绍一下Pandas实现列(column)排序的几种方法,感兴趣的可以了解一下
    2023-11-11
  • Python可变集合和不可变集合的构造方法大全

    Python可变集合和不可变集合的构造方法大全

    Python集合分为变集合和不可变集合两种,本文就详细的来介绍一下这两种集合的使用,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-12-12
  • Python实现区域填充的示例代码

    Python实现区域填充的示例代码

    这篇文章主要介绍了Python实现区域填充的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-02-02
  • Anaconda中更新当前环境的Python版本详细步骤

    Anaconda中更新当前环境的Python版本详细步骤

    Anaconda是一个开源的Python发行版本,其包含了conda、Python等180多个科学包及其依赖项,下面这篇文章主要给大家介绍了关于Anaconda中更新当前环境的Python版本的详细步骤,需要的朋友可以参考下
    2024-08-08
  • 对python:print打印时加u的含义详解

    对python:print打印时加u的含义详解

    今天小编就为大家分享一篇对python:print打印时加u的含义详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • Python脚本实现datax全量同步mysql到hive

    Python脚本实现datax全量同步mysql到hive

    这篇文章主要和大家分享一下mysql全量同步到hive自动生成json文件的python脚本,文中的示例代码讲解详细,有需要的小伙伴可以参加一下
    2024-10-10
  • Python基于yield遍历多个可迭代对象

    Python基于yield遍历多个可迭代对象

    这篇文章主要介绍了Python基于yield遍历多个可迭代对象,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-03-03
  • 详解pycharm连接远程linux服务器的虚拟环境的方法

    详解pycharm连接远程linux服务器的虚拟环境的方法

    这篇文章主要介绍了pycharm连接远程linux服务器的虚拟环境的详细教程,本文通过图文并茂的形式给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-11-11

最新评论