максимальное количество меток в классификации corss_entropy равно 300?

#tensorflow

#tensorflow

Вопрос:

Я обнаружил, что с помощью sparse_softmax_cross_entropy_with_logits у подразделения может быть не более 300 меток?

Я ничего не нашел по этому поводу.

Что, если у меня их больше?

Редактировать:

Если я не ограничиваюсь 300 классами, я получаю evrytime следующую трассировку:

 2019-03-05 15:24:17.899610: W tensorflow/core/framework/op_kernel.cc:1273] OP_REQUIRES failed at sparse_xent_op.cc:90 : Invalid argument: Received a label value of 428 which is outside the valid range of [0, 300).  Label values: 428 262
Traceback (most recent call last):
  File "C:Users\Anaconda3libsite-packagestensorflowpythonclientsession.py", line 1334, in _do_call
    return fn(*args)
  File "C:Users\Anaconda3libsite-packagestensorflowpythonclientsession.py", line 1319, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "C:Users\Anaconda3libsite-packagestensorflowpythonclientsession.py", line 1407, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Received a label value of 428 which is outside the valid range of [0, 300).  Label values: 428 262
         [[{{node QAModel/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits}} = SparseSoftmaxCrossEntropyWithLogits[T=DT_FLOAT, Tlabels=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](QAModel/StartDist/SimpleSoftmaxLayer/Add, _arg_QAModel/Placeholder_4_0_5)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "main.py", line 236, in <module>
    tf.app.run()
  File "C:Users\Anaconda3libsite-packagestensorflowpythonplatformapp.py", line 125, in run
    _sys.exit(main(argv))
  File "main.py", line 194, in main
    qa_model.train(sess, train_context_path, train_qn_path, train_ans_path, dev_qn_path, dev_context_path, dev_ans_path)
  File "C:Users\IBQA-Models-Bidafcodeqa_model.py", line 764, in train
    loss, global_step, param_norm, grad_norm = self.run_train_iter(session, batch, summary_writer)
  File "C:Users\IBQA-Models-Bidafcodeqa_model.py", line 359, in run_train_iter
    [_, summaries, loss, global_step, param_norm, gradient_norm] = session.run(output_feed, input_feed)
  File "C:Users\Anaconda3libsite-packagestensorflowpythonclientsession.py", line 929, in run
    run_metadata_ptr)
  File "C:Users\Anaconda3libsite-packagestensorflowpythonclientsession.py", line 1152, in _run
    feed_dict_tensor, options, run_metadata)
  File "C:Users\Anaconda3libsite-packagestensorflowpythonclientsession.py", line 1328, in _do_run
    run_metadata)
  File "C:Users\Anaconda3libsite-packagestensorflowpythonclientsession.py", line 1348, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Received a label value of 428 which is outside the valid range of [0, 300).  Label values: 428 262
         [[node QAModel/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits (defined at C:Users\IBQA-Models-Bidafcodeqa_model.py:318)  = SparseSoftmaxCrossEntropyWithLogits[T=DT_FLOAT, Tlabels=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](QAModel/StartDist/SimpleSoftmaxLayer/Add, _arg_QAModel/Placeholder_4_0_5)]]

Caused by op 'QAModel/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits', defined at:
  File "main.py", line 236, in <module>
    tf.app.run()
  File "C:Users\Anaconda3libsite-packagestensorflowpythonplatformapp.py", line 125, in run
    _sys.exit(main(argv))
  File "main.py", line 165, in main
    qa_model = QAModel(FLAGS, id2word, word2id, emb_matrix)
  File "C:Users\IBQA-Models-Bidafcodeqa_model.py", line 64, in __init__
    self.add_loss()
  File "C:Users\IBQA-Models-Bidafcodeqa_model.py", line 318, in add_loss
    loss= tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.ans_span) # loss_start has shape (batch_size)
  File "C:Users\Anaconda3libsite-packagestensorflowpythonopsnn_ops.py", line 2049, in sparse_softmax_cross_entropy_with_logits
    precise_logits, labels, name=name)
  File "C:Users\Anaconda3libsite-packagestensorflowpythonopsgen_nn_ops.py", line 8063, in sparse_softmax_cross_entropy_with_logits
    labels=labels, name=name)
  File "C:Users\Anaconda3libsite-packagestensorflowpythonframeworkop_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "C:Users\Anaconda3libsite-packagestensorflowpythonutildeprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "C:Users\Anaconda3libsite-packagestensorflowpythonframeworkops.py", line 3274, in create_op
    op_def=op_def)
  File "C:Users\Anaconda3libsite-packagestensorflowpythonframeworkops.py", line 1770, in __init__
    self._traceback = tf_stack.extract_stack()

InvalidArgumentError (see above for traceback): Received a label value of 428 which is outside the valid range of [0, 300).  Label values: 428 262
         [[node QAModel/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits (defined at C:Users\IBQA-Models-Bidafcodeqa_model.py:318)  = SparseSoftmaxCrossEntropyWithLogits[T=DT_FLOAT, Tlabels=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](QAModel/StartDist/SimpleSoftmaxLayer/Add, _arg_QAModel/Placeholder_4_0_5)]]
  

Каждый раз, когда я увеличиваю этот диапазон до 300. Почему?!

 valid range of [0, 300)
  

Комментарии:

1. я смотрю sparse_softmax_cross_entropy_with_logits в nn_ops.py и gen_nn_ops.py нет ничего о пределе или о чем-то подобном, почему вы думаете, что есть предел?

2. Я отредактировал и отправил свою трассировку обратно

3. Какова форма логитов? Согласно документам, они должны быть of shape [batch_size, num_classes]

4. логиты имеют [batch, len], но len не коррелирует с метками моего класса. это проблема?

Ответ №1:

Когда я принимал участие в конкурсе YouTube8M Kaggle, в нем было около 5 тысяч классов, и мы использовали потери, предоставленные организаторами конкурса, то есть Google, взгляните https://github.com/mpekalski/Y8M/blob/master/video_level_code/losses.py

Комментарии:

1. Я отредактировал свой пост. Почему тогда я получаю каждый раз этот диапазон до 300?