Реализация мягкого argmax в tensorflow

#python-3.x #tensorflow #tensorflow2.0 #gradient

#python-3.x #тензорный поток #тензорный поток 2,0 #градиент

Вопрос:

Я пытаюсь реализовать мягкую операцию argmax в Tensorflow. Я хочу, чтобы при прямом проходе был нормальный argmax, а при обратном-приближение softmax. Чтобы быть точным, мой ввод в формате NCHW, где C-2 канала. Проблема, с которой я сейчас сталкиваюсь, заключается в уменьшении размерности из-за нарезки. Мой градиент и вывод имеют одинаковую форму (выход 1 канала), однако они отличаются от ввода, который, насколько я понимаю, не разрешен Tensorflow.

tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 409600 values, but the requested shape has 819200

Моя нынешняя лучшая попытка:

 @tf.custom_gradient def soft_argmax(x):  out_no_grad = tf.argmax(x, axis=1)   @tf.function  def argmax_soft(x):  out = tf.nn.softmax(x, axis=1)  return   def grad(dy):  gradient = tf.gradients(argmax_soft(x), x)[0][:,0,:,:]  return gradient * dy return out_no_grad, grad  

Как я мог бы разработать такую функцию? Обратите внимание, что я специально ищу что-то, что определяет пользовательский градиент или функцию. Я знаю , что мог бы это сделать grad tf.stop_gradient(out_no_grad - grad) , но это не то, чего я хочу.