#scala #continuations #continuation-passing
#scala #продолжения #продолжение-передача
Вопрос:
Я пытаюсь реализовать пример в:
https://portal.klewel.com/watch/webcast/scala-days-2019/talk/37/
использование продолжения scala:
object ReverseGrad_CPSImproved {
import scala.util.continuations._
case class Num(
x: Double,
var d: Double = 0.0
) {
def (that: Num) = shift { (cont: Num => Unit) =>
val y = Num(x that.x)
cont(y)
this.d = y.d
that.d = y.d
}
def *(that: Num) = shift { (cont: Num => Unit) =>
val y = Num(x * that.x)
cont(y)
this.d = that.x * y.d
that.d = this.x * y.d
}
}
object Num {
implicit def fromX(x: Double): Num = Num(x)
}
def grad(f: Num => Num @cps[Unit])(x: Double): Double = {
val _x = Num(x)
reset { f(_x).d = 1.0 }
_x.d
}
}
Это работает до тех пор, пока я использую простое выражение:
it("simple") {
val fn = { x: Num =>
val result = (x 3) * (x 4)
result
}
val gg = grad(fn)(3)
println(gg)
}
Но как только я начал использовать цикл, все разваливается:
it("benchmark") {
import scala.util.continuations._
for (i <- 1 to 20) {
val n = Math.pow(2, i).toInt
val fn = { x: Num =>
var result = x 1
for (j <- 2 to n) {
result = result * (x j)
}
result
}
val nanoFrom = System.nanoTime()
val gg = grad(fn)(3)
val nanoTo = System.nanoTime()
println(s"diff = $gg,t time = ${nanoTo - nanoFrom}")
}
}
[Error] /home/peng/git-spike/scalaspike/meta/src/test/scala/com/tribbloids/spike/meta/multistage/lms/ReverseGrad_CPSImproved.scala:78: found cps expression in non-cps position
one error found
У меня сложилось впечатление, что библиотека продолжения должна иметь свою собственную реализацию цикла, которую можно переписать в рекурсию, но я не могу найти ее нигде в последней версии (scala 2.12). Какой самый простой способ использовать цикл в этом случае?
Комментарии:
1. Почему вы используете CPS, если вы все равно изменяете состояние (
var d: Double = 0.0
,this.d = that.x * y.d
)?
Ответ №1:
В CPS вам нужно переписать свой код так, чтобы вы НЕ выполняли вложенный / итеративный / рекурсивный вызов в том же контексте, а вместо этого выполняли только один шаг вычисления и передавали частичный результат вперед.
Например, если вы хотите вычислить произведение чисел A на B, вы могли бы реализовать его таким образом:
import scala.util.continuations._
case class Num(toDouble: Double) {
def get = shift { cont: (Num => Num) =>
cont(this)
}
def (num: Num) = reset {
val a = num.get
Num(toDouble a.toDouble)
}
def *(num: Num) = reset {
val a = num.get
Num(toDouble * a.toDouble)
}
}
// type annotation required because of recursive call
def product(from: Int, to: Int): Num @cps[Num] = reset {
if (from > to) Num(1.toDouble)
else Num(from.toDouble) * product(from 1, to)
}
def run: Num = reset {
product(2, 10)
}
println(run)
(смотрите эту статью).
Наиболее интересным является этот фрагмент:
reset {
if (from > to) Num(1.toDouble)
else Num(from.toDouble) * product(from 1, to)
}
Здесь компилятор (плагин) переписывает это, чтобы быть чем-то похожим на:
input: (Num => Num) => {
if (from > to) Num(1.toDouble)
else {
Num(from.toDouble) * product(from 1, to) // this is virtually (Num => Num) => Num function!
} (input)
}
Компилятор может это сделать, потому что:
- он отслеживает содержимое
shift
иreset
вызывает- оба создают что-то, что принимает некоторый параметр
A
и возвращает промежуточный результатB
(используемый, например, внутри этого или другогоreset
) и конечный результатC
(то, что вы получаете при запуске конечного результата композиции) (обозначается какA @ cpsParam[B, C]
— еслиB =:= C
вы можете использовать псевдоним типаA @ cps[A]
) reset
упрощает передачу параметров, поскольку он обрабатывает получение параметра (A
inA @ cpsParam[B, C]
) и передачу его всем вложенным вызовам CPS и получение промежуточного результата (soB
inA @ cpsParam[B, C]
) и создание целого блока, возвращающего конечный результат —C
A @ cpsParam[B, C]
shift
поднимает функцию(A => B) => C
вA @ cpsParam[B, C]
- оба создают что-то, что принимает некоторый параметр
- когда он видит, что возвращаемый тип
Input @cpsParam[Output1, Output2]
равен, он знает, что он должен переписать код, чтобы ввести параметр и передать его туда
На практике это немного сложнее, но в основном это так.
Тем временем вы делаете свой
for (j <- 2 to n) {
result = result * (x j)
}
вне этого контекста, где компилятор не выполняет никаких преобразований. Вы должны, по крайней мере, составить все операции CPS внутри reset
. (Кроме того, вы запускаете вещи в цикле и мутации, которые также могут быть делегированы CPS).
Тем не менее, CPS (например, в этой конкретной реализации) мертв. Он был удален в Scala 2.13, его никто не поддерживает, и использование некоторой монады на основе батута (например Cont
, от Cats) намного проще для понимания, поэтому единственные места, где я все еще вижу это, — это устаревшие курсы или статьи об исторических мелочах.