«ИИ без границ»: как научить Transformer обрабатывать длинные тексты

Дисклеймер

Статья посвящёна проблеме обработки длинных входных последовательностей нейросетевыми моделями на основе архитектуры Transformer. От читателя требуется понимание общих принципов работы Transformer и устройства self-attention. Если хочется сперва разобраться в существующих моделях от самых простых до более современных, можно заглянуть в этот обзор (часть более старых работ позаимствована оттуда). Если непонятно вообще всё, то лучше начать с основ.

Работы по теме в русскоязычном сегменте есть (например, эта, эта или эта) и их имеет смысл почитать. Текст статьи основан на моей лекции для студентов ВМК МГУ.

Введение

Одна из ключевых проблем Transformer — квадратичная сложность обработки последовательности слоем self-attention (механизм внимания), O(n^2 \cdot d), где n— длина последовательности, а d— размерность каждого её элемента. Из-за этого первые модели обычно ограничивались относительно небольшими длинами контекста (256, 512), да и сейчас основные LLM общего назначения обычно имеют контекст 2048 или 4096 (либо используют какие-то оптимизации из описываемых ниже). В ряде задач (суммаризация книг, анализ документации, ведение длинных диалогов и т.д.) длина последовательности имеет критическое значение, поэтому попытки каким-то образом расширить контекст начали предприниматься почти сразу после появления первых предобученных моделей.

В этой (первой) статье приведены основные идеи большинства популярных работ по теме оптимизации self-attention, опубликованных за последние годы. Рассматриваются разные подходы: приближённое вычисление внимания, иерархическая обработка последовательности, добавление рекурентности, математические преобразования формул self-attention и вычислительные оптимизации. Все изображения взяты из первоисточников, туда же рекомендуется идти за подробностями реализаций.

Ещё одной важной частью работы с длинным контекстом является выбор эффективного способа кодирования позиционной информации, этой теме будет посвящена вторая статья. В описаниях ниже для простоты эта информация не приводится или присутствует минимально.

Обзор работ

Transformer-XL, 2019

Архитектура декодировщика, входная последовательность разбивается на сегменты, обрабатываемые рекурентно:

  • выходы self-attention для текущего сегмента кэшируются во всех блоках Transformer

  • при обработке i-го сегмента используются кэшированные выходы (i-1)-го

  • выходы двух сегментов конкатенируются, по ним считаются ключи (K) и значения (V)

  • запросы (Q) считаются только по токенам текущего сегмента, градиент тоже течёт только через эти токены

c137e6de612807ae50abbbf09582f705.png

Такая схема обработки не может работать с абсолютным позиционным кодированием, оно заменяется на относительное: позиционные эмбеддинги убираются, вместо них информация о позиции добавляется напрямую в подсчёт self-attention. Для этого используются фиксированная матрица на основе значений синусов от расстояний между всевозможными позициями токенов и два обучаемых вектора, общих для всех позиций.

Sparse Transformer, 2019

Архитектура декодировщика, для ускорения используется локальный self-attention: при обработке токена внимание обращается не на всю последовательность, а на сегмент до текущего токена (если вход двумерный, то по вертикали и горизонтали)

Для того, чтобы информация о других частях последовательности тоже могла дойти до токена, в каждом сегменте используется C выделенных токенов, на которые могут смотреть при подсчёте внимания как токены текущего сегмента, так и токены всех последующих сегментов. Например, в каждом сегменте из 128 токенов можно брать по 8 таких токенов, т.е. токены из более поздних сегментов смогут взаимодействовать с большим числом выделенных токенов. Параметры можно подобрать так, чтобы свести сложность к O(n \cdot \sqrt n).

Слева маска обычного causal внимания, справа — sparse.

Слева маска обычного causal внимания, справа — sparse.

Reformer, 2020

В self-attention наибольший вклад для запроса вносят ключи, которые являются наиболее близкими к нему, поэтому предлагается производить вычисления только для близких пар K и V.

Реализуется это с помощью kNN:

  • обучается общая для ключей и запросов весовая матрица

  • ключи собираются в LSH-индекс

  • для запроса ищется ближайший в индексе центроид

  • внимание для запроса считается только с соседями центроида слева

ed7c4049eda8e7fc3c7fb4c42c964549.png

Для экономии памяти за счёт замедления обучения используются RevNet (пересчёт выходов слоя по выходам следующего в backprop) и блочные вычисления на каждом слое. Настройкой модели можно снизить сложность до O(n \cdot \log n), обучаются модели с контекстом до 64К токенов.

Longformer, 2020

