#python #numpy #tensorflow #keras #loss-function
Вопрос:
Я использую пользовательскую функцию потерь на основе дивергенции Коши-Шварца, доступную по адресу https://gist.github.com/Jarino/cb6d9b39abcf773a1fb0e9a90ee67db9 для обучения модели DL для задачи многоклассовой классификации в Tensorflow 2.4 с Keras 2.4.0. Метки y_true закодированы в одну горячую. Функция потерь приведена ниже:
from math import sqrt
from math import log
from scipy.stats import gaussian_kde
from scipy import special
def cs_divergence(p1, p2):
"""p1 (numpy array): first pdfs, p2 (numpy array): second pdfs, Returns:float: CS divergence"""
r = range(0, p1.shape[0])
p1_kernel = gaussian_kde(p1)
p2_kernel = gaussian_kde(p2)
p1_computed = p1_kernel(r)
p2_computed = p2_kernel(r)
numerator = sum(p1_computed * p2_computed)
denominator = sqrt(sum(p1_computed ** 2) * sum(p2_computed**2))
return -log(numerator/denominator)
sgd = SGD(lr=0.0001, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(optimizer=sgd,
loss=[cs_divergence],
metrics=['accuracy'])
При запуске обучающего кода возникает ошибка, как показано ниже:
File "C:Userscodescustomloss.py", line 741, in <module>
verbose=1)
File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythonkerasenginetraining.py", line 1100, in fit
tmp_logs = self.train_function(iterator)
File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythoneagerdef_function.py", line 828, in __call__
result = self._call(*args, **kwds)
File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythoneagerdef_function.py", line 871, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythoneagerdef_function.py", line 726, in _initialize
*args, **kwds))
File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythoneagerfunction.py", line 2969, in _get_concrete_function_internal_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythoneagerfunction.py", line 3361, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythoneagerfunction.py", line 3206, in _create_graph_function
capture_by_value=self._capture_by_value),
File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythonframeworkfunc_graph.py", line 990, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythoneagerdef_function.py", line 634, in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythonframeworkfunc_graph.py", line 977, in wrapper
raise e.ag_error_metadata.to_exception(e)
TypeError: in user code:
c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythonkerasenginetraining.py:805 train_function *
return step_function(self, iterator)
C:Userscodescustom_loss.py:295 cs_divergence *
r = range(0, p1.shape[0])
c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythonautographoperatorspy_builtins.py:365 range_ **
return _py_range(start_or_stop, stop, step)
c:usersappdatalocalcontinuumanaconda3envstf_2.4libsite-packagestensorflowpythonautographoperatorspy_builtins.py:389 _py_range
return range(start_or_stop, stop)
TypeError: 'NoneType' object cannot be interpreted as an integer
Комментарии:
1. Вы должны реализовать свой проигрыш только с помощью функций tensorflow, он не будет работать, если вы используете на нем какую-либо функцию numpy или scipy.
2. @Dr. Snoopy Не могли бы вы помочь с версией tensorflow функции пользовательских потерь?
3. Привет@шива ,не могли бы вы попробовать еще раз изменить код с loss=[cs_divergence] на loss=cs_divergence . Пожалуйста, предоставьте фрагмент кода для model (), чтобы продолжить . Ссылка : keras.io/api/losses