Тензорный поток отфильтровывает тензоры без нуля

#python-3.x #tensorflow #tensor

Вопрос:

У меня есть пакетные тензоры X и Y вроде этого

 X = tf.constant([[[1,-2], [2,0],  [-2,2], [4,-1]],
                 [[3,1],  [4,1],  [**0**,1], [-5,3]],
                 [[5,-4], [6,-2], [-2,1], [-2,2]]], dtype=tf.float16)
Y = tf.constant([[1], [43], [2]])
 

X в действительности имеет измерение TensorShape([512, 30, 57]) .

Я хочу отфильтровать элементы в измерении 0, которые имеют нуль в любом из первых элементов в измерении 2 (проверьте выделенный ноль выше).

 X = tf.constant([[[1,-2], [2,0],  [-2,2], [4,-1]],
                 [[5,-4], [6,-2], [-2,1], [-2,2]]], dtype=tf.float16)
Y = tf.constant([[1], [2]])
 

На данный момент у меня есть следующий код

 idx = [k for k, v in enumerate(X) if 0 not in v[:, 0]]
X_clean = [X[x, :, :] for x in idx]
X_clean = tensorflow.stack(X_clean)
Y_clean = tensorflow.stack([Y[x] for x in idx])
 

Это очень медленно, и для каждой итерации требуется около 2 секунд. Как я могу ускорить эту работу?

Ответ №1:

Вы можете достичь более эффективного решения , используя tf.where tf.reduce_all и tf.gather :

 # getting the index of the valid elements batch wise
# X[...,0]!=0 checks that the first element in the last dimension is not 0 
# reduce_all cheks that this is true for every element along dimension 1 
# where gives the index of those valid elements
valid_element_idxs = tf.squeeze(tf.where(tf.reduce_all(X[...,0]!=0,axis=-1)))
X_clean = tf.gather(X, valid_element_idxs)
Y_clean = tf.gather(Y, valid_element_idxs)
 

Сравнивая ваш подход и этот с %времени на 2 небольших тензорах, которые вы привели в качестве примера:

 >>> %timeit list_comp(X,Y)
2.82 ms ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit tf_native(X,Y)
263 µs ± 2.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
 

Вы можете немного повысить производительность, используя tf.function :

 >>> %timeit tf_native_decorated(X,Y)
206 µs ± 6.31 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
 

Определение функций для справки:

 def list_comp(X,Y):
    idx = [k for k, v in enumerate(X) if 0 not in v[:, 0]]
    X_clean = [X[x, :, :] for x in idx]
    X_clean = tf.stack(X_clean)
    Y_clean = tf.stack([Y[x] for x in idx])
    return X_clean, Y_clean

def tf_native(X,Y):
    valid_elements_idx = tf.squeeze(tf.where(tf.reduce_all(X[...,0]!=0,axis=-1)))
    X_clean = tf.gather(X, valid_elements_idx)
    Y_clean = tf.gather(Y, valid_elements_idx)
    return X_clean, Y_clean

@tf.function
def tf_native_decorated(X,Y):
    valid_elements_idx = tf.squeeze(tf.where(tf.reduce_all(X[...,0]!=0,axis=-1)))
    X_clean = tf.gather(X, valid_elements_idx)
    Y_clean = tf.gather(Y, valid_elements_idx)
    return X_clean, Y_clean