Разделение внимания на локальное и глобальное, у каждого свои весовые матрицы. Для глобального self-attention выбирается фиксированный набор токенов (могут быть важные, типа CLS или токенов вопроса для QA), на эти токены могут смотреть все остальные токены, и эти токены могут взаимодействовать со всеми остальными.

В локальном внимании токен может смотреть на N токенов вокруг себя, от слоя к слою захватывая всё более длинные зависимости. При этом N токенов могут быть как ближайшими к текущему обрабатываемому, так и идти с заданным шагом (Dilated Sliding Window, допустимо на верхних слоях). Разные головы self-attention могут иметь шаблоны внимания.

ab3ce60d11ba150ac26b634662bb5b25.png

Модели тренируются с относительным позиционным кодированием, n при обучении наращивается от 2К до 23К, тестируется на 32K.

Performer, 2020

Вместо упрощения подсчёта self-attention производится его аппроксимация. За счёт использования специальных ядер получаются новые матрицы K' и Q' с новой небольшой размерностью признаков r < d. С их помощью которых можно получить схожий результат со сложностью О(n \cdot r \cdot d):

\mathrm{softmax}(Q_{n\times d} K^{T}_{n \times d})V_{n \times d} \approx Q'_{n \times r} (K^{'T}_{n \times r} V_{n \times d})L на изображении — это n в нашей нотации. A — результат выход softmax.

L на изображении — это n в нашей нотации. A — результат выход softmax.

Модели учатся для разных задач с длиной контекста от 8K до 12K.

Linformer, 2020

Ещё один подход к приближённому вычислению self-attention. Предлагается понижать ранг матриц K и V, что позволяет уменьшить размерность длины последовательности с nдо kи понизить сложность с O(n^2 \cdot d)до O(n \cdot k \cdot d):

\texttt{softmax}(Q_{n\times d} K^{T}_{n \times d})V_{n \times d} \rightarrow \texttt{softmax}(Q_{n \times d} K^{T}_{k \times d})V_{k \times d}

Проекция векторов матриц K и V в меньшую размерность k производится с помощью обучаемых линейных слоёв. Модели обучаются для разные задачи с длиной контекста от 8K до 12K.

Big Bird, 2020

Продолжение идей self-attention с разными масками внимания. Предлагается использовать три типа внимания:

  • локальное в пределах контекста токена

  • внимание на случайный разреженный набор токенов по всей последовательности

  • глобальное внимание (как в Longformer)

b1ca003572e07fedbfd0cfd0876e6d35.png

В применении к модели RoBERTa удалось увеличить размер контекста с 512 до 4096.

LongT5, 2021

Sequence-to-sequence модель на базе Т5 с локальным и глобальным вниманием. Локальное — в пределах контекста токена (127 в каждую сторону).

Глобальное внимание:

  • вход делится на блоки по 16 токенов

  • глобальный токен блока равен сумме токенов блока

  • все токены смотрят на все глобальные токены

6ef58a1b293c0e5e99548b41322e5356.png

На предобучении n = 4K, на этапе fine-tuning в разных задачах n = 4–47K.

Top Down Transformer, 2022

В основе тоже полный Transformer, предлагается двухуровневая обработка входа для модели суммаризации.

Шаг 1:

  • локальный self-attention на последовательности (w соседей у каждого токена)

  • пулинг поверх выходов для получения m \ll nглобальных токенов

  • полный self-attention на глобальных токенах

Шаг 2:

  • на основе выходов локального self-attention получаются Q (их nштук)

  • выходы полного self-attention дают K и V (их по mштук)

  • self-attention с этими Q, K и V даёт итоговый выход кодировщика

Получается качественная модель суммаризации глав книг размером меньше 500М параметров.

Получается качественная модель суммаризации глав книг размером меньше 500М параметров.

Далее выходы кодировщика идут как обычно в cross-attention в декодировщик. Пулинг может быть усреднением или более сложным и обучаемым. Сложность (опустим dдля наглядности) понижается до O(n \cdot w + m^2 + n \cdot m).

Memorizing Transformer, 2022

В основе лежит Transformer-XL с сегментами по 512 токенов. В последнем блоке добавляется слой kNN-augmented attention:

  • для слоя заводится блок памяти на M пар векторов ключей и значений

  • текущие K и V self-attention слоя добавляются в конец памяти

  • на ключах памяти запускается kNN, свой для запроса из текущего Q

  • self-attention для запроса считается только с ближайшими соседями

Результаты обычного и kNN-augmented внимания складываются с обучаемым весом. Память своя для каждой головы self-attention, если в ней не хватает места, вытесняются наиболее старые пары.

d6f1128d3e7b209cd55c3243ad17f039.png

