Создание функции Python для перебора списка/фрейма данных (VIF)

#python #pandas #machine-learning #statistics #statsmodels

Вопрос:

У меня есть набор данных, и я хочу выбрать подмножество переменных с VIF(Коэффициент инфляции дисперсии) меньше определенного порогового значения. Моя идея состояла в том, чтобы вычислить VIF для каждой переменной, затем извлечь переменную с наибольшим значением (если оно превышает определенный порог), пересчитать VIF для каждой оставшейся переменной и повторять процесс до тех пор, пока VIF не превысит пороговое значение.

В этом подходе нет никакой новой идеи, но я не мог пройти мимо определенного момента, чтобы создать функцию для автоматизации этого процесса в Python.

x — это набор данных с удаленной целевой переменной

 import pandas as pd
import numpy as np
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.tools.tools import add_constant

x_vif = add_constant(x)

vif = pd.DataFrame([variance_inflation_factor(x_vif.values, i) for i in range(x_vif.shape[1])], index=x_vif.columns)
 

vif также может быть списком. Итак, есть ли какой-либо пакет, который делает это автоматически, или не могли бы вы дать мне представление о том, как создать эту функцию ?

Я нашел библиотеку R (thinXwithVIF), которая могла бы сделать это автоматически, но я не смог заставить rpy2 работать с версией python, которую мне нужно использовать.

Ответ №1:

Возможно, имело бы смысл удалить переменную с наибольшим значением vif в каждом раунде, подмножество фрейма данных и остановиться, когда все переменные будут ниже вашего порогового значения. Я не думаю, что vif был бы всем и всем, и вам действительно нужно посмотреть на данные, чтобы решить, что включать и т. Д.

 import statsmodels.api as sm
import pandas as pd
from statsmodels.stats.outliers_influence import variance_inflation_factor

data = sm.datasets.get_rdataset('mtcars')

x_vif = data.data.iloc[:,1:]
y = data.data['mpg']

thres = 10

while True:
    Cols = range(x_vif.shape[1])
    
    vif = np.array([variance_inflation_factor(x_vif.values, i) for i in Cols])
    if all(vif < thres):
        break
    else:
        Cols = np.delete(Cols,np.argmax(vif))
        x_vif = x_vif.iloc[:,Cols]