Понимаем обычное дерево отрезков
Всем привет! Изучив несколько статей по этой теме, у меня остались вопросы, и некоторые моменты по-прежнему были не понятны, поэтому я решил написать свою, которая, как мне кажется, была бы понятна тем, кто не силен в спортивном программировании. В ней я объясняю, как устроено дерево отрезков. Примеры с кодом будут приведены на языке C++, однако на объяснение это не влияет.
Вступление
Пусть у нас есть задача: поступают запросы двух видов для исходного массива. Первый вид — это замена конкретного элемента на другой. Второй — вычисление суммы/минимума или другой операции на диапазоне. Требуется обработать эти запросы. Дерево отрезков решает эту задачу и позволяет обновить элемент, и дать ответ по диапазону за логарифмическое время O (logN). Сама структура данных строится за линейное время O (N). Затраты по памяти — 4N. Почему это так рассмотрим в конце статьи.
Вводим в курс дела
Сразу скажу, что будем решать задачу для поиска суммы. На концепцию дерева это не влияет. Какие функции могут находиться на месте суммы мы рассматриваем ближе к концу статьи.
Пусть нам дан массив:
Дерево отрезков работает только с массивами, длина которых равна степени двойки, если это не так (как в нашем примере), то мы просто добавляем в конец массива нейтральные элементы (о них рассказано в последней главе), пока его размеры не будут равны степени двойки. Для задачи поиска суммы нейтральный элемент — это 0. Таким образом, наш исходный массив превращается в: [7, 3, 2, 4, 5, 6, 1, 0]. Дерево отрезков будем представлять просто в виде обычного массива, где ровно во второй половине располагаются элементы нашего исходного массива, выше располагаются элементы, которые равны сумме двух своих дочерних элементов и так далее до элемента с индексом 1 (он равен сумме всех элементов исходного массива). Получаем такую древовидную структуру:
Или в виде обычного массива:
Нулевой элемент мы не используем, а начинаем с 1 индекса, так как это удобно при индексировании и при «арифметике» дерева отрезков. Под словом арифметика я понимаю следующее: чтобы перейти из родителя в левый дочерний элемент, нужно умножать индекс на 2, а чтобы перейти в правый дочерний элемент, нужно умножать индекс на 2 и добавлять 1, из этого также следует, что все левые элементы (не считая вершины дерева) имеют четный индекс, а все правые элементы (не считая вершины дерева) имеют нечетный индекс. И еще, чтобы перейти обратно из любого дочернего элемента в родителя, надо целочисленно делить текущий индекс на 2. Именно на этих правилах и строится взаимодействие элементов между собой в дереве отрезков.
В следующих главах я подробно объясняю саму суть дерева отрезков и пишу код. Я привел две реализации: одну, состоящую из полностью рекурсивных функций, другую нет.
Рекурсивная реализация
Построение
Мы начинаем построение с вершины дерева, то есть с корня. Чтобы вычислить текущий элемент, мы должны знать его дочерние элементы, поэтому в рекурсивной форме мы будем спускаться вниз по дереву и вычислять их. Это будет происходить до тех пор, пока мы не опустимся до листьев дерева. Когда мы дойдем до них, то сработает условие конца рекурсии и мы сможем «распутывать» нашу рекурсию уже снизу вверх, ведь у самых нижних элементов дерева нет сыновей — они просто равны соответствующим элементам исходного массива.
Данная функция ничего не возвращает, а только заполняет массив. Во первых, в качестве параметра мы будем передавать ей индекс текущего элемента самого дерева (так как мы начинаем сверху, то этот индекс равен 1). Во-вторых, нам необходимо содержать границы исходного массива, которые охватывает наш текущий элемент дерева (так, как мы начинаем сверху, и текущий элемент охватывает весь исходный массив, то и границы будут с 0 до n-1, где n — размер исходного массива).
Здесь важно обратить внимание на то, что во всех трех функциях мы будем использовать данную связку из трех переменных, где первая переменная — это индекс текущего элемент дерева отрезков, который мы хотим вычислить, а две другие — границы исходного, повторяю исходного массива, которые охватывает этот текущий элемент дерева отрезков.
На гифке показано, как связаны элементы дерева отрезков с оригинальным массивом:
Мы будем вычислять значения сначала в левой части, а затем в правой (то есть сначала будем вычислять значение левого сына, а затем правого). Мы это делаем, так как наш родитель равен сумме двух его дочерних элементов. Таким образом, мы будем делить так до тех пор, пока размер нашего диапазона не будет равен 1, а когда это случится — присваиваем текущему элементу дерева, соответствующее ему значение оригинального массива. Как только мы вычислили листья дерева, то можем подниматься на уровень выше и так далее до самой вершины.
Теперь мы можем написать данную часть кода, где i — индекс, текущего элемента дерева (так как мы начинаем рекурсию с самого верхнего элемента, то i = 1. Помним, что 0 индекс дерева мы не используем), right = 0 и left = n-1 (так как элемент дерева с индексом i = 1 охватывает весь оригинальный массив):
// i - индекс элемента дерева отрезков, который мы хотим сейчас вычислить
// right, left - диапазон оригинального массива, который охватывает элемент
// дерева отрезков с индексом i
void build_tree(int i, int right, int left) {
if (right == left)
segment_tree[i] = input_array[right];
else {
// later
}
}
Осталось сделать нашу функцию рекурсивной и сохранять промежуточные значения дерева. Будем вызывать нашу рекурсивную функцию для двух половин дерева отрезков (будем делить текущий диапазон на два):
int mid = (left + right) / 2;
build_tree(i*2, left, mid);
build_tree(i*2 + 1, mid + 1, right);
Вспоминая арифметику дерева, нам не составит труда вычислить текущий элемент:
segment_tree[i] = segment_tree[i*2] + segment_tree[i*2 + 1];
Построение завершено:
void build_tree(int i, int left, int right) {
if (left == right)
segment_tree[i] = input_array[right];
else {
int mid = (left + right) / 2;
build_tree(i*2, left, mid);
build_tree(i*2 + 1, mid + 1, right);
segment_tree[i] = segment_tree[i*2] + segment_tree[i*2 + 1];
}
}
Обновление элемента
Обновление элемента схоже с построением дерева. Действительно, они отличаются лишь тем, что в построении дерева мы обновляем (вычисляем) каждый элемент нашего дерева, а в обновлении, мы обновляем лишь элементы на конкретной ветке, то есть logN + 1 элементов (каждый элемент на уровне).
У данной функции такая же сигнатура с двумя дополнительными параметрами: индексом заменяемого значения и, собственно, с заменяемым значением. По аналогии с построением, если наш диапазон из left и right захлопнется, то есть его размеры будут равны 1, то мы достигли нужного элемента. Можно его обновлять, подниматься вверх и изменять все, зависящие от него, элементы.
// i - индекс текущего элемента дерева отрезков
// right, left - диапазон исходного массива, который охватывает элемент
// дерева отрезков c индексом i
// update_index - индекс заменяемого элемента исходного массива
// value - заменяемое значение
void update_tree(int i, int left, int right, int update_index, int value) {
if (left == right)
segment_tree[i] = value;
else {
// later
}
}
В else-блоке все практически также, как и в построении. Единственно, мы должны спускаться лишь по ветке, которая зависит от обновляемого значения. Чтобы это сделать будем просто с помощью if-else условия спрашивать в каком диапазоне находится обновляемый элемент и спускаться в ту половину, где он находится.
Получаем полную функцию обновления элемента:
void update_tree(int i, int left, int right, int update_index, int value) {
if (left == right)
segment_tree[i] = value;
else {
int mid = (left + right) / 2;
if (update_index <= mid)
update_tree(i*2, left, mid, update_index, value);
else
update_tree(i*2, mid + 1, right, update_index, value);
segment_tree[i] = segment_tree[i*2] + segment_tree[i*2 + 1];
}
}
Запрос суммы
Итак, мы подошли к вычислению самой суммы. Как и раньше начинать будем с верхнего элемента. Наша функция будет как-обычно содержать 3 параметра, которые были во всех предыдущих функция: индекс текущей вершины дерева и диапазон покрытия этой вершины оригинального массива. В дополнении к этому, нам дается еще диапазон, на котором надо найти сумму.
Получаем такую сигнатуру:
// i - индекс текущего элемента дерева отрезков
// right, left - диапазон исходного массива, который охватывает элемент
// дерева отрезков c индексом i
// input_left, input_right - диапазон исходного массива, на котором необходимо
// найти сумму
int sum_on_range(int i, int left, int right, int input_left, int input_right) {
// later
}
Что мы будем делать? Все просто: если мы начинаем с самого верхнего элемента, то, логично, наш диапазон, на котором требуется найти сумму, может содержаться либо в обеих половинах, либо в одной их них. Поэтому возможны два варианта: либо пойти в одну из половин, либо в обе. Под словами пойти в какую-либо половину я подразумеваю следующее: уменьшаем размер вспомогательного диапазона (left и right) в 2 раза и в зависимости в какую половину пойдем: правую или левую задаем текущему элементу индекс 2 * i или 2 * i + 1. На самом деле проще всего не смотреть в какую половину заходить, а в какую нет, а просто всегда заходить в обе половины, однако добавить условие выхода: если индексы переданного нами диапазона противоречат друг другу, то возвращаем 0 (zero on-english).
if (input_left > input_right)
return 0;
Теперь пропишем удачное завершение рекурсии, а именно, если наш диапазон, принадлежащий к дереву отрезков в точности равен диапазону оригинального массива, то возвращаем текущий элемент дерева.
if (left == input_left && right == input_right)
return segment_tree[i];
Теперь осталось прописать то, как мы будет двигаться вниз по дереву. А тут по анологии с уже прописанными функциями.
int mid = (left + right) / 2;
int left_son = sum_on_range(i*2, left, mid, input_left, std::min(mid, input_right));
int right_son = sum_on_range(i*2 + 1, mid + 1, right, std::max(mid, input_left, input_right);
segment_tree[i] = left_son + right_son;
Тут важно понимать следующий момент: мы берем границы input_left и min (mid, input_right), потому что мы всегда заходим в две половины и если правая из них не содержит диапазона обрубаем ее условием выхода, однако левая половина может содержать диапазон, меньший, чем mid, поэтому, если мы брали бы диапазон input_left и mid, то мы могли посчитать лишние элементы. Тоже самое и с границей max (mid, input_left) и input_right: можем взять лишние элементы.
int sum_on_range(int i, int left, int right, int input_left, int input_right) {
if (input_left > input_right)
return 0;
if (left == input_left && right == input_right)
return segment_tree[i];
int mid = (left + right) / 2;
int left_son = sum_on_range(i*2, left, mid, input_left, std::min(mid, input_right));
int right_son = sum_on_range(i*2 + 1, mid + 1, right, std::max(mid, input_left, input_right);
segment_tree[i] = left_son + right_son;
}
Полная рекурсивная реализация
Так как я использую данную структуру данных в рамках спортивного программирования, то и конечный исходный код будет соответствовать его нормам. Данный код решает задачу https://cses.fi/problemset/task/1648/ .
#include
typedef long long ll;
constexpr ll MAX_N = 1000006;
ll input_array[MAX_N];
ll segment_tree[MAX_N*4];
void build_tree(ll i, ll left, ll right) {
if (right == left)
segment_tree[i] = input_array[right];
else {
ll mid = (right + left) / 2;
build_tree(i*2, left, mid);
build_tree(i*2 + 1, mid + 1, right);
segment_tree[i] = segment_tree[i*2] + segment_tree[i*2 + 1];
}
}
void update_tree(ll i, ll left, ll right, ll update_index, ll value) {
if (right == left)
segment_tree[i] = value;
else {
ll mid = (right + left) / 2;
if (update_index <= mid)
update_tree(i*2, left, mid, update_index, value);
else
update_tree(i*2 + 1, mid + 1, right, update_index, value);
segment_tree[i] = segment_tree[i*2] + segment_tree[i*2 + 1];
}
}
ll sum_on_range(ll i, ll left, ll right, ll input_left, ll input_right) {
if (input_left > input_right)
return 0;
if (left == input_left && right == input_right)
return segment_tree[i];
ll mid = (left + right) / 2;
ll left_son = sum_on_range(i*2, left, mid, input_left, std::min(mid, input_right));
ll right_son = sum_on_range(i*2 + 1, mid + 1, right, std::max(mid + 1, input_left), input_right);
return left_son + right_son;
}
int main() {
// optimization
std::cin.tie(0);
std::ios_base::sync_with_stdio(false);
std::memset(input_array, 0, sizeof(input_array));
ll n, q;
std::cin >> n >> q;
for (ll i = 0; i < n; i++)
std::cin >> input_array[i];
build_tree(1, 0, n - 1);
while (q--) {
ll type;
std::cin >> type;
ll x, y;
std::cin >> x >> y;
if (type == 1)
update_tree(1, 0, n - 1, x - 1, y);
else
std::cout << sum_on_range(1, 0, n - 1, x - 1, y - 1) << '\n';
}
}
Нерекурсивная реализация
Построение
Данная функция интуитивно понятна. Мы вначале заполняем нижний слой (он начинается ровно во второй половине дерева отрезков). Далее идем в обратную сторону и заполняем оставшиеся элементы с помощью арифметики дерева отрезков.
Имеем:
// n - размер исходного массива
void build(int n) {
for (int i = 0; i < n; i++) seg_tree[i + n] = input_arr[i];
for (int i = n - 1; i > 0; i--) seg_tree[i] = seg_tree[i*2] + seg_tree[i*2 + 1];
}
Обновление элемента
По аналогии с рекурсивной функцией данная функция схожа с построением дерева. Мы также обновляем конкретный элемент, а далее идем по элементам, которые зависят от него и пересчитываем их:
void update(int n, int index, int value) {
seg_tree[index + n] = value;
for (index += n; index > 1; index /= 2)
seg_tree[index / 2] = seg_tree[index] + seg_tree[index^1];
}
Очень удобно здесь использовать исключающее или: index^1. Данная запись означает, что если index четный (левый), то index^1 будет нечетным (правым) и наоборот.
Запрос суммы
Идея тут состоит в том, что если правая граница нашего диапазона (в нерекурсивной реализация речь идет только о диапазоне запроса. Вспомогательный диапазон, как в рекурсивной реализации, нам не нужен здесь) относится к правому элементу, то мы поднимаемся вверх по дереву, не прибавляя к ответу наш текущий элемент, потому что родитель этого правого элемента точно содержит наш диапазон. То же можно сказать и про левую границу. Если наша левая граница принадлежит левому элементу, то можно переходить в родителя. Помните как переходить в родительный элемент? Просто делить на два. Следующий случай: если наша правая граница принадлежит к левому элементу, то, логично, наш родитель содержит лишний элемент, поэтому прибавлять к результату родителя нельзя — ответ будет больше. Нам остается прибавить текущий элемент к ответу и перейти ближе к левой границе путем вычитания 1 из текущего индекса (мы переходим на один элемент влево, чтобы поменять родительный элемент, ведь родитель левого элемента будет содержать, как раз, часть или весь диапазон, который мы ищем). Аналогично и с левой границей: если она принадлежит к правому элементу, то его родитель будет содержать лишний элемент, тогда мы прибавляем текущий элемент к ответу и переходим на соседний правый элемент (путем прибавления 1 к индексу), так как родитель соседнего элемента точно содержит либо целый , либо часть оставшегося диапазона.
На вход функция принимает диапазон запроса и размер исходного массива. Так как мы работаем с деревом отрезков, а не с исходным массивом, то, чтобы перейти в листья дерева, мы должны прибавить к соответствующим индексам исходного массива n (r +=n и l += n). Все, что я описал абзацем выше, поместим в цикл while (l <= r) - это будет наше условие выхода. На каждой итерации будем либо прибавлять текущий элемент, либо нет, а в конце всегда переходим в родителя.
Приведем полный код запроса:
// l, r - диапазон суммы
// n - размер исходного массива
void sum(int l, int r, int n) {
r += n; l += n;
ll res = 0;
while (l <= r) {
if (l %= 1) res += seg_tree[l--];
if (r %= 0) res += seg_tree[r++];
l /= 2;
r /= 2;
}
}
Обращаю внимание, что в нерекурсивных функциях мы работаем исключительно с деревом отрезков, не прибегая к помощи исходного массива, как было с рекурсивной реализацией. Именно поэтому, чтобы попасть в самый нижний слой дерева мы прибавляем к полученным границам размер исходного массива (n).
Если вы еще не поняли, как работает запрос суммы, то давайте разберем с вами конкретные примеры:
Если дан такой диапазон, то мы просто прибавляем к ответу наши части и переходим в их родителей (мы видим, что правая граница относится к левому элементу, а левая граница относится к правому элементу). Индексы поменялись местами и l стал больше r, а этого быть не может. На этот случай у нас срабатываем условие while (l <= r).
Cледующий пример с диапазоном размера 1. Тут все совсем просто. Срабатывает одно из условий (левая граница относится к правому элементу), мы прибавляем к ответу текущий элемент. Далее r уходит вверх, а l вбок и тоже вверх, и соответственно срабатывает условие цикла: while (l <= r).
Контрольный пример. Вначале мы поднимаемся вверх, ничего не прибавляя. Далее к ответу добавляем текущий элемент. Затем левую границу выкидывает дальше, а правый переходит вверх и мы заканчиваем условием while (l >= r).
Полная нерекурсивная реализация
Данный код решает задачу https://cses.fi/problemset/task/1648/ .
#include
typedef long long ll;
constexpr ll MAX_N = 1000006;
ll input_arr[MAX_N];
ll seg_tree[4*MAX_N];
ll n;
void build() {
for (int i = 0; i < n; i++) seg_tree[i + n] = input_arr[i];
for (int i = n - 1; i > 0; i--) seg_tree[i] = seg_tree[i*2] + seg_tree[i*2 + 1];
}
void update(ll index, ll value) {
seg_tree[index + n] = value;
for (index += n; index > 1; index /= 2)
seg_tree[index / 2] = seg_tree[index] + seg_tree[index^1];
}
ll sum(ll l, ll r) {
l += n; r += n;
ll res = 0;
while (l <= r) {
if (l % 2 == 1) res += seg_tree[l++];
if (r % 2 == 0) res += seg_tree[r--];
l /= 2; r /= 2;
}
return res;
}
void solve() {
ll q;
std::cin >> n >> q;
for (ll i = 0; i < n; i++)
std::cin >> input_arr[i];
ll N = 0;
while (1 << (N) < n) N++;
n = 1 << (N);
build();
while (q--) {
int a, b, c;
std::cin >> a >> b >> c;
if (a == 1)
update(b-1, c);
else
std::cout << sum(b-1, c-1) << '\n';
}
}
int main() {
// optimization
std::cin.tie(0);
std::ios_base::sync_with_stdio(false);
std::memset(input_arr, 0, sizeof(input_arr));
solve();
return 0;
}
Затраты по времени и памяти
Начнем с затрат по памяти. В лучшем случае они будут равны 2N. Лучший случай — это когда наш размер является степенью двойки. Почему 2N? Так как наш исходный массив имеет размеры n и дополнительные элементы c 0 элементом имеют размеры n. Однако если размер не является степенью двойки, тогда нам придется добавлять элементы до степени двойки (наш массив может увеличиться почти в 2 раза) — уже получаем 2n. И еще 2n на дополнительные элементы нашей березы отрезков. Собирая все вместе, получаем 4N.
С памятью разобрались. Теперь со временем. Построение будет иметь время
O (N), просто потому что мы будем идти по каждому элементу бревенчатой структуры. Далее обновление элемента — время O (logN), потому что мы будем идти только по ветке, в которой изменяем конкретный элемент, а именно от этого элемента зависят logN + 1 элементов, то есть кол-во уровней нашего дерева. Последняя операция — запрос о сумме. Время также O (logN), потому что мы также идем не по всему дереву, а по его уровням.
Подводя итоги
Вы скажете: «постой, но ты же разобрал только дерево суммы». А я скажу: «вы уже все знаете». И это правда. Чтобы переделать данную структуру в то, что вы хотите надо просто заменить функцию вместо сложения на ту, которая вам нужна. Однако эта функция должна быть ассоциативной и обладать нейтральным элементом. Ассоциативность выражается так: (a + b) + c = a + (b + c) или в виде функций f (f (x, y), z) = f (x, f (y, z)). Нейтральный элемент — это элемент, при добавлении которого результат не меняется. В случае сложения это 0. Именно нейтральным элементом мы заполняем оставшиеся ячейки нашего дерева.
Что лучше использовать: рекурсию или нет? Тут все зависит от задач, которые вы хотите решать. Что касается времени, то, конечно, нерекурсивная реализация будет работать чуть быстрее, так как нерекурсивных функции работают «на месте», а в случае с рекурсией придется прыгать по функциям, а это занимает время. Например, на сайте cses время 1 теста было 0.01 секунды и 0.13 секунд для 2 теста для нерекурсивной реализации и 0.01 секунды для 1 теста и 0.17секунд для 2 теста для рекурсивной. Многие люди также просто избегают рекурсию, однако не все так с ней критично. Во-первых в решение олимпиадных задач она просто необходима и существует множество задач, где без нее не обойтись или код с ней будет гораздо проще. Во-вторых, касательно текущей темы, я считаю, что, реализовав рекурсивный подход, вы лучше поймете данную тему.
Я надеюсь вы поняли, что я пытался вам донести и теперь, даже если вы не пишите на С++ сможете реализовать дерево отрезков, ведь тут код почти не отличается от других языков, а даже если где-то он вам не понятен, то главное — это понимание самой структуры данных, а уже потом ее реализация с помощью каких либо инструментов. Теперь для проверки я желаю вам решить самостоятельно задачи с сайта cses. А именно 3 задачи, одну из которых на сумму мы уже решили.