keras实现theano和tensorflow训练的模型相互转换

 更新时间:2020年06月19日 11:50:22   作者:零落_World  
这篇文章主要介绍了keras实现theano和tensorflow训练的模型相互转换,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

我就废话不多说了,大家还是直接看代码吧~

</pre><pre code_snippet_id="1947416" snippet_file_name="blog_20161025_1_3331239" name="code" class="python">

# coding:utf-8
"""
If you want to load pre-trained weights that include convolutions (layers Convolution2D or Convolution1D),
be mindful of this: Theano and TensorFlow implement convolution in different ways (TensorFlow actually implements correlation, much like Caffe),
and thus, convolution kernels trained with Theano (resp. TensorFlow) need to be converted before being with TensorFlow (resp. Theano).
"""
from keras import backend as K
from keras.utils.np_utils import convert_kernel
from text_classifier import keras_text_classifier
import sys
 
def th2tf( model):
  import tensorflow as tf
  ops = []
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      ops.append(tf.assign(layer.W, converted_w).op)
  K.get_session().run(ops)
  return model
 
def tf2th(model):
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      K.set_value(layer.W, converted_w)
  return model
 
def conv_layer_converted(tf_weights, th_weights, m = 0):
  """
  :param tf_weights:
  :param th_weights:
  :param m: 0-tf2th, 1-th2tf
  :return:
  """
  if m == 0: # tf2th
    tc = keras_text_classifier(weights_path=tf_weights)
    model = tc.loadmodel()
    model = tf2th(model)
    model.save_weights(th_weights)
  elif m == 1: # th2tf
    tc = keras_text_classifier(weights_path=th_weights)
    model = tc.loadmodel()
    model = th2tf(model)
    model.save_weights(tf_weights)
  else:
    print("0-tf2th, 1-th2tf")
    return
if __name__ == '__main__':
  if len(sys.argv) < 4:
    print("python tf_weights th_weights <0|1>\n0-tensorflow to theano\n1-theano to tensorflow")
    sys.exit(0)
  tf_weights = sys.argv[1]
  th_weights = sys.argv[2]
  m = int(sys.argv[3])
  conv_layer_converted(tf_weights, th_weights, m)

补充知识:keras学习之修改底层为TensorFlow还是theano

我们知道,keras的底层是TensorFlow或者theano

要知道我们是用的哪个为底层,只需要import keras即可显示

修改方法:

打开

修改

以上这篇keras实现theano和tensorflow训练的模型相互转换就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 使用python实现tcp自动重连

    使用python实现tcp自动重连

    下面小编就为大家带来一篇使用python实现tcp自动重连实现方法。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。
    2017-07-07
  • caffe binaryproto 与 npy相互转换的实例讲解

    caffe binaryproto 与 npy相互转换的实例讲解

    今天小编就为大家分享一篇caffe binaryproto 与 npy相互转换的实例讲解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • 利用Python编写的实用运维脚本分享

    利用Python编写的实用运维脚本分享

    Python在很大程度上可以对shell脚本进行替代。笔者一般单行命令用shell,复杂点的多行操作就直接用Python了。本文归纳了Python中一些实用脚本操作,需要的可以参考一下
    2022-05-05
  • OpenCV图像颜色反转算法详解

    OpenCV图像颜色反转算法详解

    这篇文章主要介绍了OpenCV图像颜色反转算法详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-05-05
  • 基于Numpy.convolve使用Python实现滑动平均滤波的思路详解

    基于Numpy.convolve使用Python实现滑动平均滤波的思路详解

    这篇文章主要介绍了Python极简实现滑动平均滤波(基于Numpy.convolve)的相关知识,非常不错,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-05-05
  • 解析Python扩展模块的加速方案

    解析Python扩展模块的加速方案

    这章我们来介绍Python的扩展名之ctypes,教大家认识ctypes,有需要的朋友可以借鉴参考下,希望可以有所帮助,祝大家多多进步,早日升职加薪
    2021-09-09
  • python之excel文件(.xls文件)处理方式

    python之excel文件(.xls文件)处理方式

    这篇文章主要介绍了python之excel文件(.xls文件)处理方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-05-05
  • python实现用户名密码校验

    python实现用户名密码校验

    这篇文章主要为大家详细介绍了python实现用户名密码校验,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-03-03
  • 模拟浏览器的Python爬虫工具全面深入探索

    模拟浏览器的Python爬虫工具全面深入探索

    Python爬虫是获取网页信息的重要工具,但有时网站对爬虫有限制,要求模拟浏览器行为,本文将深入探讨如何使用Python模拟浏览器行为进行网络数据抓取,我们将介绍相关工具和技术,提供详细的示例代码和解释
    2024-01-01
  • python安装cx_Oracle模块常见问题与解决方法

    python安装cx_Oracle模块常见问题与解决方法

    这篇文章主要介绍了python安装cx_Oracle模块常见问题与解决方法,举例分析了Python在Windows平台与Linux平台安装cx_Oracle模块常见问题、解决方法及相关注意事项,需要的朋友可以参考下
    2017-02-02

最新评论