Вероятность тензорного потока — MCMC — проблемы с биекторами в переходном ядре?

#python #tensorflow #mcmc #tensorflow-probability #hierarchical-bayesian

#python #тензорный поток #mcmc #тензорный поток-вероятность #иерархический-байесовский

Вопрос:

Я создаю смесь моделей в tensorflow-probability. Совместное распределение для одной заданной модели:

 one_network_prior = tfd.JointDistributionNamed(
    model=dict(
        mu_g=tfb.Sigmoid(
            low=-1.0,
            high=1.0,
            validate_args=True,
            name="mu_g"
        )(
            tfd.Normal(
                loc=tf.zeros((D,)),
                scale=0.5,
                validate_args=True
            )
        ),
        epsilon=tfd.Gamma(
            concentration=0.4,
            rate=1.0,
            validate_args=True,
            name="epsilon"
        ),
        mu_s=lambda mu_g, epsilon: tfb.Sigmoid(
            low=-1.0,
            high=1.0,
            validate_args=True,
            name="mu_s"
        )(
            tfd.Normal(
                loc=tf.stack(
                    [
                        mu_g
                    ] * S
                ),
                scale=epsilon,
                validate_args=True
            )
        ),
        sigma=tfd.Gamma(
            concentration=0.3,
            rate=1.0,
            validate_args=True,
            name="sigma"
        ),
        mu_s_t=lambda mu_s, sigma: tfb.Sigmoid(
            low=-1.0,
            high=1.0,
            validate_args=True,
            name="mu_s_t"
        )(
            tfd.Normal(
                loc=tf.stack(
                    [
                        mu_s
                    ] * T
                ),
                scale=sigma,
                validate_args=True
            )
        )
    )
)
  

Затем мне нужно «перепутать» модели, но эта смесь довольно нестандартна, я делаю это вручную в пользовательской log_prob_fn логарифмической функции вероятности:

 def log_prob_fn(
    mu_g,
    epsilon,
    mu_s,
    sigma,
    mu_s_t,
    kappa,
    spatial,
    observed
):
    log_probs_per_network = []
    probs_per_network = []
    for l in range(L):
        log_probs_per_network.append(
            tf.reduce_sum(
                one_network_prior.log_prob(
                    {
                        "mu_g": mu_g[l],
                        "epsilon": epsilon[l],
                        "mu_s": mu_s[l],
                        "sigma": sigma[l],
                        "mu_s_t": mu_s_t[l]
                    }
                )
            )
        ) 

        dist = tfb.Sigmoid(
            low=-1.0,
            high=1.0,
            validate_args=True
        )(
            tfd.Normal(
                loc=tf.stack(
                    [
                        mu_s_t[l]
                    ] * N
                ),
                scale=kappa
            )
        )

        probs_per_network.append(
            tf.reduce_prod(            
                dist.prob(
                    observed
                ),
                axis=-1
            )
        )
    
    kappa_log_prob = kappa_prior.log_prob(
        kappa
    )

    mixed_probs = (
        spatial
        *
        tf.stack(
            probs_per_network,
            axis=-1
        )
    )
    margin_prob = tf.reduce_sum(
        mixed_probs,
        axis=-1
    )

    mix_log_prob = tf.reduce_sum(
        tf.math.log(
            margin_prob
        )
    )
    
    return (
        tf.reduce_sum(
            log_probs_per_network
        )
          kappa_log_prob
          mix_log_prob
    )
  

(Я знаю, что эта функция неэффективна, но я не могу легко выполнить выборку из моей предыдущей модели — пакетные формы — поэтому мне пока приходится перебирать модели)

Обратите внимание, что распределение dist создается «на лету» для каждой сети.

Тогда цель состоит в том, чтобы использовать эту модель и подогнать ее к данным. Я сгенерировал начальное состояние, используя one_network_prior , и я вручную смешал данные, чтобы получить (N, T, S, D) наблюдаемые данные, которые будут переданы в MCMC следующим образом:

 hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo(
  target_log_prob_fn=lambda *params: log_prob_fn(
      *params,
      observed=observed
  ),
  step_size=0.065,
  num_leapfrog_steps=5
)

