Выход модели VGG становится постоянным после обучения, а потери/точность не улучшаются

#python #tensorflow #keras #vgg-net #image-classification

Вопрос:

Я пытаюсь реализовать немного уменьшенную версию VGG16 и обучить ее с нуля на наборе данных из примерно 6000 изображений (5400 для обучения и 600 для проверки). Я выбрал размер пакета 30, чтобы он мог аккуратно вписаться в набор данных, иначе я получил бы его с ошибкой несовместимой формы во время обучения.

После прохождения 15-20 эпох начинается ранний обратный вызов и прекращается обучение.

С этой моделью я сталкиваюсь с двумя проблемами

  1. После этого, когда я передаю тестовые изображения в модель, результат, кажется, остается постоянным. По крайней мере, я ожидаю, что для imageA прогнозируемый результат должен отличаться от imageB. Я не могу понять, почему это так
  2. Потеря и точность, похоже, не сильно меняются. Я ожидал, что точность возрастет, по крайней мере, примерно до 50% для числа эпох, но она не превышает 23%. Я попытался включить steps_per_epoch, ReduceLROnPlateau, но они, похоже, не оказывают никакого влияния.

Результаты Обучения:

 Epoch 1/50
180/180 [==============================] - 50s 278ms/step - loss: 1.6095 - categorical_accuracy: 0.1987 - val_loss: 1.6109 - val_categorical_accuracy: 0.1267

Epoch 00001: val_loss improved from inf to 1.61094, saving model to vgg16.h5
Epoch 2/50
180/180 [==============================] - 51s 285ms/step - loss: 1.6095 - categorical_accuracy: 0.2044 - val_loss: 1.6107 - val_categorical_accuracy: 0.2133

Epoch 00002: val_loss improved from 1.61094 to 1.61067, saving model to vgg16.h5
Epoch 3/50
180/180 [==============================] - 51s 285ms/step - loss: 1.6098 - categorical_accuracy: 0.1946 - val_loss: 1.6106 - val_categorical_accuracy: 0.1400

Epoch 00003: val_loss improved from 1.61067 to 1.61059, saving model to vgg16.h5
Epoch 4/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.1928 - val_loss: 1.6098 - val_categorical_accuracy: 0.2000

Epoch 00004: val_loss improved from 1.61059 to 1.60983, saving model to vgg16.h5

Epoch 00004: ReduceLROnPlateau reducing learning rate to 2.5000001187436283e-05.
Epoch 5/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6093 - categorical_accuracy: 0.2033 - val_loss: 1.6103 - val_categorical_accuracy: 0.1467

Epoch 00005: val_loss did not improve from 1.60983
Epoch 6/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6094 - categorical_accuracy: 0.1989 - val_loss: 1.6106 - val_categorical_accuracy: 0.1400

Epoch 00006: val_loss did not improve from 1.60983
Epoch 7/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6094 - categorical_accuracy: 0.2069 - val_loss: 1.6098 - val_categorical_accuracy: 0.1733

Epoch 00007: val_loss improved from 1.60983 to 1.60978, saving model to vgg16.h5

Epoch 00007: ReduceLROnPlateau reducing learning rate to 1e-05.
Epoch 8/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6093 - categorical_accuracy: 0.2076 - val_loss: 1.6103 - val_categorical_accuracy: 0.1600

Epoch 00008: val_loss did not improve from 1.60978
Epoch 9/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.2006 - val_loss: 1.6097 - val_categorical_accuracy: 0.2200

Epoch 00009: val_loss improved from 1.60978 to 1.60975, saving model to vgg16.h5
Epoch 10/50
180/180 [==============================] - 52s 287ms/step - loss: 1.6095 - categorical_accuracy: 0.2043 - val_loss: 1.6101 - val_categorical_accuracy: 0.1667

Epoch 00010: val_loss did not improve from 1.60975
Epoch 11/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6094 - categorical_accuracy: 0.2009 - val_loss: 1.6102 - val_categorical_accuracy: 0.1800

Epoch 00011: val_loss did not improve from 1.60975
Epoch 12/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.2041 - val_loss: 1.6115 - val_categorical_accuracy: 0.1600

