#tensorflow #pickle
#tensorflow #pickle
Вопрос:
Я пытаюсь использовать Dataset API с моим набором данных, которые являются файлами pickle. Эти файлы содержат мои данные, которые представляют собой вектор с плавающей точкой, и метки, которые являются одним горячим вектором.
Я пытался использовать tf.py_func для загрузки функций, но я не могу этого сделать, так как у меня нет совпадающих фигур. Поскольку я использую эти файлы pickle, которые также включают метку, я не могу передать ее непосредственно кортежу в качестве примера здесь. Итак, я немного не понимаю, как продолжить.
Пока это мой код
path = "my_dir_to_pkl_files"
pkl_files = glob.glob((path "*.pkl"))
dataset = tf.data.Dataset.from_tensor_slices((pkl_files))
dataset = dataset.map(
lambda filename: tuple(tf.py_func(
load_features, [filename], [tf.float32])))
И вот моя функция python для чтения функций.
def load_features(name):
decoded = name.decode("UTF-8")
if os.path.exists(decoded):
with open(decoded, 'rb') as f:
file = pickle.load(f)
return file['features']
# I have commented the line below but this should return
# the features and the label in a one hot vector
# return file['features'], file['targets']
else:
print("Something went wrong!")
exit(-1)
Я ожидал бы, что Dataset API вернет кортеж с N объектами и 1 горячим вектором для каждого образца в моей партии. Вместо этого я получаю
InvalidArgumentError: pyfunc_0 возвращает 30 значений, но ожидает увидеть 1 значение.
Есть предложения? Спасибо.
Редактировать: я показываю, как выглядит мой файл pickle. Вектор объектов имеет форму [30,100]. Я также прикрепляю тот же файл здесь.
{'features': array([[0.64864044, 0.71419346, 0.35874235, ..., 0.66058507, 0.89013242,
0.67564707],
[0.15958826, 0.38115951, 0.46636267, ..., 0.49682084, 0.08863887,
0.17142761],
[0.26925915, 0.27901399, 0.91624607, ..., 0.30269212, 0.47494327,
0.43265325],
...,
[0.50405357, 0.7441127 , 0.04308265, ..., 0.06766902, 0.87449393,
0.31018099],
[0.44777562, 0.30836258, 0.48148097, ..., 0.74899213, 0.97264324,
0.43391464],
[0.50583501, 0.56803691, 0.61290449, ..., 0.8350931 , 0.52897295,
0.23731264]]), 'targets': array([0, 0, 1, 0])}
Ошибка, которую я получил, возникла после того, как я попытался получить элемент для набора данных
dataset.make_one_shot_iterator()
next_element = iterator.get_next()
print(sess.run(next_element))
Комментарии:
1. Вы должны предоставить мини-файл pickle, чтобы мы могли вам помочь.
2. @giser_yugang Я отредактировал вопрос. Я загрузил файл и опубликовал его содержимое.
3. Странно, что я не столкнулся с вашей ошибкой в
tensorflow=1.12
andpython=3.6
. Когда я использую кодdataset = dataset.map(lambda filename: tuple(tf.py_func(load_features, [filename], [tf.float64])))
, я могу его запустить.4. Я получаю сообщение об ошибке при попытке получить элемент из набора данных. dataset.make_one_shot_iterator() next_element = итератор.get_next() печать (sess.run(next_element)) Может ли это быть так?
5. Код, который вы получаете из элемента dataset, работает на моей машине правильно. Я могу распечатать содержимое вашего файла.