#python #tensorflow #classification #conv-neural-network #vgg-net
#python #tensorflow #классификация #conv-нейронная сеть #vgg-net
Вопрос:
В настоящее время я пытаюсь обучить сети классификации с использованием TensorFlow API (https://github.com/tensorflow/models). После создания TFRecords для моего набора данных (хранящихся в research / slim / data), я обучаю сети, используя следующую команду:
python research/slim/train_image_classifier.py
--train_dir=research/slim/training/current_model
--dataset_name=my_dataset
--dataset_split_name=train
--dataset_dir=research/slim/data
--model_name=vgg_16
--checkpoint_path=research/slim/training/vgg_16_2016_08_28/vgg_16.ckpt
--checkpoint_exclude_scopes=vgg_16/fc7,vgg_16/fc8
--trainable_scopes=vgg_16/fc7,vgg_16/fc8
--batch_size=5
--log_every_n_steps=10
--max_number_of_steps=1000
Это хорошо работает для нескольких сетей классификации (Inception, ResNet, MobileNet), но не так хорошо для VGG-Net. Я точно настраиваю следующую модель VGG-Net 16:
http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz
В общем, это работает для обучения этой модели, но когда я обучаю сеть, потери увеличиваются, а не уменьшаются. Возможно, это связано с моим выбором ‘checkpoint_exclude_scopes’.
Правильно ли использовать последний полностью подключенный слой в качестве checkpoint_exclude_scopes?
Тот же вопрос возникает при замораживании графика для параметра ‘output_node_names’. Например, для InceptionV3 это работает с ‘output_node_names=InceptionV3/Predictions/Reshape_1’. Но как установить этот параметр для VGG-Net. Я попробовал следующее:
python research/slim/freeze_graph.py
--input_graph=research/slim/training/current_model/graph.pb
--input_checkpoint=research/slim/training/current_model/model.ckpt
--input_binary=true
--output_graph=research/slim/training/current_model/frozen_inference_graph.pb
--output_node_names=vgg_16/fc8
Я не нашел ни одного слоя, содержащего «Предсказания» или «Логиты» в модели VGG-Net, поэтому я не уверен.
Спасибо за помощь!
Комментарии:
1. сработало ли это для MobileNet, если да, то какие значения вы передали в trainable_scopes, checkpoint_exclude_scopes и какой файл контрольных точек вы использовали в файле контрольных точек checkpoint_path (т. Е. файле контрольных точек нового набора данных или файле контрольных точек по умолчанию в Mobilenet? Не могли бы вы, пожалуйста, рассказать об этом
2. Почему вы не предоставляете скрипты для модели, с которой у вас возникли проблемы (VGG16), вместо InceptionV3?
3. @Anju Paul — Intel: Я только что обновил сообщение, указав именно те команды скрипта, которые я использовал для VGG16.
4. @Dinesh: Да, это работает для MobileNet. Вот параметры, которые я использовал для MobileNet v1: —trainable_scopes=MobileNetV1/Logits —checkpoint_exclude_scopes=MobileNetV1/Logits —checkpoint_path=mobilenet_v1_1.0_224 / mobilenet_v1_1.0_224.ckpt ___ И для замораживания графика я используется —output_node_names=MobileNetV1/Предсказания/Изменение формы_1
Ответ №1:
Я попытался запустить train_image_classifier.py как в вашем скрипте с несколькими изменениями, упомянутыми ниже:
- Изменил train_dir, dataset_dir и checkpoint_path на мой локальный путь
- Поскольку я работал на CPU, добавил
--clone_on_cpu=True
параметр в команду - Удалил параметр
dataset_name=my_dataset
, поскольку он выдавал ошибку для меня
Все прошло нормально. Потеря началась с 448, а затем постепенно уменьшилась и к концу 1000-го шага сократилась до 3.5. Он значительно колебался, но тенденция потерь была нисходящей. Не уверен, почему вы не смогли увидеть то же самое при попытке запуска.
Что касается вашего вопроса о checkpoint_exclude_scopes при обучении и output_node_names при замораживании графика, я думаю, что ваш выбор слоев абсолютно правильный. Однако я бы предпочел обучать только последний полностью подключенный уровень (fc8) для более быстрой конвергенции.
Комментарии:
1. Спасибо за ваш ответ и за вашу помощь. В принципе, вы сделали то же самое, что и я, только с другим набором данных. Приятно, что теперь это работает для вас. с помощью этих параметров. Затем я посмотрю дальше на набор данных — возможно, мой текущий набор данных слишком мал для VGG-Net.