Python利用DNN实现宝石识别

 更新时间:2022年01月28日 10:01:02   作者:Livingbody  
深度神经网络(Deep Neural Networks,简称DNN)是深度学习的基础,其结构为input、hidden(可有多层)、output,每层均为全连接。本文将利用DNN实现宝石识别,感兴趣的可以了解一下

任务描述

本次实践是一个多分类任务,需要将照片中的宝石分别进行识别,完成宝石的识别

实践平台:百度AI实训平台-AI Studio、PaddlePaddle1.8.0 动态图

深度神经网络(DNN)

深度神经网络(Deep Neural Networks,简称DNN)是深度学习的基础,其结构为input、hidden(可有多层)、output,每层均为全连接。

数据集介绍

  • 数据集文件名为archive_train.zip,archive_test.zip。
  • 该数据集包含25个类别不同宝石的图像。
  • 这些类别已经分为训练和测试数据。
  • 图像大小不一,格式为.jpeg。

# 查看当前挂载的数据集目录, 该目录下的变更重启环境后会自动还原
# View dataset directory. This directory will be recovered automatically after resetting environment. 
!ls /home/aistudio/data

data55032 dataset

#导入需要的包
import os
import zipfile
import random
import json
import cv2
import numpy as np
from PIL import Image
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear
import matplotlib.pyplot as plt

1.数据准备

'''
参数配置
'''
train_parameters = {
    "input_size": [3, 64, 64],                           #输入图片的shape
    "class_dim": -1,                                     #分类数
    'augment_path' : '/home/aistudio/augment',           #数据增强图片目录
    "src_path":"data/data55032/archive_train.zip",       #原始数据集路径
    "target_path":"/home/aistudio/data/dataset",        #要解压的路径 
    "train_list_path": "./train_data.txt",              #train_data.txt路径
    "eval_list_path": "./val_data.txt",                  #eval_data.txt路径
    "label_dict":{},                                    #标签字典
    "readme_path": "/home/aistudio/data/readme.json",   #readme.json路径
    "num_epochs": 20,                                    #训练轮数
    "train_batch_size": 64,                             #批次的大小
    "learning_strategy": {                              #优化函数相关的配置
        "lr": 0.001                                     #超参数学习率
    } 
}
def unzip_data(src_path,target_path):

    '''
    解压原始数据集,将src_path路径下的zip包解压至data/dataset目录下
    '''

    if(not os.path.isdir(target_path)):    
        z = zipfile.ZipFile(src_path, 'r')
        z.extractall(path=target_path)
        z.close()
    else:
        print("文件已解压")
def get_data_list(target_path,train_list_path,eval_list_path, augment_path):
    '''
    生成数据列表
    '''
    #存放所有类别的信息
    class_detail = []
    #获取所有类别保存的文件夹名称
    data_list_path=target_path
    class_dirs = os.listdir(data_list_path)
    if '__MACOSX' in class_dirs:
        class_dirs.remove('__MACOSX')
    # #总的图像数量
    all_class_images = 0
    # #存放类别标签
    class_label=0
    # #存放类别数目
    class_dim = 0
    # #存储要写进eval.txt和train.txt中的内容
    trainer_list=[]
    eval_list=[]
    #读取每个类别
    for class_dir in class_dirs:
        if class_dir != ".DS_Store":
            class_dim += 1
            #每个类别的信息
            class_detail_list = {}
            eval_sum = 0
            trainer_sum = 0
            #统计每个类别有多少张图片
            class_sum = 0
            #获取类别路径 
            path = os.path.join(data_list_path,class_dir)
            # print(path)
            # 获取所有图片
            img_paths = os.listdir(path)
            for img_path in img_paths:                                  # 遍历文件夹下的每个图片
                if img_path =='.DS_Store':
                    continue
                name_path = os.path.join(path,img_path)                       # 每张图片的路径
                if class_sum % 15 == 0:                                 # 每10张图片取一个做验证数据
                    eval_sum += 1                                       # eval_sum为测试数据的数目
                    eval_list.append(name_path + "\t%d" % class_label + "\n")
                else:
                    trainer_sum += 1 
                    trainer_list.append(name_path + "\t%d" % class_label + "\n")#trainer_sum测试数据的数目
                class_sum += 1                                          #每类图片的数目
                all_class_images += 1                                   #所有类图片的数目 
            # ----------------------------------数据增强----------------------------------
            aug_path = os.path.join(augment_path, class_dir)
            for img_path in os.listdir(aug_path):                                  # 遍历文件夹下的每个图片
                name_path = os.path.join(aug_path,img_path)                       # 每张图片的路径
                trainer_sum += 1 
                trainer_list.append(name_path + "\t%d" % class_label + "\n")#trainer_sum测试数据的数目
                all_class_images += 1                                   #所有类图片的数目
            # ----------------------------------------------------------------------------
            # 说明的json文件的class_detail数据
            class_detail_list['class_name'] = class_dir             #类别名称
            class_detail_list['class_label'] = class_label          #类别标签
            class_detail_list['class_eval_images'] = eval_sum       #该类数据的测试集数目
            class_detail_list['class_trainer_images'] = trainer_sum #该类数据的训练集数目
            class_detail.append(class_detail_list)  
            #初始化标签列表
            train_parameters['label_dict'][str(class_label)] = class_dir
            class_label += 1
            
    #初始化分类数
    train_parameters['class_dim'] = class_dim
    print(train_parameters)
    #乱序  
    random.shuffle(eval_list)
    with open(eval_list_path, 'a') as f:
        for eval_image in eval_list:
            f.write(eval_image) 
    #乱序        
    random.shuffle(trainer_list) 
    with open(train_list_path, 'a') as f2:
        for train_image in trainer_list:
            f2.write(train_image) 

    # 说明的json文件信息
    readjson = {}
    readjson['all_class_name'] = data_list_path                  #文件父目录
    readjson['all_class_images'] = all_class_images
    readjson['class_detail'] = class_detail
    jsons = json.dumps(readjson, sort_keys=True, indent=4, separators=(',', ': '))
    with open(train_parameters['readme_path'],'w') as f:
        f.write(jsons)
    print ('生成数据列表完成!')
