Не удается запустить «Многомерное прогнозирование» в учебнике по пиротехнике

#runtime-error #forecasting #pyro

Вопрос:

Я пытаюсь просто запустить примерную программу по адресу https://pyro.ai/examples/forecast_simple.html. Он выполняется до тех пор, пока не достигнет «Ошибка выполнения torch.linalg.cholesky: Для пакета 4284: U(2,2) равно нулю, единственное число U.» Каждый раз, когда я запускаю код, он останавливается в том же месте с «пакетом 4284».

Кто-нибудь может научить меня, как это исправить?

Я использую следующие версии. Python 3.9.1 pyro-api 0.1.2 pyro-ppl 1.7.0 факел 1.19.0 Windows10Pro 64 бит 20H2 VSCode 1.60.0


 INFO     step    0 loss = 7.356
INFO     step   50 loss = 1.87751
INFO     step  100 loss = 1.55338
INFO     step  150 loss = 1.40953
INFO     step  200 loss = 1.31982
INFO     step  250 loss = 1.2017
INFO     step  300 loss = 1.1389
INFO     step  350 loss = 1.10407
INFO     step  400 loss = 1.07474
INFO     step  450 loss = 1.06728
INFO     step  500 loss = 1.0285
DEBUG    crps = 0.59017
DEBUG    mae = 0.866027
DEBUG    num_samples = 100
DEBUG    rmse = 1.02721
DEBUG    seed = 1.23457e 09
DEBUG    t0 = 0
DEBUG    t1 = 2160
DEBUG    t2 = 2496
DEBUG    test_walltime = 0.411458
DEBUG    train_walltime = 28.8177
DEBUG    AutoNormal.locs.obs_corr = -1.62159
DEBUG    AutoNormal.locs.trans_corr = 2.49729
DEBUG    AutoNormal.locs.trans_loc = 0.904184
DEBUG    AutoNormal.scales.obs_corr = 0.207397
DEBUG    AutoNormal.scales.trans_corr = 0.0915508
DEBUG    AutoNormal.scales.trans_loc = 0.0111603
INFO     Training on window [168:2328], testing on window [2328:2664]
INFO     step    0 loss = 7.37245
INFO     step   50 loss = 1.87162
     :
     :
     :
DEBUG    crps = 0.62036
DEBUG    mae = 0.907584
DEBUG    num_samples = 100
DEBUG    rmse = 1.08631
DEBUG    seed = 1.23457e 09
DEBUG    t0 = 1512
DEBUG    t1 = 3672
DEBUG    t2 = 4008
DEBUG    test_walltime = 0.404958
DEBUG    train_walltime = 26.7937
DEBUG    AutoNormal.locs.obs_corr = -0.889496
DEBUG    AutoNormal.locs.trans_corr = 1.85566
DEBUG    AutoNormal.locs.trans_loc = 0.903074
DEBUG    AutoNormal.scales.obs_corr = 0.247679
DEBUG    AutoNormal.scales.trans_corr = 0.0577488
DEBUG    AutoNormal.scales.trans_loc = 0.012068
INFO     Training on window [1680:3840], testing on window [3840:4176]
INFO     step    0 loss = 7.48406
INFO     step   50 loss = 1.92277
INFO     step  100 loss = 1.58563
INFO     step  150 loss = 1.52081
INFO     step  200 loss = 1.44076
INFO     step  250 loss = 1.38033
INFO     step  300 loss = 1.29202
INFO     step  350 loss = 1.26101
INFO     step  400 loss = 1.23141
INFO     step  450 loss = 1.23901
INFO     step  500 loss = 1.21247

