Как выполнить выборку из нормального распределения случайных значений внутри диапазона с помощью Tensorflow?

#python #tensorflow #tensorflow-probability

#python #tensorflow #tensorflow-вероятность

Вопрос:

У меня есть две переменные mean , stddev которые являются тензорами формы (1,), и они представляют множество нормальных распределений со средним, скажем, средним [i] и stardard deviation stddev [i] .

Из этих распределений я хочу выбрать одно значение в диапазоне в [ low , up ,] для всех, а затем я хочу получить логарифмические вероятности выбранных значений.

Из документов я обнаружил, что этот experimental_sample_and_log_prob метод почти для меня, потому что он не выбирает элементы в диапазоне значений (low, up), которые я хотел бы иметь.

Итак, я закодировал несколько строк, но это работает не очень хорошо, естественно, потому что это так дорого в вычислительном отношении.

 import tensorflow as tf
from tensorflow_probability import distributions as tfd


def sample_and_log_prob(dist, up, down):
    samples = dist.sample()
    accepted = False
    print("Is {} accepted? {}".format(samples, accepted))
    while not accepted:
        # sample < up
        cond1 = tf.less_equal(samples, up)
        # sample > down
        cond2 = tf.greater_equal(samples, down)
        # if down < sample < up
        accepted = tf.logical_and(cond1, cond2) 
        samples = tf.where(
            tf.logical_not(accepted),
            samples,
            dist.sample())
        print("Is {} accepted? {}".format(samples, accepted))
    
    return samples, dist.log_prob(samples)


distribution = tfd.Normal(
    loc=mean ,
    scale=stddev,
    validate_args=True,
    allow_nan_stats=False)

samples, log_probs = sample_and_log_prob(distribution, up=-1, down=1)
 

Какие-либо советы по ее решению?

Комментарии:

1. «это работает не очень хорошо» — это не программный оператор. Что это значит? В чем конкретная проблема, с которой вы столкнулись?

2. Я имею в виду, что это очень дорого с точки зрения вычислений.

Ответ №1:

Похоже, вам нужен TruncatedNormal дистрибутив.