Обзор Llemma: новая математическая open-source модель

Привет! Меня зовут Дарина, и я занимаюсь фундаментальными исследованиями в MTS AI. Основной фокус нашей работы сейчас — обучение больших языковых моделей, их тестирование и оптимизация.

Сегодня хочу сделать обзор на недавно вышедшую статью LLEMMA: an open language model for mathematics. Расскажу про обучение модели, новый датасет Proof-Pile-2 и в конце сравню ее с ChatGPT и GPT-4 на ЕГЭ заданиях по профильной математике.

Введение

За последнее время было выпущено много больших языковых моделей, которые умеют поддержать диалог, решить математическую задачку, помочь составить презентацию и т.д. Однако, если обучать или дообучать модель на определенную сферу знаний, то это принесет больше пользы и к тому же на это будет потрачено меньше ресурсов. Например, модель Galactica, обученная на научных данных, превосходит более модели GPT3 и BLOOM. Также недавно выпущенная CodeLlama показывает результаты лучше, чем ее базовая модель Llama 2.

Авторы статьи решили создать открытую языковую модель Llemma, умеющую решать математические задачи. Также они собрали датасет Proof-Pile-2, на котором обучались. Модели, код для обучения и датасет были выложены в открытый доступ на GitHub и HuggingFace.

Датасет

Данные, использованные во время обучения и структура датасета Proof-Pile-2, кол-во токенов и вес каждого датасета во время обучения.

Данные, использованные во время обучения и структура датасета Proof-Pile-2, кол-во токенов и вес каждого датасета во время обучения.

В качестве основного датасета был сформирован Proof-Pile-2, который состоит из датасетов поменьше:

  • AlgebraicStack. Авторы создали датасет из 11 миллиардов токенов исходного кода на 17 языках, связанных с математикой. Сюда входят языки программирования для доказательств теорем: Lean, Isabelle, Coq и другие. Из популярных языков программирования включены Python, Matlab, C и C++. Данные были взяты и отфильтрованы из the Stack, публичных гитхаб-репозиториев с помощью GitHub API. Отдельно данные для Lean и Isabelle были получены, соответственно, из Mathlib, архива формальных доказательств и стандартной библиотеки Isabelle. 

  • OpenWebMath. Датасет, состоящий из 15 миллиардов токенов включает в себя веб-страницы с математическим контентом.

  • Научные статьи из arXiv. Часть датасета RedPajama, состоящая из 29 миллиардов токенов. RedPajama — это датасет, воспроизводящий датасет для обучения Llama.

Proof-Pile-2 актуален на апрель 2023 года.

Также помимо Proof-Pile-2 модель была обучена на датасете the Pile и подмножестве GitHub из датасета RedPajama.

Модель и обучение

Каждая модель Llemma была инициализирована с CodeLlama — decoder-only трансформер, обученный на 500B токенов кода. Авторы продолжили обучение две CodeLLama модели 7B и 34B.

Llemma 7B обучалась на 200B токенах 23 тысячи A100-часов. Llemma 34B обучалась на 50B токенах 47 тысяч A100-часов. Обе модели имеют контекст 4096 токенов.

Во время обучения использовались tensor и data parallelism, а также ZeRO Stage 1. Не буду вдаваться в подробности что это такое, но советую почитать эту статью для понимания. Не обошлось также без FlashAttention2 и RoPE.

Оценка

Chain-of-thought промптинг

Сначала авторы статьи решили оценить способность модели решать математические задачи используя chain of thought reasoning. CoT — это промпт, который включает в себя обоснование данного ответа. Например, как показано на картинке ниже.

Изображение взято из Wei et al. (2022).

Для оценки использовались следующие датасеты:

  • MATH — датасет, включающий в себя 12.5 тысяч задач из соревнований по математике среди старших школ. Модели подается проблема, а ее ответ генерируется в виде Latex решения. В статье использовали 4-shot промптинг.

  • GSM8k — датасет из 8.5 тысячи математических задач уровня средней школы, написанными людьми. Оценка проводилась с помощью 8-shot промптинга.

  • OCWCourses — коллекция задач уровня бакалавриата полученных из OpenCourseWare от MIT.

  • MMLU-STEM — подмножества 18 предметных областей из 57 бенчмарка MMLU. Использовался 4-shot промптинг.

  • SAT — созданный авторами статьи, датасет, состоящий из 32 математических вопросов.

Пример генерации решения от Llemma 34B на задачу из датасета MATH, имеющую самую высокую сложность 5.

Пример генерации решения от Llemma 34B на задачу из датасета MATH, имеющую самую высокую сложность 5.

Целью было оценить Llemma как базовую модель, сравнивая ее с подобными, которые не дообучались (fine-tuning) на математических данных. В качестве главного конкурента была выбрана Minerva от Google Research, которая продолжила обучение PaLM. Самым главным преимуществом Llemma является ее открытость. Minerva же закрытая модель с закрытым датасетом.

Сравнение Llemma и Minerva.

Сравнение Llemma и Minerva.

Дополнительно модель сравнивалась с CodeLlama и Llama 2. В качестве метрики была выбрана точность совпадения строк или его SymPy эквивалента. Ниже можно увидеть результаты.

