pytorch单元测试的实现示例

 更新时间:2024年04月18日 10:57:40   作者:Hi20240217  
单元测试是一种软件测试方法,本文主要介绍了pytorch单元测试的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

希望测试pytorch各种算子、block、网络等在不同硬件平台,不同软件版本下的计算误差、耗时、内存占用等指标.

本文基于torch.testing._internal

一.公共模块[common.py]

import torch
from torch import nn
import math
import torch.nn.functional as F
import time
import os
import socket
import sys
from datetime import datetime
import numpy as np
import collections
import math
import json
import copy
import traceback
import subprocess
import unittest
import torch
import inspect
from torch.testing._internal.common_utils import TestCase, run_tests,parametrize,instantiate_parametrized_tests
from torch.testing._internal.common_distributed import MultiProcessTestCase
import torch.distributed as dist

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANDOM_SEED"] = "0" 

device="cpu"
device_type="cpu"
device_name="cpu"

try:
    if torch.cuda.is_available():     
        device_name=torch.cuda.get_device_name().replace(" ","")
        device="cuda:0"
        device_type="cuda"
        ccl_backend='nccl'
except:
    pass

host_name=socket.gethostname()    
sdk_version=os.getenv("SDK_VERSION","")   						 #从环境变量中获取sdk版本号
metric_data_root=os.getenv("TORCH_UT_METRICS_DATA","./ut_data")  #日志存放的目录
device_count=torch.cuda.device_count()

if not os.path.exists(metric_data_root):
    os.makedirs(metric_data_root)

def device_warmup(device):
    '''设备warmup,确保设备已经正常工作,排除设备初始化的耗时'''
    left = torch.rand([128,512], dtype = torch.float16).to(device)
    right = torch.rand([512,128], dtype = torch.float16).to(device)
    out=torch.matmul(left,right)
    torch.cuda.synchronize()

torch.manual_seed(1) 
np.random.seed(1)

def loop_decorator(loops,rank=0):
    '''循环装饰器,用于统计函数的执行时间,内存占用等'''
    def decorator(func):
        def wrapper(*args,**kwargs):
            latency=[]
            memory_allocated_t0=torch.cuda.memory_allocated(rank)
            for _ in range(loops):
                input_copy=[x.clone() for x in args]
                beg= datetime.now().timestamp() * 1e6
                pred= func(*input_copy)
                gt=kwargs["golden"]
                torch.cuda.synchronize()
                end=datetime.now().timestamp() * 1e6
                mse = torch.mean(torch.pow(pred.cpu().float()- gt.cpu().float(), 2)).item()
                latency.append(end-beg)
            memory_allocated_t1=torch.cuda.memory_allocated(rank)
            avg_latency=np.mean(latency[len(latency)//2:]).round(3)
            first_latency=latency[0]
            return { "first_latency":first_latency,"avg_latency":avg_latency,
                      "memory_allocated":memory_allocated_t1-memory_allocated_t0,
                      "mse":mse}
        return wrapper
    return decorator

class TorchUtMetrics:
    '''用于统计测试结果,比较之前的最小值'''
    def __init__(self,ut_name,thresold=0.2,rank=0):
        self.ut_name=f"{ut_name}_{rank}"
        self.thresold=thresold
        self.rank=rank
        self.data={"ut_name":self.ut_name,"metrics":[]}
        self.metrics_path=os.path.join(metric_data_root,f"{self.ut_name}_{self.rank}.jon")
        try:
            with open(self.metrics_path,"r") as f:
                self.data=json.loads(f.read())
        except:
            pass

    def __enter__(self):
        self.beg= datetime.now().timestamp() * 1e6
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):        
        self.report()
        self.save_data()

    def save_data(self):
        with open(self.metrics_path,"w") as f:
            f.write(json.dumps(self.data,indent=4))

    def set_metrics(self,metrics):
        self.end=datetime.now().timestamp() * 1e6
        item=collections.OrderedDict()
        item["time"]=datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
        item["sdk_version"]=sdk_version
        item["device_name"]=device_name
        item["host_name"]=host_name
        item["metrics"]=metrics
        item["metrics"]["e2e_time"]=self.end-self.beg
        self.cur_item=item
        self.data["metrics"].append(self.cur_item)

    def get_metric_names(self):
        return self.data["metrics"][0]["metrics"].keys()

    def get_min_metric(self,metric_name,devicename=None):
        min_value=0
        min_value_index=-1
        for idx,item in enumerate(self.data["metrics"]):
            if devicename and (devicename!=item['device_name']):                
                continue            
            val=float(item["metrics"][metric_name])
            if min_value_index==-1 or val<min_value:
                min_value=val
                min_value_index=idx
        return min_value,min_value_index

    def get_metric_info(self,index):
        metrics=self.data["metrics"][index]
        return f'{metrics["device_name"]}@{metrics["sdk_version"]}'

    def report(self):
        assert len(self.data["metrics"])>0
        for metric_name in self.get_metric_names():
            min_value,min_value_index=self.get_min_metric(metric_name)
            min_value_same_dev,min_value_index_same_dev=self.get_min_metric(metric_name,device_name)
            cur_value=float(self.cur_item["metrics"][metric_name])
            print(f"-------------------------------{metric_name}-------------------------------")
            print(f"{cur_value}#{device_name}@{sdk_version}")
            if min_value_index_same_dev>=0:
                print(f"{min_value_same_dev}#{self.get_metric_info(min_value_index_same_dev)}")
            if min_value_index>=0:
                print(f"{min_value}#{self.get_metric_info(min_value_index)}")

二.普通算子测试[test_clone.py]

from common import *
class TestCaseClone(TestCase):
    #如果不满足条件,则跳过这个测试
    @unittest.skipIf(device_count>1, "Not enough devices") 
    def test_todo(self):
        print(".TODO")

    #框架会自动遍历以下参数组合
    @parametrize("shape", [(10240,20480),(128,256)])
    @parametrize("dtype", [torch.float16,torch.float32])
    def test_clone(self,shape,dtype):
        
        #让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量
        @loop_decorator(loops=5)
        def run(input_dev):
            output=input_dev.clone()
            return output
        
        #记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)
        with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2) as m:
            input_host=torch.ones(shape,dtype=dtype)*np.random.rand()
            input_dev=input_host.to(device)
            metrics=run(input_dev,golden=input_host.cpu())
            m.set_metrics(metrics)
            assert(metrics["mse"]==0)
        
