Лучший способ сопоставления переменных с различными входными слоями

#python #tensorflow

Вопрос:

У меня есть 25 смешанных переменных: некоторые из них двоичные, некоторые непрерывные, и большинство из них являются факторами высокого уровня, которые необходимо внедрить.

Затем есть моя модель глубокого обучения, которая использует множество входных данных и строит вокруг нее автоэнкодер, который можно увидеть ниже.

Мой вопрос в том, как мне сопоставить переменные из моего фрейма данных pandas с соответствующими входными слоями? Например, каждый фактор высокого уровня для перехода к нужному слою встраивания. Первые мысли-расположить набор данных в правильном порядке входных данных перед выполнением какого-либо обучения или каким-либо образом сопоставить переменные (например, provID — > input_provid).

 autoencoder.summary()
Model: "claims_ae"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_provid (InputLayer)       [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_pos_code (InputLayer)     [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_prindiag (InputLayer)     [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_billtype2 (InputLayer)    [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_lob (InputLayer)          [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_ppg_code (InputLayer)     [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_segment (InputLayer)      [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_dofr (InputLayer)         [(None, 1)]          0                                            
__________________________________________________________________________________________________
embedding_8 (Embedding)         (None, 1, 70)        337820      input_provid[0][0]               
__________________________________________________________________________________________________
embedding_9 (Embedding)         (None, 1, 6)         168         input_pos_code[0][0]             
__________________________________________________________________________________________________
embedding_10 (Embedding)        (None, 1, 77)        447524      input_prindiag[0][0]             
__________________________________________________________________________________________________
embedding_11 (Embedding)        (None, 1, 5)         95          input_billtype2[0][0]            
__________________________________________________________________________________________________
embedding_12 (Embedding)        (None, 1, 3)         24          input_lob[0][0]                  
__________________________________________________________________________________________________
embedding_13 (Embedding)        (None, 1, 10)        930         input_ppg_code[0][0]             
__________________________________________________________________________________________________
embedding_14 (Embedding)        (None, 1, 2)         8           input_segment[0][0]              
__________________________________________________________________________________________________
embedding_15 (Embedding)        (None, 1, 3)         21          input_dofr[0][0]                 
__________________________________________________________________________________________________
input_number_features (InputLay [(None, 4)]          0                                            
__________________________________________________________________________________________________
input_binary_features (InputLay [(None, 12)]         0                                            
__________________________________________________________________________________________________
reshape_8 (Reshape)             (None, 70)           0           embedding_8[0][0]                
__________________________________________________________________________________________________
reshape_9 (Reshape)             (None, 6)            0           embedding_9[0][0]                
__________________________________________________________________________________________________
reshape_10 (Reshape)            (None, 77)           0           embedding_10[0][0]               
__________________________________________________________________________________________________
reshape_11 (Reshape)            (None, 5)            0           embedding_11[0][0]               
__________________________________________________________________________________________________
reshape_12 (Reshape)            (None, 3)            0           embedding_12[0][0]               
__________________________________________________________________________________________________
reshape_13 (Reshape)            (None, 10)           0           embedding_13[0][0]               
__________________________________________________________________________________________________
reshape_14 (Reshape)            (None, 2)            0           embedding_14[0][0]               
__________________________________________________________________________________________________
reshape_15 (Reshape)            (None, 3)            0           embedding_15[0][0]               
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 192)          0           input_number_features[0][0]      
                                                                 input_binary_features[0][0]      
                                                                 reshape_8[0][0]                  
                                                                 reshape_9[0][0]                  
                                                                 reshape_10[0][0]                 
                                                                 reshape_11[0][0]                 
                                                                 reshape_12[0][0]                 
                                                                 reshape_13[0][0]                 
                                                                 reshape_14[0][0]                 
                                                                 reshape_15[0][0]                 
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 16)           3088        concatenate_1[0][0]              
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, 8)            136         dense_8[0][0]                    
__________________________________________________________________________________________________
dense_10 (Dense)                (None, 4)            36          dense_9[0][0]                    
__________________________________________________________________________________________________
dense_11 (Dense)                (None, 2)            10          dense_10[0][0]                   
__________________________________________________________________________________________________
dense_12 (Dense)                (None, 4)            12          dense_11[0][0]                   
__________________________________________________________________________________________________
dense_13 (Dense)                (None, 8)            40          dense_12[0][0]                   
__________________________________________________________________________________________________
dense_14 (Dense)                (None, 16)           144         dense_13[0][0]                   
__________________________________________________________________________________________________
dense_15 (Dense)                (None, 192)          3264        dense_14[0][0]                   
==================================================================================================
Total params: 793,320
Trainable params: 793,320
Non-trainable params: 0
 

Ответ №1:

Из документации tf.keras

x: Входные данные. Это может быть:

  • Массив Numpy (или подобный массиву) или список массивов (в случае, если модель имеет несколько входов).
  • Тензор тензорного потока или список тензоров (в случае, если модель имеет несколько входов).
  • Диктант, сопоставляющий имена входных данных с соответствующими массивами/тензорами, если модель имеет именованные входные данные.
  • tf.data Набор данных. Должен возвращать кортеж либо (inputs, targets) или (inputs, targets, sample_weights) .
  • Генератор или keras.utils.Sequence возвращающийся (inputs, targets) или (inputs, targets, sample_weights) .

Таким образом, вы можете передать attr name Input Слоям, а затем предоставить входные данные с dict keys именами, которые являются входными данными.