Scala 3 Metaprogramming: реализация списка с известным на этапе компиляции размером

5e9834d5db983f72327e34b39056150b

Введение

Одна из проблем структур данных — отсутствие информации о размере на этапе компиляции. Из-за этого мы не можем быть наверняка уверены, можно ли выполнить над коллекцией определенные операции. Например, функция List[A].head выбросит исключение, если список пустой. Эта проблема легко решается на уровне значений: проверкой List[A].size > 0.

А можно ли определить, возможно ли запустить метод коллекции, на этапе компиляции?
Да, если добавить к типу списка его размер. С его помощью мы сможем понять, какие операции допустимы, а какие — нет.

Именно это я решил попробовать сделать со связным списком.

Сразу скажу, что я не эксперт в метапрограммировании, и данная статья ориентирована на «чайников». Далее я подробно, с большим количеством кода, опишу реализацию и расскажу о проблемах, с которыми столкнулся.

Статья будет интересна всем бездельникам любителям Scala, которым хочется потрогать метапрограммирование на Scala 3, но непонятно, с чего начать. Мы рассмотрим и воспользуемся следующими возможностями языка:

А чтобы убедить вас прочитать до конца, приведу пару примеров в коде:

// создание списков
val intList: SList[Int, 3] = 1 :: 2 :: 3 :: SNil
val stringList: SList[String, 2] = "foo" :: "bar" :: SNil

// выведение типа списка
val emptyList: SList[Int, 0] = SNil
intList.refined // SCons[Int, 2]
emptyList.refined // SNil.type

// безопасные head и tail
stringList.head // foo
stringList.tail // SList(bar): SList[String, 1]

// ошибка компиляции
stringList.tail.tail.head
stringList.tail.tail.tail

// использование map/flatMap

// SList(1foo,1bar,2foo,2bar,3foo,3bar)
val combinedList: SList[String, 6] = for {
	int <- intList
	string <- stringList
	resultValue = int.toString + string
} yield resultValue
val intList: SList[Int, 3] = 1 :: 2 :: 3 :: SNil
val stringList: SList[String, 2] = "foo" :: "bar" :: SNil

// SList(1foo,1bar,2foo,2bar,3foo,3bar)
val combinedList: SList[String, 6] = for {
	int <- intList
	string <- stringList
	resultValue = int.toString + string
} yield resultValue

ADT-модель и базовые операции

Список смоделируем аналогично List[A], добавив тип-параметр, описывающий его размер:

import scala.compiletime.ops.int.*

sealed trait SList[+A, N <: Int]

case object SNil extends SList[Nothing, 0]

case class SCons[+A, N <: Int](head: A, tail: SList[A, N]) 
	extends SList[A, S[N]]

Литеральные типы N и S[N]

В Scala 3 были добавлены литеральные типы, описывающие единственное значение. Например:

import scala.compiletime.ops.int.*

val three: 3 = 3
val four: 2 + 2 = 4
val myTrue: true = true
val myFalse: 3 < 2 = false
// compile error
val error: 4 = 3

Целочисленные литеральные типы являются подтипом Int, поэтому запись N <: Int означает целочисленный литерал.
S[N] описывает число, следующее за N.

Теперь мы можем реализовать методы SList.head и SList.tail, используя информацию о размере списка:

def head(using N > 0 =:= true): A = 
	this.asInstanceOf[SCons[A, N - 1]].head
def tail(using N > 0 =:= true): SList[A, N - 1] = 
	this.asInstanceOf[SCons[A, N - 1]].tail

Добавив неявный параметр N > 0 =:= true, мы гарантируем, что методы SList.head и SList.tail получится вызвать только в случае, если булев литерал N > 0 равен (=:=) true.
Компилятор автоматически выведет неявное значение N > 0 =:= true, если N известно на этапе компиляции и является положительным.

Также добавим оператор :: для удобства:

def ::[A1 >: A](x: A1): SList[A1, S[N]] = SCons(x, this)

Далее мы можем реализовать некоторые рекурсивные операции, такие как map или оператор :+:

def :+[A1 >: A](x: A1): SList[A1, S[N]] = this match {
    case SCons(v, vs) => v :: (vs :+ x)
    case SNil => x :: SNil
}