unconstraining_bijectors = [
    tfb.Sigmoid(
        low=-1.0,
        high=1.0
    ),
    tfb.Softplus(),
    tfb.Sigmoid(
        low=-1.0,
        high=1.0
    ),
    tfb.Softplus(),
    tfb.Sigmoid(
        low=-1.0,
        high=1.0
    ),
    tfb.Softplus(),
    tfb.SoftmaxCentered()
]

transformed_kernel = tfp.mcmc.TransformedTransitionKernel(
    inner_kernel=hmc_kernel,
    bijector=unconstraining_bijectors
)

adapted_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
    inner_kernel=transformed_kernel,
    num_adaptation_steps=400,
    target_accept_prob=0.65
)

@tf.function
def run_chain(initial_state, num_results=1000, num_burnin_steps=500):
  return tfp.mcmc.sample_chain(
    num_results=num_results,
    num_burnin_steps=num_burnin_steps,
    current_state=initial_state,
    kernel=adapted_kernel
  )

samples, kernel_results = run_chain(
    initial_state=init_state,
    num_results=20000,  
    num_burnin_steps=5000
)
  

Но когда я запускаю run_chain функцию, после нескольких итераций я получаю ошибку:

 ---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-45-2ff348713067> in <module>
----> 1 samples, kernel_results = run_chain(
      2     initial_state=init_state,
      3     num_results=20000,
      4     num_burnin_steps=5000
      5 )

~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    778       else:
    779         compiler = "nonXla"
--> 780         result = self._call(*args, **kwds)
    781 
    782       new_tracing_count = self._get_tracing_count()

~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    844               *args, **kwds)
    845       # If we did not create any variables the trace we have is good enough.
--> 846       return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access
    847 
    848     def fn_with_cond(*inner_args, **inner_kwds):

~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _filtered_call(self, args, kwargs, cancellation_manager)
   1841       `args` and `kwargs`.
   1842     """
-> 1843     return self._call_flat(
   1844         [t for t in nest.flatten((args, kwargs), expand_composites=True)
   1845          if isinstance(t, (ops.Tensor,

~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1921         and executing_eagerly):
   1922       # No tape is watching; skip to running the function.
-> 1923       return self._build_call_outputs(self._inference_function.call(
   1924           ctx, args, cancellation_manager=cancellation_manager))
   1925     forward_backward = self._select_forward_and_backward_functions(

~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    543       with _InterpolateFunctionError(self):
    544         if cancellation_manager is None:
--> 545           outputs = execute.execute(
    546               str(self.signature.name),
    547               num_outputs=self._num_outputs,

~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     57   try:
     58     ctx.ensure_initialized()
---> 59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:

InvalidArgumentError:  assertion failed: [Argument `scale` must be positive.] [Condition x > 0 did not hold element-wise:] [x (mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/strided_slice_1:0) = ] [-nan]
     [[{{node mcmc_sample_chain/trace_scan/while/body/_415/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/body/_2366/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/body/_3200/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/JointDistributionNamed/log_prob/Normal/assert_positive/assert_less/Assert/AssertGuard/else/_3580/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/JointDistributionNamed/log_prob/Normal/assert_positive/assert_less/Assert/AssertGuard/Assert}}]] [Op:__inference_run_chain_169987]

Function call stack:
run_chain
  

Я понимаю, что отрицательный kappa передается на dist , но после прохождения Softplus биектора это не должно быть возможным? И при инвертировании всех моих биекторов функция все еще работала, что странно, потому что размеры должны быть нарушены из-за SoftmaxCentered .

Итак, у меня такое чувство, что мои биекторы просто игнорируются. Что я пропустил?

Заранее спасибо за помощь 🙂