Как получить метки из набора данных tensorflow

#tensorflow #tensorflow2.0

#tensorflow #tensorflow2.0

Вопрос:

 ds_test = tf.data.experimental.make_csv_dataset(
    file_pattern = "./dfj_test/part-*.csv.gz",
    batch_size=batch_size, num_epochs=1, 
    #column_names=use_cols, 
    label_name='label_id',
    #select_columns= select_cols,
    num_parallel_reads=30, compression_type='GZIP',
    shuffle_buffer_size=12800)
 

Это мой tesetset во время обучения. После завершения модели я хочу заархивировать столбцы прогнозов и меток для df_test .

 preds = model.predict(df_test)
 

Получить прогнозы довольно просто, и они имеют формат numpy array. Однако я не знаю, как получить соответствующие метки из df_test.
Я хочу zip (preds, labels) для дальнейшего анализа.
Любой намек? Спасибо.

(версия tf 2.3.1)

Ответ №1:

Вы можете сопоставить каждый пример, чтобы вернуть нужное поле

 # load some exemplary data
TRAIN_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/train.csv"
train_file_path = tf.keras.utils.get_file("train.csv", TRAIN_DATA_URL)
dataset = tf.data.experimental.make_csv_dataset(train_file_path, batch_size=100, num_epochs=1) 
 
 # get field by unbatching 
labels_iterator= dataset.unbatch().map(lambda x: x['survived']).as_numpy_iterator()
labels = np.array(list(labels_iterator))
 
 # get field by concatenating batches
labels_iterator= dataset.map(lambda x: x['survived']).as_numpy_iterator()
labels = np.concatenate(list(labels_iterator))
 

Комментарии:

1. Опция unbatch() работает. Спасибо. Кстати, знаете ли вы какой-либо метод эффективного подсчета всех выборок набора данных?

2. для num _ в enumerate(ds_train): pass .