Epoch 00012: val_loss did not improve from 1.60975
Epoch 13/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.1989 - val_loss: 1.6108 - val_categorical_accuracy: 0.1867

Epoch 00013: val_loss did not improve from 1.60975
Epoch 14/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6094 - categorical_accuracy: 0.2009 - val_loss: 1.6102 - val_categorical_accuracy: 0.1733

Epoch 00014: val_loss did not improve from 1.60975
Epoch 15/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6093 - categorical_accuracy: 0.2074 - val_loss: 1.6113 - val_categorical_accuracy: 0.1467

Epoch 00015: val_loss did not improve from 1.60975
Epoch 16/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6098 - categorical_accuracy: 0.1983 - val_loss: 1.6105 - val_categorical_accuracy: 0.1867

Epoch 00016: val_loss did not improve from 1.60975
Epoch 17/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.2056 - val_loss: 1.6119 - val_categorical_accuracy: 0.1667

Epoch 00017: val_loss did not improve from 1.60975
Epoch 18/50
180/180 [==============================] - 52s 286ms/step - loss: 1.6093 - categorical_accuracy: 0.1994 - val_loss: 1.6110 - val_categorical_accuracy: 0.1800

Epoch 00018: val_loss did not improve from 1.60975
Epoch 19/50
180/180 [==============================] - 51s 286ms/step - loss: 1.6095 - categorical_accuracy: 0.2026 - val_loss: 1.6103 - val_categorical_accuracy: 0.1667

Epoch 00019: val_loss did not improve from 1.60975
Restoring model weights from the end of the best epoch.
Epoch 00019: early stopping
 

Код, используемый для получения прогнозов:

 predictions = []
actuals=[]

for i, (images, labels) in enumerate( test_datasource):
  if i > 2:
    break
  pred = model_2(images)
  print(labels.shape, pred.shape)
  for j in range(len(labels)):
    actuals.append( labels[j])
    predictions.append(pred[j])
    print(labels[j].numpy(), "t", pred[j].numpy())

 

