PyMC3 не удается передать правильные размеры для вывода

#python #matrix #theano #bayesian #pymc3

#python #матрица #theano #байесовский #pymc3

Вопрос:

Я пытаюсь распространить идеи теории ответов на несколько ответов. Рассмотрим маркетинговый опрос, в котором клиентам задается вопрос: «Что является решающим фактором при покупке продукта X или нет?» Где ответы {0: цена, 1: долговечность, 2: простота использования}.

Вот некоторые синтетические данные (строки — это клиенты, столбцы — продукты, каждая ячейка — это ответ класса.)

 responses = np.array([ 
          [0,1,2,1,0],
          [1,1,1,1,1],
          [0,0,2,2,1],
          [1,1,2,2,1],
          [1,1,0,0,0]  
    ])

students = 5
questions = 5
categories = 3

with pm.Model() as model:
    z_student = pm.Normal("z_student", mu=0, sigma=1, shape=(students,categories))
    z_question = pm.Normal("z_question",mu=0, sigma=1, shape=(categories,questions))
    
    # Transformed parameter
    theta = pm.Deterministic("theta", tt.nnet.softmax(z_student - z_question))
     
    # Likelihood
    kij = pm.Categorical("kij", p=theta, observed=responses)
    trace = pm.sample(chains=4)

az.plot_trace(trace, var_names=["z_student", "z_question"], compact=False);
 

Этот код выдает следующую ошибку: ValueError: Input dimension mis-match. (input[0].shape[0] = 5, input[1].shape[0] = 3) .

Однако, когда я меняю строку theta на: theta = pm.Deterministic("theta", tt.nnet.softmax(z_student - z_question.transpose())) сэмплер не выходит из строя мгновенно, скорее, это неправильные выборки.

 az.summary(trace)

