Как я могу выбрать строки тензора nd с индексами, хранящимися в другом тензоре?

#tensorflow

#tensorflow

Вопрос:

Я пытаюсь разрезать тензор формы (?, 32, 32) вдоль первого измерения. Я должен выбрать две строки с индексами, хранящимися в другом тензоре формы (1, 2) . Я хочу что-то вроде array[list of indexes, :, :] в numpy.

Как я могу это сделать? Мне нужна эта операция для вычисления потерь внутри model_fn функции, переданной моему пользовательскому оценщику Tensorflow.

Ответ №1:

Я решил это с помощью tf.gather_nd . Я изменил форму тензора, содержащего индексы, с помощью:

ids = tf.reshape(tensor_with_indexes, shape=(-1, 1))

и затем я применил:

new_tensor = tf.gather_nd(original_tensor, ids)