Multilabel-классификация знаний школьников
«Устный счёт», Богданов‑Бельский, 1895.
Привет, Хабр! Меня зовут Егор, сейчас я учусь на четвёртом курсе кафедры математических методов прогнозирования (ММП) ВМК МГУ и изучаю машинное обучение, в том числе, обработку естественных языков (Natural Language Processing). Этим летом я стажировался в Лаборатории искусственного интеллекта, в центре Инструментов машинного обучения, где смог применить свои знания для решения практических задач. Об одной из них я и хочу рассказать.
Задача
Использование методов машинного обучения и нейронных сетей в сфере образования является актуальной темой для исследований. Передо мной стояла задача оценивания свободных ответов пользователей на вопросы по школьной программе. Для решения этой задачи у меня была выборка, которая состояла из вопросов, соответствующих им критериев, и ответов. Для каждого ответа была представлена разметка соответствия критериям. Всего в обучающем наборе было 46 вопросов, для каждого вопроса — в среднем 267 ответов, при этом для одного вопроса количество критериев варьировалось от 2 до 13. Суммарно выборка состояла из 12 281 объектов. Набор оказался сильно несбалансированным в сторону отрицательных критериев: около 76 процентов ответов «Нет», и оставшиеся 24 процента — «Да».
В качестве бизнес-метрики рассчитывалась доля объектов, на которых модель ошиблась меньше, чем на 20% критериев. Если у вопроса менее пяти критериев, все они должны быть оценены корректно.
Таким образом, мне предстояло решить задачу multilabel-классификации для каждого из ответов. Один из вариантов решения — fine-tuning энкодерных архитектур (например, BERT’а). Однако итоговая модель должна иметь хорошую обобщающую способность в том числе для пар вопросов и критериев, ранее не встречавшихся в наборе данных. Поэтому я решил использовать zero-shot / few-shot классификацию с помощью больших языковых моделей (LLM), которые при правильно поданном prompt’e могут дать корректный ответ в задаче с теми вопросами, которые ранее не встречались.
Выбор модели и её fine-tuning
18 июля вышла языковая модель LLaMA-2, которая заняла первое место на Open LLM Leaderboard. LLaMA-2 является мультилингвальной моделью, в предобучении которой использовался в том числе и русский язык. Чат версия этой модели дообучена на выполнение инструкций и поддержание диалога, поэтому я использовал именно её. Однако на практике модель оказалась склонна давать ответы на английском, из-за чего потребовалось адаптировать её под русский язык. Поскольку модели очень тяжеловесны, для полного fine-tuning даже версии 7B требуется большое количество данных и ресурсов. Одним из способов, как можно избежать полного дообучения моделей на этом шаге, является использование адаптеров, в частности, QLoRA.
QLoRA как расширение LoRA-подхода состоит в том, что к исходным весам LLM прибавляется дополнительная малоранговая обучаемая матрица, которая является произведением двух матриц, а веса основной модели заморожены и квантизованы. Добавочные веса позволяют сократить количество обучаемых параметров до 1–3% от количества исходных весов. Но даже в этом случае модель потребляет большое количество памяти видеокарты как для обучения, так и для инференса. Поэтому я использовал четырёх- и восьмибитную квантизацию весов основной модели. При использовании восьмибитной квантизации количество используемой видеопамяти уменьшается примерно в два раза по сравнению с обучением без квантизации, при использовании четырехбитной квантизации — примерно в три раза, но при этом качество модели ухудшается. Впрочем, добавление адаптера к квантизованным весам позволяет частично или полностью компенсировать потерю качества в фиксированной задаче.
Необходимое количество видеопамяти для обучения (в гигабайтах) | LLaMA 2 7b | LLaMA 2 13b |
4-битная квантизация | 19,9 | 29,6 |
8-битная квантизация | 24,7 | 37 |
Необходимое количество видеопамяти для инференса (в гигабайтах) | LLaMA 2 7b | LLaMA 2 13b |
4-битная квантизация | 4,9 | 8,1 |
8-битная квантизация | 8 | 14,4 |
Для адаптации модели под русский язык я взял наборы данных, собранные Ильёй Гусевым. Практически все они сгенерированы с помощью ChatGPT на основе специального prompt’а для инструктивного fine-tuning’а новой адаптированной языковой модели. Я поменял системный prompt Ильи Гусева на prompt, под который была обучена LLaMA 2, то есть вместо спецтокенов user и bot я использовал спецтокены [INST] и [/INST].
Основная гипотеза состояла в том, что модели с большей квантизацией, меньшем количеством параметров и меньшим рангом должны показать качество хуже, чем аналогичные модели с меньшей квантизацией. При этом они должны занимать меньше памяти, то есть быть удобнее в обучении и инференсе.
Для эксперимента обучим QLoRA-адаптер для LLaMA-2-chat с 7 и 13 миллиардами параметров, с четырёх- и восьмибитной квантизацией. А зависимость от ранга исследуем отдельно на примере восьмибитной квантизации. Обучение длилось 10 эпох c 32-битным оптимайзером AdamW и косинусным lr_scheduler, размер батча был равен 128, шаг обучения — 1e-4.
Для оценки обученных языковых моделей будем использовать метрики RussianSuperGLUE (Total Score) и LiDiRus, которые показывают, насколько хорошо модели могут извлекать информацию из текста, делать логические выводы, понимать значение слов и так далее.
Модель | Квантизация | Ранг | Total Score | LiDiRus |
LLaMA 2 7B | 8 бит | w/o SFT | 0,387 | 0,132 |
LLaMA 2 7B | 4 бит | 16 | 0,5 | 0,195 |
LLaMA 2 7B | 8 бит | 4 | 0,536 | 0,231 |
LLaMA 2 7B | 8 бит | 16 | 0,526 | 0,3 |
LLaMA 2 7B | 8 бит | 64 | 0,519 | 0,287 |
LLaMA 2 13B | 8 бит | w/o SFT | 0,402 | 0,137 |
LLaMA 2 13B | 4 бит | 16 | 0,528 | 0,249 |
LLaMA 2 13B | 8 бит | 16 | 0,557 | 0,283 |
Для моделей с 7 миллиардами параметров мы перебрали значения ранга, и модели с разными рангами получили схожие результаты (у одних больше общий результат (Total Score), у других — LiDiRus). Тем не менее, мы выбрали ранг 16, так как он даёт наилучший результат по совокупности обеих метрик при меньших затратах по времени обучения и объёму памяти. Стоит отметить, что для 13-ти миллиардных моделей результаты минимально отличаются от Saiga 13B, размещенных в RussianSuperGLUE Leaderboard. Это может быть связано с пределом, который мы можем достичь при тюнинге LLaMA-2 чекпоинтов на данном фиксированном инструктивном наборе данных.
Обучение адаптеров при всех изученных параметрах даёт прирост метрик по сравнению с исходными квантизованными моделями. Эксперименты с четырёхбитной квантизацией показали качество хуже, чем аналогичные модели с восьмибитной квантизацией, что неудивительно — квантизация понижает качество моделей. Наилучший результат показала модель с 13 миллиардами параметров и восьмибитной квантизацией, однако обучение одной эпохи для этой модели заняло существенно больше времени, чем для других. В то же время обучение одной эпохи моделей с четырёхбитной квантизацией более чем в два раза быстрее, чем обучение аналогичных моделей с восьмибитной квантизацией.
Дополнительная адаптация под задачу
Из таблицы 2 видно, что обучение адаптера под русский язык дало сравнимый результат с открытыми и коммерческими моделями, в случае сравнения с Saiga и Saiga 2 это во многом связано с тем, что дообучение производилось на одинаковом инструктивном наборе.
Стоит отметить, что во время тестирования zero/few shot подходов я столкнулся с проблемой парсинга ответов при промптинге как открытых (Saiga, Saiga 2), так и коммерческих моделей (Gigachat). Это может существенно понижать значение метрик качества решения данной задачи, особенно бизнес-метрику, так как она является более дискретной и соответственно менее устойчивой.
По этой причине я решил обучить дополнительный адаптер непосредственно под оценивание ответов.
В этом случае, как и ранее, я воспользовался QLoRA-адаптером. Из-за нехватки разнообразия данных нам необходимо собрать искусственный инструктивный набор, который поможет модели лучше оценивать ответы пользователей.
Для этого были использованы наборы данных ms_marco, ru_sberquad_long_answers и MMLU. Ниже приведены детали подготовки каждого из них.
ms_marco — набор, собранный с помощью Bing Question, каждый вопрос которого содержит несколько (в среднем 8) ответов, при этом только один ответ является правильным. Так как данные англоязычные, был использован перевод на русский язык с помощью Google Translate.
ru_sberquad_long_answers содержит вопросы с контекстом, в котором содержится правильный ответ. Ответы отдельно были сгенерированы моделью T5. В наборе около 50 тысяч объектов. В отличие от предыдущего, этот набор русскоязычный, то есть перевод для него не требуется, но при этом он не содержит неправильных ответов. Неправильные ответы необходимо генерировать самостоятельно. В качестве процедуры генерации использовался Gigachat: для каждого вопроса генерировался набор ответов, далее, используя модель для sentence‑эмбеддингов, правильные и неправильные ответы отбирались по косинусному расстоянию: ответ считался правильным, если косинусное расстояние между данным и истинным ответами было меньше 0,25, ответ считался неправильным, если косинусное расстояние было больше 0,25, но меньше 0,8. Таким образом были добавлены «сложные» неправильные ответы в выборку.
MMLU содержит вопросы и по четыре ответа на них, из которых только один является правильным. Изначально, как и ms_marco, набор англоязычный, однако существует его переведенная на русский язык версия. В итоговый набор было собрано примерно по 150 вопросов на каждую из 57 дисциплин, от астрономии до социологии, всего около 15 тысяч вопросов.
Из полученных выше наборов данных было сэмплировано 200 тысяч объектов, сбалансированных по классам. Итоговый набор содержит «грязные» данные, так как при переводе и генерации ответов часть информации теряется или искажается. Для дальнейшего обучения необходимо было произвести разделение на обучающую и валидационную выборки. Для этого я реализовал стратификацию по тематикам и классам, предварительно кластеризовав все вопросы с использованием KMeans алгоритма.
В качестве итоговой модели, к первоначальному адаптеру я добавил ещё один, чтобы обучить его на собранном инструктивном наборе. Для этого я взял модель с 13 миллиардами параметров, 8-битной квантизацией и рангом 16 — она лучше всего показала себя по совокупности метрик в эксперименте по адаптации к русскому языку.
Дообучение длилось 5 эпох, использовался тот же 32-битный оптимайзер AdamW и косинусный lr_scheduler, размер батча также был равен 128, шаг обучения был равен 1e-4.
Для сравнения моделей использовались accuracy и описанная выше бизнес-метрика. Сравним нашу модель с Gigachat’ом и открытыми русскоязычными моделями Saiga и Saiga 2:
Точность | Бизнес-метрика | |
Saiga | 0,6 | 0,24 |
Saiga 2 | 0,39 | 0,13 |
Gigachat | 0,63 | 0,23 |
LLaMA-2-Chat + адаптер для русского языка | 0,57 | 0,26 |
LLaMA-2-Chat + адаптер для русского языка + адаптер под задачу | 0,62 | 0,33 |
Наша модель на узконаправленной̆ задаче показала более высокое качество по бизнес-метрике, чем Gigachat, Saiga и Saiga 2. Данный прирост может быть обусловлен двумя факторами: отсутствием большого числа ошибок парсинга и добавлением данных, максимально приближенных к целевой выборке.
По результату стажировки в Лаборатории искусственного интеллекта мне удалось повысить бизнес-метрику в узконаправленной̆ задаче в сравнении с коммерческими и открытыми аналогами. А добиться этого я смог благодаря обучению серии адаптеров над LLaMA-2-chat для русского языка.
Лаборатория искусственного интеллекта занимается научными исследованиями и разработкой фреймворков для машинного обучения. Все библиотеки Лаборатории выложены в Open Source и доступны в GitHub по ссылке https://github.com/sb-ai-lab.