Ценность подхода в том, что его можно применять не только на стадии предобучения, но и при дообучении модели, за счёт чего удаётся быстро достигнуть того же качества, что и при обучении с нуля.

SLED, 2022

В работе предлагается подход по увеличению длины обрабатываемого контекста в sequence-to-sequence моделях на этапе fine-tuning:

  • вход (16К) нарезается на сегменты с контекстом с обеих сторон (256)

  • опционально к каждому сегменту добавляется префикс (промпт, вопрос)

  • каждый сегмент обрабатывается кодировщиком независимо, токены внутри него могут обращать внимание друг на друга, на токены префикса и контекста вокруг

  • из результатов удаляются лишние токены, закодированная последовательность идёт в cross-attention в декодировщик

bed66a68bd56a3167b3b94727be10b74.png

FlashAttention, 2022

2b22b7b567beacd035f26e77b1e7c80d.png

Одна из наиболее успешных работ по оптимизации работы моделей Transformer. Вместо аппроксимации полного подсчёта self-attention можно оптимизировать само его вычисление.

Слева показана иерархия типов памяти при работе с GPU: помимо известных всем DRAM и VRAM (GPU HBM) есть ещё GPU SRAM, представляющая собой аналог кэша в CPU. Это очень маленькая и быстрая память в которой производятся вычисления.

Проблема вычисления self-attention в том, что на каждом его шаге требуется множество пересылок данных между HBM и SRAM, на это тратится значительное время:

На каждом шаге вычисляемые матрицы выгружаются, чтобы потом быть подгруженными снова.

На каждом шаге вычисляемые матрицы выгружаются, чтобы потом быть подгруженными снова.

Первое нововведение состоит в по-блочном вычислении softmax (Tiling), это возможно при использовании дополнительной памяти O (N) для специальных счётчиков. В этом случае вычисления идут в двух циклах: во внешнем идёт итерация по блокам K и V, во внутреннем — по Q, что позволяет минимизировать количество операций обмена данными между кэшем и основной памятью GPU:

Слева схема нового подсчёта self-attention, справа — результат её применения.

Слева схема нового подсчёта self-attention, справа — результат её применения.

Выполнение многих операций за раз (умножение матриц, вычисление softmax, маскирование, дропаут) позволяет выполнять их одним CUDA ядром (fusing), что тоже значительно ускоряет вычисления.

Второй применяемой техникой в статье является пересчёт промежуточных значений на шаге backward (Recomputation):

  • для обычного вычисления градиентов требуются промежуточные матрицы размера O(n^2)(входы и выходы softmax)

  • их можно не хранить, а вычислять на лету, имея выходы self-attention и дополнительные статистики (считаются при tiling), в статье приводятся полные формулы обратного шага

  • получается вариация gradient checkpointing, но за счёт уменьшения переноса данных в SRAM она не только экономит память, но ещё и ускоряет подсчёт внимания

Работа стала стандартом в мире обучения и инференса LLM, поскольку внимание вычисляется точно, а не приближённо, и функции из FlashAttention легко подменяют собой оригинальную реализацию. Модификация сразу стала внедряться в разные модели, а её версия 2.0 (2023) с дополнительными вычислительными оптимизациями на GPU вошла в библиотеку transformers. Для крупных моделей FlashAttention даёт и ускорение, и экономию потребления памяти, что позволяет сильно увеличивать число параметров и длину обрабатываемого контекста в рамках тех же вычислительных ресурсов.

Unlimiformer, 2023

Ориентация на полный Transformer, можно использовать как с дообучением, так и без. Вход кодировщика разбивается на сегменты с пересечением, обрабатываемые независимо, в конце контекст удаляется. Результаты всего сохраняются и используются в cross-attention декодировщика с kNN-индесом (как в Memorizing Transformers):

87b9efacadbe031579569a58c0118314.png

Есть проблема: на каждом слое и в каждой голове нужен свой kNN-индекс, который нужно перестраивать для каждого набора пар ключ-значение. Это очень затратно по памяти и времени, поэтому в Memorizing Transformers модифицированный слой добавляется только в последнем блоке.

Предлагается следующее решение. Переписывается формула подсчёта весов внимания (h_eи h_d— выходы кодировщика и декодировщика соответственно, W_qи W_k— матрицы весов внимания):

QK^T = (h_dW_q)(h_eW_k)^T = (h_dW_q)(W^T_kh^T_e)= (h_dW_qW_k^T)h_e^T

В поисковый индекс кладутся не ключи, а выходы кодировщика, которые являются общими для всех слоёв. Запросы для kNN-индекса формируются на каждом слое-голове, для извлечённых h_eнесложно считаются значения V.