def map[B](f: A => B): SList[B, N] = this match {
    case SCons(v, vs) => f(v) :: vs.map(f)
    case SNil => SNil
}

Аналогично List[A] можно реализовать и другие операции, итерирующиеся по списку, например foldLeft.

Ниже продемонстрирована работа реализованных функций:

// val list: SList[Int, 2] = SCons(1,SCons(2,SNil))
val list = 1 :: 2 :: SNil

// val res0: Int = 1
list.head

// res1: SList[Int, 1] = SCons(2,SNil)
list.tail

// compile error: Cannot prove that (0 : Int) > (0 : Int) =:= (true : Boolean).
list.tail.tail.head

// compile error: Cannot prove that (0 : Int) > (0 : Int) =:= (true : Boolean).
list.tail.tail.tail

// val res2: SList[String, 2] = SCons(2,SCons(4,SNil))
list.map(x => (x * 2).toString)

Challenge #1: А может ли N быть меньше нуля?

Формально — да. Попробуем ограничить размер списка SList[A, N] неотрицательными числами.
На первый взгляд кажется, что достаточно, по аналогии с head/tail, добавить к трейту неявный параметр:

sealed trait SList[+A, N <: Int](using N >= 0 =:= true) { /* */ }

case object SNil extends SList[Nothing, 0]

// compile error: Cannot prove that S[N] >= (0 : Int) =:= (true : Boolean)
case class SCons[+A, N <: Int](head: A, tail: SList[A, N]) 
	extends SList[A, S[N]]

Однако на практике компилятор Scala не может доказать, что S[N] >= 0, если N >= 0. Следующие попытки вывода также не приводят к успеху, выбрасывая ошибку java.lang.AssertionError: assertion failed while typechecking:

given [N <: Int](using N >= 0 =:= true): (S[N] >= 0 =:= true) = summon

import scala.compiletime.summonInline
inline given [N <: Int](using N >= 0 =:= true): (S[N] >= 0 =:= true) = 
	summonInline

summonInline позволяет на этапе компиляции захватывать значения неявных параметров

В конце концов я, устав бороться с компилятором, попросил помощь на /r/Scala, и, несколько модифицировав ответ одного из пользователей Reddit, пришел к описанию ограничений N >= 0 и N > 0 в виде GeqZ[N] и GtZ[N] соответственно:

import scala.compiletime.ops.int.*

sealed trait GeqZ[N <: Int]
object GeqZ {
    given[N <: Int](using N >= 0 =:= true): GeqZ[N] = new GeqZ[N] {}
}

sealed trait GtZ[N <: Int] extends GeqZ[N]
object GtZ {
    given[N <: Int](using N > 0 =:= true): GtZ[N] = new GtZ[N] {}

    // Явное доказательство, что S[N] > 0, если N >= 0
    def snNext[N <: Int](using GeqZ[N]): GtZ[S[N]] = new GtZ[S[N]] {}
}

Таким образом, если значение GeqZ[N] можно вывести, то N — неотрицательное число. Это можно сделать двумя способами:

  • Если число N — неотрицательное и известно на этапе компиляции, т.е. если компилятор может вывести N >= 0 =:= true, создаем инстанс GeqZ[N].

  • Если для числа N можно вывести GeqZ[N], то для числа S[N] можно вывести GtZ[N] через функцию GtZ.snNext. GtZ[N] означает, что N > 0.

Почему snNext не задана как given-функция?

Потому что если сделать так, компилятор для всех N > 0 игнорирует первый, «эффективный» способ вывода GtZ[N] и пользуется вторым.
В таком случае для вывода GtZ[N] компилятору понадобится вывести еще N - 1 значений (вспоминаем аксиоматику Пеано), из-за чего он радостно выбрасывает StackOverflowError при попытке, например, вывести summon[GtZ[1000]].

Модифицируем ADT-модель, используя GeqZ[N], гарантируя таким образом, что N >= 0, на этапе компиляции:

sealed trait SList[+A, N <: Int](using GeqZ[N]) {
	// head/tail можно вызвать, если N > 0, т.е. можно вывести GtZ[N]
	def head(using GtZ[N]): A = /* */
    def tail(using GtZ[N]): SList[A, N - 1] = /* */
}
case object SNil extends SList[Nothing, 0]
case class SCons[+A, N <: Int](head: A, tail: SList[A, N])(using GeqZ[N]) 
	extends SList[A, S[N]](using GtZ.snNext)

