[Перевод] Быстрая свёртка множеств (алгоритм)

Эту статью меня вдохновила написать задача с codeforces. В статье будет разобран алгоритм для решения задачи.

Даны f : 2^S \rightarrow R, g : 2^S \rightarrow R (пояснение) нужно найти q : 2^S \rightarrow R такую что:

q(s) = \sum_{s' \subset s} f(s') * g(s \setminus s')

За время O(N ^2 2^N)где N = |S|

Но для решения нам нужно будет построить небольшую математическую теорию))

Вводная задача

Дано множество Sразмера Nи функция f : 2^S \rightarrow R. Нужно для \forall s \subset Sнайти:

\hat f (s) = \sum_{s' \subset s} f(s)

Дальше для обозначения множество будем использовать int mask в его битном представление. Например у нас множество состоит из 20 элементов, то тогда первые 20 бит в mask будут обозначать есть ли этот элемент в множестве или нет в зависимости от 0 и 1.

Пример: mask = 0010010 это значит что s состоит из 2 и 5 элементов из S.

Тривиальный алгоритм

for (int mask = 0; mask < (1 << N); ++ mask) {
  hat_f[mask] = 0;
  for (int sub = 0; sub < (1 << N); ++ sub) {
    // проверяем является ли mask2 подмножеством mask1
    if (sub | mask == mask) {
      hat_f[mask] += f[sub];
    }
  }
}

Тут мы для каждого множества перебираем остальные множества и если оно оказалось подмножеством, то добавляем функцию к результату.

Работает за O(4^N)

Простой алгоритм

for (int mask = 0; mask < (1 << N); ++ mask) {
  hat_f[mask] = f[0];
  // сразу перебираем все подмножества mask1
  for (int sub = mask; sub > 0; (sub - 1) & mask) {
      hat_f[mask] += f[sub];
  }
}

Тут мы сразу перебираем все подмножества, без проверки лишних.

Асимптотика O(\sum_{s \subset S} \sum_{s' \subset s}1) = O(\sum_i^N C^i_N*2^i) = [\text{бином Ньютона}] = O(3^N)

Быстрый алгоритм

Давайте введем вспомогательную функцию:

S(mask, i) = \{x \subset mask: x \oplus mask < 2^i \}

Мы смотрим на подмножества которые отличаются от mask только в первых i битах.

Пример: S(\underline{101}01010, 5) = \{\underline{101}01010, \underline{101}01000, \underline{101}00010, \underline{101}00000\} тут мы зафиксировали последние 3 бита и перебираем подмножества по первым 5 битам.

Теперь заметим одно замечательное свойство:

S(mask, i) = \begin{cases} S(mask, i - 1)  & \quad \text{если  }i \text{ бит равен } 0 \\ S(mask, i - 1) \cup S(mask \oplus 2^{i-1}, i - 1) & \quad \text{если  }i \text{ бит равен } 1\end{cases}Схема для понимания вычислений
Схема для понимания вычислений

И тогда, получаем такую динамику:

for(int mask = 0; mask < (1 << N); ++ mask) {
  dp[mask][0] = f[mask];
  for (int i = 1; i <= N; ++ i) {
    if (mask & (1 << i - 1)) {
      dp[mask][i] = dp[mask][i - 1] + dp[mask^(1 << i - 1)][i - 1];   
    } else {
      dp[mask][i] = dp[mask][i - 1];
    }
  }
  hat_f[mask] = dp[mask][N];
}

Можно написать более просто:

for(int mask = 0; mask < (1 << N); ++ mask) 
  hat_f[mask] = f[mask];
for (int i = 0; i < N; ++ i) 
  for (int mask = 0; mask < (1 << N); ++ mask)
    if (mask & (1 << i))
      hat_f[mask] += hat_f[mask ^ (1 << i)];

Асимптотика O(N2^N)

Свертка множеств

Наша полученная функция \hat{f} называется сверткой (Zeta Transform) и обозначается z[f]. Но для того чтобы она была полноценной, нужна и развертка)

Рассмотрим трансформацию Мёбиуса (Mobius Transform):

\mu[f](s) = \sum_{s' \subset s} (-1)^{|s \setminus s'|} f(s')
трансформация Мёбиуса через Зетта трансформацию

Пусть \sigma [f] (s)= (-1)^{|s|}f(s).

Тогда \sigma z \sigma[f]  = \sigma[z[\sigma[f]]] = \mu[f]

Доказательство:

\sigma z \sigma[f](s) = (-1)^{|s|}\sum_{s' \subset s} (-1)^{|s'|}f(s') = \\ [\text{рассматривая разные четности } |s| \text{ и } |s'| \text{ увидим, что }(-1)^{|s'| + |s|}=(-1)^{|s\setminus s'|}] \\ = \sum_{s' \subset s} (-1)^{|s\setminus s'|}f(s') = \mu [f](s)
Код для трансформации Мёбиуса
// сразу заполняем с действием sigma
for (uint mask = 0; mask < (1 << N); ++ mask) {
  hat_f[mask] = ((popcount(mask) & 1) ? -1 : 1) * f[mask];
}
// z свертка
for (int i = 0; i < N; ++ i) 
  for (int mask = 0; mask < (1 << N); ++ mask)
    if (mask & (1 << i))
      hat_f[mask] += hat_f[mask ^ (1 << i)];
// sigma преобразование
for (uint mask = 0; mask < (1 << N); ++ mask) {
  hat_f[mask] *= (popcount(mask) & 1) ? -1 : 1;
}

popcount() считает количество 1ных бит в числе для с++20.

__builtin_popcount() тоже самое, но для более ранних версий с++.

Для неё выполняется\mu [z [f]] = f = z[\mu[f]], что как раз и является обратным преобразованием : D

Доказательство
\mu [z[f]](s) = \sum_{s' \subset s} (-1)^{|s\setminus s'|} \sum_{s'' \subset s'} f(s'') = \\ = \sum_{s'' \subset s} f(s'') \sum_{s'' \subset  s' \subset s} (-1)^{|s \setminus s'|} = \sum_{s'' \subset s} f(s'') \sum_{a \subset (s\setminus s'' )} (-1) ^{|a|} = \\ [\text{заметим что }\sum_{a \subset b}(-1)^{|a|} = 0 \text{ если } |b| > 0 \text { и } 1 \text{ иначе}] \\ = \sum_{s'' \subset s} f (s'') I[|s\setminus s''| = 0] = f (s)» src=«https://habrastorage.org/getpro/habr/upload_files/a73/0e1/eda/a730e1edaa4507d5bcfbcaf26b85a581.svg» /></div></div><p>Это очень напоминает FFT. Пусть у нас даны <img alt= и g:2^N \rightarrow R

Тогда q = FFT^{-1}[FFT[f] * FFT[g]]где * это почленное умножение, это тоже самое, что:

q(s) = \sum_{a, b \\ a + b = s} f(a) * g(b)

А если мы воспользуемся нашей сверткой q = z^{-1}[z[f] * z[g]]:

q(s) = \sum_{a, b \\ a | b = s} f(a) * g(b)

где a|b побитовое or, что эквивалентно a \cup b

Доказательство
q(s) = \sum_{s' \subset s} (-1)^{|s \setminus s'|} (\sum_{a \subset s'} f(a)) * (\sum_{b \subset s'} g(b)) = \\  =\sum_{s' \subset s}  \sum_{a \subset s'} \sum_{b \subset s'} (-1)^{|s \setminus s'|} f(a) g(b)  = \sum_{a \subset s \\ b \subset s} f(a)g(b) \sum_{c \subset (s \ (a \cup b))} (-1)^{|c|}  = \\ = [\text{заметим что }\sum_{a \subset b}(-1)^{|a|} = 0 \text{ если } |b| > 0 \text { и } 1 \text{ иначе}] = \\ = \sum_{a, b \\ a \cup b = s} f (a)g (a)» src=«https://habrastorage.org/getpro/habr/upload_files/7b3/b55/7f2/7b3b557f22ac911fc42601959928f208.svg» /></div></div><p>Подробнее про суммы с операциями <img alt=, or и and можно прочитать на этом сайте.

Возвращаемся к основной задаче

Но всё ещё нет решения на изначальную задачу. Предыдущая формула не подходит, так как для a \cup b = s может быть, что a \cap b \neq \varnothing, а нам нужно чтобы a и b не пересекались.

Но есть решение!

1) Введем новую функцию f_i(s) = \begin{cases} f(s) & \text{если } |s| = i \\ 0 & \text{иначе} \end{cases}

И для g ведем аналогичную g_i(s).

2) Пусть p_i(s) = z^{-1}[\sum_{j = 0}^i z[f_j] * z[g_{i - j}]](s) сумма в смысле почленного суммирования.

3) Тогда q(s) = p_{|s|}(s).

Это и будет решение нашей задачи потому, что:

q(s) = \sum_{s' \subset s} f(s') * g(s \setminus s')
Доказательство

1) Предположим что:

p_i(s) = \sum_{j =0}^i\sum_{a\subset s \\ |a| = j} \sum_{b \subset s \\ |b| =i -j \\ a \cup b = s} f(a)*g(b)

2) Тогда:

p_{|s|}(s) = \sum_{j =0}^{|s|}\sum_{a\subset s \\ |a| = j} \sum_{b \subset s \\ |b| =|s| -j \\ a \cup b = s} f(a)*g(b) =\sum_{s' \subset s} f(s') * g(s \setminus s')

3) Пусть h_i = \sum_{j = 0}^i z[f_j] * z[g_{i - j}], тогда:

h_i(s) = \sum_{j=0}^i (\sum_{a \subset s} f_j(a)) * (\sum_{b \subset s} g_{i-j}(b)) = \\ = \sum_{j=0}^i\sum_{a\subset s \\ b \subset s}f_j(a) * g_{i - j}(b) = \sum_{j=0}^i\sum_{a\subset s \\ b \subset s  \\ |a| = j \\ |b| =i -j}f(a) * g(b) = \\ = \sum_{s' \subset s} \sum^i_{j = 0}  \sum_{a\subset s' \\ b \subset s'  \\ |a| = j \\ |b| =i -j \\ a \cup b = s'}f(a) * g(b) = \sum_{s' \subset s} p_i(s')

4) Выводим что:

p_i(s) = z^{-1}[\sum_{j = 0}^i z[f_j] * z[g_{i - j}]](s) = z^{-1}[h_i](s) = \\ = z^{-1}[\sum_{s' \subset s} p_i(s')] = p_i(s)

Значит p_i действительно подходит

5) и последний шаг из пункта 2.

\mu(s) = p_{|s|}(s) = \sum_{s' \subset s} f(s') * g(s \setminus s')

Что и требовалось доказать.

Пишем код

// заполняем f_i и g_i
for (uint mask = 0; mask < (1 << N); ++ mask) {
  hat_f[popcount(mask)][mask] = f[mask];
  hat_g[popcount(mask)][mask] = g[mask];
}


// применяем z свертку к f_i и g_i
for (int i = 0; i < N; ++ i)
  for (int j = 0; j < N; ++ j) 
    for (int mask = 0; mask < (1 << N); ++ mask)
      if (mask & (1 << j)){
        hat_f[i][mask] += hat_f[i][mask ^ (1 << j)];
        hat_g[i][mask] += hat_g[i][mask ^ (1 << j)];
      }


// делаем внутрение преобразование
for (int i = 0; i < N; ++ i)
  for (int j = 0; j <= i; ++ j)
    for(int mask = 0; mask < (1 << N); ++ mask)
      hat_mu[i][mask] += hat_f[j][mask] * hat_g[i - j][mask];


// применяем преобразование Мёбиуса hat_mu
// sigma преобразование
for (int i = 0; i < N; ++ i)
  for(uint mask = 0; mask < (1 << N); ++ mask)
    hat_mu[i][mask] *= (popcount(mask) & 1) ? -1 : 1;
// z cсвертка
for (int i = 0; i < N; ++ i)
  for (int j = 0; j < N; ++ j)
    for(int mask = 0; mask < (1 << N); ++ mask)
      if (mask & (1 << j))
        hat_mu[i][mask] += hat_mu[i][mask ^ (1 << j)];
// sigma преобразование
for (int i = 0; i < N; ++ i)
  for(uint mask = 0; mask < (1 << N); ++ mask)
    hat_mu[i][mask] *= (popcount(mask) & 1) ? -1 : 1;


// Запись в ответ
for(uint mask = 0; mask < (1 << N); ++ mask)
  mu[mask] = hat_mu[popcount(mask)][mask];

popcount() считает количество 1ных бит в числе для с++20.

__builtin_popcount() тоже самое, но для более ранних версий с++.

Применение

Этот алгоритм очень полезен для работы с множествами. Ну и конечно в олимпиадном программирование иногда (очень редко) встречается)))

Ссылки

Источник для основного алгоритма

Источник для первой задачи

Статья с другими свертками множеств

Быстрое преобразование Фурье на википедии и алгоритмике

Спасибо за то что прочитали, всем хорошего настроения!

© Habrahabr.ru