Output of the above code:

 (30, 5) (30, 5)
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
(30, 5) (30, 5)
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
(30, 5) (30, 5)
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 0. 1.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[1. 0. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 1. 0. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 0. 1. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
[0. 0. 1. 0. 0.]     [0.19907779 0.20320047 0.1968051  0.20173152 0.19918515]
 

Вот краткое описание модели:

 Model: "vgg16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_3 (InputLayer)         [(30, 224, 224, 3)]       0         
_________________________________________________________________
conv_1_1 (Conv2D)            (30, 224, 224, 32)        896       
_________________________________________________________________
conv_1_2 (Conv2D)            (30, 224, 224, 32)        9248      
_________________________________________________________________
maxpool_1 (MaxPooling2D)     (30, 112, 112, 32)        0         
_________________________________________________________________
conv_2_1 (Conv2D)            (30, 112, 112, 64)        18496     
_________________________________________________________________
conv_2_2 (Conv2D)            (30, 112, 112, 64)        36928     
_________________________________________________________________
maxpool_2 (MaxPooling2D)     (30, 56, 56, 64)          0         
_________________________________________________________________
conv_3_1 (Conv2D)            (30, 56, 56, 128)         73856     
_________________________________________________________________
conv_3_2 (Conv2D)            (30, 56, 56, 128)         147584    
_________________________________________________________________
conv_3_3 (Conv2D)            (30, 56, 56, 128)         147584    
_________________________________________________________________
maxpool_3 (MaxPooling2D)     (30, 28, 28, 128)         0         
_________________________________________________________________
conv_4_1 (Conv2D)            (30, 28, 28, 256)         295168    
_________________________________________________________________
conv_4_2 (Conv2D)            (30, 28, 28, 256)         590080    
_________________________________________________________________
conv_4_3 (Conv2D)            (30, 28, 28, 256)         590080    
_________________________________________________________________
maxpool_4 (MaxPooling2D)     (30, 14, 14, 256)         0         
_________________________________________________________________
conv_5_1 (Conv2D)            (30, 14, 14, 256)         590080    
_________________________________________________________________
conv_5_2 (Conv2D)            (30, 14, 14, 256)         590080    
_________________________________________________________________
conv_5_3 (Conv2D)            (30, 14, 14, 256)         590080    
_________________________________________________________________
maxpool_5 (MaxPooling2D)     (30, 7, 7, 256)           0         
_________________________________________________________________
flatten (Flatten)            (30, 12544)               0         
_________________________________________________________________
fc_1 (Dense)                 (30, 4096)                51384320  
_________________________________________________________________
dropout_1 (Dropout)          (30, 4096)                0         
_________________________________________________________________
fc_2 (Dense)                 (30, 4096)                16781312  
_________________________________________________________________
dropout_2 (Dropout)          (30, 4096)                0         
_________________________________________________________________
output (Dense)               (30, 5)                   20485     
=================================================================
Total params: 71,866,277
Trainable params: 71,866,277
Non-trainable params: 0
 

Код находится здесь, в Google Colab: https://colab.research.google.com/drive/1AWe87Zb3MvF90j3RS7sv3OiSgR86q4j_

Я попробовал две версии VGG-16, одну с половиной глубины фильтров, чем в оригинале, и вторую с четвертью глубины фильтров.

Комментарии:

1. похоже, что даже ваши потери в тренировках не улучшаются после каждой эпохи, в то время как тренировки почти одинаковы после каждой эпохи, пожалуйста, еще раз проверьте тренировочный конвейер или, пожалуйста, упомяните тренировочный конвейер здесь, чтобы другие могли его увидеть

2. Под обучающим конвейером вы подразумеваете вывод модели.резюме? Извините, я все еще изучаю жаргоны этой области

3. извините, не заметил, что вы уже предоставили ссылку на Google colab

4. можете ли вы попробовать увеличить скорость обучения, а также проверить, что ваша функция потерь получает правильный ввод, который она ожидает, tensorflow.org/api_docs/python/tf/keras/losses/…

5. Я преобразую метки в векторы с одним горячим кодированием, используя label_mode= «категориальный» при загрузке данных. метки действительно генерируются таким образом, как видно выше, где я сравниваю метки с прогнозируемыми результатами. Категориальная перекрестная энтропия, по-видимому, является потерей, которая обычно используется в сценариях с одним горячим кодированием. Есть ли какая-либо другая причина, по которой вы просите проверить функцию потерь? Сейчас я постараюсь увеличить скорость обучения. добавлю комментарий, как только закончу

Ответ №1:

У меня сложилось впечатление, что проблема заключалась в том, что потери и точность были неизменными, и именно поэтому при использовании модели для прогнозов выходные данные были фиксированными и неизменными независимо от того, какие входные данные в нее подавались.

Когда я повторно инициализировал модель,вызвав функцию Model(входы…, выходы…), и передал ей входные данные без обучения, выходные данные, по крайней мере, изменились.

Я пробовал использовать несколько скоростей обучения и оптимизаторов без каких-либо изменений в поведении модели.

После еще нескольких поисков в Google я случайно наткнулся на эти статьи:

  1. https://www.quora.com/Why-does-my-convolutional-neural-network-always-produce-the-same-outputs
  2. https://www.quora.com/Why-does-my-own-neural-network-give-me-the-same-output-for-different-input-sets
  3. https://datascience.stackexchange.com/questions/5706/what-is-the-dying-relu-problem-in-neural-networks

Я внес два изменения в код, чтобы он наконец заработал.

  1. Первоначально я делил массив изображений только на 255, чтобы значения находились в диапазоне от 0 до 1, а затем вычитал 0,5 из результата, чтобы значения находились в диапазоне от -0,5 до 0,5. Это было изменено на использование tf.image.per_image_standardization(изображения-127) и деление результата на максимальное значение в каждом изображении. В результате значения изображения упали между -1 и 1
  2. Другой основной причиной фиксированных выходов является то, что блоки relu модели умирают (или насыщаются) во время обучения. функция активации relu по своей сути имеет эту проблему, когда, как только вес переменной становится равным 0, она не восстанавливается после этого. Хотя говорят, что высокая скорость обучения является причиной этой проблемы, я не смог найти скорость обучения, которая облегчила бы эту проблему. Другим решением является изменение функции активации на дырявый relu или elu (Экспоненциальный relu), которые имеют встроенный механизм для устранения этой проблемы

С этими изменениями потери модели снизились до < 1, а точность обучения-до