Создание списков, а также функции head/tail работают, как полагается:

val list = 1 :: 2 :: SNil
list.head // 1
list.tail // SList(2)

// compile error: No given instance of type GtZ[(0 : Int)] was found
// for parameter x$1 of method head in trait SList.
list.tail.tail.head

Если мы знаем N, то мы ведь знаем и тип списка, правда?

Звучит логично, однако мы не можем оперировать типом N как числом напрямую.

И здесь нам поможет метапрограммирование, а именно — inline и constValue[N]:

  • inline позволит нам вместо вызова функции подставить ее тело;

  • constValue[N]: N (reference) позволит нам сгенерировать число на основе литерального типа.

Размер списка можно определить следующим образом:

inline def size: N = constValue[N]

Легко и просто. Без inline сгенерировать constValue[N] не получится: это значение не живет в рантайме, поэтому его нужно сразу подставить в код.

Зная размер N, мы можем определить, с каким наследником SList мы работаем: с SCons или с SNil. Напишем функцию refined, которая, в зависимости от значения N, преобразует SList либо в SCons, либо в SNil:

transparent inline def refined: SCons[A, N - 1] | SNil.type = 
	inline if(constValue[N] == 0) SNil
	else this.asInstanceOf[SCons[A, N - 1]]

В этой простой функции мы сталкиваемся еще с парой возможностей Scala 3:

  • При использовании inline if компилятор вычисляет условие и, в зависимости от результата, подставляет при генерации кода нужную ветку if.

  • Модификатор inline-функции transparent позволяет уточнить тип результата, насколько это возможно. Указанный возвращаемый тип «прозрачной» функции — это upper bound.

  • A | B — union-тип, обозначающий значение, которое может иметь как тип A, так и тип B.

Явный каст this нужен, чтобы компилятор воспринимал значение в else-ветке как SCons[A, N - 1]. На этапе компиляции мы не знаем точный подтип this, и если вернуть просто this, возвращаемый тип будет SList[A, N].

Работа функций size и refined продемонстрирована ниже. Обратите внимание, что возвращаемый тип refined — не SCons | SNil, а более конкретный, определяемый на этапе компиляции:

// list: SList[Int, 2] = SCons(1,SCons(2,SNil))
val list = 1 :: 2 :: SNil

// val res0: Int = 2
list.size

// val res1: SCons[Int, 1] = SCons(1,SCons(2,SNil))
list.refined

// val nil: SList[Int, 0] = SNil
val nil: SList[Int, 0] = SNil

// val res2: SNil.type = SNil
nil.refined

Модификация refined

Функцию refined можно улучшить:

  • Хотелось бы задать upper bound refined как SList[A, N]: тип SCons[A, N - 1] | SNil.type не является подтипом SList[A, N], так как у этих двух подтипов разные значения N (N - 1 и 0 соответственно).

  • Текущая реализация не работает для списков SList[A, ?], размер которых неизвестен на этапе компиляции, поскольку сгенерировать constValue[N] не получится.

Решим эти проблемы, используя summonFrom и constValueOpt[N]:

transparent inline def refined: SList[A, N] = inline constValueOpt[N] match {
    case Some(n) if n == 0 =>
        summonFrom {
            case given (SNil.type <:< SList[A, N]) => SNil
        }
    case Some(_) =>
        summonFrom {
            case given (SCons[A, N - 1] <:< SList[A, N]) => 
            	this.asInstanceOf[SCons[A, N - 1]]
        }
    case None => this
}
  • constValueOpt[N] возвращает Some(n), если тип N известен на этапе компиляции, и None в противном случае. Благодаря этому мы можем при неизвестном N возвратить this, не уточняя тип.

  • summonFrom позволяет искать неявные значения в области видимости. В данном случае мы пытаемся найти инстанс MyList <:< SList[A, N], указывающий на то, что MyList — подтип SList[A, N].

    • Обладая этой информацией, компилятор не жалуется на то, что, например, SNil не является SList[A, N] при N == 0.

    • Если не использовать имплисит <:< и просто написать case Some(n) if n == 0 => SNil, то компиляция не выполнится из-за type mismatch.

  • Паттерн-матчинг можно инлайнить по аналогии с inline if, используя конструкцию inline value match.