RuntimeError: torch.linalg.cholesky: For batch 4284: U(2,2) is zero, singular U.
 

     RuntimeError                              Traceback (most recent call last)
    ~AppDataLocalTemp/ipykernel_16928/3907557438.py in <module>
         15 
         16     args = parser.parse_args()
    ---> 17     main(args)
    
    ~AppDataLocalTemp/ipykernel_16928/4270697941.py in main(args)
         24     }
         25 
    ---> 26     metrics = backtest(
         27         data,
         28         covariates,
    
    c:Users9033113venvlibsite-packagespyrocontribforecastevaluate.py in backtest(data, covariates, model_fn, forecaster_fn, metrics, transform, train_window, min_train_window, test_window, min_test_window, stride, seed, num_samples, batch_size, forecaster_options)
        199         while True:
        200             try:
    --> 201                 pred = forecaster(
        202                     train_data,
        203                     test_covariates,
    
    c:Users9033113venvlibsite-packagespyrocontribforecastforecaster.py in __call__(self, data, covariates, num_samples, batch_size)
        359         :rtype: ~torch.Tensor
        360         """
    --> 361         return super().__call__(data, covariates, num_samples, batch_size)
        362 
        363     def forward(self, data, covariates, num_samples, batch_size=None):
    
    c:Users9033113venvlibsite-packagestorchnnmodulesmodule.py in _call_impl(self, *input, **kwargs)
       1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or
_global_backward_hooks
       1050                 or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1051             return forward_call(*input, **kwargs)
       1052         # Do not call functions when jit is used
       1053         full_backward_hooks, non_full_backward_hooks = [], []
    
    c:Users9033113venvlibsite-packagespyrocontribforecastforecaster.py in forward(self, data, covariates, num_samples, batch_size)
        388                     stack.enter_context(poutine.replay(trace=tr.trace))
        389                 with pyro.plate("particles", num_samples, dim=dim):
    --> 390                     return self.model(data, covariates)
        391 
        392 
    
    c:Users9033113venvlibsite-packagespyronnmodule.py in __call__(self, *args, **kwargs)
        424     def __call__(self, *args, **kwargs):
        425         with self._pyro_context:
    --> 426             return super().__call__(*args, **kwargs)
        427 
        428     def __getattr__(self, name):
    
    c:Users9033113venvlibsite-packagestorchnnmodulesmodule.py in _call_impl(self, *input, **kwargs)
       1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or
_global_backward_hooks
       1050                 or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1051             return forward_call(*input, **kwargs)
       1052         # Do not call functions when jit is used
       1053         full_backward_hooks, non_full_backward_hooks = [], []
    
    c:Users9033113venvlibsite-packagespyrocontribforecastforecaster.py in forward(self, data, covariates)
        183             self._forecast = None
        184 
    --> 185             self.model(zero_data, covariates)
        186 
        187             assert self._forecast is not None, ".predict() was not called by .model()"
    
    ~AppDataLocalTemp/ipykernel_16928/1541431941.py in model(self, zero_data, covariates)
         76 
         77         # The final statement registers our noise model and prediction.
    ---> 78         self.predict(noise_model, prediction)
    
    c:Users9033113venvlibsite-packagespyrocontribforecastforecaster.py in predict(self, noise_dist, prediction)
        155             # PrefixConditionMessenger is handled outside of the .model() call.
        156             self._prefix_condition_data["residual"] = data - left_pred
    --> 157             noise = pyro.sample("residual", noise_dist)
        158             del self._prefix_condition_data["residual"]
        159 
    
    c:Users9033113venvlibsite-packagespyroprimitives.py in sample(name, fn, *args, **kwargs)
        162         }
        163         # apply the stack and return its return value
    --> 164         apply_stack(msg)
        165         return msg["value"]
        166 
    
    c:Users9033113venvlibsite-packagespyropoutineruntime.py in apply_stack(initial_msg)
        215             break
        216 
    --> 217     default_process_message(msg)
        218 
        219     for frame in stack[-pointer:]:
    
    c:Users9033113venvlibsite-packagespyropoutineruntime.py in default_process_message(msg)
        176         return msg
        177 
    --> 178     msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
        179 
        180     # after fn has been called, update msg to prevent it from being called again.
    
    c:Users9033113venvlibsite-packagespyrodistributionstorch_distribution.py in __call__(self, sample_shape)
         46         """
         47         return (
    ---> 48             self.rsample(sample_shape)
         49             if self.has_rsample
         50             else self.sample(sample_shape)
    
    c:Users9033113venvlibsite-packagespyrodistributionshmm.py in rsample(self, sample_shape)
        582         )
        583         trans = trans.expand(trans.batch_shape[:-1]   (self.duration,))
    --> 584         z = _sequential_gaussian_filter_sample(self._init, trans, sample_shape)
        585         x = self._obs.left_condition(z).rsample()
        586         return x
    
    c:Users9033113venvlibsite-packagespyrodistributionshmm.py in _sequential_gaussian_filter_sample(init, trans, sample_shape)
        142         joint = (x   y).event_permute(perm)
        143         tape.append(joint)
    --> 144         contracted = joint.marginalize(left=state_dim)
        145         if time > even_time:
        146             contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1)
    
    c:Users9033113venvlibsite-packagespyroopsgaussian.py in marginalize(self, left, right)
        242         P_ba = self.precision[..., b, a]
        243         P_bb = self.precision[..., b, b]
    --> 244         P_b = cholesky(P_bb)
        245         P_a = triangular_solve(P_ba, P_b, upper=False)
        246         P_at = P_a.transpose(-1, -2)
    
    c:Users9033113venvlibsite-packagespyroopstensor_utils.py in cholesky(x)
        398     if x.size(-1) == 1:
        399         return x.sqrt()
    --> 400     return torch.linalg.cholesky(x)
        401 
        402 
    
    RuntimeError: torch.linalg.cholesky: For batch 4284: U(2,2

) is zero, singular U
 

.

Комментарии:

1. Я использую следующие версии. Python 3.9.1 pyro-api 0.1.2