Самые быстрые алгоритмы распределенного и асинхронного обучения (с точки зрения теории)

Всем привет! Меня зовут Александр Тюрин, я руководитель группы «Методы оптимизации в машинном обучении» в AIRI и старший преподаватель Сколтеха. Мы с коллегами занимаемся оптимизацией распределённого обучения — это довольно актуальная проблема, учитывая, что современные модели обучаются на многих тысячах GPU.

За последние 2 года нам удалось сделать несколько открытий в асинхронных методах оптимизации, которые мы изложили в 5 статьях [1–5] на NeurIPS и ICLR. В этой статье я расскажу, в чём заключаются особенности распределённого обучения и что нового привнесли в него мы с точки зрения теории.

0e6fbaf2096550fa8b9c78bb8ca4746e.png

Введение

Как вы, наверное, знаете, задачи минимизации функций возникают в самых разных теоретических и практических сферах:  линейное программирование в экономике, нахождение оптимальных параметров при конструировании зданий, поиск самого быстрого маршрута и многих других.

Машинное и глубокое обучение не является исключением. Почти все современные задачи, связанные с обучением больших языковых моделей, моделей для распознания видео и фото, в конце концов, сводятся к следующей математической проблеме:

\min_{x \in \mathbb{R}^d } \left\{f(x) := \frac{1}{m} \sum_{i = 1}^m f(x;i) \right\}, \qquad (1)

где f(x;i) — функция потерь по одному i-му сэмплу набора данных размера m. Здесь мы хотим найти оптимальные веса x^* обучаемой модели.

Двумерная визуализация проблемы (1). В общем случае размерность  может быть очень и очень большой.
Двумерная визуализация проблемы (1). В общем случае размерность d может быть очень и очень большой.
Детальное пояснение

С точки зрения оптимизации данная формула означает, что мы хотим найти точку x^* \in \mathbb{R}^d, которая минимизирует функцию f. Функция f — это среднее функций f(x;i). Эквивалентно, с точки зрения машиного обучения мы хотим найти веса модели x^* \in \mathbb{R}^d размера d, которые минимизируют (1), где f(x;i) — это функция потерь, соответствующая i‑му элементу набора данных. Например, у нас есть набор картинок c классами \{a_i, y_i\}_{i = 1}^m и мы решаем задачу классификации с помощью логистического лосса и некоторой свёрточной нейронной сети. Тогдаf(x;i) = \log(1 + \exp(- y_i \times \text{CNN}(a_i))).

На практике количество данных m может быть очень большим — и даже бесконечным (например, в связи с аугментацией), поэтому мы записываем (1)через математическое ожидание:

\min_{x \in \mathbb{R}^d } \left\{f(x) := \mathbb{E}_{\xi} [f(x;\xi)] \right\}, \qquad (2)

где \xi — это некоторая случайная величина, отвечающая за распределение набора данных.(1)эквивалентно (2), когда \xi имеет дискретное равномерное распределение. Далее мы сфокусируемся на том, как можно решать (2) с помощью численных методов оптимизации.

Стохастический градиентный спуск (SGD)

На самом деле, сейчас ничего не придумали лучше, чем использовать алгоритм стохастического градиентного спуска (SGD), изобретенный еще Робинсом и Монро в 1951 году. SGD — это итеративный метод. На каждом шаге он сэмлируем одну случайную величину \xi^k (картинки из ImageNet / последовательность токенов), считает стохастический градиент с помощью алгоритма backpropagation и делает шаг стохастического градиентного спуска:

x^{k+1} = x^k - \gamma\nabla f(x^k;\xi^k). \qquad (\text{SGD})

Данный алгоритм является краеугольным камнем в задачах обучения. Например, библиотека PyTorch совершает сэмплирование с помощью torch.utils.data.DataLoader, а шаги SGD с помощью torch.optim.SGD.

Визуализация SGD
Визуализация SGD
О других алгоритмах

И да, конечно, есть более практические модификации SGD типа Adam, AdamW или Schedule Free. Но, во‑первых, с теоретической точки зрения их преимущество ещё не доказано и является открытым вопросом. Во‑вторых, все ниже сказанное может быть применимо для них. Для простоты мы фокусируемся на SGD.

Теоретические основы SGD

SGD давно зарекомендовал себя на практике, но для нас не менее важно подкрепить этот алгоритм теорией, которая бы объясняла успех SGD. С теоретической точки зрения этот алгоритм довольно неплохо изучен: существуют, без преувеличения, тысячи статей и десятки книг об SGD (могу посоветовать [6] для начального ознакомления).

