keras中的loss、optimizer、metrics用法

 更新时间:2020年06月15日 09:48:38   作者:wyf  
这篇文章主要介绍了keras中的loss、optimizer、metrics用法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

用keras搭好模型架构之后的下一步,就是执行编译操作。在编译时,经常需要指定三个参数

loss

optimizer

metrics

这三个参数有两类选择:

使用字符串

使用标识符,如keras.losses,keras.optimizers,metrics包下面的函数

例如:

sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',
  optimizer=sgd,
  metrics=['accuracy'])

因为有时可以使用字符串,有时可以使用标识符,令人很想知道背后是如何操作的。下面分别针对optimizer,loss,metrics三种对象的获取进行研究。

optimizer

一个模型只能有一个optimizer,在执行编译的时候只能指定一个optimizer。

在keras.optimizers.py中,有一个get函数,用于根据用户传进来的optimizer参数获取优化器的实例:

def get(identifier):
 # 如果后端是tensorflow并且使用的是tensorflow自带的优化器实例,可以直接使用tensorflow原生的优化器 
 if K.backend() == 'tensorflow':
 # Wrap TF optimizer instances
 if isinstance(identifier, tf.train.Optimizer):
  return TFOptimizer(identifier)
 # 如果以json串的形式定义optimizer并进行参数配置
 if isinstance(identifier, dict):
 return deserialize(identifier)
 elif isinstance(identifier, six.string_types):
 # 如果以字符串形式指定optimizer,那么使用优化器的默认配置参数
 config = {'class_name': str(identifier), 'config': {}}
 return deserialize(config)
 if isinstance(identifier, Optimizer):
 # 如果使用keras封装的Optimizer的实例
 return identifier
 else:
 raise ValueError('Could not interpret optimizer identifier: ' +
    str(identifier))

其中,deserilize(config)函数的作用就是把optimizer反序列化制造一个实例。

loss

keras.losses函数也有一个get(identifier)方法。其中需要注意以下一点:

如果identifier是可调用的一个函数名,也就是一个自定义的损失函数,这个损失函数返回值是一个张量。这样就轻而易举的实现了自定义损失函数。除了使用str和dict类型的identifier,我们也可以直接使用keras.losses包下面的损失函数。

def get(identifier):
 if identifier is None:
 return None
 if isinstance(identifier, six.string_types):
 identifier = str(identifier)
 return deserialize(identifier)
 if isinstance(identifier, dict):
 return deserialize(identifier)
 elif callable(identifier):
 return identifier
 else:
 raise ValueError('Could not interpret '
    'loss function identifier:', identifier)

metrics

在model.compile()函数中,optimizer和loss都是单数形式,只有metrics是复数形式。因为一个模型只能指明一个optimizer和loss,却可以指明多个metrics。metrics也是三者中处理逻辑最为复杂的一个。

在keras最核心的地方keras.engine.train.py中有如下处理metrics的函数。这个函数其实就做了两件事:

根据输入的metric找到具体的metric对应的函数

计算metric张量

在寻找metric对应函数时,有两种步骤:

使用字符串形式指明准确率和交叉熵

使用keras.metrics.py中的函数

def handle_metrics(metrics, weights=None):
 metric_name_prefix = 'weighted_' if weights is not None else ''

 for metric in metrics:
 # 如果metrics是最常见的那种:accuracy,交叉熵
 if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
  # custom handling of accuracy/crossentropy
  # (because of class mode duality)
  output_shape = K.int_shape(self.outputs[i])
  # 如果输出维度是1或者损失函数是二分类损失函数,那么说明是个二分类问题,应该使用二分类的accuracy和二分类的的交叉熵
  if (output_shape[-1] == 1 or
  self.loss_functions[i] == losses.binary_crossentropy):
  # case: binary accuracy/crossentropy
  if metric in ('accuracy', 'acc'):
   metric_fn = metrics_module.binary_accuracy
  elif metric in ('crossentropy', 'ce'):
   metric_fn = metrics_module.binary_crossentropy
  # 如果损失函数是sparse_categorical_crossentropy,那么目标y_input就不是one-hot的,所以就需要使用sparse的多类准去率和sparse的多类交叉熵
  elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
  # case: categorical accuracy/crossentropy
  # with sparse targets
  if metric in ('accuracy', 'acc'):
   metric_fn = metrics_module.sparse_categorical_accuracy
  elif metric in ('crossentropy', 'ce'):
   metric_fn = metrics_module.sparse_categorical_crossentropy
  else:
  # case: categorical accuracy/crossentropy
  if metric in ('accuracy', 'acc'):
   metric_fn = metrics_module.categorical_accuracy
  elif metric in ('crossentropy', 'ce'):
   metric_fn = metrics_module.categorical_crossentropy
  if metric in ('accuracy', 'acc'):
   suffix = 'acc'
  elif metric in ('crossentropy', 'ce'):
   suffix = 'ce'
  weighted_metric_fn = weighted_masked_objective(metric_fn)
  metric_name = metric_name_prefix + suffix
 else:
  # 如果输入的metric不是字符串,那么就调用metrics模块获取
  metric_fn = metrics_module.get(metric)
  weighted_metric_fn = weighted_masked_objective(metric_fn)
  # Get metric name as string
  if hasattr(metric_fn, 'name'):
  metric_name = metric_fn.name
  else:
  metric_name = metric_fn.__name__
  metric_name = metric_name_prefix + metric_name

 with K.name_scope(metric_name):
  metric_result = weighted_metric_fn(y_true, y_pred,
      weights=weights,
      mask=masks[i])

 # Append to self.metrics_names, self.metric_tensors,
 # self.stateful_metric_names
 if len(self.output_names) > 1:
  metric_name = self.output_names[i] + '_' + metric_name
 # Dedupe name
 j = 1
 base_metric_name = metric_name
 while metric_name in self.metrics_names:
  metric_name = base_metric_name + '_' + str(j)
  j += 1
 self.metrics_names.append(metric_name)
 self.metrics_tensors.append(metric_result)

 # Keep track of state updates created by
 # stateful metrics (i.e. metrics layers).
 if isinstance(metric_fn, Layer) and metric_fn.stateful:
  self.stateful_metric_names.append(metric_name)
  self.stateful_metric_functions.append(metric_fn)
  self.metrics_updates += metric_fn.updates

