Обзор Llemma: новая математическая open-source модель
Привет! Меня зовут Дарина, и я занимаюсь фундаментальными исследованиями в MTS AI. Основной фокус нашей работы сейчас — обучение больших языковых моделей, их тестирование и оптимизация.
Сегодня хочу сделать обзор на недавно вышедшую статью LLEMMA: an open language model for mathematics. Расскажу про обучение модели, новый датасет Proof-Pile-2 и в конце сравню ее с ChatGPT и GPT-4 на ЕГЭ заданиях по профильной математике.
Введение
За последнее время было выпущено много больших языковых моделей, которые умеют поддержать диалог, решить математическую задачку, помочь составить презентацию и т.д. Однако, если обучать или дообучать модель на определенную сферу знаний, то это принесет больше пользы и к тому же на это будет потрачено меньше ресурсов. Например, модель Galactica, обученная на научных данных, превосходит более модели GPT3 и BLOOM. Также недавно выпущенная CodeLlama показывает результаты лучше, чем ее базовая модель Llama 2.
Авторы статьи решили создать открытую языковую модель Llemma, умеющую решать математические задачи. Также они собрали датасет Proof-Pile-2, на котором обучались. Модели, код для обучения и датасет были выложены в открытый доступ на GitHub и HuggingFace.
Датасет
Данные, использованные во время обучения и структура датасета Proof-Pile-2, кол-во токенов и вес каждого датасета во время обучения.
В качестве основного датасета был сформирован Proof-Pile-2, который состоит из датасетов поменьше:
AlgebraicStack. Авторы создали датасет из 11 миллиардов токенов исходного кода на 17 языках, связанных с математикой. Сюда входят языки программирования для доказательств теорем: Lean, Isabelle, Coq и другие. Из популярных языков программирования включены Python, Matlab, C и C++. Данные были взяты и отфильтрованы из the Stack, публичных гитхаб-репозиториев с помощью GitHub API. Отдельно данные для Lean и Isabelle были получены, соответственно, из Mathlib, архива формальных доказательств и стандартной библиотеки Isabelle.
OpenWebMath. Датасет, состоящий из 15 миллиардов токенов включает в себя веб-страницы с математическим контентом.
Научные статьи из arXiv. Часть датасета RedPajama, состоящая из 29 миллиардов токенов. RedPajama — это датасет, воспроизводящий датасет для обучения Llama.
Proof-Pile-2 актуален на апрель 2023 года.
Также помимо Proof-Pile-2 модель была обучена на датасете the Pile и подмножестве GitHub из датасета RedPajama.
Модель и обучение
Каждая модель Llemma была инициализирована с CodeLlama — decoder-only трансформер, обученный на 500B токенов кода. Авторы продолжили обучение две CodeLLama модели 7B и 34B.
Llemma 7B обучалась на 200B токенах 23 тысячи A100-часов. Llemma 34B обучалась на 50B токенах 47 тысяч A100-часов. Обе модели имеют контекст 4096 токенов.
Во время обучения использовались tensor и data parallelism, а также ZeRO Stage 1. Не буду вдаваться в подробности что это такое, но советую почитать эту статью для понимания. Не обошлось также без FlashAttention2 и RoPE.
Оценка
Chain-of-thought промптинг
Сначала авторы статьи решили оценить способность модели решать математические задачи используя chain of thought reasoning. CoT — это промпт, который включает в себя обоснование данного ответа. Например, как показано на картинке ниже.
Для оценки использовались следующие датасеты:
MATH — датасет, включающий в себя 12.5 тысяч задач из соревнований по математике среди старших школ. Модели подается проблема, а ее ответ генерируется в виде Latex решения. В статье использовали 4-shot промптинг.
GSM8k — датасет из 8.5 тысячи математических задач уровня средней школы, написанными людьми. Оценка проводилась с помощью 8-shot промптинга.
OCWCourses — коллекция задач уровня бакалавриата полученных из OpenCourseWare от MIT.
MMLU-STEM — подмножества 18 предметных областей из 57 бенчмарка MMLU. Использовался 4-shot промптинг.
SAT — созданный авторами статьи, датасет, состоящий из 32 математических вопросов.
Пример генерации решения от Llemma 34B на задачу из датасета MATH, имеющую самую высокую сложность 5.
Целью было оценить Llemma как базовую модель, сравнивая ее с подобными, которые не дообучались (fine-tuning) на математических данных. В качестве главного конкурента была выбрана Minerva от Google Research, которая продолжила обучение PaLM. Самым главным преимуществом Llemma является ее открытость. Minerva же закрытая модель с закрытым датасетом.
Сравнение Llemma и Minerva.
Дополнительно модель сравнивалась с CodeLlama и Llama 2. В качестве метрики была выбрана точность совпадения строк или его SymPy эквивалента. Ниже можно увидеть результаты.
Также оценка проводилась с помощью majority voting или maj@k. Это способ выбора самого популярного ответа среди k сгенерированных ответов, вместо greedy decoding, который просто выбирает самый вероятный. На рисунке ниже показан пример.
Как можно заметить, это довольно долгий процесс оценки, поэтому авторы решили сделать оценку только для Minerva и Llemma. Для бенчмарка MATH k равнялся 256, а для GSM8k и OCW k = 100. В случае SAT и MATH генерировалось всего 16 сэмплов. После выбора самого популярного ответа, использовался nucleus sampling с p = 0.95. Результаты можно увидеть ниже.
Результаты оценки maj@k на пяти бенчмарках для моделей Minerva и Llemma.
Proof assistant
Proof assistant — это интерактивная программа для доказательств теорем. Обычно такая программа имеет свой собственный язык. Как уже было выше сказано, авторы создали датасет AlgebraicStack, который включает в себя 1.5 миллиарда токенов таких языков и решили проверить свою модель на двух задачах:
Informal-to-formal. Перевод из неформально описанной задачи и ее решения на языке Latex в формальный язык Isabelle. Чтобы оценить правильность, был использован бенчмарк miniF2F.
Formal-to-formal. Генерация продолжения доказательства на основе предыдущих шагов для программы Lean 4. Иными словами, Copilot для этого языка. Также использовался бенчмарк miniF2F.
Слева: дана проблема, неформальное док-во и формальное описание проблемы. Модель должна сгенерировать формальное док-во. Справа: дано несколько строчек кода (выделено серым цветом), модель сгенерировала следующий шаг док-ва (например, rw […])
Результаты получились такими:
Llemma превосходит все сравниваемые модели для бенчмарка miniF2F.
Тестирование
Я решила проверить Llemma 7B на 21 задании ЕГЭ профильной математики. С помощью Яндекс Переводчика я перевела вариант на английский и глазами проверила ответы. Вариант на английском лежит тут. Я использовала 1-shot промптинг, взяв только первую задачу из промпта для бенчмарка MATH, используемого в статье. Из-за вычислительных ограничений я не тестировала Llemma 34B.
Для сравнения я также протестировала на этих заданиях GPT-4 и ChatGPT (gpt-3.5-turbo). Получились такие результаты:
Модель | Кол-во верных ответов |
Llemma 7B | 6 |
ChatGPT | 11 |
GPT-4 | 12 |
Перебор параметров генерации Llemma для достижения лучшего результата занял довольно долгое время. Она повторялась, пыталась генерировать сама вопросы и ответы на них, ошибалась в знаках. В итоге у меня так и не получилось найти золотую середину, где недостатков совсем не было. Я остановилась на параметрах repetition_penalty=1.1 и temperature=0.2.
GPT-3 и ChatGPT справились без дополнительного промпта, не повторялись и подробно объясняли свой ответ.
Заключение
В этом материале я рассмотрела модели Llemma 7B и 34B, продолжающие обучение CodeLlama на математических данных. Рассказала, как обучались модели и какие данные для этого использовались. Были рассмотрены методы оценки модели на основных математических бенчмарках, а также как модель справляется с генерацией кода для математических доказательств.
Дополнительно я провела свою оценку на 21 задаче из ЕГЭ по профильной математике, сравнивая модель с ChatGPT и GPT-4. В результате, Llemma показала худшие результаты. В защиту хотелось бы сказать, что тестирование проводилось на маленькой модели. Также цель авторов статьи заключалась в создании открытой базовой модели, которую лучше дообучить на определенной сфере математики.