#type-conversion #julia #automatic-differentiation #autodiff
#преобразование типов #julia #автоматическое дифференцирование #автодифф
Вопрос:
Я просто хотел бы использовать ForwardDiff.jl
функциональность для определения функции и построения ее градиента (вычисляется с использованием ForwardDiff.gradient
). Кажется, это не работает, потому что вывод ForwardDiff.gradient
— это странный Dual
тип, и его нелегко преобразовать в желаемый тип (в моем случае, одномерный массив Float32).
using Plots
using ForwardDiff
my_func(x::Array{Float32,1}) = 1f0. / (1f0 . exp(3f0 .* x)) # doesn't matter what this is, just a sigmoid function here
grad_f(x::Array{Float32,1}) = ForwardDiff.gradient(my_func, x)
x_values = collect(Float32,0:0.01:10)
plot(x_values,my_func(x_values)); # this works fine
plot!(x_values,grad_f(x_values)); # this throws an error
И это ошибка, которую я получаю:
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(g),Float32},Float64,12})
Когда я проверяю тип grad_f(x_values)
, я получаю это:
Array{Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(g),Float32},Float32,12},1},1}
Почему этого не происходит, например, в примере в документации для ForwardDiff? Смотрите здесь: https://github.com/JuliaDiff/ForwardDiff.jl
Заранее спасибо.
РЕДАКТИРОВАТЬ: после комментариев Кристоффера Карлссона: я пробовал это, но это все еще не работает. Я не понимаю, что так отличается от того, что я пробовал здесь, по сравнению с тем, что он предложил:
function g(x::Float32)
return x / (1f0 exp(10f0 * (x - 5f0)))
end
function ∂g∂x(x::Float32)
return ForwardDiff.derivative(g, x)
end
x_vals = collect(Float32,0:0.01:10)
plot(x_vals,g.(x_vals))
plot!(x_vals,∂g∂x.(x_vals))
С ошибкой, которая теперь:
no method matching g(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(g),Float32},Float32,1})
И эта ошибка возникает только при вызове ∂g∂x(x)
, независимо от того, использую я широковещательную версию или нет ∂g∂x.(x)
. Я предполагаю, что это как-то связано с определением функции, но я не вижу, как способ, которым я его определил, отличается от версии Кристоффера, кроме того, что он не определен в одной строке…Это так сбивает с толку.
Это должно сработать, потому что, согласно ForwardDiff
документации, вам просто нужно, чтобы типы входных данных были подтипом Real
— и Float32
являются подтипом Real .
РЕДАКТИРОВАТЬ: я понимаю, что теперь, прочитав комментарии других, вам нужно ограничить ваши функции, чтобы они были достаточно общими, чтобы принимать любые входные данные абстрактного типа Real
, которые я не совсем понял из документации. Приносим извинения за путаницу.
Комментарии:
1. «Для меня это имеет смысл 0, учитывая то, что сказано в документации ForwardDiff.jl об отсутствии ограничений на то, какой подтип Real вы используете». Нет, документы в juliadiff.org/ForwardDiff.jl/stable/user/limitations явно скажите: «Целевая функция должна быть написана достаточно обобщенно, чтобы принимать числа типа
T<:Real
в качестве входных данных», и если вы ограничите этоFloat32
, то это неверно.2. Ах, я не понимал, что это значит — я думал, что обозначение
T<:Real
означает, что вы можете ограничить любой тип T, который является подтипомReal
, и поэтому они могут быть ограничены таким образом. Моя ошибка, спасибо за разъяснение.
Ответ №1:
Вы определяете функции в массивах вместо скаляров, а также слишком сильно ограничиваете типы ввода. Кроме того, для скалярных функций вы должны использовать ForwardDiff.derivative
. Попробуйте что-то вроде:
using Plots
using ForwardDiff
my_func(x::Real) = 1f0 / (1f0 exp(3f0 * x))
my_func_derivative(x::Real) = ForwardDiff.derivative(my_func, x)
plot(my_func, xlimits = (0, 10))
plot!(my_func_derivative)
предоставление:
Комментарии:
1. Спасибо за совет — однако я попробовал это немного по-другому, и это не сработало. Вы знаете, почему? Я отредактирую исходное сообщение, чтобы показать, что я пробовал.
2. Возможно, вам придется прочитать о том, как работает ForwardDiff — он вызывает вашу функцию с двойными числами для отслеживания производной. Если вы ограничите свою функцию вводом типа
Float32
, это не сработает. Вам нужно ослабить егоReal
, посколькуDual
тип ForwardDiff является подтипомReal
.3. Я бы добавил, что, если вы не планируете определять другой метод для этой функции (чтобы вести себя по-другому с другими типами), вам, вероятно, не следует применять какой-либо тип ввода. Код будет легче читать, он будет таким же производительным и, вероятно, будет сочетаться с другими пакетами с нулевыми усилиями.