Зачем в Scala трамплины и как их использовать

36b8f7ea7eb485bd60bac4b6bd74b4ec.png

В этой статье директор департамента разработки российской компании «Криптонит» и «скалист» Алексей Шуксто рассказывает о специфической технике функционального программирования, которая называется «трамплин» (trampoline).

Если кратко, то «трамплин» — это постоянный вызов в цикле новых частей вычисления вплоть до получения конечного результата. Трамплин можно рассматривать как шаблон проектирования, который позволяет избежать переполнения стека при рекурсивных вызовах функций.

Достигается это следующим образом: когда функция вызывает саму себя, то вместо этого вызова управление передаётся другой функции — трамплину. Эта функция-трамплин вызывает исходную функцию с нужными параметрами и, если нужно, передаёт управление другой функции-трамплину. Таким образом, при рекурсивных вызовах функций никакая информация не сохраняется на стеке, а управление всегда передаётся между функциями-трамплинами.

Чтобы вникнуть в детали, поясним ещё несколько моментов:

  1. эффекты или контексты вычисления;

  2. программирование при помощи передачи продолжений;

  3. трамплин на основе эффекта Eval.

Эффекты или `Effect[F[_]]`

В качестве примера давайте посмотрим на разные способы решить простую задачу: найти длину гипотенузы по длине двух катетов.

Для начала рассмотрим «классический» вариант, написанный в императивном стиле.

def pyth(a: Double, b: Double): Double =
  require(a > 0)
  require(b > 0)

  var result = 0.0
  result += math.pow(a, 2.0)
  result += math.pow(b, 2.0)

  math.sqrt(result)

Код простой, логика ясна, результат понятен.

Попробуем переписать то же самое, но «функционально» — без изменяемого состояния, разделив проверки и саму логику умножения/сложения:

def squared(a: Double): Option[Double] =
  Option.when(a > 0)(math.pow(a, 2.0))

def pythMatch(a: Double, b: Double): Option[Double] =
  squared(a) match
    case None => None
    case Some(a) =>
      squared(b) match
        case None    => None
        case Some(b) => Some(math.sqrt(a + b))
end pythMatch

Из примера видно, что увеличилась вложенность кода, появился дрифт вправо — это не здорово. Что можно с этим сделать?

Первое: вспоминаем, что существует метод [`Option.flatMap`]. С его помощью можно переписать метод первый раз:

def pythFlatMap(a: Double, b: Double): Option[Double] =
  squared(a).flatMap { a =>
    squared(b).flatMap { b =>
      Some(math.sqrt(a + b))
    }
  }

Второе: вспоминаем такой механизм в Scala, как [for-comprehensions]:

def pythFor(a: Double, b: Double): Option[Double] = for
  a <- squared(a)
  b <- squared(b)
yield math.sqrt(a + b)

Получилось существенно короче! В целом даже похоже на «императивную» версию. Неплохо!

Однако мы можем пойти ещё дальше: для компилятора Scala [for-comprehensions] работают для любого типа, который определяет методы `.map, .flatMap, .withFilter` определённой сигнатуры.

Данные методы совпадают с характерными методами таких представителей теории категорий (CTFP), как Функтор, Аппликативный функтор и Монада:

trait Functor[F[_]]:
  extension [A](fa: F[A])
    def map[B](f: A => B): F[B]

trait Applicative[F[_]] extends Functor[F]:
  def pure[A](a: A): F[A]

trait Monad[F[_]] extends Applicative[F]:
  extension [A](fa: F[A])
    def flatMap[B](f: A => F[B]): F[B]

    def map[B](f: A => B): F[B] = 
      flatMap(a => pure(f(a)))

С использованием этих типов (возьмём реализации из cats, чтобы не изобретать всё самим) мы можем ещё раз переписать наш метод:

def squared[F[_]: Applicative](a: Double): F[Double] =
  math.pow(a, 2.0).pure[F]

def pythMonad[F[_]: Monad](a: Double, b: Double): F[Double] = for
  a2 <- squared[F](a)
  b2 <- squared[F](b)
yield math.sqrt(a2 + b2)

Интересным эффектом (pun intended) наших действий является то, что теперь мы можем использовать метод `pythMonad` для любого типа данных `F[_]`, для которых определён экземпляр `Monad[F]`.

Примерами таких типов могут быть `Id[A], Option[A], List[A], Set[A]` и другие: `pythMonad[Option](3.0, 4.0) → Some (5.0)`