mean    sd  hdi_3%  hdi_97% mcse_mean   mcse_sd ess_mean    ess_sd  ess_bulk    ess_tail    r_hat
z_student[0,0]  0.150   0.893   -1.620  1.752   0.012   0.013   5789.0  2327.0  5771.0  2991.0  1.0
z_student[0,1]  0.393   0.879   -1.319  1.980   0.012   0.012   5150.0  2610.0  5153.0  3195.0  1.0
z_student[0,2]  -0.591  0.915   -2.254  1.108   0.011   0.012   6408.0  2737.0  6415.0  2830.0  1.0
z_student[1,0]  -0.064  0.860   -1.676  1.538   0.011   0.014   5748.0  1942.0  5747.0  2850.0  1.0
z_student[1,1]  0.602   0.864   -0.982  2.185   0.012   0.011   4921.0  3028.0  4920.0  3269.0  1.0
z_student[1,2]  -0.548  0.906   -2.218  1.137   0.012   0.012   6076.0  2870.0  6083.0  3410.0  1.0
z_student[2,0]  -0.166  0.907   -1.974  1.450   0.013   0.014   4681.0  2121.0  4692.0  3108.0  1.0
z_student[2,1]  -0.188  0.875   -1.776  1.472   0.011   0.014   5923.0  2073.0  5945.0  3333.0  1.0
z_student[2,2]  0.344   0.865   -1.288  1.951   0.012   0.012   4828.0  2750.0  4822.0  3039.0  1.0
z_student[3,0]  -0.212  0.892   -1.980  1.395   0.011   0.013   6019.0  2504.0  5996.0  3391.0  1.0
z_student[3,1]  0.097   0.876   -1.573  1.713   0.012   0.013   5304.0  2252.0  5332.0  2971.0  1.0
z_student[3,2]  0.096   0.851   -1.583  1.645   0.011   0.012   5554.0  2678.0  5543.0  3288.0  1.0
z_student[4,0]  0.160   0.881   -1.367  1.947   0.012   0.013   5421.0  2189.0  5413.0  2927.0  1.0
z_student[4,1]  0.414   0.863   -1.255  2.026   0.012   0.012   4900.0  2548.0  4897.0  3248.0  1.0
z_student[4,2]  -0.558  0.901   -2.266  1.130   0.011   0.012   6551.0  2728.0  6582.0  3142.0  1.0
z_question[0,0] -0.179  0.883   -1.795  1.488   0.011   0.015   6317.0  1769.0  6315.0  3389.0  1.0
z_question[0,1] 0.107   0.886   -1.511  1.807   0.012   0.013   5236.0  2431.0  5209.0  3503.0  1.0
z_question[0,2] 0.164   0.878   -1.450  1.834   0.012   0.013   5131.0  2248.0  5106.0  3102.0  1.0
z_question[0,3] 0.186   0.904   -1.450  1.882   0.011   0.014   6228.0  2175.0  6219.0  3335.0  1.0
z_question[0,4] -0.187  0.877   -1.790  1.508   0.011   0.014   5819.0  2089.0  5834.0  3198.0  1.0
z_question[1,0] -0.389  0.849   -1.948  1.219   0.012   0.012   4726.0  2494.0  4713.0  3146.0  1.0
z_question[1,1] -0.600  0.858   -2.249  0.946   0.012   0.011   5093.0  3247.0  5116.0  3312.0  1.0
z_question[1,2] 0.179   0.868   -1.520  1.763   0.012   0.012   5204.0  2514.0  5201.0  3418.0  1.0
z_question[1,3] -0.103  0.862   -1.683  1.561   0.013   0.013   4608.0  2212.0  4615.0  3163.0  1.0
z_question[1,4] -0.381  0.866   -2.047  1.147   0.011   0.012   6181.0  2735.0  6188.0  3038.0  1.0
z_question[2,0] 0.565   0.908   -1.125  2.337   0.012   0.012   6022.0  2879.0  6045.0  3173.0  1.0
z_question[2,1] 0.536   0.923   -1.192  2.241   0.012   0.013   6041.0  2476.0  6046.0  3059.0  1.0
z_question[2,2] -0.325  0.856   -1.918  1.289   0.012   0.012   5429.0  2741.0  5418.0  3004.0  1.0
z_question[2,3] -0.107  0.881   -1.953  1.363   0.012   0.012   5834.0  2545.0  5841.0  3332.0  1.0
z_question[2,4] 0.576   0.910   -1.202  2.253   0.011   0.013   6385.0  2606.0  6371.0  2905.0  1.0
theta[0,0]  0.360   0.173   0.072   0.685   0.003   0.002   4309.0  3774.0  4256.0  2846.0  1.0
theta[0,1]  0.528   0.182   0.208   0.857   0.003   0.002   4949.0  4563.0  4908.0  3050.0  1.0
theta[0,2]  0.113   0.104   0.001   0.304   0.001   0.001   6095.0  4045.0  7146.0  2780.0  1.0
theta[1,0]  0.216   0.144   0.007   0.477   0.002   0.002   6149.0  4576.0  6493.0  3116.0  1.0
theta[1,1]  0.678   0.168   0.381   0.962   0.002   0.002   5954.0  5954.0  6180.0  3320.0  1.0
theta[1,2]  0.107   0.100   0.000   0.294   0.001   0.001   6321.0  3863.0  7623.0  3252.0  1.0
theta[2,0]  0.234   0.150   0.010   0.509   0.002   0.002   6154.0  4352.0  6684.0  3252.0  1.0
theta[2,1]  0.230   0.152   0.005   0.506   0.002   0.001   6885.0  5424.0  6459.0  2923.0  1.0
theta[2,2]  0.536   0.186   0.194   0.858   0.002   0.002   5595.0  5250.0  5622.0  2805.0  1.0
theta[3,0]  0.239   0.157   0.007   0.526   0.002   0.002   5843.0  4627.0  5789.0  2853.0  1.0
theta[3,1]  0.381   0.178   0.065   0.703   0.003   0.002   4927.0  4377.0  5009.0  3315.0  1.0
theta[3,2]  0.380   0.174   0.069   0.692   0.003   0.002   4653.0  4176.0  4624.0  2562.0  1.0
theta[4,0]  0.361   0.175   0.057   0.668   0.002   0.002   5185.0  4637.0  5269.0  2985.0  1.0
theta[4,1]  0.527   0.184   0.186   0.852   0.003   0.002   4614.0  4445.0  4668.0  2497.0  1.0
theta[4,2]  0.111   0.100   0.002   0.303   0.001   0.001   6159.0  3978.0  7520.0  3473.0  1.0
 

