#tensorflow #tensorflow2.0 #tensorflow-datasets
Вопрос:
Учтите, что у меня есть файл data.csv
, который содержит:
feature0,feature1,label
True,0.1,class_1
False,2.7,class_2
False,10.1,class_3
Я хотел бы загрузить это как набор данных и преобразовать label
в логическое значение, чтобы оно было истинным для class_1
и ложным в противном случае. Вот мой код:
import tensorflow as tf
data = tf.data.experimental.make_csv_dataset(
'data.csv',
32,
label_name='label',
shuffle=False,
num_epochs=1)
def view(ds, num_batches=1):
for f, l in ds.take(num_batches):
print('Features:')
print(f)
print('Labels:')
print(l)
def process_labels(features, label):
if label == 'class_1':
label = True
else:
label = False
# label = label=='class_1'
return features, label
view(data.map(process_labels))
Это приводит к ошибке: InvalidArgumentError: Input to reshape is a tensor with 3 values, but the requested shape has 1 [[{{node Reshape}}]]
. Это почему? Это тем более сбивает с толку, что , когда я заменяю if~else на однострочный, который прокомментирован label = label=='class_1'
, проблема исчезает. Что здесь происходит?
Я использую TensorFlow 2.4.1 и Python 3.8.5.
Ответ №1:
Вторым аргументом tf.data.experimental.make_csv_dataset
является размер пакета, что означает, что созданный набор данных имеет следующую форму: (batch_size, feature)
. Любая функция, отображаемая в этом наборе данных, должна работать с пакетом данных, а не только с одним элементом набора данных.
label = label=='class_1'
работает из-за трансляции, но ваша предыдущая функция этого не делает.
У вас есть два способа заставить эту функцию работать:
- либо напишите функцию, которая обрабатывает пакеты данных (т. е. ваше рабочее решение).
- вызовите
unbatch
набор данных. Это может иметь негативные последствия для производительности, как указано в документе:
Примечание. Для разборки требуется копия данных, чтобы разделить пакетированный тензор на более мелкие, несвязанные тензоры. При оптимизации производительности старайтесь избегать ненужного использования unbatch.
data.unbatch().map(process_labels).batch(32)