Как устроена языковая модель без перемножения матриц
Нейросети любой архитектуры построены на перемножении матриц. Например, трансформеры из входного вектора создают (тоже перемножением) три матрицы, Q, K и V, которые затем несколько раз перемножаются на пути к получению выходного вектора. Именно умножение матриц с его кубической сложностью (если не прибегать к разного рода ухищрениям) занимает большую часть вычислительных мощностей.
Поэтому возникает естественное желание облегчить эту самую массивную часть вычислений и ускорить таким образом любую архитектуру. Периодически возникают новые подходы и идеи, тут расскажем о последней громкой статье по этой теме — Scalable MatMul-free Language Modeling.
Но сначала небольшое отступление (вслед за авторами) к BitNet от Майкрософт и коллег. Там смогли частично избавиться от перемножения, правда только векторов на матрицы. Все плотные вектора заменили на бинарные (-1,1) или тернарные (-1,0,1) значения, и умножения векторов на матрицы превратилось в сложение. Но механизмы self-attention, внутри которого перемножаются матрицы Q и K, остались без изменения. Поэтому авторы новой статьи задались вопросом, а возможно ли расширить этот подход полностью избавиться от матричного умножения в LLM?
Допустим, мы умножаем входной вектор x на матрицу весов W. Теперь продолжая идею BitNet, ограничим допустимые значения W тремя значениями: -1, 0, 1. Тогда перемножение превращается в сложение (если Wij равен 1) или вычитание (если Wij равен -1) элементов входного вектора.
Таким образом мы избавляемся от перемножения чисел с плавающей точкой (входной х при этом остался без изменений, это может быть число с любой точностью, квантизация касается только весов).
Однако если просто взять BitNet и расширить его на произведение матрица-матрица, то ничего хорошего не получается, модель становится хуже и вообще перестает сходиться (Transformer ++ на графике ниже). Авторы считают, что BitNet дал правильную идею, но неправильно или не до конца её реализовал. Чтобы довести её до завершения, предлагается еще два усовершенствования — одно аппаратное и одно концептуальное.
Первый пункт — оптимизация с точки зрения аппаратного обеспечения. В современных GPU есть двухступенчатая иерархия: большая и глобальная высокопропускная память HBM (high bandwidth memory) и более быстрая и мелкая статическая память с произвольным доступом SRAM (static random access memory). BitNet построен таким образом, что к HBM обращаются несколько раз на каждом слое — сначала чтение, потом запись обратно результата RMSNorm, затем снова чтение для квантизации, снова запись и снова чтение для линейных операций. Авторы новой статьи оптимизируют этот процесс — чтение происходит один раз, а RMSNorm и квантизация объединены в одну операцию на SRAM.
Второй пункт — концептуальный. Простая замена перемножения матриц в модуле self-attention на тернарные операции (сложение и вычитание) не работает. Это ожидаемо, потому что квантизация хорошо работает пока она не слишком строгая. Обрубив все значения до -1 и 1 мы теряем слишком много информации и модель становится просто бессмысленной. Раз успешно сделать это в рамках self-attention не получилось, авторы предлагают сделать два шага назад и вернуться к GRU (Gated Recurrent Unit). Среди различных RNN она отличается простотой и эффективностью, поэтому выбор авторов пал на нее.
Главная черта GRU заключается в объединении векторов входа и фильтра забывания в единый блок «утечки» (leakage). Он помогает удерживать нужную информацию предыдущих скрытых состояний и сохранять новую. Но что самое важное, это происходит с помощью обычного поэлементного умножения. Если теперь избавиться от нелинейности классического GRU, то получается перейти к MatMul-free модели. Для этого нужно убрать веса, которые зависят от скрытого состояния. Таким образом исчезают перемножения матриц и появляется возможность параллельных вычислений, как для трансформеров.
Затем вычисление возможного нового скрытого состояния упростили до линейных преобразований, и заменили все оставшиеся веса на тернарные матрицы.
Итоговая архитектура выглядит так (все матрицы W приведены к тернарным значениям). :
Экспериментальные результаты показывают, что без MatMul-free модель работает плюс-минус на равных с полноценными трансформерами, но экономит 61% памяти. Хотя, похоже главный результат даже не в этом.
Если всё верно, то главные преимущества MatMul-free подхода мы увидим позднее. На графике сверху две нижние прямые показывают закон масштабирования для трансформеров и MatMul-free архитектуры. Пока что трансформеры все-таки лучше. По крайней мере на той области, где провели эксперименты. Но по мере увеличения FLOP прямые сближаются и пересекаются где-то в районе 1023, что близко с тем, что использует Llama 2. Важная оговорка: при таком положении прямых точка пересечения определяется неустойчиво. Так что при каком там порядке они действительно пересекаются — это вопрос. Однако за этой точкой (авторы обозначили ее звездочкой), а еще лучше где-то справа от графика прямые начнут расходиться и тут-то мы должны ощутить весь положительный эффект MatMul-free моделей. Впрочем, это не кажется слишком далеким будущим, проверить сможем совсем скоро.