Более того, данная иерархия также позволяет нам «ограничивать» наши возможности по композиции операций над типам, «расширяя» тем самым список типов, с которыми возможны те или иные действия.

Так, мы можем заменить использование `.flatMap` в одном из наших первых примеров на использование `match`:

def pythTuple(a: Double, b: Double): Option[Double] =
  (squared(a), squared(b)) match
    case (Some(a), Some(b)) => Some(math.sqrt(a + b))
    case _                  => None

Аппликативный функтор (Applicative) предоставляет для всех своих типов операции `.mapN` над кортежем любого порядка:

def pythApplicative[F[_]: Applicative](
  a: Double, b: Double
): F[Double] = 
  (squared[F](a), squared[F](b)).mapN { (a2, b2) =>
    math.sqrt(a2 + b2)
  }

На основе вышесказанного дадим своё простое определение эффекта:

Эффект или контекст выполнения — это часть логики программы, обладающая средствами композиции с другими такими же частями.

Программирование при помощи передачи продолжений

Давайте рассмотрим ещё одну задачу: определение существования чётного числа в последовательности. Что важно: нужно уметь работать и с потенциально бесконечными последовательностями чисел.

Первой снова будет выступать императивная версия. Она простая, понятная, рабочая!

def hasEvenImperative(is: Iterable[Int]): Boolean =
  val iter = is.iterator
  while iter.hasNext do if iter.next % 2 == 0 then return true
  false

Дальше попробуем сделать всё то же самое через свёртку слева, «функционально»:

def hasEvenFL(is: Iterable[Int]): Boolean =
  is.foldLeft(false)((acc, i) => acc || (i % 2 == 0))

Вроде тоже получилось недлинно и несложно, но для бесконечной последовательности код, увы, не сработает.

Свёртка справа из стандартной библиотеки Scala нам тоже не поможет, так как из-за особенностей записи замыканий в Scala (параметры замыканий всегда определяются строго) является строгой функцией.

Наконец, попробуем переписать в виде хвостовой рекурсии:

def hasEvenTailRec(is: Iterable[Int]): Boolean =
  @tailrec def loop(rest: Iterable[Int]): Boolean =
    if rest.isEmpty then false
    else if rest.head % 2 == 0 then true
    else loop(rest.tail)

  loop(is)
end hasEvenTailRec

Ура! Однако с хвостовой рекурсией есть проблемы:

  • во-первых, не всегда это так просто (хотя компилятор + `@tailrec` спасает!);

  • во-вторых, приходится писать под каждый случай применения в отдельности.

Есть ли у нас возможность научиться писать что-то похожее на хвостовую рекурсию, но «универсально»? Оказывается, есть!

Чтобы понять, откуда ноги растут, перепишем наш самый первый пример про гипотенузу в весьма странном виде:

def pow2[R](a: Double): (Double => R) => R =
  next => next(math.pow(a, 2.0))

def sqrt[R](a: Double): (Double => R) => R =
  next => next(math.sqrt(a))

def add[R](a: Double, b: Double): (Double => R) => R =
  next => next(a + b)

def pyth[R](a: Double, b: Double): (Double => R) => R = 
  next =>
    pow2(a) { a2 =>
      pow2(b) { b2 =>
        add(a2, b2) { anb =>
          sqrt(anb)(next)
        }
      }
    }

Здесь каждая функция не возвращает результат сразу, а принимает некоторое замыкание (которое уже потом когда-то вернёт  результат) и передаёт в это замыкание результат собственных вычислений.

Записанная таким образом теорема Пифагора выглядит не так уж далеко от наших «функциональных» версий из первой части статьи. Не хватает только .flatMap, не так ли?

Действительно, данный способ записи функций можно оформить в виде типа `Cont[R, A]`, для которого могут быть определены операции `.pure, .map, .flatMap`, превращающие его в монаду (для любого заранее фиксированного типа результата `R`).

type Cont[R, A] = (A => R) => R
object Cont:
  def apply[R, A](f: (A => R) => R): Cont[R, A] = f
  def pure[R, A](a: A): Cont[R, A]              = apply(cb => cb(a))
end Cont

extension [R, A](cont: Cont[R, A])
  def map[B](f: A => B): Cont[R, B] = Cont { cb =>
    cont(f.andThen(cb))
  }

  def flatMap[B](f: A => Cont[R, B]): Cont[R, B] = Cont { cb =>
    val run: Cont[R, B] => R = c => c(cb)
    cont(f.andThen(run))
  }
end extension

