使用python如何将数据集划分为训练集、验证集和测试集

 更新时间:2023年09月09日 09:15:34   作者:肖申克的陪伴  
这篇文章主要介绍了使用python如何将数据集划分为训练集、验证集和测试集问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

python将数据集划分为训练集、验证集和测试集

划分数据集

众所周知,将一个数据集只区分为训练集和验证集是不行的,还需要有测试集,本博文针对上一篇没有分出测试集的不足,重新划分数据集

直接上代码:

#split_data.py
#划分数据集flower_data,数据集划分到flower_datas中,训练集:验证集:测试集比例为6:2:2
import os
import random
from shutil import copy2
# 源文件路径
file_path = r"D:/other/ClassicalModel/other/flower_data"
# 新文件路径
new_file_path = r"D:/other/ClassicalModel/other/flower_datas"
# 划分数据比例为6:2:2
split_rate = [0.6, 0.2, 0.2]
print("Starting...")
print("Ratio= {}:{}:{}".format(int(split_rate[0] * 10), int(split_rate[1] * 10), int(split_rate[2] * 10)))
class_names = os.listdir(file_path)
# 在目标目录下创建文件夹
split_names = ['train', 'val', 'test']
# 判断是否存在木匾文件夹
if os.path.isdir(new_file_path):
    pass
else:
    os.mkdir(new_file_path)
for split_name in split_names:
    # split_path = os.path.join(new_file_path, split_name)
    split_path = new_file_path + "/" + split_name
    if os.path.isdir(split_path):
        pass
    else:
        os.mkdir(split_path)
    # 然后在split_path的目录下创建类别文件夹
    for class_name in class_names:
        class_split_path = os.path.join(split_path, class_name)
        if os.path.isdir(class_split_path):
            pass
        else:
            os.mkdir(class_split_path)