Возможны разные варианты применения Unlimiformer к модели:

  • fine-tuning модели на задачу с ограниченной длиной сэмплов + Unlimiformer на тесте

  • то же самое, но с Unlimiformer на валидации для возможности раннего останова и выбора более хорошего чекпойнта

  • то же самое, но тексты не обрезаются, а нарезаются на сегменты максимальной длины и подаются как отдельные сэмплы

  • fine-tuning с Unlimiformer с контекстом 8–16K + неограниченный контекст на тесте

  • то же самое, но при обучении вместо kNN выбираются случайные токены

  • чередование двух подходов: первый учит, второй выступает регуляризатором и не даёт модели зацикливать своё внимание на топ-k ключах

LongNet, 2023

Последовательность разбивается на сегменты длины w, обрабатываемые независимо. Внутри сегмента self-attention считается разреженно (dilated), участвуют только 1 / rтокенов с индексом i : i \mod r = 0:

c041ae3adb344605adad5c9aae8030a0.png

При этом в одной голове несколько подсчётов с разными (w, r), результаты суммируются с весами, пропорциональными их знаменателям softmax. Во всех головах внимания используется один и тот же шаблон маски, и для того, чтобы улавливать различные зависимости, предлагается в каждой голове делать свой сдвиг маски:

3d6bc1a7c0e3a06f563f957ca0f144c0.png

RMT, 2023

Архитектура кодировщика, последовательность разбивается на сегменты, обрабатываемые последовательно. В начало каждой последовательности добавляются M«векторов памяти». Выходные представления векторов памяти сегмента i идут на вход сегменту i+1 (как вектор состояния в RNN). Механизм рекурентности дополняет модель, не требуя архитектурных изменений.

cd5c56ca496e7f8afcda74e400742b16.png

Focused Transformer (LongLLaMA), 2023

Идея схожа с Memorizing Transformers. Вход нарезается на сегменты, обрабатываемые последовательно, на инференсе у модели в части слоёв есть кэши для пар K и V. При подсчёте внимания для запроса q используются и K и V из текущего сегмента, и наиболее близкие пары (по ключу к q) из кэша (можно брать и все пары, это не сильно медленнее и проще в реализации). Кэш с вытеснением, пополняется по мере обработки сегментов.

b3687ada0dd8833b23c671260aa51751.png

На обучении описанные кэши не используется, но модель нужно научить смотреть на много пар K и V и уметь выделять нужные. Для этого на обучении заводится свой кэш:

  • порядок документов в батче фиксированный (сегменты одного документа всегда в той же позиции)

  • кэш перезаписывается для каждого сегмента и содержит K и V для предыдущего сегмента этого документа

  • дополнительно в него пишутся дистракторы: ключи и значения предыдущего сегмента dслучайных документов

  • в результате при обработке сегмента документа есть возможность заглядывать в кэш, но нужно учиться вытаскивать из него именно нужные данные

e46ba00e8fdf2b67c64b4e2e1af959aa.png

Выбор количества дистракторов имеет значение, предлагается начинать с 2 и постепенно повышать dдо 64. Обученная при такой стратегии модель на слое внимания с памятью даёт большие вероятности парам K и V из целевого документа.

Grouped-Query Attention

Обобщение предложенной ранее идеи MQA (Multi-query attention), в рамках которой self-attention для разных голов рассчитывается со своими векторами Q и с общими векторами K и V. Предлагается Grouped-Query Attention, в котором ключи и значения свои в каждой группе голов внимания:

2eb4ae323b43866c4cdb9469d3cf3a84.png

Переход от обученной модели к новому подсчёту делается в два шага: сперва веса конвертируются (в рамках группы голов их матрицы для получения K и V усредняются в одну), затем обновлённая модель дообучается (до \alpha= 5%вычислений от предобучения). В результате получается при правильном подборе числа групп и итераций дообучения достичь ускорения при генерации последовательности и сохранить качество почти неизменным (для набора генеративных задач):

Слева:  качество модели в зависимости от длины дообучения для 8 групп в модели T5 XXL, справа — скорость обработки одной последовательности в зависимости от числа групп.

Слева: качество модели в зависимости от длины дообучения для 8 групп в модели T5 XXL, справа — скорость обработки одной последовательности в зависимости от числа групп.

RetNet (2023)

Инференс моделей с длинным контекстом является отдельной проблемой, которую в работе предлагается решать, заменив блок multi-head self-attention (MHA) на multi-scale retention MSR. С математической и практической точек зрения механизм retention близок к attention, но у него есть важное преимущество: возможность параллельного и рекурентного представления.

