如何在keras中添加自己的优化器(如adam等)

 更新时间:2020年06月19日 10:43:54   作者:阳光一直都在  
这篇文章主要介绍了在keras中实现添加自己的优化器(如adam等)方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

本文主要讨论windows下基于tensorflow的keras

1、找到tensorflow的根目录

如果安装时使用anaconda且使用默认安装路径,则在 C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow处可以找到(此处为GPU版本),cpu版本可在C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow处找到。若并非使用默认安装路径,可参照根目录查看找到。

2、找到keras在tensorflow下的根目录

需要特别注意的是找到keras在tensorflow下的根目录而不是找到keras的根目录。一般来说,完成tensorflow以及keras的配置后即可在tensorflow目录下的python目录中找到keras目录,以GPU为例keras在tensorflow下的根目录为C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow\python\keras

3、找到keras目录下的optimizers.py文件并添加自己的优化器

找到optimizers.py中的adam等优化器类并在后面添加自己的优化器类

以本文来说,我在第718行添加如下代码

@tf_export('keras.optimizers.adamsss')
class Adamsss(Optimizer):

 def __init__(self,
  lr=0.002,
  beta_1=0.9,
  beta_2=0.999,
  epsilon=None,
  schedule_decay=0.004,
  **kwargs):
 super(Adamsss, self).__init__(**kwargs)
 with K.name_scope(self.__class__.__name__):
 self.iterations = K.variable(0, dtype='int64', name='iterations')
 self.m_schedule = K.variable(1., name='m_schedule')
 self.lr = K.variable(lr, name='lr')
 self.beta_1 = K.variable(beta_1, name='beta_1')
 self.beta_2 = K.variable(beta_2, name='beta_2')
 if epsilon is None:
 epsilon = K.epsilon()
 self.epsilon = epsilon
 self.schedule_decay = schedule_decay

 def get_updates(self, loss, params):
 grads = self.get_gradients(loss, params)
 self.updates = [state_ops.assign_add(self.iterations, 1)]

 t = math_ops.cast(self.iterations, K.floatx()) + 1

 # Due to the recommendations in [2], i.e. warming momentum schedule
 momentum_cache_t = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
 momentum_cache_t_1 = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
 m_schedule_new = self.m_schedule * momentum_cache_t
 m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
 self.updates.append((self.m_schedule, m_schedule_new))

 shapes = [K.int_shape(p) for p in params]
 ms = [K.zeros(shape) for shape in shapes]
 vs = [K.zeros(shape) for shape in shapes]

 self.weights = [self.iterations] + ms + vs

 for p, g, m, v in zip(params, grads, ms, vs):
 # the following equations given in [1]
 g_prime = g / (1. - m_schedule_new)
 m_t = self.beta_1 * m + (1. - self.beta_1) * g
 m_t_prime = m_t / (1. - m_schedule_next)
 v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
 v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
 m_t_bar = (
  1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime

 self.updates.append(state_ops.assign(m, m_t))
 self.updates.append(state_ops.assign(v, v_t))

 p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
 new_p = p_t

 # Apply constraints.
 if getattr(p, 'constraint', None) is not None:
 new_p = p.constraint(new_p)

 self.updates.append(state_ops.assign(p, new_p))
 return self.updates

 def get_config(self):
 config = {
 'lr': float(K.get_value(self.lr)),
 'beta_1': float(K.get_value(self.beta_1)),
 'beta_2': float(K.get_value(self.beta_2)),
 'epsilon': self.epsilon,
 'schedule_decay': self.schedule_decay
 }
 base_config = super(Adamsss, self).get_config()
 return dict(list(base_config.items()) + list(config.items()))

然后修改之后的优化器调用类添加我自己的优化器adamss

需要修改的有(下面的两处修改依旧在optimizers.py内)

# Aliases.

sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
adadelta = Adadelta
adam = Adam
adamsss = Adamsss
adamax = Adamax
nadam = Nadam

以及

def deserialize(config, custom_objects=None):
 """Inverse of the `serialize` function.

 Arguments:
 config: Optimizer configuration dictionary.
 custom_objects: Optional dictionary mapping
  names (strings) to custom objects
  (classes and functions)
  to be considered during deserialization.

 Returns:
 A Keras Optimizer instance.
 """
 if tf2.enabled():
 all_classes = {
 'adadelta': adadelta_v2.Adadelta,
 'adagrad': adagrad_v2.Adagrad,
 'adam': adam_v2.Adam,
		'adamsss': adamsss_v2.Adamsss,
 'adamax': adamax_v2.Adamax,
 'nadam': nadam_v2.Nadam,
 'rmsprop': rmsprop_v2.RMSprop,
 'sgd': gradient_descent_v2.SGD
 }
 else:
 all_classes = {
 'adadelta': Adadelta,
 'adagrad': Adagrad,
 'adam': Adam,
 'adamax': Adamax,
 'nadam': Nadam,
		'adamsss': Adamsss,
 'rmsprop': RMSprop,
 'sgd': SGD,
 'tfoptimizer': TFOptimizer
 }

这里我们并没有v2版本,所以if后面的部分不改也可以。

4、调用我们的优化器对模型进行设置

model.compile(loss = 'crossentropy', optimizer = 'adamss', metrics=['accuracy'])

5、训练模型

train_history = model.fit(x, y_label, validation_split = 0.2, epoch = 10, batch = 128, verbose = 1)

补充知识:keras设置学习率--优化器的用法

优化器的用法

优化器 (optimizer) 是编译 Keras 模型的所需的两个参数之一:

from keras import optimizers
 
model = Sequential()
model.add(Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(Activation('softmax'))
 
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)

你可以先实例化一个优化器对象,然后将它传入 model.compile(),像上述示例中一样, 或者你可以通过名称来调用优化器。在后一种情况下,将使用优化器的默认参数。

# 传入优化器名称: 默认参数将被采用
model.compile(loss='mean_squared_error', optimizer='sgd')

以上这篇如何在keras中添加自己的优化器(如adam等)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python Process创建进程的2种方法详解

    Python Process创建进程的2种方法详解

    这篇文章主要介绍了Python Process创建进程的2种方法详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-01-01
  • Django中的静态文件管理过程解析

    Django中的静态文件管理过程解析

    这篇文章主要介绍了Django中的静态文件管理过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-08-08
  • Python读写压缩文件的方法

    Python读写压缩文件的方法

    这篇文章主要介绍了Python读写压缩文件的方法,文中讲解非常细致,代码帮助大家更好的理解和学习,感兴趣的朋友可以了解下
    2020-07-07
  • python主线程与子线程的结束顺序实例解析

    python主线程与子线程的结束顺序实例解析

    这篇文章主要介绍了python主线程与子线程的结束顺序实例解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-12-12
  • Python安装OpenCV库超时失败解决

    Python安装OpenCV库超时失败解决

    使用pip installopencv-python安装时,安装速度很慢,本文主要介绍了Python安装OpenCV库超时失败,具有一定的参考价值,感兴趣的可以了解一下
    2024-05-05
  • Django如何创作一个简单的最小程序

    Django如何创作一个简单的最小程序

    这篇文章主要介绍了Django如何创作一个简单的最小程序,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-05-05
  • Python四款GUI图形界面库介绍

    Python四款GUI图形界面库介绍

    这篇文章介绍了Python的四款GUI图形界面库,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-06-06
  • python基础教程之获取本机ip数据包示例

    python基础教程之获取本机ip数据包示例

    本文主要介绍了python获取本机ip数据包的示例,大家参考使用吧
    2014-02-02
  • 详解Python+opencv裁剪/截取图片的几种方式

    详解Python+opencv裁剪/截取图片的几种方式

    这篇文章主要介绍了详解Python+opencv裁剪/截取图片的几种方式,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-04-04
  • Python3 循环语句(for、while、break、range等)

    Python3 循环语句(for、while、break、range等)

    这篇文章主要介绍了Python3 循环语句(for、while、break、range等),大家把下面的文章看完就基本上就可以了解了python的循环实现方式了
    2017-11-11

最新评论