# 按照比例划分数据集,并进行数据图片的复制
# 首先进行分类遍历
for class_name in class_names:
    current_class_data_path = os.path.join(file_path, class_name)
    current_all_data = os.listdir(current_class_data_path)
    current_data_length = len(current_all_data)
    current_data_index_list = list(range(current_data_length))
    random.shuffle(current_data_index_list)
    train_path = os.path.join(os.path.join(new_file_path, 'train'), class_name)
    val_path = os.path.join(os.path.join(new_file_path, 'val'), class_name)
    test_path = os.path.join(os.path.join(new_file_path, 'test'), class_name)
    train_stop_flag = current_data_length * split_rate[0]
    val_stop_flag = current_data_length * (split_rate[0] + split_rate[1])
    current_idx = 0
    train_num = 0
    val_num = 0
    test_num = 0
    for i in current_data_index_list:
        src_img_path = os.path.join(current_class_data_path, current_all_data[i])
        if current_idx <= train_stop_flag:
            copy2(src_img_path, train_path
            train_num = train_num + 1
        elif (current_idx > train_stop_flag) and (current_idx <= val_stop_flag):
            copy2(src_img_path, val_path)
            val_num = val_num + 1
        else:
            copy2(src_img_path, test_path
            test_num = test_num + 1
        current_idx = current_idx + 1
    print("<{}> has {} pictures,train:val:test={}:{}:{}".format(class_name, current_data_length, train_num, val_num,
                                                              test_num))
print("Done")

输出结果:

注意:

只需要修改file_path(源文件夹)和new_file_path(新生成的文件夹)

其次是修改split_rate

python自动划分训练集和测试集

在进行深度学习的模型训练时,我们通常需要将数据进行划分,划分成训练集和测试集,若数据集太大,数据划分花费的时间太多!!!

不多说,上代码(python代码)

代码

# *_*coding: utf-8 *_*
import os
import random
import shutil
import time
def copyFile(fileDir,origion_path1,class_name):
    name = class_name
    path = origion_path1
    image_list = os.listdir(fileDir) # 获取图片的原始路径
    image_number = len(image_list)
    train_number = int(image_number * train_rate)
    train_sample = random.sample(image_list, train_number) # 从image_list中随机获取0.75比例的图像.
    test_sample = list(set(image_list) - set(train_sample))
    sample = [train_sample, test_sample]
    # 复制图像到目标文件夹
    for k in range(len(save_dir)):
            if os.path.isdir(save_dir[k]) and os.path.isdir(save_dir1[k]):
                for name in sample[k]:
                    name1 = name.split(".")[0] + '.xml'
                    shutil.copy(os.path.join(fileDir, name), os.path.join(save_dir[k], name))
                    shutil.copy(os.path.join(path, name1), os.path.join(save_dir1[k], name1))
            else:
                os.makedirs(save_dir[k])
                os.makedirs(save_dir1[k])
                for name in sample[k]:
                    name1 = name.split(".")[0] + '.xml'
                    shutil.copy(os.path.join(fileDir, name), os.path.join(save_dir[k], name))
                    shutil.copy(os.path.join(path, name1), os.path.join(save_dir1[k], name1))
if __name__ == '__main__':
    time_start = time.time()
    # 原始数据集路径
    origion_path = './JPEGImages/'
    origion_path1 = './Annotations/'
    # 保存路径
    save_train_dir = './train/JPEGImages/'
    save_test_dir = './test/JPEGImages/'
    save_train_dir1 = './train/Annotations/'
    save_test_dir1 = './test/Annotations/'
    save_dir = [save_train_dir, save_test_dir]
    save_dir1 = [save_train_dir1, save_test_dir1]
    # 训练集比例
    train_rate = 0.75
    # 数据集类别及数量
    file_list = os.listdir(origion_path)
    num_classes = len(file_list)
    for i in range(num_classes):
        class_name = file_list[i]
    copyFile(origion_path,origion_path1,class_name)
    print('划分完毕!')
    time_end = time.time()
    print('---------------')
    print('训练集和测试集划分共耗时%s!' % (time_end - time_start))

1.需要修改的地方

  • origion_path:图片路径
  • origion_path1:xml文件路径
  • train_rate:训练集比例

2.执行文件deal.py后生成

  • train-img:训练集图片数据
  • train-xml:训练集xml数据
  • test-img:测试集图片数据
  • test-xml:测试及xml数据

3.train_rate可以根据实际情况进行调整,一般train:test是3:1

注:每次划分数据都是随机的,每次执行时将之前划分好的数据保存或者重命名,不然会重复写入到4个文件夹中

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python输入二维数组方法

    Python输入二维数组方法

    下面小编就为大家分享一篇Python输入二维数组方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • Python复制Word内容并使用格式设字体与大小实例代码

    Python复制Word内容并使用格式设字体与大小实例代码

    这篇文章主要介绍了Python复制Word内容并使用格式设字体与大小实例代码,小编觉得还是挺不错的,具有一定借鉴价值,需要的朋友可以参考下
    2018-01-01
  • python连接PostgreSQL数据库的过程详解

    python连接PostgreSQL数据库的过程详解

    这篇文章主要介绍了python连接PostgreSQL数据库的过程详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-09-09
  • Python3如何在服务器打印资产信息

    Python3如何在服务器打印资产信息

    这篇文章主要介绍了Python3如何在服务器打印资产信息,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-08-08
  • Python有参函数使用代码实例

    Python有参函数使用代码实例

    这篇文章主要介绍了Python有参函数使用代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-01-01
  • Python中的左斜杠、右斜杠(正斜杠和反斜杠)

    Python中的左斜杠、右斜杠(正斜杠和反斜杠)

    这篇文章主要介绍了Python中的左斜杠、右斜杠(正斜杠和反斜杠)的相关资料,非常不错,具有参考借鉴价值,需要的朋友可以参考下
    2016-08-08
  • python中的np.round()函数示例详解

    python中的np.round()函数示例详解

    np.round()是NumPy库中的一个函数,用于对数组或单个数值进行四舍五入,该函数返回一个与输入类型相同的数组或数值,并可以通过可选的参数来指定保留的小数位数,这篇文章主要介绍了python中的np.round()函数,需要的朋友可以参考下
    2024-06-06
  • Python程序实现向MySQL存放图片

    Python程序实现向MySQL存放图片

    这篇文章主要介绍了Python程序实现向MySQL存放图片,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-03-03
  • python处理大数字的方法

    python处理大数字的方法

    这篇文章主要介绍了python处理大数字的方法,涉及Python递归操作的相关技巧,需要的朋友可以参考下
    2015-05-05
  • python里面单双下划线的区别详解

    python里面单双下划线的区别详解

    本文主要介绍了python里面单双下划线的区别详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-04-04

最新评论