Опыт дистилляции моделей распознавания речи

В области распознавания речи, если смотреть на SOTA решения, то заметно, что эти модели хоть и не сопоставимы по размерам с современными языковыми моделями на десятки миллиардов параметров, но инферить и тем более обучать эти модели на локальном устройстве без нормального железа довольно проблематично. А маленькие модели, хоть и летают на устройствах, но похвастаться хорошим качеством распознавания не могут.

Для решения такой пробемы нам может помочь дистилляция, хотя есть и другие техники ускорения обучения и уменьшения затрат памяти типа квантизации, прунинга и др. подходы, но в этой статье речь пойдет именно про неё.

Дистилляция по своей сути представляет процесс переноса «знаний» от модели «учителя» к модели «ученика». И даже не обязательно, чтобы ученик был меньше учителя. Более подробно и формально про дистилляцию можно почитать здесь. В этой же статье по большей части будет рассказана небольшая история моего проекта, в которой очень важное место играла дистилляция. Если очень вкратце про дистилляцию, то в результате её мы хотим, чтобы наш ученик выучил то же распределение, что и учитель, чтобы он учился подражать учителю, выучивая то, что самостоятельно выучить бы не смог. Большому учителю проще выучивать сложные и не очень зависимости, а затем мы эти выученные зависимости передаём ученику.

В ходе нашего проекта дистилляцию мы использовали аж два раза, в первый раз мы обучали большую модель учителя, для переноса знаний от другой большой модели. Во второй уже модель ученика была действительно маленькой моделью (меньше в 161 раз, и которая работатет приблизительно в 500 раз быстрее).

Дистилляция 1

Какая была задача:

Нам нужно было иметь отличное качество распознавание на своём датасете, и не просесть по метрикам WER и CERна датасете общего назначения (в качестве таких можно взять Common Voice или Golos.

Изначально мы имели довольно хорошую модель для распознавания речи, которую можно взять в открытом доступе и с которой не нужно было устраивать танцы с бубном, чтобы инферить или дообучать её. В качестве одной из таких моделей можно взять например эту. Мы использовали модель типа Wav2Vec2 . Ей довольно несложно дообучать, в сравнении с тем же виспером и к тому же у нас на начало проекта уже был неплохо обученный экземпляр такой модели)

Чтобы получить модель, которая хорошо игнорирует шумы, нами был написал аугментатор с тремя типами шумов: речеподобные, домашние шумы и звуки животных. Аугментатор можно найти по ссылке. Шумы накладывались в 2 режимах, либо шум накладывался в случайное место на аудио файле, либо он повторялся во время всего аудио сигнала. Чтобы понять на сколько запись зашумлена и на сколько её нужно зашумлять мы использовали SNR

snr = 10 \log_{10}\frac{P_{signal}^2}{P_{noise}^2}

где P_{signal}— мощность чистого звукового сигнала, аP_{noise}— мощность шума. Чем ниже эта величина, тем более зашумлен сигнал. Нулевое значение означает, что мощность сигнала равна мощности шума и если это речеподобный шум, то вне контекста понять, где шум, а где речь нельзя.

На первых этапах у нас возникала проблема, что нельзя просто так взять и дообучить модель на новых данных с шумами. Модель просто не могла ничего выучить, мы посчитали, что это из-за того, что мы заставляем модель прыгать с места в карьер давая ей слишком сложные сигналы, или если она выучивала что-то на наших данных, то качество на датасете общего назначения падало. Поэтому нами было принято решение попробовать дистилляцию.

В качестве учителя и ученика изначально была взята одна и та же модель. Только модель учителя мы зафиксировали, а модель ученика обучали. В качестве distillation_loss мы взяли KLDivLoss. Часто в качестве бейзлайна используют MSE лосс для дистилляции, но мы решили сразу использовать дивергенцию Кульбака Лейблера, т.к. те минусы, которые есть у этой функции потерь нам либо не страшны, либо не интересуют. Распределения у обоих моделей находятся близко друг к другу (в начале так они вообще совпадают), а отсутствие св-ва симметричности нам не вредит, ну и на практике KLDivLoss показывает себя лучше)

Ну и в качестве таргетной лосс функции мы взяли CTCLoss, который хорошо знаком всем кто занимается распознаванием речи. Суть её в том, что мы после того как получили наши логиты (распределения символов для каждого момента времени) матчим таргетную последовательность, фиксированной длинны со всеми возможными последовательностями (они могут иметь разные длины), которые могут получиться из логитов. Также к слову, мы пробовали использовать подход с выравниванием таргетов (force alignment), чтобы для каждого фрейма по времени у нас была метка класса (символа), и рассматривали задачу классификации используя CrossEntropyLoss. Но потом поняли, что этот подход был некорректен из-за того, что мы фиксировали только одну цепочку и говорили модели предсказывать именно её, лишая модель возможности пропускать градиенты по другим цепочкам.