Обновленная функция refined работает штатно:

// val list: SList[Int, 1] = SCons(1,SNil)
val list: SList[Int, 1] = 1 :: SNil

// val res0: SCons[Int, 0] = SCons(1,SNil)
list.refined

// val nil: SList[Int, 0] = SNil
val nil: SList[Int, 0] = SNil

// val res1: SNil.type = SNil
nil.refined

// val unsized: SList[Int, ?] = SCons(1,SNil)
val unsized: SList[Int, ?] = list

// val res2: SList[Int, ?] = SCons(1,SNil)
unsized.refined

Логичные (на первый взгляд), но нерабочие варианты с использованием `match`

Почему бы не реализовать refined следующим образом, с помощью паттерн-матчинга?

transparent inline def refined: SCons[A, N - 1] | SNil.type = 
	this match {
    	case cons: SCons[A, N - 1] => cons
    	case SNil => SNil
	}

Данная версия, к сожалению, не работает для непустых списков в принципе:

val list: SList[Int, 2] = 1 :: 2 :: SNil

list.refined

-- [E007] Type Mismatch Error: -------------------------------------------------
1 |list.refined
  |^^^^^^^^^^^^
  |Found:    SCons[Int, (1 : Int)] | SNil.type
  |Required: SList[Int, ? >: (2 : Int) & (0 : Int) <: (2 : Int) | (0 : Int)]

и работает с пустыми списками… весьма посредственно:

val nil: SList[Int, 0] = SNil

//val res0:
//  SList[Int, ?
//     >: scala.compiletime.ops.int.S[-1] & 0 <: scala.compiletime.ops.int.S[-1] |
//       0
//  ] = SNil

nil.refined

Вероятно, такое поведение связано с тем, что match не заинлайнен и поэтому transparent не может вывести тип возвращаемого выражения.

А если заинлайнить?

transparent inline def refined: SCons[A, N - 1] | SNil.type = 
	inline this match {
    	case cons: SCons[A, N - 1] => cons
    	case SNil => SNil
	}

Снова сталкиваемся с проблемой, но уже другого характера:

val list: SList[Int, 2] = 1 :: 2 :: SNil
list.refined

-- Error: ----------------------------------------------------------------------
 1 |list.refined
   |^^^^^^^^^^^^
   |cannot reduce inline match with
   | scrutinee:  SList_this : (SList_this : (list : SList[Int, (2 : Int)]))
   | patterns :  case cons @ _:SCons[Int, scala.compiletime.ops.int.-[N, 1.type]]
   |             case SNil

val nil: SList[Int, 0] = SNil
nil.refined
-- Error: ----------------------------------------------------------------------
 1 |nil.refined
   |^^^^^^^^^^^
   |cannot reduce inline match with
   | scrutinee:  SList_this : (SList_this : (nil : SList[Int, (0 : Int)]))
   | patterns :  case cons @ _:SCons[Int, scala.compiletime.ops.int.-[N, 1.type]]
   |             case SNil

Суть ошибки: inline match не получается сократить на этапе компиляции, так как ни одна из его веток не подходит.

Почему, ведь у list тип SCons[Int, 2 — 1]?
Потому что мы не знаем об этом на этапе компиляции; всё, что нам известно — это то, что list имеет тип SList[Int, 2]:

val list: SList[Int, 2] = ???

Challenge #2: сложение списков

Чтобы реализовать flatMap и использовать синтаксис for-comprehension для SList, нам нужно сначала научиться складывать списки. Для удобства реализуем сложение в синглтоне-компаньоне.
Попробуем сначала реализовать сложение по аналогии с обычными списками:

def add[A, N1 <: Int, N2 <: Int](
	list1: SList[A, N1], list2: SList[A, N2]
): SList[A, N1 + N2] = list1 match {
    case SCons(x, xs) => x :: add(xs, list2)
    case SNil => list2
}

И получим от компилятора ответ:

