#python #tensorflow
Вопрос:
У меня есть 25 смешанных переменных: некоторые из них двоичные, некоторые непрерывные, и большинство из них являются факторами высокого уровня, которые необходимо внедрить.
Затем есть моя модель глубокого обучения, которая использует множество входных данных и строит вокруг нее автоэнкодер, который можно увидеть ниже.
Мой вопрос в том, как мне сопоставить переменные из моего фрейма данных pandas с соответствующими входными слоями? Например, каждый фактор высокого уровня для перехода к нужному слою встраивания. Первые мысли-расположить набор данных в правильном порядке входных данных перед выполнением какого-либо обучения или каким-либо образом сопоставить переменные (например, provID — > input_provid).
autoencoder.summary()
Model: "claims_ae"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_provid (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
input_pos_code (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
input_prindiag (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
input_billtype2 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
input_lob (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
input_ppg_code (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
input_segment (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
input_dofr (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
embedding_8 (Embedding) (None, 1, 70) 337820 input_provid[0][0]
__________________________________________________________________________________________________
embedding_9 (Embedding) (None, 1, 6) 168 input_pos_code[0][0]
__________________________________________________________________________________________________
embedding_10 (Embedding) (None, 1, 77) 447524 input_prindiag[0][0]
__________________________________________________________________________________________________
embedding_11 (Embedding) (None, 1, 5) 95 input_billtype2[0][0]
__________________________________________________________________________________________________
embedding_12 (Embedding) (None, 1, 3) 24 input_lob[0][0]
__________________________________________________________________________________________________
embedding_13 (Embedding) (None, 1, 10) 930 input_ppg_code[0][0]
__________________________________________________________________________________________________
embedding_14 (Embedding) (None, 1, 2) 8 input_segment[0][0]
__________________________________________________________________________________________________
embedding_15 (Embedding) (None, 1, 3) 21 input_dofr[0][0]
__________________________________________________________________________________________________
input_number_features (InputLay [(None, 4)] 0
__________________________________________________________________________________________________
input_binary_features (InputLay [(None, 12)] 0
__________________________________________________________________________________________________
reshape_8 (Reshape) (None, 70) 0 embedding_8[0][0]
__________________________________________________________________________________________________
reshape_9 (Reshape) (None, 6) 0 embedding_9[0][0]
__________________________________________________________________________________________________
reshape_10 (Reshape) (None, 77) 0 embedding_10[0][0]
__________________________________________________________________________________________________
reshape_11 (Reshape) (None, 5) 0 embedding_11[0][0]
__________________________________________________________________________________________________
reshape_12 (Reshape) (None, 3) 0 embedding_12[0][0]
__________________________________________________________________________________________________
reshape_13 (Reshape) (None, 10) 0 embedding_13[0][0]
__________________________________________________________________________________________________
reshape_14 (Reshape) (None, 2) 0 embedding_14[0][0]
__________________________________________________________________________________________________
reshape_15 (Reshape) (None, 3) 0 embedding_15[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 192) 0 input_number_features[0][0]
input_binary_features[0][0]
reshape_8[0][0]
reshape_9[0][0]
reshape_10[0][0]
reshape_11[0][0]
reshape_12[0][0]
reshape_13[0][0]
reshape_14[0][0]
reshape_15[0][0]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 16) 3088 concatenate_1[0][0]
__________________________________________________________________________________________________
dense_9 (Dense) (None, 8) 136 dense_8[0][0]
__________________________________________________________________________________________________
dense_10 (Dense) (None, 4) 36 dense_9[0][0]
__________________________________________________________________________________________________
dense_11 (Dense) (None, 2) 10 dense_10[0][0]
__________________________________________________________________________________________________
dense_12 (Dense) (None, 4) 12 dense_11[0][0]
__________________________________________________________________________________________________
dense_13 (Dense) (None, 8) 40 dense_12[0][0]
__________________________________________________________________________________________________
dense_14 (Dense) (None, 16) 144 dense_13[0][0]
__________________________________________________________________________________________________
dense_15 (Dense) (None, 192) 3264 dense_14[0][0]
==================================================================================================
Total params: 793,320
Trainable params: 793,320
Non-trainable params: 0
Ответ №1:
Из документации tf.keras
x: Входные данные. Это может быть:
- Массив Numpy (или подобный массиву) или список массивов (в случае, если модель имеет несколько входов).
- Тензор тензорного потока или список тензоров (в случае, если модель имеет несколько входов).
- Диктант, сопоставляющий имена входных данных с соответствующими массивами/тензорами, если модель имеет именованные входные данные.
tf.data
Набор данных. Должен возвращать кортеж либо(inputs, targets)
или(inputs, targets, sample_weights)
.- Генератор или
keras.utils.Sequence
возвращающийся(inputs, targets)
или(inputs, targets, sample_weights)
.
Таким образом, вы можете передать attr name
Input
Слоям, а затем предоставить входные данные с dict
keys
именами, которые являются входными данными.