Каковы фактические метки классов при использовании SparseCategoricalCrossEntropy loss для многоклассовой классификации в keras?

#python #tensorflow #keras #tensorflow2.0

Вопрос:

Я пытаюсь использовать: [SparseCategoricalCrossEntropy][https://www.tensorflow.org/api_docs/python/tf/keras/losses/SparseCategoricalCrossentropy] для многоклассовой классификации

введите описание изображения здесь

Это даст мне последнее измерение в виде числа классов (N_CLASSES). Но я хочу восстановить фактические метки классов из прогнозов.

В принципе, если у меня есть 5 классов (N_CLASSES=5), то у меня есть 5 столбцов, каждый из которых содержит вероятность данного класса. Но я не знаю, какой столбец относится к какой фактической метке. Как мне получить фактические метки классов ?

Например, если у меня есть мои фактические ярлыки классов, такие как [1.03, 2.07, -2.09, -974, 366], затем из вывода shape (Нет, 5) как я узнаю, какой столбец представляет какой класс?

Примечание: Я не могу использовать категориальную перекрестную энтропию и передать фактическое целевое представление с одним горячим кодом из-за проблем с памятью.

Любая помощь будет очень признательна

Ответ №1:

На самом деле это довольно просто. Давайте предположим, что ваша модель выводит predictions = [1.03, 2.07, -2.09, -974, 366] . Эти 5 чисел отражают уверенность вашей модели в том, что ваши входные данные соответствуют каждому из 5 различных классов. Если вы затем применитесь np.argmax к своим прогнозам, которые вернут индекс максимального значения в predictions :

np.argmax(predictions)

вы получите индекс 4. Предполагая, что каждая метка в вашем наборе данных является целым числом от 0 до 4 , и поскольку вы используете SparseCategoricalCrossEntropy , вы можете сказать, что ваша модель наиболее уверена в том, что ваши входные данные принадлежат классу 4 (каким бы ни был класс 4). Надеюсь, вы поняли мою идею.