keras的siamese(孪生网络)实现案例

 更新时间:2020年06月12日 14:20:25   作者:李上花开  
这篇文章主要介绍了keras的siamese(孪生网络)实现案例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python字节串类型bytes及用法

    Python字节串类型bytes及用法

    这篇文章介绍了Python字节串类型bytes及用法,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-05-05
  • Python实现单项链表的最全教程

    Python实现单项链表的最全教程

    单向链表也叫单链表,是链表中最简单的一种形式,它的每个节点包含两个域,一个信息域(元素域)和一个链接域,这个链接指向链表中的下一个节点,而最后一个节点的链接域则指向一个空值,这篇文章主要介绍了Python实现单项链表,需要的朋友可以参考下
    2023-01-01
  • Keras使用ImageNet上预训练的模型方式

    Keras使用ImageNet上预训练的模型方式

    这篇文章主要介绍了Keras使用ImageNet上预训练的模型方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • Pytorch中的torch.where函数使用

    Pytorch中的torch.where函数使用

    这篇文章主要介绍了Pytorch中的torch.where函数使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-02-02
  • Python编程itertools模块处理可迭代集合相关函数

    Python编程itertools模块处理可迭代集合相关函数

    本篇博客将为你介绍Python函数式编程itertools模块中处理可迭代集合的相关函数,有需要的朋友可以借鉴参考下,希望可以有所帮助
    2021-09-09
  • python制作一个桌面便签软件

    python制作一个桌面便签软件

    这篇文章主要介绍了python制作一个桌面便签软件分别给大家附上ubuntu和windows版的程序及源码,有需要的小伙伴可以参考下。
    2015-08-08
  • darknet框架中YOLOv3对数据集进行训练和预测详解

    darknet框架中YOLOv3对数据集进行训练和预测详解

    这篇文章主要为大家介绍了darknet框架中YOLOv3对数据集进行训练和预测使用详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-11-11
  • python中通过Django捕获所有异常的处理

    python中通过Django捕获所有异常的处理

    诚然,每个人都会写bug,程序抛异常是一件很正常的事;既然异常总是会抛,那就想办法在抛出后,尽早解决才是王道。不能老是等待用户反馈异常和问题,万一用户懒得反馈了,岂不很尴尬
    2021-09-09
  • Python数据可视化常用4大绘图库原理详解

    Python数据可视化常用4大绘图库原理详解

    这篇文章主要介绍了Python数据可视化常用4大绘图库原理详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-10-10
  • set在python里的含义和用法

    set在python里的含义和用法

    在本篇内容中我们给大家整理了关于set在python里的用法含义等相关知识点内容,有兴趣的朋友们可以学习下。
    2019-06-06

最新评论