python numpy фильтрует ndarray на основе значения элемента

#python #numpy #numpy-ndarray

#python #numpy #numpy-ndarray

Вопрос:

У меня есть numpy ndarray, как показано ниже. Я хочу фильтровать строки, столбцы, где четвертая координата не равна 1. то есть, где ndarary[0][0][-1] != 1

 >>> print(ndarray)
array([[[0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        ...,
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 0, 0, 1]],

       [[0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        ...,
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 0, 0, 1]],

       [[0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        ...,
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 0, 0, 1]],

       ...,

       [[0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        ...,
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 0, 0, 1]],

       [[0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        ...,
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 0, 0, 1]],

       [[0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        ...,
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 0, 0, 1]]], dtype=uint8)
  

Код, который я пробовал и который работал:

 row_cols = []

for ir, row in enumerate(ndarray):
  for ic, col in enumerate(row):
    if col[-1] != 1:
      row_cols.append((ir,ic))
  

Но это решение O (N ^ 2) и занимает много времени, поскольку ndarray имеет форму (800,1280 * 4), и я должен выполнить это для нескольких тысяч массивов.

Есть ли лучший способ фильтрации?

Ответ №1:

Операционная функция numpy slicing np.where поможет вам:

 np.random.seed(2020)
array = np.random.randint(0, 3, 36).reshape([4, 3, 3])
  

где array :

 array([[[0, 0, 2],
    [1, 0, 1],
    [0, 0, 0]],
   [[2, 1, 2],
    [2, 2, 1],
    [0, 0, 0]],
   [[0, 2, 0],
    [1, 1, 1],
    [2, 1, 2]],
   [[1, 1, 2],
    [2, 2, 2],
    [1, 0, 2]]])
  

Результаты вашего кода:

 [(0, 0), (0, 2), (1, 0), (1, 2), (2, 0), (2, 2), (3, 0), (3, 1), (3, 2)]
  

Используя нарезку и np.where :

 simple_array = array[..., -1]
ir, ic = np.where(simple_array != 1)
  

Это ir :

 array([0, 0, 1, 1, 2, 2, 3, 3, 3], dtype=int64)
  

Это ic :

 array([0, 2, 0, 2, 0, 2, 0, 1, 2], dtype=int64)
  

Производительность:

 import numpy as np
from time import time

array = np.random.randint(0, 4, 800 * 1280 * 4).reshape([800, 1280, 4])
start = time()
row_cols = []

for ir, row in enumerate(array):
    for ic, col in enumerate(row):
        if col[-1] != 1:
            row_cols.append((ir, ic))
print(time() - start)  # 0.6560261249542236
start = time()
simple_array = array[..., -1]
ir, ic = np.where(simple_array != 1)
print(time() - start)  # 0.02800583839416504