[error] -- [E007] Type Mismatch Error: SList.scala 
[error] 46 |        case SCons(x, xs) => x :: add(xs, list2)
[error]    |                             ^^^^^^^^^^^^^^^^^^^
[error]    |Found:    SList[A, compiletime.ops.int.S[Int + N2]]
[error]    |Required: SList[A, N1 + N2]
[error]    |
[error]    |where:    N1 is a type in method add which is an alias of compiletime.ops.int.S[Int]
[error]    |          N2 is a type in method add with bounds <: Int
[error]    |
[error] -- [E007] Type Mismatch Error: SList.scala
[error] 47 |        case SNil => list2
[error]    |                     ^^^^^
[error]    |     Found:    (list2 : SList[A, N2])
[error]    |     Required: SList[A, N1 + N2]
[error]    |
[error]    |     where:    N1 is a type in method add which is an alias of (0 : Int)
[error]    |               N2 is a type in method add with bounds <: Int

Компилятор снова не может самостоятельно вывести равенство литеральных типов:

  • В case SNil компилятору не нравится, что мы предоставили SList[A, N2], а не SList[A, N1 + N2]; при этом известно, что N1 = 0.

  • В case SCons компилятор вроде бы и понимает, что загадочный Int — это число N0 такое, что S[N0] = N, но всё равно использует в возвращаемом типе общий Int, а не конкретное число. И даже если бы он этого не делал, он всё равно не смог бы вывести, что S[N0 + N2] = S[N0] + N2.

Мы уже сталкивались с подобной проблемой ранее, при реализации refined. Реализуем сложение списков похожим образом, используя паттерн-матчинг по constValueOpt[N]:

inline def add[A, N1 <: Int, N2 <: Int](
	list1: SList[A, N1], list2: SList[A, N2]
): SList[A, N1 + N2] =
    inline constValueOpt[N1] match {
        case Some(n) if n == 0 =>
            summonFrom {
                case given (SList[A, N2] =:= SList[A, N1 + N2]) => 
                	list2
            }
        case Some(_) =>
            val SCons(x, xs) = list1.asInstanceOf[SCons[A, N1 - 1]]
            summonFrom {
                case given (SList[A, S[N1 - 1 + N2]] =:= SList[A, N1 + N2]) => 
                	x :: add(xs, list2)
            }
        case _ => ???
    }

Здесь мы ищем неявные инстансы MyList =:= SList[A, N1 + N2], чтобы компилятор смог вывести равенство типов, опираясь на конкретные значения N1 и N2.

Используя refined, мы можем по-другому реализовать функцию add:

inline def addUsingRefined[A, N1 <: Int, N2 <: Int](
	list1: SList[A, N1], list2: SList[A, N2]
): SList[A, N1 + N2] =
    inline list1.refined match {
        case cons: SCons[A, N1 - 1] =>
            summonFrom {
                case given (SList[A, S[N1 - 1 + N2]] =:= SList[A, N1 + N2]) => 
                	cons.head :: add(cons.tail, list2)
            }
        case _: SNil.type =>
            summonFrom {
                case given (SList[A, N2] =:= SList[A, N1 + N2]) => 
                	list2
            }
    }

Поскольку тип list1.refined известен на этапе компиляции, если известно N, мы можем проводить inline matching по нему.

Обе функции работают идентично:

val list1: SList[Int, 3] = 1 :: 2 :: 3 :: SNil
val list2: SList[Int, 3] = 4 :: 5 :: 6 :: SNil
val nil: SList[Int, 0] = SNil

//val res0: SList[Int, 6] = SList(1,2,3,4,5,6)
SList.add(list1, list2)
//val res1: SList[Int, 3] = SList(4,5,6)
SList.add(nil, list2)
//val res3: SList[Int, 3] = SList(1,2,3)
SList.add(list1, nil)

// true
SList.add(list1, list2) == SList.addUsingRefined(list1, list2)
SList.add(nil, list2) == SList.addUsingRefined(nil, list2)
SList.add(list1, nil) == SList.addUsingRefined(list1, nil)

На основе add реализуем операторы сложения списков в трейте SList. Важно здесь и далее использовать inline, потому что иначе не получится вывести constValue[N]:

inline def :::[A1 >: A, N1 <: Int](begin: SList[A1, N1]): SList[A1, N1 + N] = 
	SList.add(begin, this)
inline def ++[A1 >: A, N1 <: Int](end: SList[A1, N1]): SList[A1, N + N1] = 
	SList.add(this, end)

Альтернативная реализация сложения без метапрограммирования

