keras分类模型中的输入数据与标签的维度实例

 更新时间:2020年07月03日 09:44:03   作者:xytywh  
这篇文章主要介绍了keras分类模型中的输入数据与标签的维度实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

在《python深度学习》这本书中。

一、21页mnist十分类

导入数据集
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

初始数据维度:
>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

数据预处理:
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
  
之后:
print(train_images, type(train_images), train_images.shape, train_images.dtype)
print(train_labels, type(train_labels), train_labels.shape, train_labels.dtype)
结果:
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]] <class 'numpy.ndarray'> (60000, 784) float32
[[0. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]] <class 'numpy.ndarray'> (60000, 10) float32

二、51页IMDB二分类

导入数据:

from keras.datasets import imdb (train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)

参数 num_words=10000 的意思是仅保留训练数据中前 10 000 个最常出现的单词。

train_data和test_data都是numpy.ndarray类型,都是一维的(共25000个元素,相当于25000个list),其中每个list代表一条评论,每个list中的每个元素的值范围在0-9999 ,代表10000个最常见单词的每个单词的索引,每个list长度不一,因为每条评论的长度不一,例如train_data中的list最短的为11,最长的为189。

train_labels和test_labels都是含25000个元素(元素的值要不0或者1,代表两类)的list。

数据预处理:

# 将整数序列编码为二进制矩阵
def vectorize_sequences(sequences, dimension=10000):
 # Create an all-zero matrix of shape (len(sequences), dimension)
 results = np.zeros((len(sequences), dimension))
 for i, sequence in enumerate(sequences):
  results[i, sequence] = 1. # set specific indices of results[i] to 1s
 return results


x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)

第一种方式:shape为(25000,)
y_train = np.asarray(train_labels).astype('float32') #就用这种方式就行了
y_test = np.asarray(test_labels).astype('float32')
第二种方式:shape为(25000,1)
y_train = np.asarray(train_labels).astype('float32').reshape(25000, 1)
y_test = np.asarray(test_labels).astype('float32').reshape(25000, 1)
第三种方式:shape为(25000,2)
y_train = to_categorical(train_labels) #变成one-hot向量
y_test = to_categorical(test_labels)

第三种方式,相当于把二分类看成了多分类,所以网络的结构同时需要更改,

最后输出的维度:1->2

最后的激活函数:sigmoid->softmax

损失函数:binary_crossentropy->categorical_crossentropy

预处理之后,train_data和test_data变成了shape为(25000,10000),dtype为float32的ndarray(one-hot向量),train_labels和test_labels变成了shape为(25000,)的一维ndarray,或者(25000,1)的二维ndarray,或者shape为(25000,2)的one-hot向量。

注:

1.sigmoid对应binary_crossentropy,softmax对应categorical_crossentropy

2.网络的所有输入和目标都必须是浮点数张量

补充知识:keras输入数据的方法:model.fit和model.fit_generator

1.第一种,普通的不用数据增强的

from keras.datasets import mnist,cifar10,cifar100
(X_train, y_train), (X_valid, Y_valid) = cifar10.load_data() 
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, shuffle=True,
    verbose=1, validation_data=(X_valid, Y_valid), )

2.第二种,带数据增强的 ImageDataGenerator,可以旋转角度、平移等操作。

from keras.preprocessing.image import ImageDataGenerator
(trainX, trainY), (testX, testY) = cifar100.load_data()
trainX = trainX.astype('float32')
testX = testX.astype('float32')
trainX /= 255.
testX /= 255.
Y_train = np_utils.to_categorical(trainY, nb_classes)
Y_test = np_utils.to_categorical(testY, nb_classes)
generator = ImageDataGenerator(rotation_range=15,
        width_shift_range=5./32,
        height_shift_range=5./32)
generator.fit(trainX, seed=0)
model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size),
     steps_per_epoch=len(trainX) // batch_size, epochs=nb_epoch,
     callbacks=callbacks,
     validation_data=(testX, Y_test),
     validation_steps=testX.shape[0] // batch_size, verbose=1)

以上这篇keras分类模型中的输入数据与标签的维度实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python使用sorted函数对列表进行排序的方法

    python使用sorted函数对列表进行排序的方法

    这篇文章主要介绍了python使用sorted函数对列表进行排序的方法,涉及Python使用sorted函数的技巧,非常具有实用价值,需要的朋友可以参考下
    2015-04-04
  • Python pandas数据合并merge函数用法详解

    Python pandas数据合并merge函数用法详解

    这篇文章主要给大家介绍了关于Python pandas数据合并merge函数用法的相关资料,数据分析中经常会遇到数据合并的基本问题,文中通过示例代码介绍的非常详细,需要的朋友可以参考下
    2023-07-07
  • 教你用Pygame制作简单的贪吃蛇游戏

    教你用Pygame制作简单的贪吃蛇游戏

    贪吃蛇(也叫做贪食蛇)游戏是一款休闲益智类游戏,既简单又耐玩,唯一的目标就是做这条gai上最长(pang)的蛇(zhu),这篇文章主要给大家介绍了关于如何使用Pygame制作简单的贪吃蛇游戏的相关资料,需要的朋友可以参考下
    2022-06-06
  • Python编程实现两个文件夹里文件的对比功能示例【包含内容的对比】

    Python编程实现两个文件夹里文件的对比功能示例【包含内容的对比】

    这篇文章主要介绍了Python编程实现两个文件夹里文件的对比功能,包含内容的对比操作,涉及Python文件与目录的遍历、比较、运算等相关操作技巧,需要的朋友可以参考下
    2017-06-06
  • python实现FTP服务器服务的方法

    python实现FTP服务器服务的方法

    本篇文章主要介绍了python实现FTP服务器的方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-04-04
  • python执行js脚本报错CryptoJS is not defined问题

    python执行js脚本报错CryptoJS is not defined问题

    这篇文章主要介绍了python执行js脚本报错CryptoJS is not defined问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-05-05
  • Python导出DBF文件到Excel的方法

    Python导出DBF文件到Excel的方法

    这篇文章主要介绍了Python导出DBF文件到Excel的方法,实例分析了Python基于win32com模块实现文件导出与转换的相关技巧,需要的朋友可以参考下
    2015-07-07
  • OpenCV实现机器人对物体进行移动跟随的方法实例

    OpenCV实现机器人对物体进行移动跟随的方法实例

    这篇文章主要给大家介绍了关于OpenCV实现机器人对物体进行移动跟随的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-11-11
  • OpenCV结合selenium实现滑块验证码

    OpenCV结合selenium实现滑块验证码

    本文主要介绍了OpenCV结合selenium实现滑块验证码,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-08-08
  • Python爬虫实例扒取2345天气预报

    Python爬虫实例扒取2345天气预报

    本篇文章给大家详细分析了通过Python爬虫如何采集到2345的天气预报信息,有兴趣的朋友参考学习下吧。
    2018-03-03

最新评论