Чтобы анализировать поведение SGD на задачеf нам надо ввести два предположения о функции f и о дисперсии шума, возникающего при сэмлировании \xi и подсчете стохастического градиента\nabla f(x;\xi):

  • Функция f является L-гладкой:

    \|\nabla f(x) - \nabla f(y)\| \leq L \|x - y\|.
  • Стохастический градиент \nabla f(x;\xi)несмещенный и имеет ограниченную дисперсию:

    \mathbb{E}[\nabla f(x;\xi)] = \nabla f(x), \qquad \mathbb{E}[\|\nabla f(x;\xi) - \nabla f(x)\|^2] \leq \sigma^2

У нас возникают две константы L и \sigma^2, которые характеризуют проблему. L отвечает за сложность функции f, a \sigma^2— это дисперсия стохастического градиента, которая характеризует разброс собранного набора данных. Интуитивно, чем больше наши данные имеют дисперсию, тем сложнее решать оптимизационную задачу (2).

Мы можем доказать следующую теорему:

Теорема 1. Если функция f является L‑гладкой, и стохастиеческий градиент \nabla f(x;\xi) несмещенный и имеет ограниченную дисперсию, то SGD находит решение с точностью \varepsilon не более чем за 

\qquad\quad\qquad\qquad\qquad\qquad K_{\text{SGD}} = \mathcal{O}\left(\frac{L}{\varepsilon} + \frac{L \sigma^2}{\varepsilon^2}\right)

итераций.

Таким образом, у нас есть теоретические гарантии того, что нам достаточнопосчитать K_{\text{SGD}} стохастических градиентов и сделать K_{\text{SGD}} итераций x^{k+1} = x^k - \gamma\nabla f(x^k;\xi^k), чтобы найти точку, близкую к решению. Параметр \varepsilon отвечает за то, насколько близко мы хотим приблизиться к нему. По формуле для K_{\text{SGD}} видно, что, чем больше шум \sigma^2 и меньше \varepsilon, тем больше потребуется итераций.

Может возникнуть вопрос, а насколько K_{\text{SGD}} хорошая сложность? Можно ли придумать какой‑то ещё метод, имеющий лучшие гарантии?

Чтобы на него ответить, давайте отойдем ненадолго в сторону и вспомним проблему сортировки n элементов в массиве. Для этой задачи имеется множество алгоритмов, включая quicksort и mergesort. Хорошо известно, что mergesort гарантирует решить задачу за \mathcal{O}(n \log n) сравнений. При этом существуют нижние оценки, показывающие, что \Omega(n \log n) нельзя улучшить в худшем случае никаким другим алгоритмом.

Возвращаясь к SGD и K_{\text{SGD}}, можно точно так же доказать, что сложность K_{\text{SGD}}не улучшаема: SGD — это оптимальный метод.

Распределенная оптимизация

Кажется, проблема решена, так как мы знаем, что SGD оптимален, и сложность K_{\text{SGD}}не может быть улучшена. Так оно и есть, но здесь есть важный нюанс. Мы предполагаем, что за раз, в один момент времени мы можем считать только один стохастический градиент. Этот результат полезен в случае, когда мы запускаем алгоритм на одном компьютере с одной доступной GPU/CPU. Но современные модели обучаются на многих тысячах GPU.

Далее мы будем предполагать, что у нас есть n устройств (это могут быть GPU, CPU, или серверы), и они могут делать подсчеты стохастических градиентов параллельно.

Minibatch SGD

Давайте начнем с самого простого метода — Minibatch SGD. В отличие от SGD, на каждом шаге каждое устройство сэмлирует одну случайную величину \xi^k_i и считает стохастический градиент. Далее эти градиенты агрегируются (например, с помощью all‑reduce), чтобы сделать шаг

x^{k+1} = x^k - \gamma \frac{1}{n} \sum_{i=1}^n \nabla f(x^k;\xi^k_i). \qquad (\text{Minibatch SGD})

В первом приближении так работает torch.nn.parallel.DistributedDataParallel в PyTorch. Для этого метода можно доказать, что он сходится за 

K_{\text{MB}} = \mathcal{O}\left(\frac{L}{\varepsilon} + \frac{L \sigma^2}{n \varepsilon^2}\right)