def data_reader(file_list):
    '''
    自定义data_reader
    '''
    def reader():
        with open(file_list, 'r') as f:
            lines = [line.strip() for line in f]
            for line in lines:
                img_path, lab = line.strip().split('\t')
                img = Image.open(img_path) 
                if img.mode != 'RGB': 
                    img = img.convert('RGB') 
                img = img.resize((64, 64), Image.BILINEAR)
                img = np.array(img).astype('float32') 
                img = img.transpose((2, 0, 1))  # HWC to CHW 
                img = img/255                   # 像素值归一化 
                yield img, int(lab) 
    return reader
!pip install Augmentor
Looking in indexes: https://mirror.baidu.com/pypi/simple/
Requirement already satisfied: Augmentor in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (0.2.8)
Requirement already satisfied: tqdm>=4.9.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (4.36.1)
Requirement already satisfied: future>=0.16.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (0.18.0)
Requirement already satisfied: numpy>=1.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (1.16.4)
Requirement already satisfied: Pillow>=5.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (7.1.2)

'''
参数初始化
'''
src_path=train_parameters['src_path']
target_path=train_parameters['target_path']
train_list_path=train_parameters['train_list_path']
eval_list_path=train_parameters['eval_list_path']
batch_size=train_parameters['train_batch_size']
augment_path = train_parameters['augment_path']
'''
解压原始数据到指定路径
'''
unzip_data(src_path,target_path)

文件已解压

def proc_img(src):
    for root, dirs, files in os.walk(src):
        if '__MACOSX' in root:continue
        for file in files:            
            src=os.path.join(root,file)
            img=Image.open(src)
            if img.mode != 'RGB': 
                    img = img.convert('RGB') 
                    img.save(src)            


if __name__=='__main__':
    proc_img(r"data/dataset")
import os, Augmentor
import shutil, glob

if not os.path.exists(augment_path): # 控制不重复增强数据
    for root, dirs, files in os.walk("data/dataset", topdown=False):
        for name in dirs:
            path_ = os.path.join(root, name)
            if '__MACOSX' in path_:continue
            print('数据增强:',os.path.join(root, name))
            print('image:',os.path.join(root, name))
            p = Augmentor.Pipeline(os.path.join(root, name),output_directory='output')
            p.rotate(probability=0.6, max_left_rotation=2, max_right_rotation=2)
            p.zoom(probability=0.6, min_factor=0.9, max_factor=1.1)
            p.random_distortion(probability=0.4, grid_height=2, grid_width=2, magnitude=1)

            count = 1000 - len(glob.glob(pathname=path_+'/*.jpg'))
            p.sample(count, multi_threaded=False)
            p.process()

    print('将生成的图片拷贝到正确的目录')
    for root, dirs, files in os.walk("data/dataset", topdown=False):
        for name in files:
            path_ = os.path.join(root, name)
            if path_.rsplit('/',3)[2] == 'output':
                type_ = path_.rsplit('/',3)[1]
                dest_dir = os.path.join(augment_path ,type_) 
                if not os.path.exists(dest_dir):os.makedirs(dest_dir) 
                dest_path_ = os.path.join(augment_path ,type_, name) 
                shutil.move(path_, dest_path_)
    print('删除所有output目录')
    for root, dirs, files in os.walk("data/dataset", topdown=False):
        for name in dirs:
            if name == 'output':
                path_ = os.path.join(root, name)
                shutil.rmtree(path_)
    print('完成数据增强')
