使用Pytorch训练分类问题时,分类准确率的计算方式
更新时间:2023年09月14日 14:24:58 作者:jayus丶
这篇文章主要介绍了使用Pytorch训练分类问题时,分类准确率的计算方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
Pytorch训练分类问题时,分类准确率的计算
作者记录方便查询
使用条件
真实标签与预测标签都是tensor。
使用方法
#标签情况 print(y) tensor([[1, 1, 0, 0]]) print(pred) tensor([[1, 0, 1, 0]]) # 比较真实与预测 print(y==pred) tensor([[ True, False, False, True]]) # 对正确元素求和,sum会自动计算True的个数 print((y==pred).sum()) tensor(2)
因此在每个epoch开始时,只需要初始化一个计数器accuracy,对每次的正确元素进行累加,在除以训练元素的总数,便获得了每个epoch的准确率。
for epoch in range(epochs): accuracy=0 for i, (x,y) in enumerate(train_loader, 1): pred = net(x) loss = loss_function(pred.to(torch.float32),y.to(torch.float32)) optimizer.zero_grad() loss.backward() #反向传播 optimizer.step() #更新梯度 loss_steps[epoch]=loss.item()#保存loss running_loss = loss.item() accuracy += (pred == y).sum() acc = float(accuracy*100)/float(len(train_ids))# 除以元素总数,可以用其他方式获取 print(f"第{epoch}次训练,loss={running_loss:.4f},Accuracy={acc:.3f}".format(epoch,running_loss,acc))
结果
Pytorch 计算分类器准确率(总分类及子分类)
分类器平均准确率计算
correct = torch.zeros(1).squeeze().cuda() total = torch.zeros(1).squeeze().cuda() for i, (images, labels) in enumerate(train_loader): images = Variable(images.cuda()) labels = Variable(labels.cuda()) output = model(images) prediction = torch.argmax(output, 1) correct += (prediction == labels).sum().float() total += len(labels) acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())
分类器各个子类准确率计算
correct = list(0. for i in range(args.class_num)) total = list(0. for i in range(args.class_num)) for i, (images, labels) in enumerate(train_loader): images = Variable(images.cuda()) labels = Variable(labels.cuda()) output = model(images) prediction = torch.argmax(output, 1) res = prediction == labels for label_idx in range(len(labels)): label_single = label[label_idx] correct[label_single] += res[label_idx].item() total[label_single] += 1 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total)) for acc_idx in range(len(train_class_correct)): try: acc = correct[acc_idx]/total[acc_idx] except: acc = 0 finally: acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
最新评论