Большие языковые модели гораздо линейнее, чем мы думали

Хабр, привет! Это снова Антон Разжигаев, аспирант Сколтеха и научный сотрудник лаборатории FusionBrain в Институте AIRI, где мы продолжаем углубляться в изучение языковых моделей. В прошлый раз мы выяснили, что эмбеддинги трансформеров-декодеров сильно анизотропны. На этот раз я бы хотел рассказать об их удивительной линейности, ведь нашу статью про обнаруженный эффект («Your Transformer is Secretly Linear») несколько дней назад приняли на международную конференцию ACL!

Линейность считается свойством самых слабых моделей, ведь они могут решать только простейшие задачи, для которых зачастую и ML-то особо не нужен (см. картинку ниже). Поэтому принято считать, что НЕлинейность — это краеугольный камень сложных вычислений и преобразований внутри больших нейронных сетей, и, в особенности, трансформеров. 

Справа пример задачи, которую линейные модели решить не могут.

Справа пример задачи, которую линейные модели решить не могут.

Однако в нашей последней работе, мы обнаружили, что для больших языковых моделей (LLM) декодеров это совсем не так! Информация от слоя к слою практически не испытывает нелинейных преобразований, а каждый отдельный блок трансформера можно заменить всего лишь на один линейный слой без потери качества! Правда звучит интригующе? Ниже я коротко расскажу про наши главные выводы.

Как вообще оценить линейность отдельного слоя модели?

Для оценки линейности отдельных слоев модели мы использовали метод, который можно назвать обобщением Procrustes Similarity. Это звучит сложно, но суть проще чем кажется: мы берём два набора векторов (выходы двух последовательных слоев), центрируем их (чтобы убрать смещение) и нормализуем. Затем мы ищем такое линейное преобразование, которое минимизирует MSE между этими двумя наборами векторов. Если среднюю ошибку такой аппроксимации вычесть из 1, то получим коэффициент линейности (1 — наивысшая линейность, 0 — полное отсутствие линейности).

На графике ниже видно, что линейность всех LLM близка к 100%. Исключением являются только первый и последний слои, а также самая крошечная модель Pythia-70M.

Степень линейности каждого слоя во всевозможных LLM.

Степень линейности каждого слоя во всевозможных LLM.

Что происходит с линейностью во время претрейна и дообучения?

Мы взяли все открытые языковые модели с опубликованными промежуточными весами и посмотрели, как менялась линейность, усреднённая по слоям от чекпоинта к чекпоинту. 

Похоже, что по мере изначального обучения, у всех этих моделей линейность постепенно падала (то есть их нелинейные свойства проявлялись всё сильнее). При этом во время дообучения на любых задачах линейность всегда растёт, вне зависимости от типа задачи или модели.

Средняя линейность (без учёта residual) по мере обучения моделей.

Средняя линейность (без учёта residual) по мере обучения моделей.

Прирост линейности после файнтюнинга.

Прирост линейности после файнтюнинга.

Мы предполагаем, что это связано с уменьшением обобщающей способности, то есть рост линейности — это то, что приводит к катастрофическому забыванию. Грубо говоря, модель постепенно откидывает старые знания, которые не пригодились во время файнтюнинга (RLHF не исключение).

Регуляризация, усиливающая нелинейность

Раз линейность — это плохо, то, может быть, её можно как-то контролировать при помощи регуляризации? Мы попытались повлиять на степень линейности архитектуры Mistral во время претрейна, испробовали кучу разных вариантов, и ни один из них не работал — языковая модель почему-то сопротивлялась и старалась оставаться максимально линейной. 

Но в какой-то момент мы случайно перепутали знак в регуляризационном лоссе, и всё заработало! Почему-то к уменьшению линейности привёл именно тот лосс, который «стягивал» эмбеддинги с последовательных слоёв друг к другу при помощи косинуса, то есть нам помогла случайная ошибка. 

На первый взгляд, такая регуляризация наоборот должна учить модель «ничего не делать» на каждом отдельном слое, однако языковая модель отреагировала абсолютно противоположным образом. При этом с ростом нелинейных свойств её слоёв подросли и метрики, модель стала лучше решать некоторые задачи, писать более качественный текст, а её эмбеддинги стали более экспрессивными (то есть полезнее для downstream задач).

Эффект косинусной регуляризации на среднюю линейность (без учёта residual).

Эффект косинусной регуляризации на среднюю линейность (без учёта residual).

Косинусная регуляризация приводит к увеличению экспрессивности эмбеддингов. То есть растёт точность linear probing.

Косинусная регуляризация приводит к увеличению экспрессивности эмбеддингов. То есть растёт точность linear probing.

Результаты валидации через GPT-4 на Tiny-Stories. Модель с косинусной реугуляризацией генерирует более связный и качественный текст.

Результаты валидации через GPT-4 на Tiny-Stories. Модель с косинусной реугуляризацией генерирует более связный и качественный текст.

Линейный прунинг

Один из первых вопросов, возникших у нас в голове — если отдельные слои LLM на 99% линейны, то почему бы просто не заменить их на один‑единственный nn.Linear (), выкинув при этом весь этэншн, feed‑forward и тп?

Да, оказалось, что так можно! При этом точность модели и её качество практически не падают, но подменить таким образом можно только небольшое количество слоёв (~15%), а дальше ошибка линейной аппроксимации накапливается, и качество начинает ухудшаться.

Перплексия для OPT-1.3B на WikiText при выкидывании части слоёв из модели, при линейной замене слоёв и при последующем дообучении всех новых линейных слоёв одновременно.

Перплексия для OPT-1.3B на WikiText при выкидывании части слоёв из модели, при линейной замене слоёв и при последующем дообучении всех новых линейных слоёв одновременно.

Перплексия для Llama-2-7B на WikiText при выкидывании части слоёв из модели, при линейной замене слоёв и при последующем дообучении всех новых линейных слоёв одновременно

Перплексия для Llama-2–7B на WikiText при выкидывании части слоёв из модели, при линейной замене слоёв и при последующем дообучении всех новых линейных слоёв одновременно

Заключение

Обнаруженный эффект кажется очень контринтуитивным, он противоречит многим нашим представлениям о глубоком обучении. Откуда такая сильная линейность в, казалось бы, одной из самых мощных и изученных архитектур? Мы точно не знаем, но предполагаем, что это связано с режимом триггеринга фичей. То есть нелинейные свойства «вспыхивают» очень редко, а на большинстве входных токенов модель работает в около‑линейном режиме. Что‑то похожее было обнаружено в статье  Deja Vu где изучали мёртвые нейроны в языковых моделях.

Подписывайтесь на каналы авторов в телеграме AbstractDL, CompleteAI

Habrahabr.ru прочитано 2766 раз