#pytorch #padding #jax
Вопрос:
Я пытаюсь научиться использовать Jax и наткнулся на проблему преобразования torch.nn.functionnal.pad
функции в Jax. Существует функция для выполнения заполнения, но я хотел бы так же, как и в PyTorch, использовать отрицательные числа в заполнении (например F.pad(array, [-1,-1])
).
У кого-нибудь есть идея или была такая же проблема ?
Ответ №1:
jax.lax.pad
Функция принимает отрицательные индексы заполнения, хотя API немного отличается от torch.nn.functional.pad
API . Например:
from jax import lax
import jax.numpy as jnp
x = jnp.ones((2, 3))
y = lax.pad(x, padding_config=[(0, 0, 0), (1, 1, 0)], padding_value=0.0)
print(y)
# [[0. 1. 1. 1. 0.]
# [0. 1. 1. 1. 0.]]
x = lax.pad(y, padding_config=[(0, 0, 0), (-1, -1, 0)], padding_value=0.0)
print(x)
# [[1. 1. 1.]
# [1. 1. 1.]]
Если вы хотите, вы можете обернуть это функцией, которая имеет семантику, аналогичную версии torch. Вот небольшая попытка:
def jax_pad(input, pad, mode='constant', value=0):
"""JAX implementation of torch.nn.functional.pad
Warning: this has not been thoroughly tested!
"""
if mode != 'constant':
raise NotImplementedError("Only mode='constant' is implemented")
assert len(pad) % 2 == 0
assert len(pad) // 2 <= input.ndim
pad = list(zip(*[iter(pad)]*2))
pad = [(0, 0)] * (input.ndim - len(pad))
return lax.pad(
input,
padding_config=[(i, j, 0) for i, j in pad[::-1]],
padding_value=jnp.array(value, input.dtype))
x = jnp.ones((2, 3))
y = jax_pad(x, (1, 1))
print(y)
# [[0. 1. 1. 1. 0.]
# [0. 1. 1. 1. 0.]]
x = jax_pad(y, (-1, -1))
print(x)
# [[1. 1. 1.]
# [1. 1. 1.]]