итераций. Обратите внимание, что последний член в формуле в n раз меньше, чем K_{\text{SGD}}. Это улучшение ожидаемо в связи с тем, что задачу решают n устройств вместо одного. Хорошо, а есть ли другие подходы и можно ли сделать что‑то лучше?

Схема работы Minibatch SGD. Обратите внимание, что часть устройств простаивает, пока ждёт самого медленного.
Схема работы Minibatch SGD. Обратите внимание, что часть устройств простаивает, пока ждёт самого медленного.

Asynchronous SGD

В Minibatch SGD необходимо синхронизировать все устройства, чтобы сделать шаг. Идея: давайте не дожидаться других устройств и сразу делать шаг, как только когда хотя бы одно устройство посчитает стохастический градиент. Ниже представлена имплементация этой идеи:

Алгоритм работы Asynchronous SGD
Алгоритм работы Asynchronous SGD

В этом методе стоит выделить шаг

x^{k+1} = x^k - \gamma_k \nabla f(x^{k - \delta^k};\xi^{k - \delta^k}_i).

Обратите внимание, что из‑за асинхронности подсчета стохастических градиентов мы берем градиент в точке x^{k-\delta^k} вместо x^k, где \delta^k — это задержка вычислений.

Схема работы Asynchronous SGD
Схема работы Asynchronous SGD

Сравнивая схему работы Minibatch SGD и Asynchronous SGD, видно, что Asynchronous SGD лучше всего использует устройства, так как он не ждет медленных вычислений, и не возникают «дырки» из‑за синхронизаций.

Формализация временной сложности

SGD — оптимальный метод со сложностью K_{\text{SGD}}, когда мы работаем с одним устройствам. Возникает логичный вопрос: какой метод будет оптимальный при наличии n устройств и чему будет равна соответствующая сложность алгоритма? Мы уже представили Minibatch SGD и Asynchronous SGD. Но можем ли мы придумать более быстрые методы? Можно ли доказать нижнюю оценку сложности? Оказывается, да, можно!

Давайте предположим, что у каждого устройства время вычисления стохастического градиента фиксированное, но разное. Следующее предположение можно обобщить на случай, когда времена вычислений не ограничены фиксированными значениями и могут изменяться произвольным/хаотичным образом во времени [5]. Но для краткости повествования давайте остановимся на более простом случае.

Предполагается, что устройство i выполняет вычисление одного стохастического градиента не более чем за τ_i секунд (без ограничения общности примем, что 0< \tau_1 \leq \tau_2 \leq \cdots \leq \tau_n). Используя это предположение, давайте сравним Minibatch SGD и Asynchronous SGD.

Мы уже знаем, что Minibatch SGD требует K_{\text{MB}} итераций, чтобы найти точку, близкую к решению. Время одной итерации занимает\tau_n = \max \tau_i, так как мы ждем самого медленного устройства. Таким образом, общее время работы равно

T_{\text{MB}} = \tau_n \times K_{\text{MB}} = \mathcal{O}\left(\tau_n  \left(\frac{L}{\varepsilon} + \frac{L \sigma^2}{n \varepsilon^2}\right)\right)

секунд. Такое же упражнение можно проделать с Asynchronous SGD и показать, что его время работы равно

T_{\text{A}} = \mathcal{O}\left(\left(\frac{1}{n} \sum\limits_{i=1}^{n} \frac{1}{\tau_{i}}\right)^{-1} \left(\frac{L}{\varepsilon} + \frac{L \sigma^2}{n \varepsilon^2}\right)\right)

секунд [8,1]. Ранее мы обсуждали, что Asynchronous SGD интуитивно лучше использует устройства. Так вот, полученная сложность T_{\text{A}} формализует это. Легко показать, что T_{\text{A}} \leq T_{\text{MB}}. Более того, T_{\text{A}} может быть сильно меньше, так как оно зависит от \tau_i как среднее гармоническое, в то время как T_\text{MB} зависит от максимума \tau_i. Например, возьмите\tau_n \to \infty, тогда T_{\text{MB}} \to \infty в то время как T_{\text{A}} < \infty.

Оптимальные методы распределенной оптимизации

Оказывается, что T_{\text{A}}— это не предел. В работе [1] мы предлагаем новый метод, Rennala SGD, который может быть представлен следующим образом:

Алгоритм работы Rennala SGD
Алгоритм работы Rennala SGD

