Простая нейронная сеть keras застряла с точностью 87,5%

#python #machine-learning #keras #neural-network

#python #машинное обучение #keras #нейронная сеть

Вопрос:

Я изучаю библиотеку Keras и столкнулся с проблемой. Я создаю очень простую нейронную сеть. Он принимает ввод из трех цифр (например, 010) и выводит одну цифру (например,1). Он должен выводить только 1, если во входных трех цифрах есть 1.

Однако, когда я запускаю свой код, точность остается на уровне 87,5% в течение полных пяти эпох. Это говорит мне, что это просто не обучение для экземпляра 000. Почему это не меняется? Я не понимаю, что я сделал не так: (

 from keras.datasets import mnist
from keras import models
from keras import layers
from keras.utils import to_categorical
import time
import numpy as np

train_images = np.array([[0,0,0],[0,0,1],[0,1,0],[0,1,1],[1,0,0],[1,0,1],[1,1,0],[1,1,1]])
train_labels = np.array([[0],[1],[1],[1],[1],[1],[1],[1]])


train_images = train_images.reshape((8, 3))


model = models.Sequential()

model.add(layers.Dense(6, input_dim=3, activation='relu'))
model.add(layers.Dense(6, activation='relu'))
model.add(layers.Dense(1, activation='softmax'))




model.compile(optimizer='adam',
                loss='categorical_crossentropy',
                metrics=['accuracy'])

model.fit(train_images, train_labels, epochs=5, batch_size=4)
 

Комментарии:

1. Ну, для начала, ваш набор данных очень мал, а затем он крайне несбалансирован — у вас есть только 1 экземпляр метки ‘0’. Попробуйте добавить несколько дополнительных записей меток [0, 0, 0] и [0] и посмотреть, изменится ли что-нибудь.

Ответ №1:

Модель застряла на уровне 87,5%, потому что она всегда предсказывает class=1 , как только с первой эпохи. Учитывая, что у вас всего 8 точек данных, сети достаточно простой эпохи, чтобы узнать как можно больше, вот почему вы не видите прогресса между первой и последней эпохой.

У вас есть 8 точек данных, 7 из них являются классом 1 , а 1 — классом 0 . Ваша модель, всегда предсказывающая class = 1, означает, что точность равна 7/8 = 87,5%

Что касается самой модели (ее архитектуры), я считаю, что она имеет 49 параметров, что, на мой взгляд, слишком много, и, как я упоминал ранее, достаточно для нее столько, сколько она может за одну эпоху. Кроме того, если вы столкнулись с проблемой двоичной классификации, вы можете захотеть:

  • Изменить softmax на sigmoid
  • Измените categorical crossentropy на binary crossentropy .