#r #tidymodels #c5.0
#r #tidymodels #c5.0
Вопрос:
Я мог и должен был сделать более простой reprex, но это действительно прямо из моей работы. Как мне «увидеть» правила, сгенерированные моделью, после обучения модели C5.0 во фреймворке Tidymodels? Я попытался воспроизвести то, что показано здесь
https://www.tidyverse.org/blog/2020/05/rules-0-0-1/
но я не зашел очень далеко (но я уверен, что решение должно быть однострочным).
Большое спасибо!
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────── tidymodels 0.1.2 ──
#> ✔ broom 0.7.2 ✔ recipes 0.1.15
#> ✔ dials 0.0.9 ✔ rsample 0.0.8
#> ✔ dplyr 1.0.2 ✔ tibble 3.0.4
#> ✔ ggplot2 3.3.2 ✔ tidyr 1.1.2
#> ✔ infer 0.5.3 ✔ tune 0.1.2.9000
#> ✔ modeldata 0.1.0 ✔ workflows 0.2.1
#> ✔ parsnip 0.1.4.9000 ✔ yardstick 0.0.7
#> ✔ purrr 0.3.4
#> ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ purrr::discard() masks scales::discard()
#> ✖ dplyr::filter() masks stats::filter()
#> ✖ dplyr::lag() masks stats::lag()
#> ✖ recipes::step() masks stats::step()
library(rules)
#>
#> Attaching package: 'rules'
#> The following object is masked from 'package:dials':
#>
#> max_rules
df_ini <- structure(list(year = c(2002, 2004, 2005, 2006, 2007, 2008, 2009,
2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019),
berd = c(3130.884, 3556.479, 4207.669, 4448.676, 4845.861,
5232.63, 5092.902, 5520.422, 5692.841, 6540.457, 6778.42,
7324.679, 7498.488, 7824.51, 7888.444, 8461.72, 8865.96),
gbaord = c(1537.89, 1537.89, 1619.74, 1697.55, 1770.144,
1986.775, 2149.787, 2269.986, 2428.143, 2452.955, 2587.586,
2647.489, 2744.844, 2875.706, 2889.779, 2913.369, 3081.087
), employment_be = c(2775.22, 2731.08, 2709.59, 2708.39,
2774.06, 2795.6, 2703.36, 2668.1, 2705.1, 2731.67, 2727.16,
2725.66, 2735.69, 2750.52, 2782.9, 2852.33, 2890.9), employment_c = c(2562.53,
2518.57, 2496.98, 2499.54, 2558.88, 2578, 2483.97, 2447.65,
2483.1, 2507.41, 2500.94, 2499.6, 2511.75, 2523.97, 2555.48,
2622.5, 2656.89), employment_j = c(400.75, 387.53, 384.64,
389.29, 385.77, 392.86, 383.91, 392.18, 410.85, 419.75, 427.59,
438.96, 440.33, 460.84, 473.4, 494.4, 513.62), employment_k = c(502.42,
504.63, 515.39, 523.45, 536.6, 550.14, 546.68, 539.96, 536.58,
534.98, 524.13, 518.89, 511.57, 505.32, 496.41, 495.4, 495.98
), employment_mn = c(1248.01, 1365.29, 1425.81, 1537.88,
1622.95, 1727.76, 1704.65, 1762.55, 1838.16, 1896.09, 1929.09,
1950.02, 1968.83, 2021.51, 2109.71, 2189.27, 2225), employment_oq = c(3241.36,
3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 3759.23,
3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 4171.72,
4238.87, 4284.27), employment_total = c(15113.52, 15307.28,
15491.61, 15762.92, 16050.92, 16356.53, 16269.97, 16392.87,
16647.79, 16820.66, 16879.06, 17039.6, 17142.13, 17365.32,
17650.21, 17951.61, 18156.52), value_be = c(47967.1, 50737.6,
52955.2, 56872.4, 60864.9, 61029, 56837.8, 58433.6, 61443,
63655.1, 64132.3, 65542.6, 67495.4, 71152.6, 72698.8, 75792.3,
77284), value_c = c(40192.9, 42014.6, 44229, 47735.5, 51552.4,
51165.9, 47129.7, 48759.3, 51467.7, 53234.6, 53431.4, 55169,
57458.7, 60962.8, 62196, 65063.5, 66063.6), value_j = c(7737.1,
7756.1, 8134.2, 8378.8, 8532.3, 8740, 8493.9, 8518.9, 9217.1,
9405.1, 9802.1, 10361.4, 10695.4, 11455.3, 11720.6, 12871,
13540.3), value_k = c(10225.2, 10541.9, 11005.3, 11912.3,
13102.7, 13205.2, 12123.9, 12113.2, 12952.8, 12254.9, 12796.6,
12962.4, 13482.9, 13236.4, 13744.1, 14152.6, 14739.1), value_mn = c(15074,
16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2,
24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6,
33781.9, 35152.9), value_oq = c(35065.6, 37329.6, 38288.8,
40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 49381.5,
50261.7, 51624.3, 53715, 55926.4, 57637.1, 59648.9, 61680.1
), value_total = c(202353.5, 216098.3, 225888.1, 239076,
253604.6, 262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3,
297230.1, 307037.7, 318952.7, 329396.1, 344338.6, 355359.1
), gdp_b1gq = c(226735.3, 242348.3, 254075, 267824.4, 283978,
293761.9, 288044.1, 295896.6, 310128.6, 318653.1, 323910.2,
333146.1, 344269.3, 357608, 369341.3, 385361.9, 397575.4),
gdp_p3 = c(164107.8, 176316.4, 185871.1, 194102, 200944.4,
208857.1, 213630.1, 218947.2, 227250.8, 233638.1, 238329.3,
243860.6, 249404.3, 257166.5, 265900.2, 274583.7, 282863.3
), gdp_p61 = c(74691.6, 83074.9, 90010.4, 100076.8, 110157.2,
113368.1, 91435.3, 111997.3, 123526.3, 125801.2, 123657.1,
126109.3, 129183.6, 131524, 140057.8, 150278.2, 152545.2),
gdp_p62 = c(28063.4, 30507, 33520.2, 36089.5, 39104, 43056.8,
38781.9, 39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8,
55885.5, 59584.7, 64333.5, 68409.7), turnover_manu_dom = c(80,
87, 93.2, 99.9, 104.9, 113.6, 97, 100, 110, 112.1, 110.7,
107.1, 104.7, 102.9, 107.9, 107.9, 107.9), turnover_manu_non_dom = c(70.9,
81.9, 86.2, 95.3, 102.8, 106.5, 86.8, 100, 112.8, 111.7,
112.8, 114.9, 118.2, 120.1, 129.2, 129.2, 129.2), turnover_manu_tot = c(74.7,
84, 89.1, 97.2, 103.7, 109.4, 91.1, 100, 111.6, 111.9, 111.9,
111.7, 112.6, 112.9, 120.3, 120.3, 120.3), price_index = c(1.7,
2, 2.1, 1.7, 2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 0.8,
1, 2.2, 2.1, 1.5), capital_n1132g = c(3638.4, 3633.3, 3616.2,
3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 4005.6, 3718.6,
3467.9, 4214.2, 4237.4, 4450.2, 4598.6, 4721), capital_n117g = c(8369.6,
8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 11554.3, 11849.9,
13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 17647.1, 18273.8,
19642.4, 20713.1), capital_n11mg = c(18749.6, 19433.5, 20051.6,
20569.8, 22646.1, 23674.5, 21200.6, 20919.6, 23157.7, 23520.7,
24057.7, 23832.8, 25019.2, 27608.2, 29790.1, 30998, 32856.8
), lagged_gbaord = c(1537.89, 1537.89, 1537.89, 1619.74,
1697.55, 1770.144, 1986.775, 2149.787, 2269.986, 2428.143,
2452.955, 2587.586, 2647.489, 2744.844, 2875.706, 2889.779,
2913.369), lagged_employment_be = c(2775.22, 2775.22, 2731.08,
2709.59, 2708.39, 2774.06, 2795.6, 2703.36, 2668.1, 2705.1,
2731.67, 2727.16, 2725.66, 2735.69, 2750.52, 2782.9, 2852.33
), lagged_employment_c = c(2562.53, 2562.53, 2518.57, 2496.98,
2499.54, 2558.88, 2578, 2483.97, 2447.65, 2483.1, 2507.41,
2500.94, 2499.6, 2511.75, 2523.97, 2555.48, 2622.5), lagged_employment_j = c(400.75,
400.75, 387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 392.18,
410.85, 419.75, 427.59, 438.96, 440.33, 460.84, 473.4, 494.4
), lagged_employment_k = c(502.42, 502.42, 504.63, 515.39,
523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 534.98, 524.13,
518.89, 511.57, 505.32, 496.41, 495.4), lagged_employment_mn = c(1248.01,
1248.01, 1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 1704.65,
1762.55, 1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 2021.51,
2109.71, 2189.27), lagged_employment_oq = c(3241.36, 3241.36,
3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 3759.23,
3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 4171.72,
4238.87), lagged_employment_total = c(15113.52, 15113.52,
15307.28, 15491.61, 15762.92, 16050.92, 16356.53, 16269.97,
16392.87, 16647.79, 16820.66, 16879.06, 17039.6, 17142.13,
17365.32, 17650.21, 17951.61), lagged_value_be = c(47967.1,
47967.1, 50737.6, 52955.2, 56872.4, 60864.9, 61029, 56837.8,
58433.6, 61443, 63655.1, 64132.3, 65542.6, 67495.4, 71152.6,
72698.8, 75792.3), lagged_value_c = c(40192.9, 40192.9, 42014.6,
44229, 47735.5, 51552.4, 51165.9, 47129.7, 48759.3, 51467.7,
53234.6, 53431.4, 55169, 57458.7, 60962.8, 62196, 65063.5
), lagged_value_j = c(7737.1, 7737.1, 7756.1, 8134.2, 8378.8,
8532.3, 8740, 8493.9, 8518.9, 9217.1, 9405.1, 9802.1, 10361.4,
10695.4, 11455.3, 11720.6, 12871), lagged_value_k = c(10225.2,
10225.2, 10541.9, 11005.3, 11912.3, 13102.7, 13205.2, 12123.9,
12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 13482.9, 13236.4,
13744.1, 14152.6), lagged_value_mn = c(15074, 15074, 16569.1,
18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2, 24895.4,
25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6, 33781.9
), lagged_value_oq = c(35065.6, 35065.6, 37329.6, 38288.8,
40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 49381.5,
50261.7, 51624.3, 53715, 55926.4, 57637.1, 59648.9), lagged_value_total = c(202353.5,
202353.5, 216098.3, 225888.1, 239076, 253604.6, 262414.7,
256671, 263633.5, 276404, 283548.2, 288624.3, 297230.1, 307037.7,
318952.7, 329396.1, 344338.6), lagged_gdp_b1gq = c(226735.3,
226735.3, 242348.3, 254075, 267824.4, 283978, 293761.9, 288044.1,
295896.6, 310128.6, 318653.1, 323910.2, 333146.1, 344269.3,
357608, 369341.3, 385361.9), lagged_gdp_p3 = c(164107.8,
164107.8, 176316.4, 185871.1, 194102, 200944.4, 208857.1,
213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 243860.6,
249404.3, 257166.5, 265900.2, 274583.7), lagged_gdp_p61 = c(74691.6,
74691.6, 83074.9, 90010.4, 100076.8, 110157.2, 113368.1,
91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 126109.3,
129183.6, 131524, 140057.8, 150278.2), lagged_gdp_p62 = c(28063.4,
28063.4, 30507, 33520.2, 36089.5, 39104, 43056.8, 38781.9,
39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 55885.5,
59584.7, 64333.5), lagged_turnover_manu_dom = c(80, 80, 87,
93.2, 99.9, 104.9, 113.6, 97, 100, 110, 112.1, 110.7, 107.1,
104.7, 102.9, 107.9, 107.9), lagged_turnover_manu_non_dom = c(70.9,
70.9, 81.9, 86.2, 95.3, 102.8, 106.5, 86.8, 100, 112.8, 111.7,
112.8, 114.9, 118.2, 120.1, 129.2, 129.2), lagged_turnover_manu_tot = c(74.7,
74.7, 84, 89.1, 97.2, 103.7, 109.4, 91.1, 100, 111.6, 111.9,
111.9, 111.7, 112.6, 112.9, 120.3, 120.3), lagged_price_index = c(1.7,
1.7, 2, 2.1, 1.7, 2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5,
0.8, 1, 2.2, 2.1), lagged_capital_n1132g = c(3638.4, 3638.4,
3633.3, 3616.2, 3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5,
4005.6, 3718.6, 3467.9, 4214.2, 4237.4, 4450.2, 4598.6),
lagged_capital_n117g = c(8369.6, 8369.6, 8679.9, 8938.9,
9823.8, 10467.1, 11047.1, 11554.3, 11849.9, 13465.4, 13927.5,
15510.2, 15754.4, 16584.7, 17647.1, 18273.8, 19642.4), lagged_capital_n11mg = c(18749.6,
18749.6, 19433.5, 20051.6, 20569.8, 22646.1, 23674.5, 21200.6,
20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 25019.2, 27608.2,
29790.1, 30998), country = c("AT", "AT", "AT", "AT", "AT",
"AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT",
"AT", "AT")), row.names = c(NA, -17L), class = c("tbl_df",
"tbl", "data.frame"))
set.seed(1234)
nn <- nrow(df_ini)
time_back <- 1
indices <-
list(analysis = 1:(nn-time_back) ,
assessment = (nn-time_back 1):nn
)
df_split <- make_splits(indices, df_ini)
df_train <- training(df_split)
df_test <- testing(df_split)
folded_data <- vfold_cv(df_train,3)
cubist_recipe <-
recipe(formula = berd ~ ., data = df_train) %>%
## step_string2factor(one_of("country")) %>%
update_role(year, new_role = "ID") %>%
step_zv(all_predictors())
cubist_spec <-
cubist_rules(committees = tune(), neighbors = tune()) %>%
set_engine("Cubist")
cubist_workflow <-
workflow() %>%
add_recipe(cubist_recipe) %>%
add_model(cubist_spec)
cubist_grid <- tidyr::crossing(committees = c(1:9, (1:5) * 10),
neighbors = c(0, 3, 6, 9))
cubist_tune <-
tune_grid(cubist_workflow, resamples = folded_data, grid = cubist_grid)
#>
#> Attaching package: 'rlang'
#> The following objects are masked from 'package:purrr':
#>
#> %@%, as_function, flatten, flatten_chr, flatten_dbl, flatten_int,
#> flatten_lgl, flatten_raw, invoke, list_along, modify, prepend,
#> splice
#>
#> Attaching package: 'vctrs'
#> The following object is masked from 'package:tibble':
#>
#> data_frame
#> The following object is masked from 'package:dplyr':
#>
#> data_frame
#> Loading required package: lattice
best_cub <- select_best(cubist_tune, "rmse")
final_cub <- finalize_workflow(
cubist_workflow,
best_cub
)
final_cub
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: cubist_rules()
#>
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 1 Recipe Step
#>
#> ● step_zv()
#>
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Cubist Model Specification (regression)
#>
#> Main Arguments:
#> committees = 1
#> neighbors = 3
#>
#> Computational engine: Cubist
fit_model <- final_cub %>%
fit(df_train)
fit_model
#> ══ Workflow [trained] ══════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: cubist_rules()
#>
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 1 Recipe Step
#>
#> ● step_zv()
#>
#> ── Model ───────────────────────────────────────────────────────────────────────
#>
#> Call:
#> cubist.default(x = x, y = y, committees = 1)
#>
#> Number of samples: 16
#> Number of predictors: 52
#>
#> Number of committees: 1
#> Number of rules: 1
### at this point how to see the rules in the model trained on the data ???
Создано 2020-12-10 пакетом reprex (версия 0.3.0)
Ответ №1:
По общему признанию, текущее решение, предлагаемое tidymodels для извлечения правил, не совсем идеально. Я считаю, что в настоящее время лучший способ получить правила в модели — это извлечь базовый объект fit, который находится на нескольких уровнях глубоко внутри рабочего процесса, а затем вызвать summary()
его. Вы хотите сделать: summary(fit_model$fit$fit$fit)
.
library(tidymodels)
library(rules)
#>
#> Attaching package: 'rules'
#> The following object is masked from 'package:dials':
#>
#> max_rules
df_ini <- structure(list(year = c(2002, 2004, 2005, 2006, 2007, 2008, 2009,
2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019),
berd = c(3130.884, 3556.479, 4207.669, 4448.676, 4845.861,
5232.63, 5092.902, 5520.422, 5692.841, 6540.457, 6778.42,
7324.679, 7498.488, 7824.51, 7888.444, 8461.72, 8865.96),
gbaord = c(1537.89, 1537.89, 1619.74, 1697.55, 1770.144,
1986.775, 2149.787, 2269.986, 2428.143, 2452.955, 2587.586,
2647.489, 2744.844, 2875.706, 2889.779, 2913.369, 3081.087
), employment_be = c(2775.22, 2731.08, 2709.59, 2708.39,
2774.06, 2795.6, 2703.36, 2668.1, 2705.1, 2731.67, 2727.16,
2725.66, 2735.69, 2750.52, 2782.9, 2852.33, 2890.9),
employment_c = c(2562.53,
2518.57, 2496.98, 2499.54, 2558.88, 2578, 2483.97, 2447.65,
2483.1, 2507.41, 2500.94, 2499.6, 2511.75, 2523.97, 2555.48,
2622.5, 2656.89),
employment_j = c(400.75, 387.53, 384.64,
389.29, 385.77, 392.86, 383.91, 392.18, 410.85, 419.75, 427.59,
438.96, 440.33, 460.84, 473.4, 494.4, 513.62),
employment_k = c(502.42,
504.63, 515.39, 523.45, 536.6, 550.14, 546.68, 539.96, 536.58,
534.98, 524.13, 518.89, 511.57, 505.32, 496.41, 495.4, 495.98
),
employment_mn = c(1248.01, 1365.29, 1425.81, 1537.88,
1622.95, 1727.76, 1704.65, 1762.55, 1838.16, 1896.09, 1929.09,
1950.02, 1968.83, 2021.51, 2109.71, 2189.27, 2225),
employment_oq = c(3241.36,
3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 3759.23,
3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 4171.72,
4238.87, 4284.27),
employment_total = c(15113.52, 15307.28,
15491.61, 15762.92, 16050.92, 16356.53, 16269.97, 16392.87,
16647.79, 16820.66, 16879.06, 17039.6, 17142.13, 17365.32,
17650.21, 17951.61, 18156.52),
value_be = c(47967.1, 50737.6,
52955.2, 56872.4, 60864.9, 61029, 56837.8, 58433.6, 61443,
63655.1, 64132.3, 65542.6, 67495.4, 71152.6, 72698.8, 75792.3,
77284),
value_c = c(40192.9, 42014.6, 44229, 47735.5, 51552.4,
51165.9, 47129.7, 48759.3, 51467.7, 53234.6, 53431.4, 55169,
57458.7, 60962.8, 62196, 65063.5, 66063.6),
value_j = c(7737.1,
7756.1, 8134.2, 8378.8, 8532.3, 8740, 8493.9, 8518.9, 9217.1,
9405.1, 9802.1, 10361.4, 10695.4, 11455.3, 11720.6, 12871,
13540.3),
value_k = c(10225.2, 10541.9, 11005.3, 11912.3,
13102.7, 13205.2, 12123.9, 12113.2, 12952.8, 12254.9, 12796.6,
12962.4, 13482.9, 13236.4, 13744.1, 14152.6, 14739.1),
value_mn = c(15074,
16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2,
24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6,
33781.9, 35152.9),
value_oq = c(35065.6, 37329.6, 38288.8,
40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 49381.5,
50261.7, 51624.3, 53715, 55926.4, 57637.1, 59648.9, 61680.1
),
value_total = c(202353.5, 216098.3, 225888.1, 239076,
253604.6, 262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3,
297230.1, 307037.7, 318952.7, 329396.1, 344338.6, 355359.1
),
gdp_b1gq = c(226735.3, 242348.3, 254075, 267824.4, 283978,
293761.9, 288044.1, 295896.6, 310128.6, 318653.1, 323910.2,
333146.1, 344269.3, 357608, 369341.3, 385361.9, 397575.4),
gdp_p3 = c(164107.8, 176316.4, 185871.1, 194102, 200944.4,
208857.1, 213630.1, 218947.2, 227250.8, 233638.1, 238329.3,
243860.6, 249404.3, 257166.5, 265900.2, 274583.7, 282863.3
), gdp_p61 = c(74691.6, 83074.9, 90010.4, 100076.8, 110157.2,
113368.1, 91435.3, 111997.3, 123526.3, 125801.2, 123657.1,
126109.3, 129183.6, 131524, 140057.8, 150278.2, 152545.2),
gdp_p62 = c(28063.4, 30507, 33520.2, 36089.5, 39104, 43056.8,
38781.9, 39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8,
55885.5, 59584.7, 64333.5, 68409.7),
turnover_manu_dom = c(80,
87, 93.2, 99.9, 104.9, 113.6, 97, 100, 110, 112.1, 110.7,
107.1, 104.7, 102.9, 107.9, 107.9, 107.9),
turnover_manu_non_dom = c(70.9,
81.9, 86.2, 95.3, 102.8, 106.5, 86.8, 100, 112.8, 111.7,
112.8, 114.9, 118.2, 120.1, 129.2, 129.2, 129.2),
turnover_manu_tot = c(74.7,
84, 89.1, 97.2, 103.7, 109.4, 91.1, 100, 111.6, 111.9, 111.9,
111.7, 112.6, 112.9, 120.3, 120.3, 120.3),
price_index = c(1.7,
2, 2.1, 1.7, 2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 0.8,
1, 2.2, 2.1, 1.5),
capital_n1132g = c(3638.4, 3633.3, 3616.2,
3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 4005.6, 3718.6,
3467.9, 4214.2, 4237.4, 4450.2, 4598.6, 4721),
capital_n117g = c(8369.6,
8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 11554.3, 11849.9,
13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 17647.1, 18273.8,
19642.4, 20713.1), capital_n11mg = c(18749.6, 19433.5, 20051.6,
20569.8, 22646.1, 23674.5, 21200.6, 20919.6, 23157.7, 23520.7,
24057.7, 23832.8, 25019.2, 27608.2, 29790.1, 30998, 32856.8
),
lagged_gbaord = c(1537.89, 1537.89, 1537.89, 1619.74,
1697.55, 1770.144, 1986.775, 2149.787, 2269.986, 2428.143,
2452.955, 2587.586, 2647.489, 2744.844, 2875.706, 2889.779,
2913.369),
lagged_employment_be = c(2775.22, 2775.22, 2731.08,
2709.59, 2708.39, 2774.06, 2795.6, 2703.36, 2668.1, 2705.1,
2731.67, 2727.16, 2725.66, 2735.69, 2750.52, 2782.9, 2852.33
), lagged_employment_c = c(2562.53, 2562.53, 2518.57, 2496.98,
2499.54, 2558.88, 2578, 2483.97, 2447.65, 2483.1, 2507.41,
2500.94, 2499.6, 2511.75, 2523.97, 2555.48, 2622.5),
lagged_employment_j = c(400.75,
400.75, 387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 392.18,
410.85, 419.75, 427.59, 438.96, 440.33, 460.84, 473.4, 494.4
),
lagged_employment_k = c(502.42, 502.42, 504.63, 515.39,
523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 534.98, 524.13,
518.89, 511.57, 505.32, 496.41, 495.4),
lagged_employment_mn = c(1248.01,
1248.01, 1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 1704.65,
1762.55, 1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 2021.51,
2109.71, 2189.27),
lagged_employment_oq = c(3241.36, 3241.36,
3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 3759.23,
3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 4171.72,
4238.87),
lagged_employment_total = c(15113.52, 15113.52,
15307.28, 15491.61, 15762.92, 16050.92, 16356.53, 16269.97,
16392.87, 16647.79, 16820.66, 16879.06, 17039.6, 17142.13,
17365.32, 17650.21, 17951.61),
lagged_value_be = c(47967.1,
47967.1, 50737.6, 52955.2, 56872.4, 60864.9, 61029, 56837.8,
58433.6, 61443, 63655.1, 64132.3, 65542.6, 67495.4, 71152.6,
72698.8, 75792.3),
lagged_value_c = c(40192.9, 40192.9, 42014.6,
44229, 47735.5, 51552.4, 51165.9, 47129.7, 48759.3, 51467.7,
53234.6, 53431.4, 55169, 57458.7, 60962.8, 62196, 65063.5
),
lagged_value_j = c(7737.1, 7737.1, 7756.1, 8134.2, 8378.8,
8532.3, 8740, 8493.9, 8518.9, 9217.1, 9405.1, 9802.1, 10361.4,
10695.4, 11455.3, 11720.6, 12871),
lagged_value_k = c(10225.2,
10225.2, 10541.9, 11005.3, 11912.3, 13102.7, 13205.2, 12123.9,
12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 13482.9, 13236.4,
13744.1, 14152.6),
lagged_value_mn = c(15074, 15074, 16569.1,
18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2, 24895.4,
25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6, 33781.9
),
lagged_value_oq = c(35065.6, 35065.6, 37329.6, 38288.8,
40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 49381.5,
50261.7, 51624.3, 53715, 55926.4, 57637.1, 59648.9),
lagged_value_total = c(202353.5,
202353.5, 216098.3, 225888.1, 239076, 253604.6, 262414.7,
256671, 263633.5, 276404, 283548.2, 288624.3, 297230.1, 307037.7,
318952.7, 329396.1, 344338.6),
lagged_gdp_b1gq = c(226735.3,
226735.3, 242348.3, 254075, 267824.4, 283978, 293761.9, 288044.1,
295896.6, 310128.6, 318653.1, 323910.2, 333146.1, 344269.3,
357608, 369341.3, 385361.9),
lagged_gdp_p3 = c(164107.8,
164107.8, 176316.4, 185871.1, 194102, 200944.4, 208857.1,
213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 243860.6,
249404.3, 257166.5, 265900.2, 274583.7),
lagged_gdp_p61 = c(74691.6,
74691.6, 83074.9, 90010.4, 100076.8, 110157.2, 113368.1,
91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 126109.3,
129183.6, 131524, 140057.8, 150278.2),
lagged_gdp_p62 = c(28063.4,
28063.4, 30507, 33520.2, 36089.5, 39104, 43056.8, 38781.9,
39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 55885.5,
59584.7, 64333.5),
lagged_turnover_manu_dom = c(80, 80, 87,
93.2, 99.9, 104.9, 113.6, 97, 100, 110, 112.1, 110.7, 107.1,
104.7, 102.9, 107.9, 107.9),
lagged_turnover_manu_non_dom = c(70.9,
70.9, 81.9, 86.2, 95.3, 102.8, 106.5, 86.8, 100, 112.8, 111.7,
112.8, 114.9, 118.2, 120.1, 129.2, 129.2),
lagged_turnover_manu_tot = c(74.7,
74.7, 84, 89.1, 97.2, 103.7, 109.4, 91.1, 100, 111.6, 111.9,
111.9, 111.7, 112.6, 112.9, 120.3, 120.3),
lagged_price_index = c(1.7,
1.7, 2, 2.1, 1.7, 2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5,
0.8, 1, 2.2, 2.1), lagged_capital_n1132g = c(3638.4, 3638.4,
3633.3, 3616.2, 3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5,
4005.6, 3718.6, 3467.9, 4214.2, 4237.4, 4450.2, 4598.6),
lagged_capital_n117g = c(8369.6, 8369.6, 8679.9, 8938.9,
9823.8, 10467.1, 11047.1, 11554.3, 11849.9, 13465.4, 13927.5,
15510.2, 15754.4, 16584.7, 17647.1, 18273.8, 19642.4),
lagged_capital_n11mg = c(18749.6,
18749.6, 19433.5, 20051.6, 20569.8, 22646.1, 23674.5, 21200.6,
20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 25019.2, 27608.2,
29790.1, 30998), country = c("AT", "AT", "AT", "AT", "AT",
"AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT",
"AT", "AT")),
row.names = c(NA, -17L), class = c("tbl_df",
"tbl", "data.frame"))
set.seed(1234)
nn <- nrow(df_ini)
time_back <- 1
indices <-
list(analysis = 1:(nn-time_back) ,
assessment = (nn-time_back 1):nn
)
df_split <- make_splits(indices, df_ini)
df_train <- training(df_split)
df_test <- testing(df_split)
folded_data <- vfold_cv(df_train,3)
cubist_recipe <-
recipe(formula = berd ~ ., data = df_train) %>%
## step_string2factor(one_of("country")) %>%
update_role(year, new_role = "ID") %>%
step_zv(all_predictors())
cubist_spec <-
cubist_rules(committees = tune(), neighbors = tune()) %>%
set_engine("Cubist")
cubist_workflow <-
workflow() %>%
add_recipe(cubist_recipe) %>%
add_model(cubist_spec)
cubist_grid <- tidyr::crossing(committees = c(1:9, (1:5) * 10),
neighbors = c(0, 3, 6, 9))
cubist_tune <-
tune_grid(cubist_workflow, resamples = folded_data, grid = cubist_grid)
#>
#> Attaching package: 'rlang'
#> The following objects are masked from 'package:purrr':
#>
#> %@%, as_function, flatten, flatten_chr, flatten_dbl, flatten_int,
#> flatten_lgl, flatten_raw, invoke, list_along, modify, prepend,
#> splice
#>
#> Attaching package: 'vctrs'
#> The following object is masked from 'package:tibble':
#>
#> data_frame
#> The following object is masked from 'package:dplyr':
#>
#> data_frame
#> Loading required package: lattice
best_cub <- select_best(cubist_tune, "rmse")
final_cub <- finalize_workflow(
cubist_workflow,
best_cub
)
fit_model <- final_cub %>%
fit(df_train)
summary(fit_model$fit$fit$fit)
#>
#> Call:
#> cubist.default(x = x, y = y, committees = 1)
#>
#>
#> Cubist [Release 2.07 GPL Edition] Thu Dec 10 16:52:59 2020
#> ---------------------------------
#>
#> Target attribute `outcome'
#>
#> Read 16 cases (53 attributes) from undefined.data
#>
#> Model:
#>
#> Rule 1: [16 cases, mean 5877.817, range 3130.884 to 8461.72, est err 251.023]
#>
#> outcome = -5043.087 0.0357 gdp_b1gq
#>
#>
#> Evaluation on training data (16 cases):
#>
#> Average |error| 196.045
#> Relative |error| 0.14
#> Correlation coefficient 0.99
#>
#>
#> Attribute usage:
#> Conds Model
#>
#> 100% gdp_b1gq
#>
#>
#> Time: 0.0 secs
Создано 2020-12-10 пакетом reprex (версия 0.3.0 9001)
Если вы хотите получить коэффициенты для их обработки, проверьте, какие результаты вы получаете as_tibble(fit_model$fit$fit$fit$coefficients)
.