Выход модели Pytorch — только сохраняйте оценки выше 0,3

#python #pandas #numpy #pytorch

Вопрос:

У меня есть большой список (на самом деле только один элемент с 3 словарями), как показано ниже. Это результат предварительной подготовки модели pytorch для одного экземпляра набора тестов. В списке есть три атрибута (поля, метки, оценки), все они имеют тип tensor. Каждая коробка имеет соответствующую оценку и этикетку. Всего здесь 100 коробок. Есть ли какой-нибудь быстрый способ сохранить только те коробки, этикетки и оценки, где оценка превышает 0,3? Таким образом, в этом примере должно быть только 5 коробок с соответствующей оценкой и этикеткой.

    output = [{'boxes': tensor([[0.0000e 00, 2.9095e 01, 7.3249e 01, 1.1387e 02],
              [7.8610e 01, 1.9392e 01, 1.6580e 02, 1.0291e 02],
              [3.6086e-01, 2.9609e 01, 1.0292e 02, 2.0285e 02],
              [1.8569e 02, 2.3418e 01, 2.4397e 02, 1.4092e 02],
              [1.9678e-03, 0.0000e 00, 5.8328e 01, 1.7467e 02],
              [1.4161e 02, 1.5196e 02, 2.2797e 02, 2.3690e 02],
              [1.5630e 02, 5.4246e 01, 2.1178e 02, 1.7170e 02],
              [5.3407e 01, 6.4962e 01, 1.0892e 02, 1.8180e 02],
              [1.0011e 02, 1.5188e 02, 1.8732e 02, 2.3737e 02],
              [1.5080e 02, 3.9219e 01, 2.3776e 02, 1.2494e 02],
              [8.9806e 01, 1.3143e 02, 1.7610e 02, 2.1669e 02],
              [1.3518e 02, 1.2713e 02, 1.9257e 02, 2.4350e 02],
              [1.1423e 02, 1.4989e 01, 1.7153e 02, 1.3093e 02],
              [7.9036e 01, 1.1927e 00, 1.9153e 02, 1.7694e 02],
              [8.4356e 01, 2.3523e 01, 1.4035e 02, 1.4181e 02],
              [6.9645e 01, 1.5251e 02, 1.5582e 02, 2.3697e 02],
              [1.4163e 02, 1.2086e 02, 2.2753e 02, 2.0553e 02],
              [8.3618e 01, 1.0583e 02, 1.4110e 02, 2.2334e 02],
              [3.2450e-01, 7.1444e 01, 7.1565e 01, 1.5488e 02],
              [7.2167e 00, 4.9198e 01, 9.3541e 01, 1.3515e 02],
              [3.8690e 01, 3.7546e 01, 1.2640e 02, 1.2457e 02],
              [1.0393e 02, 8.4865e 01, 1.6160e 02, 2.0193e 02],
              [9.6637e 00, 1.2074e 02, 9.1465e 01, 2.0829e 02],
              [2.6140e 00, 8.6522e 01, 5.9267e 01, 2.0357e 02],
              [1.6260e 02, 6.0580e 01, 2.4744e 02, 1.4646e 02],
              [1.7624e 02, 7.5614e 01, 2.3260e 02, 1.9287e 02],
              [1.2096e 02, 2.8686e 01, 2.0757e 02, 1.1476e 02],
              [1.0993e 02, 1.1107e 02, 1.9594e 02, 1.9697e 02],
              [3.8821e 01, 1.6499e 00, 1.5277e 02, 1.7680e 02],
              [1.3592e 02, 1.7006e 00, 2.5177e 02, 1.7528e 02],
              [4.3270e 01, 8.5363e 01, 9.8090e 01, 2.0313e 02],
              [3.9082e 01, 1.6281e 02, 1.2582e 02, 2.4565e 02],
              [1.0941e 02, 9.5967e 00, 1.9760e 02, 9.2006e 01],
              [9.4279e 01, 5.4012e 01, 1.5129e 02, 1.7273e 02],
              [1.7610e 02, 1.1657e 02, 2.3257e 02, 2.3306e 02],
              [1.8356e 02, 1.2347e 02, 2.5438e 02, 2.0546e 02],
              [1.4145e 00, 7.8904e 01, 7.9495e 01, 2.5600e 02],
              [5.7602e 01, 9.8933e 01, 1.7596e 02, 2.5600e 02],
              [1.3184e 02, 9.0243e 01, 2.1412e 02, 1.7617e 02],
              [1.3507e 02, 4.4525e 01, 1.9094e 02, 1.6303e 02],
              [1.0465e 01, 1.5353e 02, 9.2553e 01, 2.3779e 02],
              [1.8336e 02, 1.5361e 02, 2.5428e 02, 2.3656e 02],
              [1.9591e 02, 7.6679e 01, 2.5259e 02, 1.9309e 02],
              [9.8446e 01, 7.9284e 01, 2.1062e 02, 2.5600e 02],
              [1.3960e 02, 9.7938e 00, 2.2880e 02, 9.1881e 01],
              [9.0553e 01, 0.0000e 00, 1.7578e 02, 7.1238e 01],
              [1.8702e 00, 1.2331e 02, 4.8407e 01, 2.4824e 02],
              [0.0000e 00, 4.9428e-01, 1.2664e 02, 1.4013e 02],
              [7.9054e 01, 1.5236e 00, 2.3079e 02, 1.0113e 02],
              [1.6006e 02, 6.4527e 01, 2.5467e 02, 2.5600e 02],
              [0.0000e 00, 1.7543e 02, 1.8282e 02, 2.5560e 02],
              [1.7264e 00, 1.7961e 02, 7.0644e 01, 2.5600e 02],
              [1.8063e 02, 9.7504e 00, 2.5516e 02, 9.3108e 01],
              [5.0636e 01, 1.3299e 02, 1.3221e 02, 2.1604e 02],
              [3.1850e 01, 5.4289e 01, 8.8847e 01, 1.7312e 02],
              [2.0640e 02, 3.2171e 01, 2.5485e 02, 1.5160e 02],
              [1.9062e 01, 1.2459e 00, 1.7152e 02, 1.0271e 02],
              [8.0108e 01, 1.8195e 02, 1.6510e 02, 2.5523e 02],
              [6.4087e 00, 1.3263e 00, 9.6138e 01, 7.0220e 01],
              [2.1170e 01, 2.3619e 00, 7.9999e 01, 1.0748e 02],
              [5.7921e 01, 6.5922e-01, 1.4736e 02, 8.1241e 01],
              [1.1025e 02, 1.8136e 02, 1.9421e 02, 2.5513e 02],
              [6.1567e 01, 1.7640e 02, 2.5535e 02, 2.5600e 02],
              [3.9355e 01, 3.4047e 00, 1.0319e 02, 8.7065e 01],
              [5.0878e 01, 1.0217e 02, 1.3312e 02, 1.8664e 02],
              [7.4605e 01, 5.4398e 01, 1.2882e 02, 1.7320e 02],
              [1.7292e 02, 1.8397e 02, 2.5249e 02, 2.5558e 02],
              [1.8037e 01, 9.5900e 01, 1.3505e 02, 2.5211e 02],
              [1.4013e 02, 1.9098e 02, 2.2696e 02, 2.5415e 02],
              [6.2275e 01, 8.4387e 01, 9.0523e 01, 1.4154e 02],
              [1.7307e 01, 1.9287e 02, 1.1007e 02, 2.5532e 02],
              [7.2651e 01, 6.8909e 01, 1.0101e 02, 1.2587e 02],
              [1.6461e 02, 7.4065e 01, 1.9301e 02, 1.3071e 02],
              [7.7585e 01, 5.3052e 01, 1.0652e 02, 1.0954e 02],
              [1.6948e 02, 6.2818e 01, 1.9857e 02, 1.2058e 02],
              [5.7015e 01, 1.0044e 02, 8.5367e 01, 1.5666e 02],
              [7.7270e 01, 8.4819e 01, 1.0621e 02, 1.4148e 02],
              [6.7998e-01, 5.9344e 01, 3.2433e 01, 1.0408e 02],
              [5.7399e 01, 5.8315e 01, 8.5744e 01, 1.1517e 02],
              [1.5450e 02, 5.2688e 01, 1.8301e 02, 1.1024e 02],
              [6.7396e 01, 5.3385e 01, 9.6198e 01, 1.0940e 02],
              [5.2431e 01, 1.1456e 02, 8.0534e 01, 1.7151e 02],
              [5.2424e 01, 7.2901e 01, 8.0602e 01, 1.3118e 02],
              [5.4905e 01, 7.4931e 01, 9.8882e 01, 1.1913e 02],
              [1.7986e 02, 6.2481e 01, 2.0913e 02, 1.2018e 02],
              [6.7338e 01, 1.0491e 02, 9.5506e 01, 1.6148e 02],
              [1.7451e 02, 8.3800e 01, 2.0362e 02, 1.4150e 02],
              [4.9071e 01, 4.8894e 01, 9.4031e 01, 9.3116e 01],
              [1.0840e 00, 1.1331e 01, 3.7614e 01, 1.2899e 02],
              [8.2344e 01, 6.8916e 01, 1.1157e 02, 1.2634e 02],
              [1.6138e 02, 5.9034e 01, 2.0742e 02, 1.0366e 02],
              [1.4473e 01, 2.2719e 01, 2.1446e 02, 1.3569e 02],
              [4.7128e 01, 8.9411e 01, 7.5240e 01, 1.4684e 02],
              [1.8501e 02, 1.1383e 02, 2.1348e 02, 1.7167e 02],
              [5.4657e 01, 1.2145e 02, 9.8843e 01, 1.6540e 02],
              [3.7407e 00, 2.7927e 01, 4.8678e 01, 7.2360e 01],
              [1.6647e 02, 8.0292e 01, 2.1177e 02, 1.2410e 02],
              [3.4396e 01, 6.4177e 01, 7.8244e 01, 1.0873e 02],
              [8.7888e 01, 5.2212e 01, 1.1652e 02, 1.1079e 02],
              [1.7443e 02, 1.1404e 02, 2.0353e 02, 1.7198e 02]], device='cuda:0'),
      'labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
              1, 1, 1, 1], device='cuda:0'),
      'scores': tensor([0.3317, 0.3235, 0.3208, 0.3157, 0.3108, 0.2977, 0.2974, 0.2955, 0.2938,
              0.2904, 0.2902, 0.2836, 0.2797, 0.2794, 0.2787, 0.2784, 0.2782, 0.2746,
              0.2739, 0.2695, 0.2658, 0.2655, 0.2647, 0.2628, 0.2622, 0.2596, 0.2596,
              0.2593, 0.2591, 0.2584, 0.2574, 0.2559, 0.2550, 0.2529, 0.2526, 0.2429,
              0.2428, 0.2408, 0.2397, 0.2381, 0.2370, 0.2344, 0.2302, 0.2296, 0.2292,
              0.2260, 0.2258, 0.2252, 0.2201, 0.2166, 0.2125, 0.2063, 0.2056, 0.2054,
              0.2050, 0.2032, 0.2023, 0.2021, 0.1985, 0.1956, 0.1943, 0.1776, 0.1739,
              0.1708, 0.1700, 0.1665, 0.1657, 0.1595, 0.1588, 0.1561, 0.1553, 0.1553,
              0.1484, 0.1426, 0.1419, 0.1416, 0.1289, 0.1265, 0.1250, 0.1248, 0.1226,
              0.1219, 0.1216, 0.1208, 0.1197, 0.1186, 0.1182, 0.1164, 0.1164, 0.1157,
              0.1133, 0.1109, 0.1097, 0.1086, 0.1055, 0.1055, 0.1054, 0.1047, 0.1026,
              0.1020], device='cuda:0')}]
 

