Зачем в Scala трамплины и как их использовать
В этой статье директор департамента разработки российской компании «Криптонит» и «скалист» Алексей Шуксто рассказывает о специфической технике функционального программирования, которая называется «трамплин» (trampoline).
Если кратко, то «трамплин» — это постоянный вызов в цикле новых частей вычисления вплоть до получения конечного результата. Трамплин можно рассматривать как шаблон проектирования, который позволяет избежать переполнения стека при рекурсивных вызовах функций.
Достигается это следующим образом: когда функция вызывает саму себя, то вместо этого вызова управление передаётся другой функции — трамплину. Эта функция-трамплин вызывает исходную функцию с нужными параметрами и, если нужно, передаёт управление другой функции-трамплину. Таким образом, при рекурсивных вызовах функций никакая информация не сохраняется на стеке, а управление всегда передаётся между функциями-трамплинами.
Чтобы вникнуть в детали, поясним ещё несколько моментов:
эффекты или контексты вычисления;
программирование при помощи передачи продолжений;
трамплин на основе эффекта
Eval
.
Эффекты или `Effect[F[_]]`
В качестве примера давайте посмотрим на разные способы решить простую задачу: найти длину гипотенузы по длине двух катетов.
Для начала рассмотрим «классический» вариант, написанный в императивном стиле.
def pyth(a: Double, b: Double): Double =
require(a > 0)
require(b > 0)
var result = 0.0
result += math.pow(a, 2.0)
result += math.pow(b, 2.0)
math.sqrt(result)
Код простой, логика ясна, результат понятен.
Попробуем переписать то же самое, но «функционально» — без изменяемого состояния, разделив проверки и саму логику умножения/сложения:
def squared(a: Double): Option[Double] =
Option.when(a > 0)(math.pow(a, 2.0))
def pythMatch(a: Double, b: Double): Option[Double] =
squared(a) match
case None => None
case Some(a) =>
squared(b) match
case None => None
case Some(b) => Some(math.sqrt(a + b))
end pythMatch
Из примера видно, что увеличилась вложенность кода, появился дрифт вправо — это не здорово. Что можно с этим сделать?
Первое: вспоминаем, что существует метод [`Option.flatMap`]. С его помощью можно переписать метод первый раз:
def pythFlatMap(a: Double, b: Double): Option[Double] =
squared(a).flatMap { a =>
squared(b).flatMap { b =>
Some(math.sqrt(a + b))
}
}
Второе: вспоминаем такой механизм в Scala, как [for-comprehensions]:
def pythFor(a: Double, b: Double): Option[Double] = for
a <- squared(a)
b <- squared(b)
yield math.sqrt(a + b)
Получилось существенно короче! В целом даже похоже на «императивную» версию. Неплохо!
Однако мы можем пойти ещё дальше: для компилятора Scala [for-comprehensions] работают для любого типа, который определяет методы `.map, .flatMap, .withFilter` определённой сигнатуры.
Данные методы совпадают с характерными методами таких представителей теории категорий (CTFP), как Функтор, Аппликативный функтор и Монада:
trait Functor[F[_]]:
extension [A](fa: F[A])
def map[B](f: A => B): F[B]
trait Applicative[F[_]] extends Functor[F]:
def pure[A](a: A): F[A]
trait Monad[F[_]] extends Applicative[F]:
extension [A](fa: F[A])
def flatMap[B](f: A => F[B]): F[B]
def map[B](f: A => B): F[B] =
flatMap(a => pure(f(a)))
С использованием этих типов (возьмём реализации из cats, чтобы не изобретать всё самим) мы можем ещё раз переписать наш метод:
def squared[F[_]: Applicative](a: Double): F[Double] =
math.pow(a, 2.0).pure[F]
def pythMonad[F[_]: Monad](a: Double, b: Double): F[Double] = for
a2 <- squared[F](a)
b2 <- squared[F](b)
yield math.sqrt(a2 + b2)
Интересным эффектом (pun intended) наших действий является то, что теперь мы можем использовать метод `pythMonad` для любого типа данных `F[_]`, для которых определён экземпляр `Monad[F]`.
Примерами таких типов могут быть `Id[A], Option[A], List[A], Set[A]` и другие: `pythMonad[Option](3.0, 4.0) → Some (5.0)`
Более того, данная иерархия также позволяет нам «ограничивать» наши возможности по композиции операций над типам, «расширяя» тем самым список типов, с которыми возможны те или иные действия.
Так, мы можем заменить использование `.flatMap` в одном из наших первых примеров на использование `match`:
def pythTuple(a: Double, b: Double): Option[Double] =
(squared(a), squared(b)) match
case (Some(a), Some(b)) => Some(math.sqrt(a + b))
case _ => None
Аппликативный функтор (Applicative) предоставляет для всех своих типов операции `.mapN` над кортежем любого порядка:
def pythApplicative[F[_]: Applicative](
a: Double, b: Double
): F[Double] =
(squared[F](a), squared[F](b)).mapN { (a2, b2) =>
math.sqrt(a2 + b2)
}
На основе вышесказанного дадим своё простое определение эффекта:
Эффект или контекст выполнения — это часть логики программы, обладающая средствами композиции с другими такими же частями.
Программирование при помощи передачи продолжений
Давайте рассмотрим ещё одну задачу: определение существования чётного числа в последовательности. Что важно: нужно уметь работать и с потенциально бесконечными последовательностями чисел.
Первой снова будет выступать императивная версия. Она простая, понятная, рабочая!
def hasEvenImperative(is: Iterable[Int]): Boolean =
val iter = is.iterator
while iter.hasNext do if iter.next % 2 == 0 then return true
false
Дальше попробуем сделать всё то же самое через свёртку слева, «функционально»:
def hasEvenFL(is: Iterable[Int]): Boolean =
is.foldLeft(false)((acc, i) => acc || (i % 2 == 0))
Вроде тоже получилось недлинно и несложно, но для бесконечной последовательности код, увы, не сработает.
Свёртка справа из стандартной библиотеки Scala нам тоже не поможет, так как из-за особенностей записи замыканий в Scala (параметры замыканий всегда определяются строго) является строгой функцией.
Наконец, попробуем переписать в виде хвостовой рекурсии:
def hasEvenTailRec(is: Iterable[Int]): Boolean =
@tailrec def loop(rest: Iterable[Int]): Boolean =
if rest.isEmpty then false
else if rest.head % 2 == 0 then true
else loop(rest.tail)
loop(is)
end hasEvenTailRec
Ура! Однако с хвостовой рекурсией есть проблемы:
во-первых, не всегда это так просто (хотя компилятор + `@tailrec` спасает!);
во-вторых, приходится писать под каждый случай применения в отдельности.
Есть ли у нас возможность научиться писать что-то похожее на хвостовую рекурсию, но «универсально»? Оказывается, есть!
Чтобы понять, откуда ноги растут, перепишем наш самый первый пример про гипотенузу в весьма странном виде:
def pow2[R](a: Double): (Double => R) => R =
next => next(math.pow(a, 2.0))
def sqrt[R](a: Double): (Double => R) => R =
next => next(math.sqrt(a))
def add[R](a: Double, b: Double): (Double => R) => R =
next => next(a + b)
def pyth[R](a: Double, b: Double): (Double => R) => R =
next =>
pow2(a) { a2 =>
pow2(b) { b2 =>
add(a2, b2) { anb =>
sqrt(anb)(next)
}
}
}
Здесь каждая функция не возвращает результат сразу, а принимает некоторое замыкание (которое уже потом когда-то вернёт результат) и передаёт в это замыкание результат собственных вычислений.
Записанная таким образом теорема Пифагора выглядит не так уж далеко от наших «функциональных» версий из первой части статьи. Не хватает только .flatMap
, не так ли?
Действительно, данный способ записи функций можно оформить в виде типа `Cont[R, A]`, для которого могут быть определены операции `.pure, .map, .flatMap`, превращающие его в монаду (для любого заранее фиксированного типа результата `R`).
type Cont[R, A] = (A => R) => R
object Cont:
def apply[R, A](f: (A => R) => R): Cont[R, A] = f
def pure[R, A](a: A): Cont[R, A] = apply(cb => cb(a))
end Cont
extension [R, A](cont: Cont[R, A])
def map[B](f: A => B): Cont[R, B] = Cont { cb =>
cont(f.andThen(cb))
}
def flatMap[B](f: A => Cont[R, B]): Cont[R, B] = Cont { cb =>
val run: Cont[R, B] => R = c => c(cb)
cont(f.andThen(run))
}
end extension
Доказательство монадных законов оставим в качестве домашнего задания любопытному читателю.
Итак, `Cont[R, A]` это тип данных, который позволяет представить композицию функций как композицию монад, а также обладает свойством отложенного вычисления.
При помощи этих операций мы можем ещё раз записать нашу теорему Пифагора:
def contPow2[R](a: Double): Cont[R, Double] =
cb => cb(math.pow(a, 2.0))
def contSqrt[R](a: Double): Cont[R, Double] =
cb => cb(math.sqrt(a))
def contAdd[R](a: Double, b: Double): Cont[R, Double] =
cb => cb(a + b)
def contPyth[R](a: Double, b: Double): Cont[R, Double] = for
a2 <- contPow2(a)
b2 <- contPow2(b)
anb <- contAdd(a2, b2)
c <- contSqrt(anb)
yield c
object Cont:
def callCC[R, A, B](
f: (A => Cont[R, B]) => Cont[R, A]
): Cont[R, A] = apply { cb =>
val cont = f(a => apply(_ => cb(a)))
cont(cb)
}
end Cont
Неплохо, но не только ради этого же мы все затевали?
Нет, для `Cont[R, A]` также можно определить операцию `callCC`, которая, несмотря на свой более чем странный вид, является аналогом операции `break` для императивного цикла. Ну, или `goto end`:
С использованием этой операции мы можем легко и «просто» записать наш поиск чётного числа:
scala> def hasEven[R](is: Iterable[Int]): Cont[R, Boolean] =
| Cont.callCC { (exit: Boolean => Cont[R, Boolean]) =>
| if is.isEmpty then exit(false)
| else if is.head % 2 == 0 then exit(true)
| else hasEven(is.tail)
| }
def hasEven[R](is: Iterable[Int]): Cont[R, Boolean]
scala> hasEven(List(1, 3, 5, 20))(identity)
val res1: Boolean = true
scala> hasEven(LazyList.continually(Random.nextInt(1024)))(identity)
val res2: Boolean = true
Работает, и даже не то чтобы очень сложно выглядит…, но можно сделать и ещё проще!
Кроме того, приведённая выше реализация Cont пусть и ленива, но не стекобезопасна. Если бы наш генератор псевдослучайных чисел отказался генерировать true достаточно быстро, мы бы получили StackOverflowException — нехорошо.
`Eval` — возможно, самый простой трамплин
Давайте попробуем определить `Eval[A]` — тип данных для выполнения отложенных вычислений, который к тому же (в отличие от `cats.data.Eval`, чтобы было интересней), поддерживает канал передачи ошибок:
sealed trait Eval[A]:
def result: Either[Throwable, A]
def value: A = result match
case Right(value) => value
case Left(err) => throw err
def map[B](f: A => B): Eval[B] = flatMap(a => later(f(a)))
def flatMap[B](f: A => Eval[B]): Eval[B] = ???
end Eval
object Eval:
def now[A](a: A): Eval[A] = Now(a)
def later[A](thunk: => A): Eval[A] = Later(() => thunk)
def fail[A](err: Throwable): Eval[A] = Failure(err)
end Eval
Самыми интересными в реализации являются:
1. Определение подтипа `Eval.FlatMap`, который нужен для хранения в памяти операции монадического связывания:
abstract class FlatMap[A] extends Eval[A]:
type Head
val head: () => Eval[Head]
val tail: Either[Throwable, Head] => Eval[A]
def result: Either[Throwable, A] = evaluate(this)
end FlatMap
Фактически, данный подтип разделяет каждую операцию `flatMap[Head](f: Head => Eval[A]): Eval[A]` на два этапа: а) отложенное вычисление аргумента функции `f` и б) само преобразование аргумента в тип результата. При этом данная конструкция позволяет нам записывать последовательность операций вида `Eval.now (1).flatMap (i => Eval (i.toString)).flatMap (s => s»$s + $s = 2 * $s»)` в виде дерева объектов.
sealed trait Eval[A]:
self =>
…
def flatMap[B](f: A => Eval[B]): Eval[B] = self match
case fm: FlatMap[A] =>
new FlatMap[B]:
type Head = fm.Head
val head = fm.head
val tail = hd =>
new FlatMap[B]:
type Head = A
val head = () => fm.tail(hd)
val tail = _.fold(fail, f)
case _ =>
new FlatMap[B]:
type Head = A
val head = () => self
val tail = _.fold(fail, f)
end flatMap
…
end Eval
2. Реализация самой операции связывания при помощи этого подтипа.
sealed trait Eval[A]:
self =>
…
def flatMap[B](f: A => Eval[B]): Eval[B] = self match
case fm: FlatMap[A] =>
new FlatMap[B]:
type Head = fm.Head
val head = fm.head
val tail = hd =>
new FlatMap[B]:
type Head = A
val head = () => fm.tail(hd)
val tail = _.fold(fail, f)
case _ =>
new FlatMap[B]:
type Head = A
val head = () => self
val tail = _.fold(fail, f)
end flatMap
…
end Eval
Данная реализация как раз и обрабатывает ситуацию, когда происходит связывание нескольких (потенциально — бесконечно много) операций `.map, .flatMap` и аналогичных.
3. Реализация получения значения вычисления для подтипа `Eval.FlatMap`:
object Eval:
…
sealed trait Stack[A, B]
final class One[A, B](val f: A => B) extends Stack[A, B]
final class Many[A, B, C](
val head: Either[Throwable, A] => Eval[B],
val tail: Stack[B, C]
) extends Stack[A, C]
def evaluate[A](eval: Eval[A]): Either[Throwable, A] =
@tailrec def loop[A1](eval: Eval[A1], stack: Stack[A1, A]): Either[Throwable, A] = eval match
case f: Failure[A1] => Left(f.err)
case fm: FlatMap[A1] =>
fm.head() match
case f: Failure[fm.Head] => Left(f.err)
case fm1: FlatMap[fm.Head] => loop(fm1.head(), Many(fm1.tail, Many(fm.tail, stack)))
case inner => loop(fm.tail(inner.result), stack)
case _ =>
stack match
case o: One[A1, A] => eval.result.map(o.f)
case m: Many[A1, b, A] => loop(m.head(eval.result), m.tail)
loop(eval, One(identity))
end evaluate
end Eval
Данный метод является самым алгоритмически сложным: он разбирает древовидную структуру подтипов конкретного экземпляра `Eval[A]` и превращает её в стек операций, который растёт, если очередная обрабатываемая операция — это `FlatMap[_]`, и уменьшается, если обрабатываемая операция — это один из листьев (`Now[_], Later[_], Failure[_]`).
Таким образом, даже в случае рекурсивных методов, которые возвращают потенциально бесконечное дерево `Eval[_]`, в каждый момент времени размер «стека» наших операций ограничен только количеством оперативной памяти нашей виртуальный машины.
Какие же возможности даёт нам этот подтип?
Во-первых, мы можем переписать нашу реализацию определения наличия чётного числа ещё проще — при помощи `Eval[A]` мы наконец-то можем определить свою версию свёртки справа, поддерживающую ленивые вычисления и преждевременное их завершение:
def foldRight[A, B](ia: Iterable[A], zero: Eval[B])(
f: (A, Eval[B]) => Eval[B]
): Eval[B] =
def loop(ia: Iterable[A]): Eval[B] =
Eval.later(ia.isEmpty).flatMap {
case true => zero
case false => f(ia.head, loop(ia.tail))
}
loop(ia)
end foldRight
С её использованием, реализация поиска чётного числа становится совсем тривиальной:
scala> def hasEvenFold(is: Iterable[Int]): Eval[Boolean] =
| foldRight(is, Eval.now(false)) { (i, acc) =>
| if i % 2 == 0 then Eval.now(true)
| else acc
| }
def hasEvenFold(is: Iterable[Int]): Eval[Boolean]
scala> hasEvenFold(LazyList.continually(Random.nextInt(1024)))
val res1: Eval[Boolean] = Eval.FlatMap(.., .. => ..)
scala> res1.result
val res2: Boolean = true
Во-вторых, нам становится доступной стекобезопасная взаимная рекурсия. Компилятор Scala отлично оптимизирует хвостовую рекурсию одной функции, но с двумя и больше справиться, увы, не способен. Eval
нам в помощь!
def isEven(i: Int): Eval[Boolean] =
Eval.later(i == 0).flatMap {
case true => Eval.now(true)
case false => isOdd(i - 1)
}
def isOdd(i: Int): Eval[Boolean] =
Eval.later(i == 0).flatMap {
case true => Eval.now(false)
case false => isEven(i - 1)
}
val odd = isOdd(100501).result
val even = isEven(100501).result
Раз нам доступна взаимная рекурсия, нам оказываются внезапно доступны для реализации всякие вещи типа реализации рекурсивных грамматик, таких как арифметические калькуляторы:
expression | terms | term |
+ terms | term + terms | factor * term |
— terms | term — terms | factor / term |
terms | term | factor |
factor | primary |
primary | (expression) |
primary ^ factor | number |
Данная грамматика определяет вполне функциональный калькулятор с приоритетом операций, скобками и возведением в степень. Конечно, мы могли бы реализовать калькулятор и более широко известными методами, например, через стек операндов и параметров. Однако получившийся код выглядел бы более сложно, фактически вся реализации состоит из записи формальной грамматики в виде операций `Eval`:
object Calc:
final case class Parsed(result: Int, leftovers: List[String])
def apply(s: String): Eval[Int] =
tokenize(s).flatMap(expression).flatMap {
case Parsed(result, Nil) => Eval.now(result)
case Parsed(result, leftovers) =>
Eval.fail(
new IllegalArgumentException(
s"Leftover tokens for ‘$result’:" ++
s" [${leftovers.mkString(" ")}]"
)
)
}
def tokenize(s: String): Eval[List[String]] =
Eval.later(s.split(" ").toList)
def expression(tokens: List[String]): Eval[Parsed] =
tokens match
case "+" :: tail =>
terms(tail)
case "-" :: tail =>
terms(tail).map(p => p.copy(result = -p.result))
case _ =>
terms(tokens)
def terms(tokens: List[String]): Eval[Parsed] =
term(tokens).flatMap {
case Parsed(r, "+" :: tail) =>
terms(tail).map(p => p.copy(result = r + p.result))
case Parsed(r, "-" :: tail) =>
terms(tail).map(p => p.copy(result = r - p.result))
case p => Eval.now(p)
}
def term(tokens: List[String]): Eval[Parsed] = ???
def factor(tokens: List[String]): Eval[Parsed] = ???
def primary(tokens: List[String]): Eval[Parsed] = tokens match
case "(" :: tail => expression(tail).flatMap {
case Parsed(r, ")" :: tail) => Eval.now(Parsed(r, tail))
case _ => Eval.fail(new IllegalArgumentException(
s"Unmatched braces in: [${tokens.mkString(" ")}]"
))
}
case head :: tail => Eval.later(Parsed(head.toInt, tail))
case _ => Eval.fail(new IllegalArgumentException(s"Unable to parse primary from: [${tokens.mkString(" ")}]"))
end Calc
scala> Calc("( 1 + 3 ) * 4 - ( 6 / ( 3 ^ 0 + 1 ) )")
val res8: Eval[Int] = Eval.FlatMap(.., .. => ..)
scala> res8.result
val res9: Either[Throwable, Int] = Right(13)
scala> Calc("( 1 + 3 ) * 4 ( 6 / ( 3 ^ 0 + 1 ) )")
val res10: Eval[Int] = Eval.FlatMap(.., .. => ..)
scala> res10.result
val res11: Either[Throwable, Int] = Left(java.lang.IllegalArgumentException: Leftover tokens for '16': [( 6 / ( 3 ^ 0 + 1 ) )])
Более того, расширение грамматики новыми операторами или поддержкой переменных так же не составляет труда.
Итак, «трамплин» — это методика, позволяющая представить потенциально бесконечную последовательность вызова функций (в т. ч. рекурсивную) в виде ленивой и стекобезопасной монадической композиции этих самых функций.
Описанные тут подходы лежат в основе как монады `Free`, так и IO-подобных монад — `CEIO` (cats.effect.IO), ` ZIO` и многих других.
Ссылки
Free: https://typelevel.org/cats/datatypes/freemonad.html
CEIO: https://github.com/typelevel/…
ZIO: https://github.com/zio/…
Слайды: https://slides.com/seigert/eat
Код и скрипты: https://github.com/seigert/eat
PS: В команде «скалистов» компании «Криптонит» прямо сейчас открыты две вакансии:
Scala Team Lead — https://career.kryptonite.ru/vacancies/scala-team-lead-2/
Scala Developer / Senior Scala Developer — https://career.kryptonite.ru/vacancies/scala-developer/
Присоединяйтесь!