Pytorch 实现计算分类器准确率(总分类及子分类)
更新时间:2020年01月18日 11:27:49 作者:疯狂的小猪oO
今天小编就为大家分享一篇Pytorch 实现计算分类器准确率(总分类及子分类),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
分类器平均准确率计算:
correct = torch.zeros(1).squeeze().cuda() total = torch.zeros(1).squeeze().cuda() for i, (images, labels) in enumerate(train_loader): images = Variable(images.cuda()) labels = Variable(labels.cuda()) output = model(images) prediction = torch.argmax(output, 1) correct += (prediction == labels).sum().float() total += len(labels) acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())
分类器各个子类准确率计算:
correct = list(0. for i in range(args.class_num)) total = list(0. for i in range(args.class_num)) for i, (images, labels) in enumerate(train_loader): images = Variable(images.cuda()) labels = Variable(labels.cuda()) output = model(images) prediction = torch.argmax(output, 1) res = prediction == labels for label_idx in range(len(labels)): label_single = label[label_idx] correct[label_single] += res[label_idx].item() total[label_single] += 1 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total)) for acc_idx in range(len(train_class_correct)): try: acc = correct[acc_idx]/total[acc_idx] except: acc = 0 finally: acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)
以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
python3+selenium实现qq邮箱登陆并发送邮件功能
这篇文章主要为大家详细介绍了python3+selenium实现qq邮箱登陆,并发送邮件功能,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下2019-01-01使用Python代码进行PowerPoint演示文稿的合并与拆分
多个PowerPoint演示文稿的处理可能会成为非常麻烦的工作,有时需要将多个演示文稿合并为一个演示文稿,从而不用在演示时重复打开演示文稿,本文我们可以使用Python代码来快速、准确的执行PowerPoint演示文稿的合并于拆分操作,需要的朋友可以参考下2024-03-03
最新评论