Результаты моделей на пяти бенчмарках с использованием chain-of-thought reasoning. Семплы генерировались с помощью greedy decoding.

Результаты моделей на пяти бенчмарках с использованием chain-of-thought reasoning. Семплы генерировались с помощью greedy decoding.

Также оценка проводилась с помощью majority voting или maj@k. Это способ выбора самого популярного ответа среди k сгенерированных ответов, вместо greedy decoding, который просто выбирает самый вероятный. На рисунке ниже показан пример.

Изображение взято из Wing et al., 2023

Как можно заметить, это довольно долгий процесс оценки, поэтому авторы решили сделать оценку только для Minerva и Llemma. Для бенчмарка MATH k равнялся 256, а для GSM8k и OCW k = 100. В случае SAT и MATH генерировалось всего 16 сэмплов. После выбора самого популярного ответа, использовался nucleus sampling с p = 0.95. Результаты можно увидеть ниже.

Результаты оценки maj@k на пяти бенчмарках для моделей Minerva и Llemma.

Результаты оценки maj@k на пяти бенчмарках для моделей Minerva и Llemma.

Proof assistant

Proof assistant — это интерактивная программа для доказательств теорем. Обычно такая программа имеет свой собственный язык. Как уже было выше сказано, авторы создали датасет AlgebraicStack, который включает в себя 1.5 миллиарда токенов таких языков и решили проверить свою модель на двух задачах:

  1. Informal-to-formal. Перевод из неформально описанной задачи и ее решения на языке Latex в формальный язык Isabelle. Чтобы оценить правильность, был использован бенчмарк miniF2F.

  2. Formal-to-formal. Генерация продолжения доказательства на основе предыдущих шагов для программы Lean 4. Иными словами, Copilot для этого языка. Также использовался бенчмарк miniF2F.

Слева: дана проблема, неформальное док-во и формальное описание проблемы. Модель должна сгенерировать формальное док-во. Справа: дано несколько строчек кода (выделено серым цветом), модель сгенерировала следующий шаг док-ва (например, rw […])

Слева: дана проблема, неформальное док-во и формальное описание проблемы. Модель должна сгенерировать формальное док-во. Справа: дано несколько строчек кода (выделено серым цветом), модель сгенерировала следующий шаг док-ва (например, rw […])

Результаты получились такими:

Слева: informal-to-formal док-во в Isabelle, показывающее процент доказанных теорем с greedy decoding. Справа: formal-to-formal док-во в Lean, показывающее процент доказанных теорем при заданном числе попыток × кол-во генераций и время ожидания 10 минут. Sledgehammer - встроенная автоматизация в Isabelle. ReProver - это модель, основанная на обучении с учителем и расширенная поиском. COPRA - это метод на основе GPT-4 с применением поиска.

Слева: informal-to-formal док-во в Isabelle, показывающее процент доказанных теорем с greedy decoding. Справа: formal-to-formal док-во в Lean, показывающее процент доказанных теорем при заданном числе попыток × кол-во генераций и время ожидания 10 минут. Sledgehammer — встроенная автоматизация в Isabelle. ReProver — это модель, основанная на обучении с учителем и расширенная поиском. COPRA — это метод на основе GPT-4 с применением поиска.

Llemma превосходит все сравниваемые модели для бенчмарка miniF2F.

Тестирование

Я решила проверить Llemma 7B на 21 задании ЕГЭ профильной математики. С помощью Яндекс Переводчика я перевела вариант на английский и глазами проверила ответы. Вариант на английском лежит тут. Я использовала 1-shot промптинг, взяв только первую задачу из промпта для бенчмарка MATH, используемого в статье. Из-за вычислительных ограничений я не тестировала Llemma 34B.

Для сравнения я также протестировала на этих заданиях GPT-4 и ChatGPT (gpt-3.5-turbo). Получились такие результаты:

Модель

Кол-во верных ответов

Llemma 7B

6

ChatGPT

11

GPT-4

12

Перебор параметров генерации Llemma для достижения лучшего результата занял довольно долгое время. Она повторялась, пыталась генерировать сама вопросы и ответы на них, ошибалась в знаках. В итоге у меня так и не получилось найти золотую середину, где недостатков совсем не было. Я остановилась на параметрах repetition_penalty=1.1 и temperature=0.2. 

GPT-3 и ChatGPT справились без дополнительного промпта, не повторялись и подробно объясняли свой ответ.

Заключение

В этом материале я рассмотрела модели Llemma 7B и 34B, продолжающие обучение CodeLlama на математических данных. Рассказала, как обучались модели и какие данные для этого использовались. Были рассмотрены методы оценки модели на основных математических бенчмарках, а также как модель справляется с генерацией кода для математических доказательств.

Дополнительно я провела свою оценку на 21 задаче из ЕГЭ по профильной математике, сравнивая модель с ChatGPT и GPT-4. В результате, Llemma показала худшие результаты. В защиту хотелось бы сказать, что тестирование проводилось на маленькой модели. Также цель авторов статьи заключалась в создании открытой базовой модели, которую лучше дообучить на определенной сфере математики.

© Habrahabr.ru