Доказательство монадных законов оставим в качестве домашнего задания любопытному читателю.

Итак, `Cont[R, A]` это тип данных, который позволяет представить композицию функций как композицию монад, а также обладает свойством отложенного вычисления.

При помощи этих операций мы можем ещё раз записать нашу теорему Пифагора:

def contPow2[R](a: Double): Cont[R, Double] = 
  cb => cb(math.pow(a, 2.0))

def contSqrt[R](a: Double): Cont[R, Double] = 
  cb => cb(math.sqrt(a))

def contAdd[R](a: Double, b: Double): Cont[R, Double] =
  cb => cb(a + b)

def contPyth[R](a: Double, b: Double): Cont[R, Double] = for
  a2  <- contPow2(a)
  b2  <- contPow2(b)
  anb <- contAdd(a2, b2)
  c   <- contSqrt(anb)
yield c
object Cont:
  def callCC[R, A, B](
    f: (A => Cont[R, B]) => Cont[R, A]
  ): Cont[R, A] = apply { cb =>
    val cont = f(a => apply(_ => cb(a)))
    cont(cb)
  }
end Cont

Неплохо, но не только ради этого же мы все затевали?

Нет, для `Cont[R, A]` также можно определить операцию `callCC`, которая, несмотря на свой более чем странный вид, является аналогом операции `break` для императивного цикла. Ну, или `goto end`:

С использованием этой операции мы можем легко и «просто» записать наш поиск чётного числа:

scala> def hasEven[R](is: Iterable[Int]): Cont[R, Boolean] =
     |   Cont.callCC { (exit: Boolean => Cont[R, Boolean]) =>
     |     if is.isEmpty then exit(false)
     |     else if is.head % 2 == 0 then exit(true)
     |     else hasEven(is.tail)
     |   }
def hasEven[R](is: Iterable[Int]): Cont[R, Boolean]

scala> hasEven(List(1, 3, 5, 20))(identity)
val res1: Boolean = true

scala> hasEven(LazyList.continually(Random.nextInt(1024)))(identity)
val res2: Boolean = true

Работает, и даже не то чтобы очень сложно выглядит…, но можно сделать и ещё проще!

Кроме того, приведённая выше реализация Cont пусть и ленива, но не стекобезопасна.  Если бы наш генератор псевдослучайных чисел отказался генерировать true достаточно быстро, мы бы получили StackOverflowException — нехорошо.

`Eval` — возможно, самый простой трамплин

Давайте попробуем определить `Eval[A]` — тип данных для выполнения отложенных вычислений, который к тому же (в отличие от `cats.data.Eval`, чтобы было интересней), поддерживает канал передачи ошибок:

sealed trait Eval[A]:
  def result: Either[Throwable, A]

  def value: A = result match
    case Right(value) => value
    case Left(err)    => throw err

  def map[B](f: A => B): Eval[B] = flatMap(a => later(f(a)))
  def flatMap[B](f: A => Eval[B]): Eval[B] = ???
end Eval

object Eval:
  def now[A](a: A): Eval[A]            = Now(a)
  def later[A](thunk: => A): Eval[A]   = Later(() => thunk)
  def fail[A](err: Throwable): Eval[A] = Failure(err)
end Eval

Самыми интересными в реализации являются:

1. Определение подтипа `Eval.FlatMap`, который нужен для хранения в памяти операции монадического связывания:

abstract class FlatMap[A] extends Eval[A]:
  type Head

  val head: () => Eval[Head]
  val tail: Either[Throwable, Head] => Eval[A]

  def result: Either[Throwable, A] = evaluate(this)
end FlatMap

Фактически, данный подтип разделяет каждую операцию `flatMap[Head](f: Head => Eval[A]): Eval[A]` на два этапа: а) отложенное вычисление аргумента функции `f` и б) само преобразование аргумента в тип результата. При этом данная конструкция позволяет нам записывать последовательность операций вида `Eval.now (1).flatMap (i => Eval (i.toString)).flatMap (s => s»$s + $s = 2 * $s»)` в виде дерева объектов.

sealed trait Eval[A]:
  self =>
  …
  def flatMap[B](f: A => Eval[B]): Eval[B] = self match
    case fm: FlatMap[A] =>
      new FlatMap[B]:
        type Head = fm.Head

        val head = fm.head
        val tail = hd =>
          new FlatMap[B]:
            type Head = A
            val head = () => fm.tail(hd)
            val tail = _.fold(fail, f)

    case _ =>
      new FlatMap[B]:
        type Head = A
        val head = () => self
        val tail = _.fold(fail, f)
  end flatMap
  …
