Сжимаем трансформеры: простые, универсальные и прикладные способы cделать их компактными и быстрыми
Сейчас в сфере ML постоянно слышно про невероятные «успехи» трансформеров в разных областях. Но появляется все больше статей о том, что многие из этих успехов мягко говоря надуманы (из недавнего помню статью про пре-тренировку больших CNN в компьютерном зрении, огромную MLP сетку, статью про деконструкцию достижений в сфере трансформеров).
Если очень коротко просуммировать эти статьи — примерно все более менее эффективные нерекуррентные архитектуры на схожих вычислительных бюджетах, сценариях и данных будут показывать примерно похожие результаты.
Тем не менее у self-attention
модуля есть ряд плюсов: (i) относительная простота при правильной реализации (ii) простота квантизации (iii) относительная эффективность на коротких (до нескольких сотен элементов) последовательностях и (iv) относительная популярность (но большая часть имплементаций имеет код раздутый раз в 5).
Также есть определенный пласт статей про улучшение именно асимптотических свойств self-attention модуля (например Linformer и его аналоги). Но несмотря на это, если например открыть список пре-тренированных языковых моделей на основе self-attention модулей, то окажется, что «эффективных» моделей там буквально пара штук и они были сделаны довольно давно. Да и последовательности длиннее 500 символов нужны не очень часто (если вы не Google).
Попробуем ответить на вопрос —, а как существенно снизить размер и ускорить self-attention модуль и при этом еще удовлетворить ряду production-ready требований:
- Не нужно сильно менять свой код;
- Не нужно инвестировать много ресурсов в дистилляцию;
- Приросты должны быть кратными, а еще лучше суммарно на порядок;
- При этом качество итоговых должно оставаться примерно сопоставимым;
- Все эти оптимизации должны быть готовыми для продакшена, а не сугубо теоретическими;
- Все подходы должны переноситься между разными доменами;
Тут важно еще сделать оговорку, что проседание качества будет тем сильнее, чем сложнее ваша задача. Например на бинарной классификации ужаться можно сколько угодно (да им может проще использовать более простые методы), а вот на sequence-to-sequence задачах будут моменты.
Простейшие оптимизации
Какое-то время назад в сети проскакивала такая презентация. Если абстрагироваться от ее «космической» (когда я читаю такие материалы, мне кажется что авторы строят башню на луну) академической составляющей, то между строк можно найти такую информацию:
- Self-attention модули о 2 головах показывают примерно такие же результаты как 8-головые модули;
- По умолчанию внутри self-attention модуля активации в 2 раза «шире», чем снаружи, этот параметр тоже можно настраивать;
В целом, вопрос состоит в том, стоит ли сразу тренировать более компактные модели или все-таки дистиллировать, но это зависит уже от вашего конкретного кейса.
Плюсы:
- Заранее убираем лишние параметры и сложность;
- Очень легко имплементировать (меняем сути 2 параметра);
Минусы:
- Модель может дольше сходиться или надо дистиллировать. В теории можно и не тренировать модель, если грамотно выбрать «нужные» головы;
- Скорее всего вам придется либо самим имплементировать модуль, или найти менее ужасную имплементацию (можно начать отсюда или открыть пакет
x-transformers
);
Квантизация
В PyTorch где-то примерно начиная с версии 1.3 завезли динамическую квантизацию Linear
и LSTM
модулей. Не считая подготовительного кода она действительно (без преувеличения) делается в одну строчку кода:
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
Примеры можно посмотреть тут и тут.
Плюсы:
- В отличие от статической квантизации ее очень просто применять и даже для sequence-to-sequence (шутка ли, но мне кажется народ просто не разобрался как ее готовить, и поэтому в PyTorch начали пилить целое экспериментальное АПИ трансформации моделей). Квантизованные модели имеют почти такое же качество, посмотрите примеры например тут (суффикс
_q
означает квантизацию); - Сокращение размера квантизованных модулей в 4 раза (
float32
=>int8
); - Примерно двукратное ускорение на процессорах Intel (с AMD все сложнее, в зависимости от модели они работают от 10% до 70% медленнее).
- Очень легко имплементировать (1 строчка кода);
- Не надо тренировать;
Минусы:
- Квантизованные модели (на момент написания этой статьи) не запускаются на GPU;
- У новых фич, подобной этой, всегда несколько остает downstream поддержка, например при экспорте в ONNX;
- На процессорах AMD скорость моделей может сильно проседать (и без квантизации, от 10% до 70%);
- На момент написания статьи модуль self-attention по умолчанию из PyTorch квантизацию не поддерживает (уже 3 или 4 мажорных релиза!);
- Скорее всего вам придется либо самим имплементировать модуль, см. выше;
Факторизация
Я кратко описывал описывал этот подход в заметке на канале в телеграме. Идея состоит в том, чтобы взять готовую сетку, применить Singular Value Decomposition (который недавно тоже завезли в PyTorch) к матрице весов, заранее выбрав нужный уровень разреженности.
Чтобы не плодить лишние классы, для начала можно элегантно сделать monkey-patching своей модели при ее загрузке, заменив Linear
на FactorizedLinear
модуль.
Плюсы:
- Можно заранее выбрать уровень разреженности и теоретически снизить размер модели до нужного уровня;
- Легко имплементировать (70 строчек кода, но скорее всего можно и короче);
- Я не делал точных замеров, но мне показалось, что прирост скорости был только на GPU;
- SVD надо сделать только один раз;
Минусы:
- Модель точно надо дотюнивать;
- Из всех методов, как мне показалось, этот сильнее всего бьет по метрикам;
- Приведенная по ссылке выше имплементация не квантизуется, будет интересно смогут ли читатели догадаться как устранить этот недостаток;
- Формально этот подход снижает размер модели, но он как бы разбивает одно матричное перемножение на два более маленьких, и прирост скорости не гарантирован;
Прунинг и дистилляция
Я пробовал дистиллировать большие модели напрямую в более маленькие, но особых успехов я не достиг, так как делал этот на весьма экзотических задачах и лоссах. Так что тут не особо могу поделиться успехами.
FNet
Недавно появилась вот такая статья. В ней по сути предложили заменить self-attention
на разложение Фурье. Получается, что размер модели снижается в два раза, скорость на GPU становится меньше чуть ли не на порядок (чего нельзя сказать про CPU). Якобы на задачах авторов потеря качества в районе 10%.
Плюсы:
- В PyTorch есть даже FFT2 модуль для двумерных фич (хотя он просто вызывает два обычных подряд);
- Сетка реально становится в два раза меньше, а не «виртуально»;
- Внезапно это реально работает, вопрос только насколько хорошо;
- Мы избавляемся от самой затратной части self-attention модуля;
Минусы:
- Формально я не проверил еще насколько этот метод сжатия сравним по качеству и как дружит с квантизацией, но если продолжить тренды, то он требует больше ресурсов чем SVD, но точно работает;
- Более ярко-выраженный прирост на скорости на GPU, прирост заметен больше для больших и огромных моделей, для малых — все более скромно;
На десерт — оптимизации самого PyTorch
Все знают про fusion сверток, бнорма и relu. Но в недавней версии 1.9
добавили еще «заморозку» модели и inference mode. Я протестировал их, и при прочих равных inference mode не добавил скорости, а заморозка модели докинула 14% скорости. Тут важно отметить, что все очевидные оптимизации с моделями уже были проделаны, поэтому приросты уже небольшие. Еще важный момент состоит в том, что если вы возьмете медленную модель и примените какой-либо хак из этой статьи, вы получите условно заявленные x2. Но если вы примените много хаков сразу, то в какой-то момент начнет показывать себя закон убывающей полезности.
Плюсы:
- Просто работает из коробки;
- В следующих версиях скорее всего добавятся в бету «народные» АПИ для упрощения моделей;
Минусы:
- Нужно делать апдейт окружений;
- Заморозка не только выкидывает внутреннюю структуру модели, но и запоминает устройство и окружение (вероятно делая внутренние оптимизации), где она была заморожена (возможно стоит морозить модель при запуске);
Примерное сравнение
Качество итоговых моделей
Поскольку все очень зависит от сложности вашей задачи и конкретики, точных цифр приводить не буду, скорее отранжирую способы по качеству и сложности достижения:
Краткий итог
Просуммируем все вышесказанное. Если вынести за скобки подкрутку гипер-параметров самой модели, то методы оптимизации можно разделить на два класса (i) работающие почти из коробки, но с меньшим эффектом (ii) и требующие тюнинга, но суммарно дающие больше результата.
К первому можно отнести заморозку и квантизацию. В сумме они дают приятное уменьшение размера модели (4x) и ускорение в районе 2–3x на CPU.
Ко второму можно отнести факторизацию и FFT. Их стоит рассматривать как некую дополнительную оптимизацию, причем они скорее всего исключают друг друга. В сумме с первым типов методов можно получить суммарное снижение размера модели почти на порядок и ускорение тоже почти на порядок. Если при этом еще подкрутить гипер-параметры модели, то «порядок» в принципе не кажется недостижимым.
Как сделать ускорение на два порядка, если честно я не знаю. Возможно вы знаете?