Рекурретные нейронные сети наносят ответный удар
Рекуррентные нейронные сети (RNN), а также ее наследники такие, как LSTM и GRU, когда-то были основными инструментами для работы с последовательными данными. Однако в последние годы они были почти полностью вытеснены трансформерами (восхождение Attention is all you need), которые стали доминировать в областях от обработки естественного языка до компьютерного зрения. В статье »Were RNNs All We Needed?» коллектив авторов Лио Фэн, Фредерик Танг, Мохамед Осама Ахмед, Йошуа Бенжио и Хоссейн Хаджимирсадегх пересматривают потенциал RNN, адаптируя её под параллельные вычисления. Рассмотрим детальнее, в чем же они добились успеха.
Почему стоит цепляться за RNN?
Для начала стоит отметить, что рекуррентные сети обладают важным преимуществом: их требования к памяти линейны относительно длины последовательности на этапе обучения и остаются постоянными во время инференса. В противоположность этому, трансформеры имеют квадратичную сложность по памяти при обучении и линейную во время инференса, что особенно ощутимо на больших последовательностях данных. Это делает RNN более эффективными с точки зрения ресурсов при решении задач с длинными последовательностями.
Однако основным недостатком классических RNN было отсутствие возможности параллелизации обучения. Алгоритм backpropagation through time (BPTT) выполняется последовательно, что делает обучение на длинных последовательностях очень медленным. Именно это ограничение дало преимущество трансформерам, которые могут обучаться параллельно и, несмотря на более высокие требования к вычислительным ресурсам, значительно ускорили процесс обучения.
В чем состоит ключевое изменение?
В последние годы появилось несколько попыток устранить это ограничение RNN, такие как архитектуры LRU, Griffin, RWKV, Mamba и другие. Их объединяет использование алгоритма parallel prefix scan.
Устранив зависимости скрытого состояния от входных данных, забывания и обновления, мы позволяем LSTM и GRU больше не нуждаться в обучении через BPTT. Они могут эффективно обучаться с использованием вышеуказанного алгоритма. На основе этого подхода авторы упростили архитектуры LSTM и GRU, устранив ограничения на диапазон их выходных значений (например, избавившись от функции активации tanh) и обеспечив независимость выходных сигналов от времени по масштабу. Эти изменения привели к созданию «облегченных» версий (minLSTM и minGRU), которые используют значительно меньше параметров по сравнению с традиционными вариантами и могут обучаться параллельно.
Архитектуры minLSTM и minGRU
Для начала напишем оценки количества параметров в предыдущих моделях. Если размер скрытого состояния, то , а . Напомним, что улучшение произошло за счет того, что GRU использовало два типа ворот (gated) против трёх у LSTM, и обновляло скрытый слой напрямую, когда у LSTM было два состояние — ячейки и скрытое.
Рис. 1. Схема minGRU (стр. 4)
Они устранили зависимость update gate и скрытого состояния от предыдущего значения (z_t), полностью исключили reset gate и убрали нелинейную функцию активации tanh, т.к. уже избавлены от зависимости от скрытого слоя. Так они добились использования всего лишь параметров.
Рис. 2. Схема minLSTM (стр. 5)
В случае с LSTM изменения коснулись зависимости input и forget gate от предыдущего скрытого состояния (). Внизу происходит нормализация двух гейтов, и масштаб состояния ячейки LSTM становится независимым от времени. Обеспечив независимость масштаба скрытого состояния от времени, мы также исключаем output gate, который масштабирует скрытое состояние. Без output gate нормализованное скрытое состояние равно состоянию ячейки, что делает наличие как скрытого состояния, так и состояния ячейки избыточным. Таким образом, мы исключаем и состояние ячейки.
В результате «лайт» версия LSTM потребляет меньше параметров () по сравнению с оригинальной архитектурой.
Кратко про сравнение с другими моделями
Минимизированные версии LSTM и GRU показывают впечатляющие результаты: на последовательности в 512 элементов они быстрее оригинальных LSTM и GRU в 235 и 175 раз соответственно. Однако стоит отметить, что такой рост скорости достигается ценой увеличения требований к памяти: minGRU требует на 88% больше памяти, чем классическая GRU (для сравнения Mamba использует на 56% больше чем GRU).
Модели minLSTM и minGRU демонстрируют конкурентоспособные результаты на нескольких задачах. Например, они справились с задачей Selective Copy (авторы взяли ее из статьи про Mamba), в то время как другие конфигурации Mamba, такие как S4 и H3, лишь частично справлялись с задачей.
Рис. 3. Результаты задачи LM (стр. 9)
При проверке на задаче языкового моделирования с использованием nanoGPT (моделирование текста на уровне символов на произведениях Шекспира) minLSTM и minGRU также продемонстрировали отличные результаты, достигая минимального значения функции потерь быстрее, чем трансформеры. MinGRU и minLSTM выходят на оптимум за 575 и 625 шагов обучения соответственно, тогда как трансформеру требуется около 2000+ шагов. Mamba работает чуть хуже, но довольно быстро обучилась — всего лишь за 400 шагов.
Благодарю за внимание. Возможно нас ожидает локальный ренессанс RNN.