Очень медленная jit-компиляция для XLA при использовании jax

#compilation #gpu #jit #xla #jax

#Сборник #графический процессор #jit #xla #jax

Вопрос:

Я использую Jax для выполнения некоторых заданий машинного обучения. Jax использует XLA для выполнения некоторой своевременной компиляции для ускорения, но сама компиляция слишком медленная для процессора. Моя ситуация такова, что процессор будет использовать только одно ядро для выполнения компиляции, что совсем неэффективно.

Я нашел несколько ответов о том, что это может быть очень быстро, если я могу использовать GPU для компиляции. Кто-нибудь может сказать мне, как использовать GPU для выполнения части компиляции? Поскольку я не делал никаких настроек для компиляции. Спасибо!

Некоторое дополнение к вопросу: я использую Jax для вычисления grad и hessian, что делает компиляцию очень медленной. Код похож:

     ## get results from model ##
    def get_model_value(images):
        return jnp.sum(model(images))

    def get_model_grad(images):
        images = jnp.expand_dims(images, axis=0)
        image_grad = jacfwd(get_model_value)(images)
        return image_grad
    
    def get_model_hessian(images):
        images = jnp.expand_dims(images, axis=0)
        image_hess = jacfwd(jacrev(get_model_value))(images)
        return image_hess
  
    # get value
    model_value = model(dis_img)
    FR_value = jnp.expand_dims(FR_value, axis=1)
    value_loss = crit_mse(model_value, FR_value)
    
    # get grad
    vmap_model_grad = jax.vmap(get_model_grad)
    model_grad = vmap_model_grad(dis_img)
    
    # get hessian
    vmap_model_hessian = vmap(get_model_hessian)
    model_hessian = vmap_model_hessian(dis_img)
  

Комментарии:

1. Часто медленные компиляции можно устранить, написав свой код способом, более подходящим для модели вычислений JAX. Можете ли вы поделиться примером кода, который компилируется слишком медленно?

2. Конечно, я только что привел пример своего кода. Не могли бы вы, пожалуйста, помочь с этим?