Keras自定义实现带masking的meanpooling层方式

 更新时间:2020年06月16日 11:55:17   作者:蕉叉熵  
这篇文章主要介绍了Keras自定义实现带masking的meanpooling层方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

Keras确实是一大神器,代码可以写得非常简洁,但是最近在写LSTM和DeepFM的时候,遇到了一个问题:样本的长度不一样。对不定长序列的一种预处理方法是,首先对数据进行padding补0,然后引入keras的Masking层,它能自动对0值进行过滤。

问题在于keras的某些层不支持Masking层处理过的输入数据,例如Flatten、AveragePooling1D等等,而其中meanpooling是我需要的一个运算。例如LSTM对每一个序列的输出长度都等于该序列的长度,那么均值运算就只应该除以序列长度,而不是padding后的最长长度。

例如下面这个 3x4 大小的张量,经过补零padding的。我希望做axis=1的meanpooling,则第一行应该是 (10+20)/2,第二行应该是 (10+20+30)/3,第三行应该是 (10+20+30+40)/4。

Keras如何自定义层

在 Keras2.0 版本中(如果你使用的是旧版本请更新),自定义一个层的方法参考这里。具体地,你只要实现三个方法即可。

build(input_shape) : 这是你定义层参数的地方。这个方法必须设self.built = True,可以通过调用super([Layer], self).build()完成。如果这个层没有需要训练的参数,可以不定义。

call(x) : 这里是编写层的功能逻辑的地方。你只需要关注传入call的第一个参数:输入张量,除非你希望你的层支持masking。

compute_output_shape(input_shape) : 如果你的层更改了输入张量的形状,你应该在这里定义形状变化的逻辑,这让Keras能够自动推断各层的形状。

下面是一个简单的例子:

from keras import backend as K
from keras.engine.topology import Layer
import numpy as np

class MyLayer(Layer):

 def __init__(self, output_dim, **kwargs):
 self.output_dim = output_dim
 super(MyLayer, self).__init__(**kwargs)

 def build(self, input_shape):
 # Create a trainable weight variable for this layer.
 self.kernel = self.add_weight(name='kernel', 
  shape=(input_shape[1], self.output_dim),
  initializer='uniform',
  trainable=True)
 super(MyLayer, self).build(input_shape) # Be sure to call this somewhere!

 def call(self, x):
 return K.dot(x, self.kernel)

 def compute_output_shape(self, input_shape):
 return (input_shape[0], self.output_dim)

Keras自定义层如何允许masking

观察了一些支持masking的层,发现他们对masking的支持体现在两方面。

在 __init__ 方法中设置 supports_masking=True。

实现一个compute_mask方法,用于将mask传到下一层。

部分层会在call中调用传入的mask。

自定义实现带masking的meanpooling

假设输入是3d的。首先,在__init__方法中设置self.supports_masking = True,然后在call中实现相应的计算。

from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf

class MyMeanPool(Layer):
 def __init__(self, axis, **kwargs):
 self.supports_masking = True
 self.axis = axis
 super(MyMeanPool, self).__init__(**kwargs)

 def compute_mask(self, input, input_mask=None):
 # need not to pass the mask to next layers
 return None

 def call(self, x, mask=None):
 if mask is not None:
 mask = K.repeat(mask, x.shape[-1])
 mask = tf.transpose(mask, [0,2,1])
 mask = K.cast(mask, K.floatx())
 x = x * mask
 return K.sum(x, axis=self.axis) / K.sum(mask, axis=self.axis)
 else:
 return K.mean(x, axis=self.axis)

 def compute_output_shape(self, input_shape):
 output_shape = []
 for i in range(len(input_shape)):
 if i!=self.axis:
 output_shape.append(input_shape[i])
 return tuple(output_shape)

使用举例:

from keras.layers import Input, Masking
from keras.models import Model
from MyMeanPooling import MyMeanPool

data = [[[10,10],[0, 0 ],[0, 0 ],[0, 0 ]],
 [[10,10],[20,20],[0, 0 ],[0, 0 ]],
 [[10,10],[20,20],[30,30],[0, 0 ]],
 [[10,10],[20,20],[30,30],[40,40]]]