Следующий шаг включал в себя подготовку и подачу данных в модель. Изначальная модель (или модель учителя, на всех этапах) хорошо распознавала чистую речь, но с шумами справлялась крайне плохо. Метрика WER на чистых сигналах была 20%, а на зашумленных она возрастала до 60%, что было, мягко говоря, неприемлимо. Главной фичей, которую мы провернули, чтобы модель ученика все таки научилась распознавать шумы, был подход, когда мы на модель учителя подаем чистые данные, а на модель ученика точно такие же данные, но с наложенными поверх шумами. Таким образом мы заставляем модель выучивать распределение чистых данных, игнорируя шумы.

baf4f0291f11a60767710a04a049b628.png

Но и это ещё не всё. Нами была выдвинута гипотеза о том, что модель должна учиться поэтапно, переходить от более простых сигналов к более сложным. Конечно, определение что такое сложность аудио сигнала — тема для отдельного обсуждения, но мы пошли самым простым путём и меняли значение SNR. Начинали с 30 и постепенно уменьшали этот показатель, тем самым увеличивая мощность шума по отношению к чистому сигналу. И результат не заставил себя долго ждать. При таком обучении нам удалось достичь следующих результатов:

SNR

WER

CER

base_dataset

15

21%

15%

specialized_dataset

15

11.99%

3.18%

0

22%

6%

Т.е. нам удалось приблизиться к показателям изначальной модели на общем домене, но на очень шумных данных. Нас этот результат польностью устроил и мы пошли дальше. А дельше нам нужно было эту модель как-то уменьшить? Как её можно уменьшить? Конечно ещё одной дистилляцией!

Дистилляция 2

После того, как нами была получена мощная и вообще самая класная (лично по нашему мнению) модель, мы приступили ко второй части. В качестве ученика уже брать модель типа wav2vec2 не было возможности. Она не подходила нам ни по вычислительным соображениями, ни по архитекуре, т.к. мы хотелм в будущем допилить стриминговый инференс. А для стриминга есть модели и по лучше, например конформер!

  1. Доступна из коробки в торче.

  2. Модель типа трансформер.

  3. Может быть очень маленькой.

  4. Легко обучать. (для меня, наверное, самый важный пункт)

С точки зрения дистилляции обучение конформера выглядело очень похоже на обучение учителя другим учителем.

  1. Также чистый сигнал подаем на учителя (решили повторить этот мув, хотя модель учителя уже достаточно хорошо справляется с шумами), а ученик получает зашумленный

  2. Постепенно уменьшали SNR.

  3. В качестве distillation_loss также брали KLDivLoss.

В результате обучения мы получили следующие результаты:

CER

WER

specialized_dataset

6.7%

29%

specialized_dataset_with_noises

3.5%

15%

Что сопоставимо с результатами, которые были получены большой моделью на базе wav2vec2. Победа!

Заключение и перспективы

По итогу, изначально имели необученную модель на 315 млн. параметров весом в 1.2 гб, а получили маленькую, быструю модель весом 8 мб, которая полностью устраивает нас по качеству.

Возможно, у вас возникли вопросы о том, как мы можем говорить, что наша модель такая крутая, если мы её толком не протестировали, но я вам отвечу, что согласен с вами). Наша главная задача была получить не универсальную модель, устойчивую к сдвигу данных, а спроецировать те знания, которые мы считаем важными в большой модели в маленькую модель, чтобы выиграть по скорости и качеству. И это у нас очень даже получилось. Остается ещё очень много вопросов, касательно различных гипотез, которые мы выдвигали в ходе экспериментов, которые мы обязательно проверим, но уже не в этот раз. А что касается дистилляции, то это очень крутая техника, в которой и по сей день остается куча не решенных вопросов. Но, лично для меня, самое главное в ней, что её можно попробовать не только для исследовательских целях, но для практических нужд.

P.S. Спасибо, что дочитали до конца! Это моя первая статья на хабре, поэтому если у вас будут любые вопросы или пожелания, то я буду рад ответить на них в комментариях. Изначально хотел написать рассуждение почему и зачем мы что-то делали, но потом понял, что очень много гипотез мы ставили во время обучения, которые должным образом не были проверены и рассуждать на тему отчего это хорошо работает я не стал, расчитываю на ваше понимание!

© Habrahabr.ru