无论怎么使用metric,最终都会变成metrics包下面的函数。当使用字符串形式指明accuracy和crossentropy时,keras会非常智能地确定应该使用metrics包下面的哪个函数。因为metrics包下的那些metric函数有不同的使用场景,例如:

有的处理的是one-hot形式的y_input(数据的类别),有的处理的是非one-hot形式的y_input

有的处理的是二分类问题的metric,有的处理的是多分类问题的metric

当使用字符串“accuracy”和“crossentropy”指明metric时,keras会根据损失函数、输出层的shape来确定具体应该使用哪个metric函数。在任何情况下,直接使用metrics下面的函数名是总不会出错的。

keras.metrics.py文件中也有一个get(identifier)函数用于获取metric函数。

def get(identifier):
 if isinstance(identifier, dict):
 config = {'class_name': str(identifier), 'config': {}}
 return deserialize(config)
 elif isinstance(identifier, six.string_types):
 return deserialize(str(identifier))
 elif callable(identifier):
 return identifier
 else:
 raise ValueError('Could not interpret '
    'metric function identifier:', identifier)

如果identifier是字符串或者字典,那么会根据identifier反序列化出一个metric函数。

如果identifier本身就是一个函数名,那么就直接返回这个函数名。这种方式就为自定义metric提供了巨大便利。

keras中的设计哲学堪称完美。

以上这篇keras中的loss、optimizer、metrics用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python-tkinter之按钮的使用,开关方法

    python-tkinter之按钮的使用,开关方法

    今天小编就为大家分享一篇python-tkinter之按钮的使用,开关方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-06-06
  • Python join()函数原理及使用方法

    Python join()函数原理及使用方法

    这篇文章主要介绍了Python join()函数原理及使用方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-11-11
  • 详解python3 + Scrapy爬虫学习之创建项目

    详解python3 + Scrapy爬虫学习之创建项目

    这篇文章主要介绍了python3 Scrapy爬虫创建项目,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-04-04
  • Scrapy中诡异xpath的匹配内容失效问题及解决

    Scrapy中诡异xpath的匹配内容失效问题及解决

    这篇文章主要介绍了Scrapy中诡异xpath的匹配内容失效问题及解决方案,具有很好的价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-12-12
  • python处理中文编码和判断编码示例

    python处理中文编码和判断编码示例

    在开发自用爬虫过程中,有的网页是utf-8,有的是gb2312,有的是gbk,如果不加处理,采集到的都是乱码,解决的方法是将html处理成统一的utf-8编码
    2014-02-02
  • 使用Python创建多功能文件管理器的代码示例

    使用Python创建多功能文件管理器的代码示例

    在本文中,我们将探索一个使用Python的wxPython库开发的文件管理器应用程序,这个应用程序不仅能够浏览和选择文件,还支持文件预览、压缩、图片转换以及生成PPT演示文稿的功能,需要的朋友可以参考下
    2024-08-08
  • Python多线程编程(六):可重入锁RLock

    Python多线程编程(六):可重入锁RLock

    这篇文章主要介绍了Python多线程编程(六):可重入锁RLock,本文直接给出使用实例,然后讲解如何使用RLock避免死锁,需要的朋友可以参考下
    2015-04-04
  • 实现ECharts双Y轴左右刻度线一致的例子

    实现ECharts双Y轴左右刻度线一致的例子

    这篇文章主要介绍了实现ECharts双Y轴左右刻度线一致的例子,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • python 插入Null值数据到Postgresql的操作

    python 插入Null值数据到Postgresql的操作

    这篇文章主要介绍了python 插入Null值数据到Postgresql的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • Django中template for如何使用方法

    Django中template for如何使用方法

    这篇文章主要介绍了Django中template for如何使用方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-01-01

最新评论