Условная карта набора данных TensorFlow несовместима с «базовым» поведением Python

#tensorflow #tensorflow2.0 #tensorflow-datasets

Вопрос:

Учтите, что у меня есть файл data.csv , который содержит:

 feature0,feature1,label
True,0.1,class_1
False,2.7,class_2
False,10.1,class_3
 

Я хотел бы загрузить это как набор данных и преобразовать label в логическое значение, чтобы оно было истинным для class_1 и ложным в противном случае. Вот мой код:

 import tensorflow as tf

data = tf.data.experimental.make_csv_dataset(
    'data.csv',
    32,
    label_name='label',
    shuffle=False,
    num_epochs=1)

def view(ds, num_batches=1):
    
    for f, l in ds.take(num_batches):
        print('Features:')
        print(f)
        print('Labels:')
        print(l)

def process_labels(features, label):
    
    if label == 'class_1':
        label = True
    else:
        label = False
    
    # label = label=='class_1'
    
    return features, label

view(data.map(process_labels))
 

Это приводит к ошибке: InvalidArgumentError: Input to reshape is a tensor with 3 values, but the requested shape has 1 [[{{node Reshape}}]] . Это почему? Это тем более сбивает с толку, что , когда я заменяю if~else на однострочный, который прокомментирован label = label=='class_1' , проблема исчезает. Что здесь происходит?

Я использую TensorFlow 2.4.1 и Python 3.8.5.

Ответ №1:

Вторым аргументом tf.data.experimental.make_csv_dataset является размер пакета, что означает, что созданный набор данных имеет следующую форму: (batch_size, feature) . Любая функция, отображаемая в этом наборе данных, должна работать с пакетом данных, а не только с одним элементом набора данных.

label = label=='class_1' работает из-за трансляции, но ваша предыдущая функция этого не делает.

У вас есть два способа заставить эту функцию работать:

  • либо напишите функцию, которая обрабатывает пакеты данных (т. е. ваше рабочее решение).
  • вызовите unbatch набор данных. Это может иметь негативные последствия для производительности, как указано в документе:

    Примечание. Для разборки требуется копия данных, чтобы разделить пакетированный тензор на более мелкие, несвязанные тензоры. При оптимизации производительности старайтесь избегать ненужного использования unbatch.

     data.unbatch().map(process_labels).batch(32)