Как заменить определенные значения на основе условия с помощью tf.where()

#python #tensorflow

#python #тензорный поток

Вопрос:

Я хотел бы заменить значения в соответствии с условием.
Версия NumPy будет выглядеть следующим образом

 intensity=np.where(
  np.abs(intensity)<1e-4,
  1e-4,
  intensity)
  

Но TensorFlow использует немного другое значение для tf.where()
Когда я попробовал это

 intensity=tf.where(
  tf.math.abs(intensity)<1e-4,
  1e-4,
  intensity)
  

Я получил эту ошибку

 ValueError: Shapes must be equal rank, but are 0 and 4 for 'Select' (op: 'Select') with input shapes: [?,512,512,1], [], [?,512,512,1].
  

Означает ли это, что я должен использовать 4-мерный тензор для 1e-4 ?

Ответ №1:

Следующий код передал ошибку

 # Create an array which has small value (1e-4),  
# whose shape is (2,512,512,1)
small_val=np.full((2,512,512,1),1e-4).astype("float32")

# Convert numpy array to tf.constant
small_val=tf.constant(small_val)

# Use tf.where()
intensity=tf.where(
  tf.math.abs(intensity)<1e-4,
  small_val,
  intensity)

# Error doesn't occur
print(intensity.shape)
# (2, 512, 512, 1)
  

Комментарии:

1. Вы также можете сделать это напрямую small_val = tf.full(tf.shape(intensity), tf.constant(1e-4, dtype=intensity.dtype)) или даже просто small_val = 1e-4 tf.zeros_(intensity) .