instantiate_parametrized_tests(TestCaseClone)

if __name__ == "__main__":
    run_tests()

三.集合通信测试[test_ccl.py]

from common import *
class TestCCL(MultiProcessTestCase):
    '''CCL测试用例'''
    def _create_process_group_vccl(self, world_size, store):
        dist.init_process_group(
            ccl_backend, world_size=world_size, rank=self.rank, store=store
        )        
        pg = dist.distributed_c10d._get_default_group()
        return pg

    def setUp(self):
        super().setUp()
        self._spawn_processes()

    def tearDown(self):
        super().tearDown()
        try:
            os.remove(self.file_name)
        except OSError:
            pass

    @property
    def world_size(self):
        return 4
      
    #框架会自动遍历以下参数组合
    @unittest.skipIf(device_count<4, "Not enough devices") 
    @parametrize("op",[dist.ReduceOp.SUM])
    @parametrize("shape", [(1024,8192)])
    @parametrize("dtype", [torch.int64])
    def test_allreduce(self,op,shape,dtype):
        if self.rank >= self.world_size:
            return
        
        store = dist.FileStore(self.file_name, self.world_size)
        pg = self._create_process_group_vccl(self.world_size, store)
        if not torch.distributed.is_initialized():
            return
    
        torch.cuda.set_device(self.rank)
        device = torch.device(device_type,self.rank)
        device_warmup(device)
        #让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量
        @loop_decorator(loops=5,rank=self.rank)
        def run(input_dev):
            dist.all_reduce(input_dev, op=op)
            return input_dev
        
        #记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)
        with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2,rank=self.rank) as m:
            input_host=torch.ones(shape,dtype=dtype)*(100+self.rank)
            gt=[torch.ones(shape,dtype=dtype)*(100+i) for i in range(self.world_size)]
            gt_=gt[0]
            for i in range(1,self.world_size):
                gt_=gt_+gt[i]
            input_dev=input_host.to(device)
            metrics=run(input_dev,golden=gt_)
            m.set_metrics(metrics)
            assert(metrics["mse"]==0)
        dist.destroy_process_group(pg)
    
instantiate_parametrized_tests(TestCCL)

if __name__ == "__main__":
    run_tests()

四.测试命令

# 运行所有的测试
pytest -v -s -p no:warnings --html=torch_report.html --self-contained-html --capture=sys ./

# 运行某一个测试
python3 test_clone.py -k "test_clone_shape_(128, 256)_float32"

五.测试报告

在这里插入图片描述

到此这篇关于pytorch单元测试的实现示例的文章就介绍到这了,更多相关pytorch单元测试内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家! 

相关文章

  • python 模拟银行转账功能过程详解

    python 模拟银行转账功能过程详解

    这篇文章主要介绍了python 模拟银行转账功能过程详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-08-08
  • python广度优先搜索得到两点间最短路径

    python广度优先搜索得到两点间最短路径

    这篇文章主要为大家详细介绍了python广度优先搜索得到两点间最短路径,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-01-01
  • python如何导入依赖包

    python如何导入依赖包

    在本篇文章里小编给大家整理的是关于python导入依赖包的方法,需要的朋友们学习下。
    2020-07-07
  • Python求两个list的差集、交集与并集的方法

    Python求两个list的差集、交集与并集的方法

    这篇文章主要介绍了Python求两个list的差集、交集与并集的方法,是Python集合数组操作中常用的技巧,需要的朋友可以参考下
    2014-11-11
  • Python+OpenCV之图像梯度详解

    Python+OpenCV之图像梯度详解

    这篇文章主要为大家详细介绍了Python OpenCV中图像梯度(Sobel算子、Scharr算子和Laplacian算子)的实现,感兴趣的小伙伴可以了解一下
    2022-09-09
  • python 中使用yagmail 发送邮件功能

    python 中使用yagmail 发送邮件功能

    这篇文章主要介绍了python 中使用yagmail 发送邮件功能,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-12-12
  • python二进制读写及特殊码同步实现详解

    python二进制读写及特殊码同步实现详解

    这篇文章主要介绍了python二进制读写及特殊码同步实现详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-10-10
  • python strip()函数 介绍

    python strip()函数 介绍

    Python strip() 方法用于移除字符串头尾指定的字符,需要的朋友可以参考下
    2013-05-05
  • 如何解决pycharm中用matplotlib画图不显示中文的问题

    如何解决pycharm中用matplotlib画图不显示中文的问题

    这篇文章主要介绍了如何解决pycharm中用matplotlib画图不显示中文的问题,文章围绕主题展开详细的内容介绍,具有一定的参考价值,感兴趣的小伙伴可以参考一下
    2022-06-06
  • python中list常用操作实例详解

    python中list常用操作实例详解

    这篇文章主要介绍了python中list常用操作,以实例形式较为详细的分析了列表list中常用的建立、添加、删除、搜索、过滤等操作技巧,需要的朋友可以参考下
    2015-06-06

最新评论