Rennala SGD является полуасинхронным и может рассматриваться как Minibatch SGD в сочетании с асинхронным механизмом сбора мини‑батчей. Из‑за условия \delta^{k_b} = 0, которое игнорирует все стохастические градиенты, вычисленные в предыдущих точках, Rennala SGD выполняет следующий шаг:

x^{k+1} = x^{k} - \gamma \frac{1}{B} \sum_{j=1}^{B} \nabla f(x^{k}; \xi^{k_j}).

Обратите внимание, что все устройства вычисляют стохастические градиенты в одной и той же точке x^k, при этом устройство с номером i вычисляет B_i \geq 0 градиентов так, что \sum B_i = B. Можно показать, что сложность работы Rennala SGD равна

T_{\text{R}} = \mathcal{O}\left(\min\limits_{m \in [n]} \left[\left(\frac{1}{m} \sum\limits_{i=1}^{m} \frac{1}{\tau_{i}}\right)^{-1} \left(\frac{L \Delta}{\varepsilon} + \frac{\sigma^2 L \Delta}{m \varepsilon^2}\right)\right]\right).

Более того, мы показали [1,5], что T_{\text{R}}— это оптимальная сложность, которая не может быть улучшена никаким другим алгоритмом. Легко доказать, что T_{\text{R}} \leq T_{\text{A}} \leq T_{\text{MB}}. Кроме этого, T_{\text{R}} может быть сколь угодно меньше.

Совсем недавно [7] мы придумали еще один оптимальный метод, Ringmaster ASGD. Он более похож на Asynchronous SGD. Но в нем есть небольшая, но важная модификация, которая позволила нам сделать полностью асинхронный метод оптимальным.

Учет времени коммуникации между устройствами

Проблема разработки теоретически оптимальных методов с n устройствами, как мне кажется, решена. Но есть много факторов, которые я не учел в повествовании.

Например, когда мы сравниваем методы, мы только учитываем времена \tau_iвычисленийстохастических градиентов. В реальности нам надо учитывать не только время вычислений, но и время передачи информации между устройствами. Чтобы выполнять шаги во всех методах, устройства должны передавать вектора друг другу, что тоже может занимать ненулевое время. В работах [2,4] мы точно также формализовали время передачи, разработали новые методы и доказали их оптимальность.

Заключение

В этой статье я постарался написать введение в текущее теоретическое понимание асинхронных и параллельных методов. Начал с основ: напомнил об SGD и показал его оптимальность, когда мы работаем с одним устройством. Далее, я представил актуальные результаты по разработке асинхронных методов. Используя предположение, что устройства требуют разное время вычислений стохастических градиентов, я сравнил различные подходы (Minibatch SGD, Asynchronous SGD, Rennala SGD) и привел оптимальную временную сложность.

Список литературы

[1]: Tyurin A., Richtárik P. Optimal Time Complexities of Parallel Stochastic Optimization Methods Under a Fixed Computation Model // In Advances in Neural Information Processing Systems 36 (NeurIPS 2023)

[2]: Tyurin A., Pozzi M., Ilin I., Richtárik P. Shadowheart SGD: Distributed Asynchronous SGD with Optimal Time Complexity Under Arbitrary Computation and Communication Heterogeneity // In Advances in Neural Information Processing Systems 37 (NeurIPS 2024)

[3]: Tyurin A., Gruntkowska K., Richtárik P. Freya PAGE: First Optimal Time Complexity for Large-Scale Nonconvex Finite-Sum Optimization with Heterogeneous Asynchronous Computations // In Advances in Neural Information Processing Systems 37 (NeurIPS 2024)

[4]: Tyurin A., Richtárik P. On the Optimal Time Complexities in Decentralized Stochastic Asynchronous Optimization // In Advances in Neural Information Processing Systems 37 (NeurIPS 2024)

[5]: Tyurin A. Tight Time Complexities in Parallel Stochastic Optimization with Arbitrary Computation Dynamics // In International Conference on Learning Representations. 2025. (ICLR 2025)

[6]:  Bubeck S. Convex optimization: Algorithms and complexity. // Foundations and Trends® in Machine Learning 8.3–4 (2015): 231–357.

[7]: Maranjyan A., Tyurin A., Richtárik P. Ringmaster ASGD: The First Asynchronous SGD with Optimal Time Complexity // arXiv:2501.16168

[8]: Mishchenko K., Bach F., Even M., Woodworth B. Asynchronous SGD Beats Minibatch SGD Under Arbitrary Delays // In Advances in Neural Information Processing Systems 35 (NeurIPS 2022)

© Habrahabr.ru