Processing kunzite_20.jpg:   1%|          | 11/968 [00:00<00:14, 65.61 Samples/s]

数据增强: data/dataset/Kunzite
image: data/dataset/Kunzite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Kunzite/output.

Processing kunzite_14.jpg:   2%|▏         | 24/968 [00:00<00:17, 54.43 Samples/s]Processing kunzite_15.jpg: 100%|██████████| 968/968 [00:15<00:00, 61.57 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=350x366 at 0x7F7060EB06D0>: 100%|██████████| 32/32 [00:00<00:00, 269.33 Samples/s]                  
Processing almandine_5.jpg:   1%|          | 6/969 [00:00<00:20, 45.91 Samples/s] 

数据增强: data/dataset/Almandine
image: data/dataset/Almandine
Initialised with 31 image(s) found.
Output directory set to data/dataset/Almandine/output.

Processing almandine_2.jpg:   1%|▏         | 14/969 [00:00<00:27, 34.12 Samples/s] Processing almandine_25.jpg: 100%|██████████| 969/969 [00:22<00:00, 42.25 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705E020C90>: 100%|██████████| 31/31 [00:00<00:00, 173.21 Samples/s]                
Processing emerald_2.jpg:   1%|          | 10/964 [00:00<00:16, 58.72 Samples/s]

数据增强: data/dataset/Emerald
image: data/dataset/Emerald
Initialised with 36 image(s) found.
Output directory set to data/dataset/Emerald/output.

Processing emerald_36.jpg:   2%|▏         | 20/964 [00:00<00:17, 54.08 Samples/s]Processing emerald_15.jpg: 100%|██████████| 964/964 [00:26<00:00, 36.49 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=460x460 at 0x7F705DED0110>: 100%|██████████| 36/36 [00:00<00:00, 149.48 Samples/s]                   
Processing sapphire blue_9.jpg:   1%|          | 10/966 [00:00<00:13, 68.91 Samples/s]

数据增强: data/dataset/Sapphire Blue
image: data/dataset/Sapphire Blue
Initialised with 34 image(s) found.
Output directory set to data/dataset/Sapphire Blue/output.

Processing sapphire blue_16.jpg:   2%|▏         | 22/966 [00:00<00:16, 56.52 Samples/s]Processing sapphire blue_30.jpg: 100%|██████████| 966/966 [00:18<00:00, 53.08 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=450x450 at 0x7F706885B810>: 100%|██████████| 34/34 [00:00<00:00, 177.29 Samples/s]                  
Processing malachite_2.jpg:   1%|          | 10/972 [00:00<00:20, 47.64 Samples/s]

数据增强: data/dataset/Malachite
image: data/dataset/Malachite
Initialised with 28 image(s) found.
Output directory set to data/dataset/Malachite/output.

Processing malachite_16.jpg:   2%|▏         | 18/972 [00:00<00:20, 47.14 Samples/s]Processing malachite_22.jpg: 100%|██████████| 972/972 [00:18<00:00, 52.32 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=376x262 at 0x7F7060E93D10>: 100%|██████████| 28/28 [00:00<00:00, 173.34 Samples/s]
Processing alexandrite_0.jpg:   1%|          | 6/966 [00:00<00:24, 39.61 Samples/s] 

数据增强: data/dataset/Alexandrite
image: data/dataset/Alexandrite
Initialised with 34 image(s) found.
Output directory set to data/dataset/Alexandrite/output.

Processing alexandrite_23.jpg:   2%|▏         | 18/966 [00:00<00:21, 44.52 Samples/s]Processing alexandrite_20.jpg: 100%|██████████| 966/966 [00:20<00:00, 48.06 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=500x500 at 0x7F705E025B10>: 100%|██████████| 34/34 [00:00<00:00, 129.49 Samples/s]                 
Processing zircon_8.jpg:   1%|          | 5/967 [00:00<00:33, 28.43 Samples/s] 

数据增强: data/dataset/Zircon
image: data/dataset/Zircon
Initialised with 33 image(s) found.
Output directory set to data/dataset/Zircon/output.

Processing zircon_23.jpg:   1%|          | 6/967 [00:00<00:33, 28.43 Samples/s]Processing zircon_24.jpg: 100%|██████████| 967/967 [00:24<00:00, 38.88 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=500x500 at 0x7F705DEAC3D0>: 100%|██████████| 33/33 [00:00<00:00, 134.76 Samples/s]                 
Processing onyx black_16.jpg:   1%|          | 8/972 [00:00<00:13, 69.17 Samples/s]

数据增强: data/dataset/Onyx Black
image: data/dataset/Onyx Black
Initialised with 28 image(s) found.
Output directory set to data/dataset/Onyx Black/output.

Processing onyx black_6.jpg:   2%|▏         | 18/972 [00:00<00:18, 51.84 Samples/s] Processing onyx black_2.jpg: 100%|██████████| 972/972 [00:18<00:00, 53.19 Samples/s] 
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DEE1910>: 100%|██████████| 28/28 [00:00<00:00, 131.50 Samples/s]                 
Processing rhodochrosite_29.jpg:   1%|          | 10/971 [00:00<00:18, 53.20 Samples/s]

数据增强: data/dataset/Rhodochrosite
image: data/dataset/Rhodochrosite
Initialised with 29 image(s) found.
Output directory set to data/dataset/Rhodochrosite/output.

Processing rhodochrosite_21.jpg:   2%|▏         | 21/971 [00:00<00:16, 58.01 Samples/s]Processing rhodochrosite_15.jpg: 100%|██████████| 971/971 [00:20<00:00, 46.42 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=373x356 at 0x7F705E011910>: 100%|██████████| 29/29 [00:00<00:00, 243.76 Samples/s]                  
Processing diamond_16.jpg:   1%|          | 5/969 [00:00<00:28, 34.31 Samples/s]

数据增强: data/dataset/Diamond
image: data/dataset/Diamond
Initialised with 31 image(s) found.
Output directory set to data/dataset/Diamond/output.

Processing diamond_6.jpg:   1%|          | 11/969 [00:00<00:26, 35.79 Samples/s] Processing diamond_20.jpg: 100%|██████████| 969/969 [00:24<00:00, 40.22 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=400x400 at 0x7F705DE6CCD0>: 100%|██████████| 31/31 [00:00<00:00, 150.83 Samples/s]                
Processing benitoite_29.jpg:   1%|          | 7/969 [00:00<00:15, 63.04 Samples/s]

数据增强: data/dataset/Benitoite
image: data/dataset/Benitoite
Initialised with 31 image(s) found.
Output directory set to data/dataset/Benitoite/output.

Processing benitoite_2.jpg:   2%|▏         | 24/969 [00:00<00:16, 57.15 Samples/s] Processing benitoite_12.jpg: 100%|██████████| 969/969 [00:17<00:00, 55.09 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=472x433 at 0x7F705DFE9290>: 100%|██████████| 31/31 [00:00<00:00, 178.70 Samples/s]                 
Processing pearl_0.jpg:   1%|          | 6/967 [00:00<00:25, 38.13 Samples/s] 

数据增强: data/dataset/Pearl
image: data/dataset/Pearl
Initialised with 33 image(s) found.
Output directory set to data/dataset/Pearl/output.

Processing pearl_32.jpg:   2%|▏         | 21/967 [00:00<00:20, 47.09 Samples/s]Processing pearl_12.jpg: 100%|██████████| 967/967 [00:17<00:00, 54.49 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=301x301 at 0x7F705E020A50>: 100%|██████████| 33/33 [00:00<00:00, 205.47 Samples/s]                 
Processing beryl golden_39.jpg:   1%|          | 11/964 [00:00<00:12, 79.36 Samples/s]

数据增强: data/dataset/Beryl Golden
image: data/dataset/Beryl Golden
Initialised with 36 image(s) found.
Output directory set to data/dataset/Beryl Golden/output.

Processing beryl golden_29.jpg:   2%|▏         | 22/964 [00:00<00:14, 63.92 Samples/s]Processing beryl golden_2.jpg: 100%|██████████| 964/964 [00:16<00:00, 58.61 Samples/s] 
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DE6F910>: 100%|██████████| 36/36 [00:00<00:00, 273.71 Samples/s]                
Processing labradorite_16.jpg:   1%|          | 9/960 [00:00<00:17, 55.49 Samples/s]

数据增强: data/dataset/Labradorite
image: data/dataset/Labradorite
Initialised with 40 image(s) found.
Output directory set to data/dataset/Labradorite/output.

Processing labradorite_17.jpg:   2%|▏         | 20/960 [00:00<00:18, 52.03 Samples/s]Processing labradorite_11.jpg: 100%|██████████| 960/960 [00:21<00:00, 45.63 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=400x400 at 0x7F705DE70F10>: 100%|██████████| 40/40 [00:00<00:00, 117.40 Samples/s]                 
Processing fluorite_23.jpg:   1%|          | 11/968 [00:00<00:14, 65.24 Samples/s]

数据增强: data/dataset/Fluorite
image: data/dataset/Fluorite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Fluorite/output.

Processing fluorite_4.jpg:   1%|▏         | 14/968 [00:00<00:19, 49.03 Samples/s] Processing fluorite_4.jpg: 100%|██████████| 968/968 [00:21<00:00, 44.39 Samples/s] 
Processing <PIL.Image.Image image mode=RGB size=500x442 at 0x7F705DE87CD0>: 100%|██████████| 32/32 [00:00<00:00, 169.43 Samples/s]                 
Processing iolite_2.jpg:   1%|          | 7/968 [00:00<00:24, 39.15 Samples/s] 

数据增强: data/dataset/Iolite
image: data/dataset/Iolite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Iolite/output.

Processing iolite_35.jpg:   2%|▏         | 23/968 [00:00<00:18, 51.39 Samples/s]Processing iolite_23.jpg: 100%|██████████| 968/968 [00:16<00:00, 57.22 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DE764D0>: 100%|██████████| 32/32 [00:00<00:00, 373.16 Samples/s]                  
Processing quartz beer_24.jpg:   1%|          | 12/965 [00:00<00:16, 57.87 Samples/s]

数据增强: data/dataset/Quartz Beer
image: data/dataset/Quartz Beer
Initialised with 35 image(s) found.
Output directory set to data/dataset/Quartz Beer/output.

Processing quartz beer_28.jpg:   2%|▏         | 24/965 [00:00<00:14, 65.30 Samples/s]Processing quartz beer_30.jpg: 100%|██████████| 965/965 [00:16<00:00, 59.48 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=300x300 at 0x7F705DE82DD0>: 100%|██████████| 35/35 [00:00<00:00, 173.58 Samples/s]                 
Processing garnet red_21.jpg:   1%|          | 7/964 [00:00<00:34, 27.76 Samples/s]

数据增强: data/dataset/Garnet Red
image: data/dataset/Garnet Red
Initialised with 36 image(s) found.
Output directory set to data/dataset/Garnet Red/output.

Processing garnet red_2.jpg:   2%|▏         | 17/964 [00:00<00:28, 33.50 Samples/s] Processing garnet red_2.jpg: 100%|██████████| 964/964 [00:20<00:00, 46.97 Samples/s] 
Processing <PIL.Image.Image image mode=RGB size=301x301 at 0x7F705E020090>: 100%|██████████| 36/36 [00:00<00:00, 197.00 Samples/s]                 
Processing danburite_35.jpg:   1%|          | 8/968 [00:00<00:16, 58.65 Samples/s]

数据增强: data/dataset/Danburite
image: data/dataset/Danburite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Danburite/output.

Processing danburite_32.jpg:   2%|▏         | 17/968 [00:00<00:19, 49.88 Samples/s]Processing danburite_23.jpg: 100%|██████████| 968/968 [00:19<00:00, 50.58 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705DE78390>: 100%|██████████| 32/32 [00:00<00:00, 144.25 Samples/s]                 
Processing cats eye_7.jpg:   1%|          | 8/969 [00:00<00:24, 39.01 Samples/s] 

数据增强: data/dataset/Cats Eye
image: data/dataset/Cats Eye
Initialised with 31 image(s) found.
Output directory set to data/dataset/Cats Eye/output.

Processing cats eye_26.jpg:   2%|▏         | 15/969 [00:00<00:23, 41.33 Samples/s]Processing cats eye_33.jpg: 100%|██████████| 969/969 [00:25<00:00, 38.19 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=401x401 at 0x7F706AF09510>: 100%|██████████| 31/31 [00:00<00:00, 214.03 Samples/s]                 
Processing hessonite_1.jpg:   0%|          | 3/970 [00:00<00:33, 28.84 Samples/s] 

数据增强: data/dataset/Hessonite
image: data/dataset/Hessonite
Initialised with 30 image(s) found.
Output directory set to data/dataset/Hessonite/output.

Processing hessonite_19.jpg:   1%|▏         | 13/970 [00:00<00:31, 30.34 Samples/s]Processing hessonite_33.jpg: 100%|██████████| 970/970 [00:20<00:00, 47.73 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=301x301 at 0x7F705E020610>: 100%|██████████| 30/30 [00:00<00:00, 162.33 Samples/s]                 
Processing carnelian_12.jpg:   1%|          | 5/967 [00:00<00:28, 34.19 Samples/s]

数据增强: data/dataset/Carnelian
image: data/dataset/Carnelian
Initialised with 33 image(s) found.
Output directory set to data/dataset/Carnelian/output.

Processing carnelian_32.jpg:   1%|          | 12/967 [00:00<00:29, 32.65 Samples/s]Processing carnelian_31.jpg: 100%|██████████| 967/967 [00:24<00:00, 39.93 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=425x425 at 0x7F705DE840D0>: 100%|██████████| 33/33 [00:00<00:00, 147.85 Samples/s]                 
Processing jade_26.jpg:   1%|          | 9/972 [00:00<00:25, 38.24 Samples/s]

数据增强: data/dataset/Jade
image: data/dataset/Jade
Initialised with 28 image(s) found.
Output directory set to data/dataset/Jade/output.

Processing jade_20.jpg:   2%|▏         | 22/972 [00:00<00:19, 47.93 Samples/s]Processing jade_18.jpg: 100%|██████████| 972/972 [00:18<00:00, 51.18 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DE8B050>: 100%|██████████| 28/28 [00:00<00:00, 331.02 Samples/s]                  
Processing variscite_22.jpg:   1%|          | 5/970 [00:00<00:25, 37.31 Samples/s]

数据增强: data/dataset/Variscite
image: data/dataset/Variscite
Initialised with 30 image(s) found.
Output directory set to data/dataset/Variscite/output.

Processing variscite_10.jpg:   1%|▏         | 13/970 [00:00<00:26, 35.70 Samples/s]Processing variscite_31.jpg: 100%|██████████| 970/970 [00:21<00:00, 45.58 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705DE7BE50>: 100%|██████████| 30/30 [00:00<00:00, 157.22 Samples/s]                 
Processing tanzanite_2.jpg:   1%|          | 5/964 [00:00<00:31, 30.52 Samples/s] 

数据增强: data/dataset/Tanzanite
image: data/dataset/Tanzanite
Initialised with 36 image(s) found.
Output directory set to data/dataset/Tanzanite/output.

Processing tanzanite_15.jpg:   2%|▏         | 15/964 [00:00<00:25, 36.60 Samples/s]Processing tanzanite_37.jpg: 100%|██████████| 964/964 [00:25<00:00, 38.41 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705E00E4D0>: 100%|██████████| 36/36 [00:00<00:00, 144.18 Samples/s]                 


将生成的图片拷贝到正确的目录
删除所有output目录
完成数据增强


#每次生成数据列表前,首先清空train.txt和eval.txt
with open(train_list_path, 'w') as f: 
    f.seek(0)
    f.truncate() 
with open(eval_list_path, 'w') as f: 
    f.seek(0)
    f.truncate() 
    
#生成数据列表   
get_data_list(target_path,train_list_path,eval_list_path,augment_path)

'''
构造数据提供器
'''
train_reader = paddle.batch(data_reader(train_list_path),
                            batch_size=batch_size,
                            drop_last=True)
eval_reader = paddle.batch(data_reader(eval_list_path),
                            batch_size=batch_size,
                            drop_last=True)
{'input_size': [3, 64, 64], 'class_dim': 25, 'augment_path': '/home/aistudio/augment', 'src_path': 'data/data55032/archive_train.zip', 'target_path': '/home/aistudio/data/dataset', 'train_list_path': './train_data.txt', 'eval_list_path': './val_data.txt', 'label_dict': {'0': 'Kunzite', '1': 'Almandine', '2': 'Emerald', '3': 'Sapphire Blue', '4': 'Malachite', '5': 'Alexandrite', '6': 'Zircon', '7': 'Onyx Black', '8': 'Rhodochrosite', '9': 'Diamond', '10': 'Benitoite', '11': 'Pearl', '12': 'Beryl Golden', '13': 'Labradorite', '14': 'Fluorite', '15': 'Iolite', '16': 'Quartz Beer', '17': 'Garnet Red', '18': 'Danburite', '19': 'Cats Eye', '20': 'Hessonite', '21': 'Carnelian', '22': 'Jade', '23': 'Variscite', '24': 'Tanzanite'}, 'readme_path': '/home/aistudio/data/readme.json', 'num_epochs': 20, 'train_batch_size': 64, 'learning_strategy': {'lr': 0.001}}
生成数据列表完成!
Batch=0
Batchs=[]
all_train_accs=[]
def draw_train_acc(Batchs, train_accs):
    title="training accs"
    plt.title(title, fontsize=24)
    plt.xlabel("batch", fontsize=14)
    plt.ylabel("acc", fontsize=14)
    plt.plot(Batchs, train_accs, color='green', label='training accs')
    plt.legend()
    plt.grid()
    plt.show()

all_train_loss=[]
def draw_train_loss(Batchs, train_loss):
    title="training loss"
    plt.title(title, fontsize=24)
    plt.xlabel("batch", fontsize=14)
    plt.ylabel("loss", fontsize=14)
    plt.plot(Batchs, train_loss, color='red', label='training loss')
    plt.legend()
    plt.grid()
    plt.show()

2.定义模型

###在以下cell中完成DNN网络的定义###

#定义网络
class MyDNN(fluid.dygraph.Layer):
    '''
    卷积神经网络
    '''
    def __init__(self):
        super(MyDNN,self).__init__()
        self.hidden1=fluid.dygraph.Linear(3*64*64,1000, act='relu')
        self.hidden2=fluid.dygraph.Linear(1000,500, act='relu')
        self.hidden3=fluid.dygraph.Linear(500,100, act='relu')
        self.out = fluid.dygraph.Linear(input_dim=100, output_dim=25, act='softmax')

    def forward(self,input):
        x = fluid.layers.reshape(input,shape=[-1,3*64*64])
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)
        return x

3.训练模型

with fluid.dygraph.guard(place = fluid.CUDAPlace(0)):
    print(train_parameters['class_dim'])
    print(train_parameters['label_dict'])
    model=MyDNN() #模型实例化
    model.train() #训练模式
    opt=fluid.optimizer.SGDOptimizer(learning_rate=train_parameters['learning_strategy']['lr'], parameter_list=model.parameters())#优化器选用SGD随机梯度下降,学习率为0.001.
    epochs_num=train_parameters['num_epochs'] #迭代次数
    
    for pass_num in range(epochs_num):
        for batch_id,data in enumerate(train_reader()):
            images = np.array([x[0] for x in data]).astype('float32').reshape(-1, 3,64,64)
            labels = np.array([x[1] for x in data]).astype('int64')
            labels = labels[:, np.newaxis]

            image=fluid.dygraph.to_variable(images)
            label=fluid.dygraph.to_variable(labels)

            predict=model(image) #数据传入model
            
            loss=fluid.layers.cross_entropy(predict,label)
            avg_loss=fluid.layers.mean(loss)#获取loss值
            
            acc=fluid.layers.accuracy(predict,label)#计算精度
            
            if batch_id!=0 and batch_id%5==0:
                Batch = Batch+5 
                Batchs.append(Batch)
                all_train_loss.append(avg_loss.numpy()[0])
                all_train_accs.append(acc.numpy()[0])
                
                print("train_pass:{},batch_id:{},train_loss:{},train_acc:{}".format(pass_num,batch_id,avg_loss.numpy(),acc.numpy()))
            
            avg_loss.backward()       
            opt.minimize(avg_loss)    #优化器对象的minimize方法对参数进行更新 
            model.clear_gradients()   #model.clear_gradients()来重置梯度
    fluid.save_dygraph(model.state_dict(),'MyDNN')#保存模型

draw_train_acc(Batchs,all_train_accs)
draw_train_loss(Batchs,all_train_loss)

train_pass:19,batch_id:400,train_loss:[0.24890603],train_acc:[0.96875]

4.模型评估

#模型评估
with fluid.dygraph.guard():
    accs = []
    model_dict, _ = fluid.load_dygraph('MyDNN')
    model = MyDNN()
    model.load_dict(model_dict) #加载模型参数
    model.eval() #训练模式
    for batch_id,data in enumerate(eval_reader()):#测试集
        images = np.array([x[0] for x in data]).astype('float32').reshape(-1, 3,64,64)
        labels = np.array([x[1] for x in data]).astype('int64')
        labels = labels[:, np.newaxis]
        image=fluid.dygraph.to_variable(images)
        label=fluid.dygraph.to_variable(labels)       
        predict=model(image)       
        acc=fluid.layers.accuracy(predict,label)
        accs.append(acc.numpy()[0])
        avg_acc = np.mean(accs)
    print(avg_acc)

0.96875

5.模型预测

import os
import zipfile

def unzip_infer_data(src_path,target_path):
    '''
    解压预测数据集
    '''
    if(not os.path.isdir(target_path)):     
        z = zipfile.ZipFile(src_path, 'r')
        z.extractall(path=target_path)
        z.close()


def load_image(img_path):
    '''
    预测图片预处理
    '''
    img = Image.open(img_path) 
    if img.mode != 'RGB': 
        img = img.convert('RGB') 
    img = img.resize((64, 64), Image.BILINEAR)
    img = np.array(img).astype('float32') 
    img = img.transpose((2, 0, 1))  # HWC to CHW 
    img = img/255                # 像素值归一化 
    return img


infer_src_path = '/home/aistudio/data/data55032/archive_test.zip'
infer_dst_path = '/home/aistudio/data/archive_test'
unzip_infer_data(infer_src_path,infer_dst_path)
label_dic = train_parameters['label_dict']

'''
模型预测
'''
with fluid.dygraph.guard():
    model_dict, _ = fluid.load_dygraph('MyDNN')
    model = MyDNN()
    model.load_dict(model_dict) #加载模型参数
    model.eval() #训练模式
    
    #展示预测图片
    infer_path='data/archive_test/alexandrite_3.jpg'
    img = Image.open(infer_path)
    plt.imshow(img)          #根据数组绘制图像
    plt.show()               #显示图像

    #对预测图片进行预处理
    infer_imgs = []
    infer_imgs.append(load_image(infer_path))
    infer_imgs = np.array(infer_imgs)
   
    for i in range(len(infer_imgs)):
        data = infer_imgs[i]
        dy_x_data = np.array(data).astype('float32')
        dy_x_data=dy_x_data[np.newaxis,:, : ,:]
        img = fluid.dygraph.to_variable(dy_x_data)
        out = model(img)
        lab = np.argmax(out.numpy())  #argmax():返回最大数的索引

        print("第{}个样本,被预测为:{},真实标签为:{}".format(i+1,label_dic[str(lab)],infer_path.split('/')[-1].split("_")[0]))
        
print("结束")

第1个样本,被预测为:Malachite,真实标签为:alexandrite 结束

以上就是Python利用DNN实现宝石识别的详细内容,更多关于Python DNN宝石识别的资料请关注脚本之家其它相关文章!

相关文章

  • 详解python中TCP协议中的粘包问题

    详解python中TCP协议中的粘包问题

    这篇文章主要介绍了python中TCP协议中的粘包问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-03-03
  • Python实现Linux中的du命令

    Python实现Linux中的du命令

    这篇文章主要介绍了Python实现Linux中简单du命令,需要的朋友可以参考下
    2017-06-06
  • Python教程之基本运算符的使用(下)

    Python教程之基本运算符的使用(下)

    Python运算符通常用于对值和变量执行操作。这些是用于逻辑和算术运算的标准符号。在本文中,我们将研究运算符的优先级和关联性,感兴趣的可以了解一下
    2022-09-09
  • 详解Python中的测试工具

    详解Python中的测试工具

    本文介绍了两个Python中的测试工具: doctest和unittest,并配以简单的例子来说明这两个测试模块的使用方法,需要的朋友可以参考下
    2019-06-06
  • Python实现绘图散点图(scatter)

    Python实现绘图散点图(scatter)

    这篇文章主要介绍了Python实现绘图散点图方式(scatter),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-06-06
  • kaggle数据分析家庭电力消耗过程详解

    kaggle数据分析家庭电力消耗过程详解

    这篇文章主要为大家介绍了kaggle数据分析家庭电力消耗示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-12-12
  • Python 获得13位unix时间戳的方法

    Python 获得13位unix时间戳的方法

    本篇文章主要介绍了Python 获得13位unix时间戳的方法,非常具有实用价值,需要的朋友可以参考下
    2017-10-10
  • pytorch的梯度计算以及backward方法详解

    pytorch的梯度计算以及backward方法详解

    今天小编就为大家分享一篇pytorch的梯度计算以及backward方法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • Python keras.metrics源代码分析

    Python keras.metrics源代码分析

    最近在用keras写模型的时候,参考别人代码时,经常能看到各种不同的metrics,因此会产生几个问题,下面主要介绍了Python keras.metrics源代码分析
    2022-11-11
  • Python cookbook(数据结构与算法)实现对不原生支持比较操作的对象排序算法示例

    Python cookbook(数据结构与算法)实现对不原生支持比较操作的对象排序算法示例

    这篇文章主要介绍了Python cookbook(数据结构与算法)实现对不原生支持比较操作的对象排序算法,结合实例形式分析了Python针对类实例进行排序相关操作技巧,需要的朋友可以参考下
    2018-03-03

最新评论