Использование tf, py_func с файлами pickle в Dataset API

#tensorflow #pickle

#tensorflow #pickle

Вопрос:

Я пытаюсь использовать Dataset API с моим набором данных, которые являются файлами pickle. Эти файлы содержат мои данные, которые представляют собой вектор с плавающей точкой, и метки, которые являются одним горячим вектором.

Я пытался использовать tf.py_func для загрузки функций, но я не могу этого сделать, так как у меня нет совпадающих фигур. Поскольку я использую эти файлы pickle, которые также включают метку, я не могу передать ее непосредственно кортежу в качестве примера здесь. Итак, я немного не понимаю, как продолжить.

Пока это мой код

 
path = "my_dir_to_pkl_files"

pkl_files = glob.glob((path "*.pkl"))
dataset = tf.data.Dataset.from_tensor_slices((pkl_files))
dataset = dataset.map(
               lambda filename: tuple(tf.py_func(
               load_features, [filename], [tf.float32])))
  

И вот моя функция python для чтения функций.

 def load_features(name):
    decoded = name.decode("UTF-8")
    if os.path.exists(decoded):
        with open(decoded, 'rb') as f:
            file = pickle.load(f)
            return file['features']
            # I have commented the line below but this should return
            # the features and the label in a one hot vector
            # return file['features'], file['targets']
    else:
        print("Something went wrong!")
        exit(-1)
  

Я ожидал бы, что Dataset API вернет кортеж с N объектами и 1 горячим вектором для каждого образца в моей партии. Вместо этого я получаю

InvalidArgumentError: pyfunc_0 возвращает 30 значений, но ожидает увидеть 1 значение.

Есть предложения? Спасибо.

Редактировать: я показываю, как выглядит мой файл pickle. Вектор объектов имеет форму [30,100]. Я также прикрепляю тот же файл здесь.

 {'features': array([[0.64864044, 0.71419346, 0.35874235, ..., 0.66058507, 0.89013242,
        0.67564707],
       [0.15958826, 0.38115951, 0.46636267, ..., 0.49682084, 0.08863887,
        0.17142761],
       [0.26925915, 0.27901399, 0.91624607, ..., 0.30269212, 0.47494327,
        0.43265325],
       ...,
       [0.50405357, 0.7441127 , 0.04308265, ..., 0.06766902, 0.87449393,
        0.31018099],
       [0.44777562, 0.30836258, 0.48148097, ..., 0.74899213, 0.97264324,
        0.43391464],
       [0.50583501, 0.56803691, 0.61290449, ..., 0.8350931 , 0.52897295,
        0.23731264]]), 'targets': array([0, 0, 1, 0])}

  

Ошибка, которую я получил, возникла после того, как я попытался получить элемент для набора данных

 dataset.make_one_shot_iterator()
next_element = iterator.get_next()
print(sess.run(next_element))
  

Комментарии:

1. Вы должны предоставить мини-файл pickle, чтобы мы могли вам помочь.

2. @giser_yugang Я отредактировал вопрос. Я загрузил файл и опубликовал его содержимое.

3. Странно, что я не столкнулся с вашей ошибкой в tensorflow=1.12 and python=3.6 . Когда я использую код dataset = dataset.map(lambda filename: tuple(tf.py_func(load_features, [filename], [tf.float64]))) , я могу его запустить.

4. Я получаю сообщение об ошибке при попытке получить элемент из набора данных. dataset.make_one_shot_iterator() next_element = итератор.get_next() печать (sess.run(next_element)) Может ли это быть так?

5. Код, который вы получаете из элемента dataset, работает на моей машине правильно. Я могу распечатать содержимое вашего файла.