Умножение Монтгомери
Деление целых чисел — это долго и сложно. Вычислять остаток от деления — нисколько не проще. При этом в спортивном программировании, да и в прикладной математике типа криптографии, задача умножения чисел по модулю встречается повсеместно.
Один из вариантов эффективного решения — умножать по модулю, вообще при этом не используя операции деления, с помощью алгоритма Монтгомери.
Про него я и хотел бы поговорить.
Постановка задачи
Для простоты изложения, я буду приводить алгоритм для 32-битных целых чисел. Про более широкие числа тоже будут определённые заметки, просто без кода.
Итак, положим, что у нас есть нечётное число int N
, по модулю которого необходимо производить вычисления. Например, это может быть знакомое многим число 1_000_000_007
. Алгоритмы для чётных N
можно свести к алгоритмам для нечётных, просто потребуется заметное количество дополнительного кода для контроля младших бит, так что оставим это на другой раз.
Наша цель — иметь эффективный аналог следующего кода:
int mulMod(int a, int b, int N) {
long longA = Integer.toUnsignedLong(a);
long longB = Integer.toUnsignedLong(b);
long longN = Integer.toUnsignedLong(N);
return (int) Long.remainderUnsigned(longA * longB, longN);
}
Код будем писать для беззнаковых 32-битных целых, так интереснее. Для них в Java мы вынуждены использовать знаковый int
и специальные методы. Конвертация в long
же необходима для контроля за переполнением при умножении.
Часто на протяжении статьи я буду передавать 32-битные числа как long
только ради того, чтобы не повторять Integer.toUnsignedLong
множество раз. Это вопрос в основном краткости кода, а не того, что мне нужны беззнаковые числа. Да и конвертация в long
не бесплатная, если что.
Так же у нас есть число R=232, обозначающее разрядность чисел, с которыми мы работаем в коде. Именования N
и R
будут сохранены до конца статьи.
В теории можно брать любое значение R
, большее N
и взаимно-простое с ним. На практике удобно брать такое R
, что умножение по модулю R
эффективно реализуемо в коде. Умножение 32-битных чисел по модулю 232 в Java (да и в других языках) — это просто умножение в int
, но в теории мы могли бы брать и 264, и 2128, и другие разрядности.
Форма Монтгомери
Вместо того, чтобы работать с непосредственно значениями, над которыми необходимо производить вычисления, мы будем использовать так называемую «форму Монтгомери» для этих значений:
От условного a % N
она отличается предварительным умножением на R
. Данная форма является линейным отображением, т.е. её можно спокойно складывать и вычитать:
А с умножением не всё так просто, ведь . Вместо этого имеем следующее:
Другими словами, . Если научиться эффективно вычислять , то убьём двух зайцев одним ударом — и умножать научимся, и сможем преобразовывать результат в первоначальный вид.
Также было бы здорово уметь вычислять саму форму Монтгомери, не прибегая к делению. Об этом мы тоже поговорим, просто чуть позже.
Редукция Монтгомери (REDC)
Вычисление принято называть редукцией. На просторах интернета существует 2 версии данного алгоритма, мало отличающиеся друг от друга.
Параметр я назвал T
для того, чтобы текст был ближе к Википедии. Он может быть равен форме Монтгомери какого-то числа, если нам нужно преобразовать его в исходную форму (вычисление ), либо же он может быть равен произведению двух форм Монтгомери (вычисление ). Т.е. в нашем примере T
— это 64-битное число, т.к. нельзя терять биты переполнения от произведения.
Первым я опишу как раз алгоритм, который приведён в Википедии. Для начала положим, что мы нашли число такое, что , которое в коде я буду обозначать M
. С помощью него нужно выполнить следующую процедуру:
Операции вроде x mod R
и x / R
— тривиальны, и для R=232 представляют собойx & 0xFFFFFFFFL
и x >>> 32
соответственно. Для других разрядностей это тоже будут конъюнкция с маской и сдвиг. Если перевести это на привычный язык программного кода, то получится следующее:
static long redc(long T, long N, int M) {
long m = Integer.toUnsignedLong(((int) T) * M);
long t = (T + m * N) >>> 32;
return t >= N ? t - N : t;
}
Помним, что N
передаётся как long
для того, чтобы каждый раз не приходилось вызывать Integer.toUnsignedLong
. M
же передавать в виде long
необязательно.
Обоснование правильности приведённых формул довольно-таки скучное, при желании его можно найти самостоятельно на той же Википедии. Главное тут вот что:
Именно тот факт, что 0 <= t < 2*N
, позволяет заменить деление сравнением и вычитанием.
Для второй версии алгоритма нам понадобится другое значение M
(), а именно такое, которое соответствует уравнению , без минуса. В этом случае код будет следующим:
static long redc(long T, long N, int M) {
long m = Integer.toUnsignedLong(((int) T) * M);
long t = (T - m * N) >>> 32;
return t < 0 ? t + N : t;
}
Видите разницу? + m * N
заменилось на - m * N
, ведь у m
из-за умножение на M
, условно говоря, «противоположный знак». Кавычки потому, что все числа тут неотрицательные. Кроме t
, конечно — его знак нужен.
Это в свою очередь приводит к тому, что -N < t < N
, и здесь уже деление заменяется на сложение, а не вычитание.
Интуитивно кажется, что второй алгоритм на практике чуть более эффективен, ведь выглядит проще, чем , отличие аж на целый минус!
Как вычислить M
M
— это значение, которое для каждого N
нужно вычислить лишь один раз. Тем не менее, подходы к его вычислению весьма интересны, поэтому я заострю на них особое внимание.
Расширенный алгоритм Евклида
Обычно в интернете для этого предлагают использовать расширенный алгоритм Евклида. Мол «какой‑то алгоритм существует, а дальше сами разбирайтесь». Подробно объяснять его, опять же, долго и очень скучно. Плюс он должен входить в школьную программу по математике, насколько я помню.
Суть здесь в том, чтобы найти представление gcd(a, b) = s * a + t * b
, где gcd
— наибольший общий делитель (greatest common divisor). В нашем случае в качестве a
и b
выступают R
и N
.
По условию gcd(R, N) = 1
. Более того, в 32-битной арифметике умножение на R
всегда даёт 0
, ведь оно эквивалентно сдвигу влево на 32 бита. Учитывая это, алгоритм Евклида фактически позволит нам найти представление 1 = t * N
, т.е. буквально обратное по модулю R
. Приведу код, чтобы не пришлось писать самим:
static int inverseExtendedEuclid(long N) {
long old_r = 1L << 32, r = N;
long old_t = 0, t = 1;
while (r != 0) {
long q = old_r / r;
long tmp0 = old_r;
old_r = r;
r = tmp0 - q * r;
long tmp1 = old_t;
old_t = t;
t = tmp1 - q * t;
// Контроль инвариантов, для понятности.
assert (int) r == (int) t * N;
assert (int) old_r == (int) old_t * N;
}
// Ещё один контроль инвариантов.
assert r == 0; // Условие выхода из цикла, мог бы и не писать.
assert old_r == 1; // Значение gcd(R, N).
return (int) old_t;
}
Малая теорема Ферма
Согласно известной теореме Эйлера, являющейся обобщением малой теоремы Ферма, для любых двух взаимно простых чисел (в нашем случае N
и R
) выполнено:
— это значение функции Эйлера, так же известной под именем totient
, и равно оно количеству целых чисел, меньших R
и при этом взаимно простых с R
. Если — простое число, то , данная формула вполне известна и по ссылке на Википедии можно найти ей объяснение. Привожу я её для того, чтобы мы смогли вычислить , она как раз подходит:
Раз нам известно, что , то легко заметить, что . А значит можно воспользоваться быстрым возведением в степень, заранее зная, что 231–1 состоит из 31-го единичного бита:
static int inverseEuler(int N) {
int M = N;
// 30 потому, что 1-я итерация уже выполнена в момент присваивания выше.
for (int i = 0; i < 30; i++) {
M = M * M * N;
}
return M;
}
Здесь N
можно передать как int
, поскольку знак N
никак не влияет на вычисление 32-битных произведений.
Данный код уже не содержит делений, зато количество умножений в нём достаточно большое — O(log(R))
, 60 штук для 32-битных чисел. Это слишком много.
Ещё можно добавить, что раз , то на него можно спокойно умножать:
При явном возведении в степень эта формула приведёт в двум лишним умножениям. Тем не менее, именно такое представление является базисом для более продвинутых алгоритмов.
Метод Ньютона
Алгоритм берёт своё название от известного метода Ньютона, но дословно его не повторяет.
Суть в том, чтобы построить рекуррентную формулу, сходящуюся к верному решению. Имея изначально формулу , мы можем преобразовать её сперва в , а уже после этого в . На основании этого равенства можно построить следующее рекуррентное соотношение:
Для того, чтобы найти , достаточно всего 5-ти итераций, или если точнее, то итераций, что уже значительно меньше, чем в предыдущем способе. Для 64-битных чисел было бы 6 итераций, а для 128-битных — 7. Код для int
:
static int inverseNewton(int N) {
int M = 2 - N;
// 4 итерации потому, что 1-я итерация уже выполнена в момент присваивания выше.
for (int i = 0; i < 4; i++) {
M = M * (2 - N * M);
}
return M;
}
Ссылку на доказательство предоставлю позже, чтобы вы по ней спойлеров не начитались.
Алгоритм Дюма
Метод Ньютона — хороший, и с точки зрения оценки трудоёмкости — оптимальный. Всё, что пойдёт далее — это точечные улучшения. Первое из них выглядит, на первый взгляд, не интуитивно. Да и на второй тоже:
static int inverseDumas(int N) {
int M = 2 - N;
int y = N - 1;
// 4 итерации потому, что 1-я итерация уже выполнена в момент присваивания выше.
for (int i = 0; i < 4; i++) {
y = y * y;
M = M * (1 + y);
}
return M;
}
Формулы для y
подобраны таким образом, чтобы M
на каждой итерации был точно таким же, как в методе Ньютона, это несложно доказать с помощью математической индукции. Цель такого изменения станет понятна, если вручную развернуть цикл:
static int inverseDumas(int N) {
int M = 2 - N;
int y = N - 1;
y = y * y;
M = M * (1 + y);
y = y * y;
M = M * (1 + y);
y = y * y;
M = M * (1 + y);
y = y * y;
M = M * (1 + y);
return M;
}
В методе Ньютона выражение M = M * (2 - N * M)
обязано быть вычисленным по порядку, ведь результат каждой из операций является операндом для следующей операции. В методе Дюма же есть 2 независимых выражения — M = M * (1 + y)
и следующий за ним y = y * y
, и ваш процессор может вычислять их буквально одновременно, значительно ускоряя весь процесс.
Оптимизация первых итераций
Описывая трудоёмкость, я уже упомянул, что, увеличивая число итераций, можно получить алгоритмы для 64 или 128-битных чисел, к примеру. Аналогично этому, ничего не мешает нам уменьшать число итераций и получать алгоритмы для 16, 8 или даже 4-битных чисел:
static int inverseDumas4(int N) {
int M = 2 - N;
int y = N - 1;
y = y * y;
M = M * (1 + y);
// Можно обрезать лишние биты, если нужно.
// А можно и не обрезать, если не нужно.
return M & 0xF;
}
Это, конечно, забавное свойство, но зачем оно нам? А затем, что для 4-битных чисел есть более эффективные алгоритмы!
Первый такой алгоритм был предложен самим Монтгомери:
int inverseMontgomery4(int N) {
return 3 * N ^ 2;
}
Умножение на 3 можно вообще не считать за умножение, ведь оно эквивалентно выражению N + N + N
, которое наверняка вычисляется более эффективно.
Альтернативный вариант был найден неуказанным автором с помощью брутфорса:
int inverseBruteforce4(int N) {
return (N ^ 2) - 2 * N;
}
Скобки обязательны, поскольку у -
приоритет выше, чем у ^
. Комментарий про умножение тут имеет ещё больше смысла, его ещё и на сдвиг заменить можно.
Оба этих алгоритма доказываются полным перебором. К счастью, существует всего лишь 16 различных 4-битных чисел, так что перебор выходит небольшим.
Если объединить всё сказанное, то получим следующий код:
private static int inverseDumas(int N) {
int M = 3 * N ^ 2;
int y = 1 - N * M;
M = M * (1 + y);
y = y * y;
M = M * (1 + y);
y = y * y;
M = M * (1 + y);
return M;
}
Вот и обещанная ссылка на источник, из которого взяты эти методы, там же можно найти доказательства. Почему в начале y = 1 - N * M
там тоже есть, пересказывать не буду.
Судя по тому, что там сказано, это самый быстрый из известных способов найти , что, на мой взгляд, очень круто!
Как вычислить форму Монтгомери
Забыл сказать важную вещь: если умножение по модулю нужно сделать всего несколько раз, то, может, и алгоритм Монтгомери не нужен, ведь вычисление формы Монтгомери имеет свою цену.
Но вот что ещё интересно: в зависимости от того, для скольки различных чисел нужно найти форму Монтгомери, мы тоже можем использовать для этого разные подходы.
Самый очевидный подход — тупо взять и посчитать:
static long m(int a, long N) {
// a * R % N;
return Long.remainderUnsigned(((long) a) << 32, N);
}
Для данного кода существует один интересный частный случай — вычисление :
static long m1(int N) {
return Integer.remainderUnsigned(-1, N) + 1;
}
Данный код является лишь адаптацией формулы , которая в свою очередь справедлива потому что , а значит
Тут имеется ввиду округление вверх, так же известное как ceil
. Повторяя рассуждения, которые были чуть выше, можно заметить, что это значение совпадает с
Напомню, что 2128–1 — это 128-битное число, состоящее из 128 единиц.
Результат выражения в коде по ссылке называют M
, я повторю это именование, думаю, путаницы не будет. Алгоритмы деления длинных чисел явно выходят за рамки данной статьи, их много и они сложные. Поэтому в данном конкретном случае обойдёмся чем-нибудь простым, например ручным делением с помощью цикла.
static UUID M(long N) {
// Старшие байты деления совпадут с делением старшей половины числа на N.
long msb = Long.divideUnsigned(-1L, N);
// Long.remainderUnsigned(-1L, N);
long r = -1L - N * msb;
// Младшие биты будем накапливать в цикле.
long lsb = 0;
for (int i = 0; i < 64; i++) {
lsb <<= 1;
r = (r << 1) + 1;
if (N < r) {
r -= N;
lsb |= 1;
}
}
// Округление вверх, тот самый ceil.
lsb++;
// Обработка переполнения младших байтов.
if (lsb == 0) msb++;
// Да, использую UUID для 128-битных чисел, и что?
return new UUID(msb, lsb);
}
Неэффективный алгоритм деления не особо страшен, поскольку это буквально единственное деление, которое нам нужно.
Что дальше делать с этим числом, спросите? Подставить его в эту формулу:
Для того, чтобы вычислить эту формулу и ничего не потерять, нам понадобится:
— 64-битное число.
— это 192-битное число, но нам достаточно вычислить младшие 128 бит. Остальные исчезнут при вычислении .
— это тоже 192-битное число, но нам достаточно вычислить старшие 64 бита, остальные будут проигнорированы при делении на 2128.
Из результирующих 64 бит нам в реальности понадобятся только младшие 32. Тем не менее, я раньше уже оговаривал, что беззнаковые 32 битные целые нам проще будет хранить в
long
, чтобы реже вызыватьInteger.toUnsignedLong
.
В коде это будет выглядеть следующим образом:
static long m(long a, long N, UUID M) {
long lsb_M = M.getLeastSignificantBits();
long msb_M = M.getMostSignificantBits();
// A = a * R;
long A = a << 32;
// L = A * M;
long lsb_L = lsb_M * A;
long msb_L = msb_M * A + Math.unsignedMultiplyHigh(lsb_M, A);
// Дальше идёт то, что у Лемира названо "mul128_u64", а именно вычисление
// старших 64 бит произведения 128-битной L и 64-битной N.
long lsb_Bottom = Math.unsignedMultiplyHigh(lsb_L, N);
long lsb_Top = msb_L * N;
long msb_Top = Math.unsignedMultiplyHigh(msb_L, N);
if ((lsb_Bottom & lsb_Top) < 0 || (lsb_Bottom | lsb_Top) < 0 && (lsb_Bottom + lsb_Top) >= 0) {
msb_Top++;
}
return msb_Top & 0xFFFFFFFFL;
}
Данный код требует некоторых пояснений, поскольку он не очень простой.
Во-первых, как вычисляется . Если положить, что — разложение на старшие и младшие байты, то получим
Именно это написано в коде.
Во-вторых, Math.unsignedMultiplyHigh
может показаться чем-то незнакомым. Дело в том, что он был добавлен в Java только в 18-й версии, так что если хотите позапускать код под версией поменьше, то воспользуйтесь, пожалуйста, этой копией реализации:
public static long unsignedMultiplyHigh(long x, long y) {
// Compute via multiplyHigh() to leverage the intrinsic
long result = Math.multiplyHigh(x, y);
result += (y & (x >> 63)); // equivalent to `if (x < 0) result += y;`
result += (x & (y >> 63)); // equivalent to `if (y < 0) result += x;`
return result;
}
В-третьих, что происходит в mul128_u64
. Суть там та же, что и при вычислении , просто гораздо больше движущихся частей. По этой причине формулу я писать не буду, сильно уж большая.
В любом случае, самая загадочная часть данного кода — это условие в if
.
if ((a & b) < 0 || (a | b) < 0 && (a + b) >= 0) ...
Данный код проверяет, случится ли целочисленное переполнение, если вычислить сумму a
и b
как беззнаковое число. Это случается в двух случаях:
Когда оба числа отрицательные, т.е. имеют в качества старшего бита единицы — это левая часть дизъюнкции.
Когда оба числа имеют разный знак, т.е. у них отличаются старшие биты, и сумма чисел имеет в старшем бите
0
— это правая часть дизъюнкции.
Наверняка существует более элегантный способ проверки переполнения, но мне он в голову не пришёл. Если условие выполнено, то сложение чисел приводит к переполнению и выставлению carry
флага, который нужно не забыть прибавить к старшим байтам вычисляемого значения. Это и делается кодом msb_Top++
. Надеюсь, ничего не напутал, тесты у меня вроде проходят.
Удивительно, но этот код действительно быстрее, чем 64-битное деление, причём минимум процентов на 25. Во всяком случае, согласно моим тестам. С трудом верится, но что поделать. Данный алгоритм обобщается и на другие разрядности, вероятно там он тоже должен побеждать деление.
Заключение
В первую очередь, хочу выразить благодарность пользователям @encyclopedist, @YouDontKnowMe и @Ruimteschroot, оставившим комментарии к моей предыдущей статье и побудивших меня разобраться в теме.
Что в итоге. По моему опыту, в задачах спортивного программирования всё зависит от трудоёмкости алгоритма, а не от низкоуровнего тюнинга. Много лет я писал a % N
и не испытывал никаких проблем, укладываясь в нужные лимиты по времени. Большей проблемой для меня был парсинг ввода (java.util.Scanner
— тормозное зло, не вздумайте использовать).
С другой стороны, есть спортивное программирование, в котором люди соревнуются в скорости программ, а не только в корректности, как например тут: https://highload.fun/tasks/10. Здесь уже использование обычного деления не прокатит, так что какая-то ниша у подобных алгоритмов всё-же есть.
Ну и естественно криптография. Возводить длинные числа в степени по модулю во время шифрования — это прямо норма.