end Eval

2. Реализация самой операции связывания при помощи этого подтипа.

sealed trait Eval[A]:
  self =>
  …
  def flatMap[B](f: A => Eval[B]): Eval[B] = self match
    case fm: FlatMap[A] =>
      new FlatMap[B]:
        type Head = fm.Head

        val head = fm.head
        val tail = hd =>
          new FlatMap[B]:
            type Head = A
            val head = () => fm.tail(hd)
            val tail = _.fold(fail, f)

    case _ =>
      new FlatMap[B]:
        type Head = A
        val head = () => self
        val tail = _.fold(fail, f)
  end flatMap
  …
end Eval

Данная реализация как раз и обрабатывает ситуацию, когда происходит связывание нескольких (потенциально — бесконечно много) операций `.map, .flatMap` и аналогичных.

3. Реализация получения значения вычисления для подтипа `Eval.FlatMap`:

object Eval:
  …
  sealed trait Stack[A, B]
  final class One[A, B](val f: A => B) extends Stack[A, B]
  final class Many[A, B, C](
      val head: Either[Throwable, A] => Eval[B],
      val tail: Stack[B, C]
  ) extends Stack[A, C]

  def evaluate[A](eval: Eval[A]): Either[Throwable, A] =
    @tailrec def loop[A1](eval: Eval[A1], stack: Stack[A1, A]): Either[Throwable, A] = eval match
      case f: Failure[A1]  => Left(f.err)
      case fm: FlatMap[A1] =>
        fm.head() match
          case f: Failure[fm.Head]   => Left(f.err)
          case fm1: FlatMap[fm.Head] => loop(fm1.head(), Many(fm1.tail, Many(fm.tail, stack)))
          case inner                 => loop(fm.tail(inner.result), stack)

      case _ =>
        stack match
          case o: One[A1, A]     => eval.result.map(o.f)
          case m: Many[A1, b, A] => loop(m.head(eval.result), m.tail)

    loop(eval, One(identity))
  end evaluate
end Eval

Данный метод является самым алгоритмически сложным: он разбирает древовидную структуру подтипов конкретного экземпляра `Eval[A]` и превращает её в стек операций, который растёт, если очередная обрабатываемая операция — это `FlatMap[_]`, и уменьшается, если обрабатываемая операция — это один из листьев (`Now[_], Later[_], Failure[_]`).

Таким образом, даже в случае рекурсивных методов, которые возвращают потенциально бесконечное дерево `Eval[_]`, в каждый момент времени размер «стека» наших операций ограничен только количеством оперативной памяти нашей виртуальный машины.

Какие же возможности даёт нам этот подтип?

Во-первых, мы можем переписать нашу реализацию определения наличия чётного числа ещё проще — при помощи `Eval[A]` мы наконец-то можем определить свою версию свёртки справа, поддерживающую ленивые вычисления и преждевременное их завершение:

def foldRight[A, B](ia: Iterable[A], zero: Eval[B])(
    f: (A, Eval[B]) => Eval[B]
): Eval[B] =
  def loop(ia: Iterable[A]): Eval[B] =
    Eval.later(ia.isEmpty).flatMap {
      case true  => zero
      case false => f(ia.head, loop(ia.tail))
    }

  loop(ia)
end foldRight

С её использованием, реализация поиска чётного числа становится совсем тривиальной:

scala> def hasEvenFold(is: Iterable[Int]): Eval[Boolean] =
     |   foldRight(is, Eval.now(false)) { (i, acc) =>
     |     if i % 2 == 0 then Eval.now(true)
     |     else acc
     |   }
def hasEvenFold(is: Iterable[Int]): Eval[Boolean]

scala> hasEvenFold(LazyList.continually(Random.nextInt(1024)))
val res1: Eval[Boolean] = Eval.FlatMap(.., .. => ..)

scala> res1.result
val res2: Boolean = true

Во-вторых, нам становится доступной стекобезопасная взаимная рекурсия. Компилятор Scala отлично оптимизирует хвостовую рекурсию одной функции, но с двумя и больше справиться, увы, не способен. Eval нам в помощь!

def isEven(i: Int): Eval[Boolean] =
  Eval.later(i == 0).flatMap {
    case true  => Eval.now(true)
    case false => isOdd(i - 1)
  }

def isOdd(i: Int): Eval[Boolean] =
  Eval.later(i == 0).flatMap {
    case true  => Eval.now(false)
    case false => isEven(i - 1)
  }

