pytorch分类模型绘制混淆矩阵以及可视化详解

 更新时间:2022年04月07日 11:22:58   作者:王延凯的博客  
混淆矩阵是ROC曲线绘制的基础,同时它也是衡量分类型模型准确度中最基本,最直观,计算最简单的方法,下面这篇文章主要给大家介绍了关于pytorch分类模型绘制混淆矩阵以及可视化的相关资料,需要的朋友可以参考下

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

总结

到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python面向对象的程序设计详情

    Python面向对象的程序设计详情

    这篇文章主要介绍了Python面向对象的程序设计详情,面向对象的程序设计在Python中具有非常重要的地位,熟练的使用面向对象编程能够为我们的Python编程提供很多的便利之处,希望您阅读完本文后能够有所收获
    2022-01-01
  • python+JS 实现逆向 SMZDM 的登录加密

    python+JS 实现逆向 SMZDM 的登录加密

    这篇文章主要介绍了python+JS 实现逆向 SMZDM 的登录加密,文章通过利用SMZDM平台展开详细的内容介绍,需要的小伙伴可以参考一下
    2022-05-05
  • Pandas 实现分组计数且不计重复

    Pandas 实现分组计数且不计重复

    这篇文章主要介绍了Pandas 实现分组计数且不计重复的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • Python利用lxml库实现XML处理

    Python利用lxml库实现XML处理

    lxml库是Python中处理XML和HTML文档的强大库,提供了丰富的API以进行各种操作,本文将讨论如何使用lxml库,包括如何创建XML文档,如何使用XPath查询,以及如何解析大型XML文档,需要的可以参考下
    2023-08-08
  • Python如何爬取51cto数据并存入MySQL

    Python如何爬取51cto数据并存入MySQL

    这篇文章主要介绍了Python如何爬取51cto数据并存入MySQL,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-08-08
  • Python计算序列相似度的算法实例

    Python计算序列相似度的算法实例

    这篇文章主要介绍了Python计算序列相似度的算法实例,求两个序列转换的最少交换步骤和最小交换距离,本文提供了部分实现代码与解决思路,对开发非常有帮助,需要的朋友可以参考下
    2023-07-07
  • Python开发最牛逼的IDE——pycharm

    Python开发最牛逼的IDE——pycharm

    这篇文章给大家介绍了Python开发最牛逼的IDE——pycharm,主要是介绍python IDE pycharm的安装与使用教程,非常不错,具有一定的参考借鉴价值,需要的朋友参考下吧
    2018-08-08
  • 关于torch中tensor数据类型的转换

    关于torch中tensor数据类型的转换

    这篇文章主要介绍了关于torch中tensor数据类型的转换方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-11-11
  • python3下实现搜狗AI API的代码示例

    python3下实现搜狗AI API的代码示例

    这篇文章主要介绍了python3下实现搜狗AI API的代码示例,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-04-04
  • python库使用Fire库生成命令行参数

    python库使用Fire库生成命令行参数

    Python Fire是一个开源库,能把Python对象转换为命令行界面,Fire库是一个非常有用的工具,它可以帮助开发人员创建命令行界面,并且可以将任何Python对象转换为命令行界面,这篇文章主要介绍了python库使用Fire库生成命令行参数,需要的朋友可以参考下
    2024-02-02

最新评论