A = Input(shape=[4,2]) # None * 4 * 2
mA = Masking()(A)
out = MyMeanPool(axis=1)(mA)

model = Model(inputs=[A], outputs=[out])

print model.summary()
print model.predict(data)

结果如下,每一行对应一个样本的结果,例如第一个样本只有第一个时刻有值,输出结果是[10. 10. ],是正确的。

[[10. 10.]
 [15. 15.]
 [20. 20.]
 [25. 25.]]

在DeepFM中,每个样本都是由ID构成的,多值field往往会导致样本长度不一的情况,例如interest这样的field,同一个样本可能在该field中有多项取值,毕竟每个人的兴趣点不止一项。

采取padding的方法将每个field的特征补长到最长的长度,则数据尺寸是 [batch_size, max_timestep],经过Embedding为每个样本的每个特征ID配一个latent vector,数据尺寸将变为 [batch_size, max_timestep,latent_dim]。

我们希望每一个field的Embedding之后的尺寸为[batch_size, latent_dim],然后进行concat操作横向拼接,所以这里就可以使用自定义的MeanPool层了。希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • PyCharm-错误-找不到指定文件python.exe的解决方法

    PyCharm-错误-找不到指定文件python.exe的解决方法

    今天小编就为大家分享一篇PyCharm-错误-找不到指定文件python.exe的解决方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • Python用二分法求平方根的案例

    Python用二分法求平方根的案例

    这篇文章主要介绍了Python用二分法求平方根的案例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • Python中try excpet BaseException(异常处理捕获)的使用

    Python中try excpet BaseException(异常处理捕获)的使用

    本文主要介绍了Python中try excpet BaseException(异常处理捕获)的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-03-03
  • macOS M1(Apple Silicon)安装配置Conda环境的具体实现

    macOS M1(Apple Silicon)安装配置Conda环境的具体实现

    由于常用的Anaconda和Miniconda现在都没有提供M1处理器支持的conda环境,以下是conda-forge提供的miniforge,感兴趣的可以了解一下
    2021-08-08
  • python PyQt5中QButtonGroup的详细用法解析与应用实战记录

    python PyQt5中QButtonGroup的详细用法解析与应用实战记录

    在PyQt5中,QButtonGroup是一个用于管理按钮互斥性和信号槽连接的类,它可以将多个按钮划分为一个组,管理按钮的选中状态和ID,本文详细介绍了QButtonGroup的创建、使用方法和实际应用案例,适合需要在PyQt5项目中高效管理按钮组的开发者
    2024-10-10
  • 一文掌握Python爬虫XPath语法

    一文掌握Python爬虫XPath语法

    这篇文章主要介绍了一文掌握Python爬虫XPath语法,xpath是一门在XML和HTML文档中查找信息的语言,可用来在XML和HTML文档中对元素和属性进行遍历,XPath 通过使用路径表达式来选取 XML 文档中的节点或者节点集。下面会更学习的介绍,需要的朋友可以参考一下
    2021-11-11
  • 浅析Python函数式编程

    浅析Python函数式编程

    在本篇文章中我们给大家分享了关于Python函数式编程的相关知识点内容,有兴趣的朋友参考下。
    2018-10-10
  • Python2和Python3之间的str处理方式导致乱码的讲解

    Python2和Python3之间的str处理方式导致乱码的讲解

    今天小编就为大家分享一篇关于Python2和Python3之间的str处理方式导致乱码的讲解,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2019-01-01
  • Python协程操作之gevent(yield阻塞,greenlet),协程实现多任务(有规律的交替协作执行)用法详解

    Python协程操作之gevent(yield阻塞,greenlet),协程实现多任务(有规律的交替协作执行)用法详解

    这篇文章主要介绍了Python协程操作之gevent(yield阻塞,greenlet),协程实现多任务(有规律的交替协作执行)用法,结合实例形式较为详细的分析了协程的功能、原理及gevent、greenlet实现协程,以及协程实现多任务相关操作技巧,需要的朋友可以参考下
    2019-10-10
  • Python如何利用opencv实现手势识别

    Python如何利用opencv实现手势识别

    这篇文章主要介绍了Python如何利用opencv实现手势识别,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙可以参考一下
    2022-05-05

最新评论