Как использовать пользовательскую функцию потерь на основе дивергенции Коши-Шварца для обучения модели Keras?

#python #numpy #tensorflow #keras #loss-function

Вопрос:

Я использую пользовательскую функцию потерь на основе дивергенции Коши-Шварца, доступную по адресу https://gist.github.com/Jarino/cb6d9b39abcf773a1fb0e9a90ee67db9 для обучения модели DL для задачи многоклассовой классификации в Tensorflow 2.4 с Keras 2.4.0. Метки y_true закодированы в одну горячую. Функция потерь приведена ниже:

 from math import sqrt
from math import log
from scipy.stats import gaussian_kde
from scipy import special

def cs_divergence(p1, p2):    
    """p1 (numpy array): first pdfs, p2 (numpy array): second pdfs, Returns:float: CS divergence"""    
    r = range(0, p1.shape[0])
    p1_kernel = gaussian_kde(p1)
    p2_kernel = gaussian_kde(p2)
    p1_computed = p1_kernel(r)
    p2_computed = p2_kernel(r)
    numerator = sum(p1_computed * p2_computed)
    denominator = sqrt(sum(p1_computed ** 2) * sum(p2_computed**2))
    return -log(numerator/denominator)
sgd = SGD(lr=0.0001, decay=1e-6, momentum=0.9, nesterov=True) 
model.compile(optimizer=sgd,
              loss=[cs_divergence], 
              metrics=['accuracy'])
 

При запуске обучающего кода возникает ошибка, как показано ниже:

 File "C:Userscodescustomloss.py", line 741, in <module>
    verbose=1)

  File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythonkerasenginetraining.py", line 1100, in fit
    tmp_logs = self.train_function(iterator)

  File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythoneagerdef_function.py", line 828, in __call__
    result = self._call(*args, **kwds)

  File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythoneagerdef_function.py", line 871, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)

  File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythoneagerdef_function.py", line 726, in _initialize
    *args, **kwds))

  File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythoneagerfunction.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)

  File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythoneagerfunction.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)

  File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythoneagerfunction.py", line 3206, in _create_graph_function
    capture_by_value=self._capture_by_value),

  File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythonframeworkfunc_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)

  File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythoneagerdef_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)

  File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythonframeworkfunc_graph.py", line 977, in wrapper
    raise e.ag_error_metadata.to_exception(e)

TypeError: in user code:

    c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythonkerasenginetraining.py:805 train_function  *
        return step_function(self, iterator)
    C:Userscodescustom_loss.py:295 cs_divergence  *
        r = range(0, p1.shape[0])
    c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythonautographoperatorspy_builtins.py:365 range_  **
        return _py_range(start_or_stop, stop, step)
    c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythonautographoperatorspy_builtins.py:389 _py_range
        return range(start_or_stop, stop)

    TypeError: 'NoneType' object cannot be interpreted as an integer
 

Комментарии:

1. Вы должны реализовать свой проигрыш только с помощью функций tensorflow, он не будет работать, если вы используете на нем какую-либо функцию numpy или scipy.

2. @Dr. Snoopy Не могли бы вы помочь с версией tensorflow функции пользовательских потерь?

3. Привет@шива ,не могли бы вы попробовать еще раз изменить код с loss=[cs_divergence] на loss=cs_divergence . Пожалуйста, предоставьте фрагмент кода для model (), чтобы продолжить . Ссылка : keras.io/api/losses