val odd  = isOdd(100501).result
val even = isEven(100501).result

Раз нам доступна взаимная рекурсия, нам оказываются внезапно доступны для реализации всякие вещи типа реализации рекурсивных грамматик, таких как арифметические калькуляторы:

expression

terms

term

+ terms

term + terms

factor * term

— terms

term — terms

factor / term

terms

term

factor

factor

primary

primary

(expression)

primary ^ factor

number

Данная грамматика определяет вполне функциональный калькулятор с приоритетом операций, скобками и возведением в степень. Конечно, мы могли бы реализовать калькулятор и более широко известными методами, например, через стек операндов и параметров. Однако получившийся код выглядел бы более сложно, фактически вся реализации состоит из записи формальной грамматики в виде операций `Eval`:

object Calc:
  final case class Parsed(result: Int, leftovers: List[String])

  def apply(s: String): Eval[Int] =
    tokenize(s).flatMap(expression).flatMap {
      case Parsed(result, Nil) => Eval.now(result)
      case Parsed(result, leftovers) =>
          Eval.fail(
            new IllegalArgumentException(
              s"Leftover tokens for ‘$result’:" ++
              s" [${leftovers.mkString(" ")}]"
            )
          )
      }

def tokenize(s: String): Eval[List[String]] =
    Eval.later(s.split(" ").toList)

def expression(tokens: List[String]): Eval[Parsed] = 
    tokens match
      case "+" :: tail => 
        terms(tail)
      case "-" :: tail => 
        terms(tail).map(p => p.copy(result = -p.result))
      case _ =>
        terms(tokens)

def terms(tokens: List[String]): Eval[Parsed] = 
    term(tokens).flatMap {
      case Parsed(r, "+" :: tail) => 
        terms(tail).map(p => p.copy(result = r + p.result))
      case Parsed(r, "-" :: tail) => 
        terms(tail).map(p => p.copy(result = r - p.result))
      case p => Eval.now(p)
  }

def term(tokens: List[String]): Eval[Parsed] = ???
  def factor(tokens: List[String]): Eval[Parsed] = ???

def primary(tokens: List[String]): Eval[Parsed] = tokens match
    case "(" :: tail => expression(tail).flatMap {
      case Parsed(r, ")" :: tail) => Eval.now(Parsed(r, tail))
      case _ => Eval.fail(new IllegalArgumentException(
        s"Unmatched braces in: [${tokens.mkString(" ")}]"
      ))
    }
    case head :: tail => Eval.later(Parsed(head.toInt, tail))
    case _ => Eval.fail(new IllegalArgumentException(s"Unable to parse primary from: [${tokens.mkString(" ")}]"))

end Calc
scala> Calc("( 1 + 3 ) * 4 - ( 6 / ( 3 ^ 0 + 1 ) )")
val res8: Eval[Int] = Eval.FlatMap(.., .. => ..)

scala> res8.result
val res9: Either[Throwable, Int] = Right(13)

scala> Calc("( 1 + 3 ) * 4 ( 6 / ( 3 ^ 0 + 1 ) )")
val res10: Eval[Int] = Eval.FlatMap(.., .. => ..)

scala> res10.result
val res11: Either[Throwable, Int] = Left(java.lang.IllegalArgumentException: Leftover tokens for '16': [( 6 / ( 3 ^ 0 + 1 ) )])

Более того, расширение грамматики новыми операторами или поддержкой переменных так же не составляет труда.

Итак, «трамплин» — это методика, позволяющая представить потенциально бесконечную последовательность вызова функций (в т. ч. рекурсивную) в виде ленивой и стекобезопасной монадической композиции этих самых функций.

Описанные тут подходы лежат в основе как монады `Free`, так и IO-подобных монад — `CEIO` (cats.effect.IO), ` ZIO` и многих других.

Ссылки

Free: https://typelevel.org/cats/datatypes/freemonad.html

CEIO: https://github.com/typelevel/…

ZIO: https://github.com/zio/…

Слайды: https://slides.com/seigert/eat

Код и скрипты: https://github.com/seigert/eat

PS: В команде «скалистов» компании «Криптонит» прямо сейчас открыты две вакансии:

Scala Team Lead — https://career.kryptonite.ru/vacancies/scala-team-lead-2/

Scala Developer / Senior Scala Developer — https://career.kryptonite.ru/vacancies/scala-developer/

Присоединяйтесь!

© Habrahabr.ru