Обратите внимание, пожалуйста, обратитесь к изученным значениям theta. Их имена включают: Theta [0,0] …Theta[0,2],…Тета [4,2]. Итак, в первом примере PyMC3 узнал о силе связи между (z_student[0] - z_question[0]) и классом / ответом 0 .

Это не тот эффект, которого я хочу достичь, я хочу изучить 3D-тензор, учитывающий все возможные пары {student, question, category}; должно быть 74 тета, а не 15, где Тета [0,0,0] относится к изученному значению {student_0, question_0, response_0} . Однако мой код в настоящее время не достигает этого эффекта.

Есть идеи?

Редактировать: Совсем недавно я создал функцию в Theano, чтобы продемонстрировать свою цель:

 responses = np.array([ 
          [0,1,2,2,2],
          [0,1,2,1,1],
          [0,1,2,0,0],
          [0,1,2,0,1],
          [0,1,2,1,0]  
    ])

students = 5
questions = 5
categories = 3

a = tensor.matrix()
b = tensor.matrix()
elem_sub = a[0,0] - b[0,0], a[0,1] - b[1,0], a[0,2] - b[2,0]  
function = theano.function([a,b], elem_sub)

with pm.Model() as model:
    z_student = pm.Normal("student_dim1", mu=0, sigma=1, shape=(students,categories))
    z_question = pm.Normal("question_dim1", mu=0, sigma=1, shape=(categories,questions))
    # Transformed parameter
    theta = pm.Deterministic("theta", tt.nnet.softmax(function(z_student,z_question)))   
    # Likelihood
    kij = pm.Categorical("kij", p=theta, observed=responses)
 

Однако возникает следующая ошибка:

 TypeError: Bad input argument with name "z_student" to theano function with name "<ipython-input-2-2a16f255dca1>:23" at index 0 (0-based).  
Backtrace when that variable is created:
.
.
.
Expected an array-like object, but found a Variable: maybe you are trying to call a function on a (possibly shared) variable instead of a numeric array?
 

Комментарии:

1. кроме того, какова связь между действием и тем, как вы ожидаете, что оно будет выглядеть? z_student - z_question это больше похоже на вопрос numpy, чем на вопрос pymc3. z_student - z_question.transpose() имеет форму (5, 3), какую форму вы бы хотели, чтобы она была?

2. @ignoring_gravity, хорошие вопросы — я хочу смоделировать одну тэту для каждого возможного {ученик, вопрос, ответ}, поскольку она должна иметь форму (5,5,3) с общим количеством переменных 75. PyMC3 в настоящее время смотрит только по диагонали z_student[0] - z_question[0] z_student[1] - z_question[1] и т.д. (Таким образом, это не учитывает производительность ученика 0 по вопросу 1 или любой другой недиагональной паре.) Есть ли в этом смысл?

3. ХОРОШО — для заданного z_student и заданного z_question можете ли вы показать, как бы вы хотели, чтобы этот 3D-тензор выглядел?

4. @ignoring_gravity, да — [0,0.25,0.75] будет представлять тэта-симплекс, соответствующий произвольной паре {студент, вопрос}. В выходных az.summary(trace) данных я хотел бы theta[0,0,0] получить эту информацию (учитывая 0-й z_student, 0-й z_question и 0-й ответ)