Ответ №1:

Значит, вы ищете что-то подобное?

 mask = output[0]['scores'] > 0.3
for key,val in output[0].items():
    output[0][key] = val[mask]
output[0]
 
 {'boxes': tensor([[0.0000e 00, 2.9095e 01, 7.3249e 01, 1.1387e 02],
         [7.8610e 01, 1.9392e 01, 1.6580e 02, 1.0291e 02],
         [3.6086e-01, 2.9609e 01, 1.0292e 02, 2.0285e 02],
         [1.8569e 02, 2.3418e 01, 2.4397e 02, 1.4092e 02],
         [1.9678e-03, 0.0000e 00, 5.8328e 01, 1.7467e 02]]),
 'labels': tensor([1, 1, 1, 1, 1]),
 'scores': tensor([0.3317, 0.3235, 0.3208, 0.3157, 0.3108])}
 

Ответ №2:

Вы можете просмотреть словари и применить маску к трем тензорам:

 result = []
for d in output:
  boxes, labels, scores = d['boxes'], d['labels'], d['scores']
  m = scores > .3
  result.append(dict(boxes=boxes[m], labels=labels[m], scores=scores[m]))
 

Или с помощью диктанта:

 result = []
for d in output:
  m = d['scores'] > .3
  result.append({k: v[m] for k, v in d.items()})
 

Вы получите:

 >>> result
[{'boxes': tensor([[0.0000e 00, 2.9095e 01, 7.3249e 01, 1.1387e 02],
                   [7.8610e 01, 1.9392e 01, 1.6580e 02, 1.0291e 02],
                   [3.6086e-01, 2.9609e 01, 1.0292e 02, 2.0285e 02],
                   [1.8569e 02, 2.3418e 01, 2.4397e 02, 1.4092e 02],
                   [1.9678e-03, 0.0000e 00, 5.8328e 01, 1.7467e 02]]),
  'labels': tensor([1, 1, 1, 1, 1]),
  'scores': tensor([0.3317, 0.3235, 0.3208, 0.3157, 0.3108])}]
 

Комментарии:

1. Разве не довольно расточительно вычислять маску 3 раза?

2. @user2640045 Я вычисляю маску dict не для каждого ключа, так что только один раз. (мой for цикл проходит по словарям — в приведенном примере OP есть только один диктант)

3. О, ты прав. Твоя петля » за » только по одной вещи сбила меня с толку. Кстати, зачем тебе это делать?

4. «У меня большой список (на самом деле только один элемент с 3 словарями) «, возможно, ОП имел в виду только один элемент с 3 элементами