Каков правильный способ вычисления правильных прогнозов в Tensorflow?

#tensorflow

#tensorflow

Вопрос:

Я загружаю простой ConvNet в Tensorflow, используя файл tfrecords, содержащий изображения в оттенках серого в качестве входных данных и метки целых классов.

моя потеря определяется как loss = tf.nn.sparse_softmax_cross_entropy_with_logits(y_conv, label_batch)

где y_conv=tf.matmul(h_fc1_drop,W_fc2) b_fc2

и label_batch является тензором размера [batch_size] .

Я пытаюсь вычислить точность, используя

 correct_prediction = tf.equal(tf.argmax(label_batch,1),tf.argmax(y_conv, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  

Это correct_prediction утверждение выдает ошибку:

 InvalidArgumentError (see above for traceback): Minimum tensor rank: 2 but got: 1
  

Я немного смущен тем, как именно вычисляются правильные прогнозы в TF.

Ответ №1:

Вероятно, вы захотите использовать 0 в качестве аргумента измерения в tf.argmax, поскольку label_batch и y_conv являются векторами. Использование dimension = 1 подразумевает ранг тензора не менее 2. Смотрите документацию для параметра dimension в argmax здесь.

Я надеюсь, что это поможет!

Комментарии:

1. Первое измерение — это выборки. Измерение = 1 является правильным измерением при вычислении точности.

Ответ №2:

Для вашего y_conv вы все делаете правильно — это матрица формы, (batch_size, n_classes) где для каждого образца и для каждого класса у вас есть вероятность, что это класс, к которому принадлежит изображение. Итак, чтобы получить фактический предсказанный класс, который вам нужно вызвать argmax .

Однако ваши метки являются целыми числами и имеют форму just (batch_size,) , поскольку класс изображения известен и нет причин указывать n_classes вероятности, единственное целое число может точно так же содержать фактический класс. Таким образом, вам не нужно вызывать argmax его для преобразования вероятностей в класс, у него уже есть класс. Чтобы исправить это, просто выполните

 correct_prediction = tf.equal(label_batch, tf.argmax(y_conv, 1))