Выбор представления определяет способ подсчёта заменителя внимания. В параллельном варианте получается конструкция, схожая с обычным подсчётом self-attention, что даёт возможность эффективно (по сравнению с RNN) учить модели. В рекурентном же представлении модель может эффктивно генерировать тексты без необходимости использовать большой кэш с помощью передаваемого между токенами скрытого состояния (как в RNN). Математические детали retention можно разобрать в оригинальной работе

Двойственность представления retention. Все обозначения — из формул работы, Q, K и V по смыслу похожи на те, что используются в обычном attention, GN — Group Normalization.

Двойственность представления retention. Все обозначения — из формул работы, Q, K и V по смыслу похожи на те, что используются в обычном attention, GN — Group Normalization.

Оставаясь достаточно эффективным на обучении, RetNet показывает впечатляющие результаты на инференсе:

Wps — words per second.

Wps — words per second.

Qwen (2023)

LLM общего назначения с контекстом 8К, для поддержки которого использует две модификации внимания: LogN-Scaling (масштабирует числовые значения в self-attention для стабилизации внимания при росте длины контекста) и локальное внимание (как в Longfomer). В последнем было отмечено более сильное влияние расширения контекста на нижние слои, поэтому у каждого слоя модели для токенов в self-attention используются окна разного размера: более короткие для первых слоёв, более длинные для последующих. Реализация активно использует FlashAttention.

LongLoRA, 2023

LoRA — популярный метод эффективного дообучения моделей Transformer, использующий низкоранговое разложение обучаемых матриц-добавок к замороженным основным весовым матрицам нейросети. В случае длинного контекста обычная LoRA слабо повышает эффективность дообучения из-за квадратичности внимания. Кроме того, качество такого дообучения с т.з. перплексии оказывается сильно уже нормального fine-tuning. LongLoRA соединяет в себе LoRA и эффективный подсчёт self-attention:

5c62417c73c9ff18a1c2475837b6baed.png

Внимание считается локально внутри заданных окон по двум шаблонам: обычный и сдвинутый на половину окна. Веса внимания дообучаются с помощью LoRA, дополнительно обычным образом дообучаются веса слоёв нормализации и входные эмбеддинги (в большой модели их доля от общего числа параметров невелика).

da82225f515832996c87a22ca3a96ef4.png

Новый подсчёт self-attention реализуется так:

  • Все головы внимания делятся на две части

  • Выбирается размер окна (группы), во второй половине голов векторы сдвигаются на полгруппы

  • В каждой группе self-attention считается как обычно (по всем головам)

  • Результаты для «сдвинутых» голов сдвигаются обратно

f27016700f65d4776aabbba49fbc25f8.png

В таком подходе сама схема вычислений не меняется — можно использовать готовые оптимизации (например, Flash Attention).

Mistral (2023)

LLM общего назначения с контекстом до 32К, использующая набор описанных подходов: Grouped-Query Attention, локальное внимание из Longformer (Sliding Window Attention) и FlashAttention (что стандартно для моделей последних месяцев).

Дополнительно используется Rolling Buffer Cache (фиксированный размер окна позволяет при генерации текста вытеснять из кэша декодировщика векторы, оказавшиеся за пределами контекста внимания):

Кэш с окном в 4 токена. Перезаписываются наиболее старые значения, оранжевым выделены скрытые состояния, соответствующие последнему сгенерированному токену.

Кэш с окном в 4 токена. Перезаписываются наиболее старые значения, оранжевым выделены скрытые состояния, соответствующие последнему сгенерированному токену.

На момент написания статьи Mistral 7B — одна из самых сильный open-source моделей, превосходящая по качеству на многих задачах LLaMA 2 13B.

Заключение

Тема расширения контекста для LLM на подъёме, предлагается много разных идей, наибольшее распространение получают самые простые (локальное внимание) и универсальные, не требующие настройки и не снижающие качество (FlashAttention). С этих подходов и рекомендуется начинать любую работу по серьёзному раширению контекста имеющейся модели с помощью дообучения. Также актуальны исследования, связанные с кэшированием информации из более старой части последовательности, тут ещё большое поле для экспериментов.

Стоит отметить, что в задачах генерации короткого текста по длинному контексту хорошо показывают себя и модели на основе декодировщика Transformer, и полный Transformer. Но если речь идёт о задачах типа суммаризации, то при фиксированных размерах моделей лучше могут справиться именно полные Transformer-ы, что подтверждается интересом именно к этой архитектуре в работах, направленных на увеличение длины последовательности для seq-to-seq.

Спасибо за внимание и успехов!

© Habrahabr.ru