#python #numpy #numpy-ndarray
#python #numpy #numpy-ndarray
Вопрос:
У меня есть numpy ndarray, как показано ниже. Я хочу фильтровать строки, столбцы, где четвертая координата не равна 1. то есть, где ndarary[0][0][-1] != 1
>>> print(ndarray)
array([[[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1],
...,
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1],
...,
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1],
...,
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1]],
...,
[[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1],
...,
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1],
...,
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1],
...,
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1]]], dtype=uint8)
Код, который я пробовал и который работал:
row_cols = []
for ir, row in enumerate(ndarray):
for ic, col in enumerate(row):
if col[-1] != 1:
row_cols.append((ir,ic))
Но это решение O (N ^ 2) и занимает много времени, поскольку ndarray имеет форму (800,1280 * 4), и я должен выполнить это для нескольких тысяч массивов.
Есть ли лучший способ фильтрации?
Ответ №1:
Операционная функция numpy slicing np.where
поможет вам:
np.random.seed(2020)
array = np.random.randint(0, 3, 36).reshape([4, 3, 3])
где array
:
array([[[0, 0, 2],
[1, 0, 1],
[0, 0, 0]],
[[2, 1, 2],
[2, 2, 1],
[0, 0, 0]],
[[0, 2, 0],
[1, 1, 1],
[2, 1, 2]],
[[1, 1, 2],
[2, 2, 2],
[1, 0, 2]]])
Результаты вашего кода:
[(0, 0), (0, 2), (1, 0), (1, 2), (2, 0), (2, 2), (3, 0), (3, 1), (3, 2)]
Используя нарезку и np.where
:
simple_array = array[..., -1]
ir, ic = np.where(simple_array != 1)
Это ir
:
array([0, 0, 1, 1, 2, 2, 3, 3, 3], dtype=int64)
Это ic
:
array([0, 2, 0, 2, 0, 2, 0, 1, 2], dtype=int64)
Производительность:
import numpy as np
from time import time
array = np.random.randint(0, 4, 800 * 1280 * 4).reshape([800, 1280, 4])
start = time()
row_cols = []
for ir, row in enumerate(array):
for ic, col in enumerate(row):
if col[-1] != 1:
row_cols.append((ir, ic))
print(time() - start) # 0.6560261249542236
start = time()
simple_array = array[..., -1]
ir, ic = np.where(simple_array != 1)
print(time() - start) # 0.02800583839416504