Операция, обратная заполнению в Jax

#pytorch #padding #jax

Вопрос:

Я пытаюсь научиться использовать Jax и наткнулся на проблему преобразования torch.nn.functionnal.pad функции в Jax. Существует функция для выполнения заполнения, но я хотел бы так же, как и в PyTorch, использовать отрицательные числа в заполнении (например F.pad(array, [-1,-1]) ).

У кого-нибудь есть идея или была такая же проблема ?

Ответ №1:

jax.lax.pad Функция принимает отрицательные индексы заполнения, хотя API немного отличается от torch.nn.functional.pad API . Например:

 from jax import lax
import jax.numpy as jnp

x = jnp.ones((2, 3))
y = lax.pad(x, padding_config=[(0, 0, 0), (1, 1, 0)], padding_value=0.0)
print(y)
# [[0. 1. 1. 1. 0.]
#  [0. 1. 1. 1. 0.]]

x = lax.pad(y, padding_config=[(0, 0, 0), (-1, -1, 0)], padding_value=0.0)
print(x)
# [[1. 1. 1.]
#  [1. 1. 1.]]
 

Если вы хотите, вы можете обернуть это функцией, которая имеет семантику, аналогичную версии torch. Вот небольшая попытка:

 def jax_pad(input, pad, mode='constant', value=0):
  """JAX implementation of torch.nn.functional.pad

  Warning: this has not been thoroughly tested!
  """
  if mode != 'constant':
    raise NotImplementedError("Only mode='constant' is implemented")
  assert len(pad) % 2 == 0
  assert len(pad) // 2 <= input.ndim
  pad = list(zip(*[iter(pad)]*2))
  pad  = [(0, 0)] * (input.ndim - len(pad))
  return lax.pad(
      input,
      padding_config=[(i, j, 0) for i, j in pad[::-1]],
      padding_value=jnp.array(value, input.dtype))

x = jnp.ones((2, 3))
y = jax_pad(x, (1, 1))
print(y)
# [[0. 1. 1. 1. 0.]
#  [0. 1. 1. 1. 0.]]

x = jax_pad(y, (-1, -1))
print(x)
# [[1. 1. 1.]
#  [1. 1. 1.]]