#python #tensorflow #keras #tensorflow2.0
Вопрос:
Я пытаюсь использовать: [SparseCategoricalCrossEntropy][https://www.tensorflow.org/api_docs/python/tf/keras/losses/SparseCategoricalCrossentropy] для многоклассовой классификации
Это даст мне последнее измерение в виде числа классов (N_CLASSES). Но я хочу восстановить фактические метки классов из прогнозов.
В принципе, если у меня есть 5 классов (N_CLASSES=5), то у меня есть 5 столбцов, каждый из которых содержит вероятность данного класса. Но я не знаю, какой столбец относится к какой фактической метке. Как мне получить фактические метки классов ?
Например, если у меня есть мои фактические ярлыки классов, такие как [1.03, 2.07, -2.09, -974, 366], затем из вывода shape (Нет, 5) как я узнаю, какой столбец представляет какой класс?
Примечание: Я не могу использовать категориальную перекрестную энтропию и передать фактическое целевое представление с одним горячим кодом из-за проблем с памятью.
Любая помощь будет очень признательна
Ответ №1:
На самом деле это довольно просто. Давайте предположим, что ваша модель выводит predictions = [1.03, 2.07, -2.09, -974, 366]
. Эти 5 чисел отражают уверенность вашей модели в том, что ваши входные данные соответствуют каждому из 5 различных классов. Если вы затем применитесь np.argmax
к своим прогнозам, которые вернут индекс максимального значения в predictions
:
np.argmax(predictions)
вы получите индекс 4. Предполагая, что каждая метка в вашем наборе данных является целым числом от 0 до 4 , и поскольку вы используете SparseCategoricalCrossEntropy
, вы можете сказать, что ваша модель наиболее уверена в том, что ваши входные данные принадлежат классу 4 (каким бы ни был класс 4). Надеюсь, вы поняли мою идею.