Возврат определенных элементов с помощью Dataset api

#python-3.x #tensorflow #tensorflow-datasets #tfrecord

#python-3.x #tensorflow #tensorflow-наборы данных #tfrecord

Вопрос:

я написал файл tfrecord, в котором у меня есть изображения и их метки.Затем я могу забрать их с помощью

     def parserTrain(record):
        keys_to_features = {
            "image_raw": tf.FixedLenFeature((), tf.string, default_value=""),
            "label": tf.FixedLenFeature((), tf.int64,
                                        default_value=tf.zeros([], dtype=tf.int64)),
        }
        parsed = tf.parse_single_example(record, keys_to_features)

        # Perform additional preprocessing on the parsed data.
        image = tf.image.decode_jpeg(parsed["image_raw"])
        image = tf.reshape(image, [256, 256, 3])

        image = tf.transpose(image, perm=[2, 0, 1])  # channels first
        image = tf.truediv(image, 255.0)
        label = tf.cast(parsed["label"], tf.int32)

        return {"image": image}, label

    # Set up training input function.
    def train_input_fn():
        """Prepare data for training."""
        train_tfrecord = 'Dataset/train_images.tfrecords'

        dataset = tf.data.TFRecordDataset(train_tfrecord)
        dataset = dataset.map(parserTrain)
  

после этого я хочу отфильтровать некоторые примеры, используя, вероятно, что-то вроде этого:

 def f(x):
    return x[1] == 1


ds1 = dataset.filter(f)
  

но я получаю эту ошибку:

TypeError: f() принимает 1 позиционный аргумент, но было задано 2

Ответ №1:

Итак, учитывая, что у вас есть набор данных (например, a TFRecordDataset ), вы можете отфильтровать примеры следующим образом:

   dataset = tf.data.TFRecordDataset(filenames=files)
  dataset = dataset.filter(lambda example: example["value"] == value and example["label"] == label)
  dataset = ...
  

Комментарии:

1. Я думаю, что я не совсем понимаю вопрос. Итак, вам нужен генератор, который возвращает пакет, где каждый образец имеет одинаковую метку. Но разные метки для каждого пакета?

2. да, это то, что мне нужно. Единственный способ, которым я могу придумать, как это сделать, если я создам цикл for, а затем создам dataset на основе определенной метки, но я не знаю, насколько это эффективно

3. @christk Пожалуйста, добавьте дополнительные подробности к вашему вопросу напрямую.

4. @lhlmgr я полностью изменил вопрос, и это прямо к делу, вы можете взглянуть 🙂

Ответ №2:

Отвечая на мой вопрос, поскольку я нашел ответ. правильный синтаксис для функции фильтрации набора данных кортежей следующий:

 def f(im, label):
    return tf.equal(label, 1)


ds1 = dataset.filter(f)