Можно реализовать сложение списков, не используя метапрограммирование, стерев размеры списков и выполнив в конце рантайм-каст:

def addRuntime[A, N1 <: Int, N2 <: Int](
	list1: SList[A, N1], list2: SList[A, N2]
): SList[A, N1 + N2] = {
    def addUnsized(
    	list1: SList[A, ?], 
    	list2: SList[A, ?]
    ): SList[A, ?] = list1 match {
        case SCons(x, xs) => x :: addUnsized(xs, list2)
        case SNil => list2
    }
    
    addUnsized(list1, list2).asInstanceOf[SList[A, N1 + N2]]
}

В отличие от inline-версии, компилятору не потребуется разворачивать рекурсию на этапе компиляции, и поэтому он не упадет из-за переполнения стека на больших N1, N2.

Реализация flatten и flatMap

Используя сложение списков, паттерн-матчинг по refined и summonFrom, мы можем легко реализовать функцию flatten, практически аналогично обычным спискам:

inline def flatten[A, N0 <: Int, N <: Int](
	list: SList[SList[A, N], N0]
): SList[A, N0 * N] =
    inline list.refined match {
        case cons: SCons[SList[A, N], N0 - 1] =>
            summonFrom {
                case given (SList[A, N + (N0 - 1) * N] =:= SList[A, N0 * N]) => 
                	cons.head ::: flatten(cons.tail)
            }
        case _: SNil.type =>
            summonFrom {
                case given (SNil.type <:< SList[A, N0 * N]) => 
                	SNil
            }
    }

Как и при сложении, в каждом кейсе мы подсказываем компилятору, что возвращаемое нами значение имеет тип SList[A, N0 * N].

С помощью flatten и map функция flatMap реализуется тривиально:

inline def flatten[B, N1 <: Int](using ev: A <:< SList[B, N1]): SList[B, N * N1] = 
	SList.flatten(this.map(ev))
inline def flatMap[B, N1 <: Int](f: A => SList[B, N1]): SList[B, N * N1] = 
	this.map(f).flatten

В вызове flatten мы используем неявный параметр ev: A <:< SList[B, N1], чтобы функцию можно было вызвать только для SList с вложенностью, и используем ev.apply в map, чтобы привести значения типа A к SList[B, N1]. В итоге получаем список SList[SList[B, N1], N], который можно прокинуть во flatten в синглтоне.

Теперь мы можем использовать for-comprehension:

// val intList: SList[Int, 3] = SList(1,2,3)
val intList = 1 :: 2 :: 3 :: SNil

// val stringList: SList[String, 2] = SList(foo,bar)
val stringList = "foo" :: "bar" :: SNil

// val combinedList: SList[String, 6] = SList(1foo,1bar,2foo,2bar,3foo,3bar)
val combinedList = for {
	int <- intList
	string <- stringList
	resultValue = int.toString + string
} yield resultValue

Заключение

Mission completed! Мы успешно реализовали список SList[A, N] с:

  • размером N, известным на этапе компиляции;

  • безопасными функциями head/tail;

  • базовыми функциями ::, :+, :::, ++, map, flatten, flatMap, сохраняющими информацию о размере списка;

  • функцией refined, уточняющей тип списка на основе его размера.

И по ходу дела разобрались с литеральными типами, compile-time операциями и inline-возможностями Scala 3 на примере.

Полная реализация списка доступна на Github.

Стоит отметить, что данный список — это скорее демонстрация возможности Scala 3, малоприменимая на практике, потому что:

  • обычно нас интересует не точный размер списка, а минимально гарантированное число элементов, как в котовском NonEmptyList.

  • inline-функции сложения, flatten и flatMap реализованы рекурсивно, поэтому компилятор быстро упирается в предел рекурсии в inline-функциях (по умолчанию — 32).

Это моя первая статья на Хабре, так что прошу вас, дорогой читатель, написать в комментариях:

  • оказалась ли статья для вас полезной;

  • интересен ли вам формат а-ля «разжевываем метапрограммирование на практических примерах с кучей кода».

Любая конструктивная критика также весьма приветствуется :)

Если вам будет интересно, могу продолжить цикл статей на тему метапрограммирования для «чайников»…, а заодно и сам разберусь.

Да пребудет с вами мета-сила!

© Habrahabr.ru