Оптимизируем Shuffle в Spark

Привет, Хабр!

Меня зовут Сергей Смирнов, я аналитик в продукте CVM в X5 Tech. Я занимаюсь разработкой инструмента анализа A/B экспериментов. Мы ежедневно считаем десятки метрик для сотен экспериментов на десятки миллионов клиентов — это терабайты данных, поэтому наш инструмент разработан на Spark.

В последнее время мы заметили, что существенную часть времени работы наших Spark-приложений занимает обмен данными (Shuffle) между исполнителями. В этой статье я расскажу о том, какие оптимизации помогли нам избавиться от самых тяжёлых операций Shuffle. Речь пойдёт не только о BroadcastJoin, но и о двух других неочевидных методах — предварительное репартицирование и бакетирование.

Что такое Shuffle

Shuffle — это операция перераспределения данных между партициями датафрейма, которая требуется для выполнения широких трансформаций (wide transformations), таких как join,  groupBy,  distinct, dropDuplicates и оконных функций. В любом Spark-приложении операция Shuffle практически неизбежна. Несмотря на это, Shuffle является очень затратной по времени и ресурсам операцией.

Рассмотрим этапы, из которых состоит Shuffle, подробнее:

  1. Вычисление хеша ключа трансформации: для каждой строки данных Spark вычисляет хеш ключа трансформации. Например, для трансформации groupBy("customer_id") Spark вычислит хеш от колонки customer_id.

  2. Сжатие данных: перед обменом данными между исполнителями, Spark сериализует и сжимает их для уменьшения нагрузки на сеть и диски.

  3. Обмен данными: данные перераспределяются между исполнителями таким образом, чтобы все строки с одинаковым хешем оказались в одной партиции на одном исполнителе. Этот процесс часто требует записи всех данных на диск и последующего чтения этих данных в нужном порядке.

  4. Распаковка и преобразование данных: после завершения обмена Spark распаковывает данные и преобразует их в RDD или DataFrame для дальнейшей обработки.

684d9204d973d4c0bd2708a0ddcca65c.png

Shuffle создаёт большую нагрузку на вычислительные ресурсы (сериализация, сжатие, распаковка и десериализация данных), а также нагружает сеть и диск (во время обмена данными). Информацию о количестве передаваемых данных можно найти в Spark UI:

72ca2c764b64a5bc4ccc86e05a2a2691.png

Часто бывает так, что больше всего времени выполнения приложения занимает именно Shuffle. В этой статье поговорим о трёх методах преобразования запросов, которые  позволят избавиться от некоторых операций Shuffle:

  • BroadcastJoin: подсказка .hint("broadcast") убирает Shuffle при джойне маленького датафрейма.

  • Репартицирование: инструкция .repartition (), вызванная в правильном месте, может избавить сразу от нескольких Shuffle.

  • Бакетирование: способ хранения таблиц, который позволяет избежать Shuffle при её чтении.

Подробнее о каждом из них расскажу далее.

BroadcastJoin

Операция Join — это широкая трансформация, требующая перераспределения данных между партициями для обоих датафреймов, что мы можем увидеть в плане выполнения:

Spark репартицирует (Shuffle) оба датафрейма по ключу джойна для того, чтобы гарантировать, что строки с одинаковым хешем находятся в одной партиции, а уже затем выполняет SortMergeJoin

Spark репартицирует (Shuffle) оба датафрейма по ключу джойна для того, чтобы гарантировать, что строки с одинаковым хешем находятся в одной партиции, а уже затем выполняет SortMergeJoin

Когда один из датафреймов очень мал, Spark оптимизирует план выполнения, и вместо обычного Join выполняет BroadcastJoin. В этом случае Spark передаёт меньший по размеру датафрейм на все исполнители, что позволяет избежать Shuffle для другого соединяемого датафрейма. Эта оптимизация контролируется порогом spark.sql.autoBroadcastJoinThreshold, который по умолчанию равен 10 МБ.

Как Spark оценивает размер датафрейма? Давайте посмотрим на примерах:

# DataFrame[id: bigint]
# Точная оценка: 3000000 * 8 B
df = spark.range(3_000_000)            # 22.9 MB


# DataFrame[id: bigint, id: bigint]
# Оценка сверху: 24000000 * 24000000 B
df.join(df2, on=df.id==df.id)          # 523.9 TB

df.write.saveAsTable("saved_df")

# Таблица сохранена в Hive и хранится в сжатом формате parquet
# Spark получает размер от Hive
df = spark.read("saved_df")            # 11.5 MB

# Оценка сверху
df = df.filter(F.col("id") % 30 == 0)  # 11.5 MB

# Точная оценка: 3000000 / 30 * 8 B
df.cache().count()

df                                     # 781.3 KB

Итак, Spark точно знает размер датафрейма в случае если:

  • Датафрейм — это результат чтения таблицы из Hive.

  • Датафрейм сгенерирован, например, используя spark.range().

  • Датафрейм закеширован.

В остальных случаях Spark даёт грубую оценку сверху. Поскольку Spark не перестраивает план выполнения на ходу, то в случае, когда мы уверены, что в ходе вычисления какой-нибудь из промежуточных датафреймов будет достаточно мал для BroadcastJoin, нам необходимо указать на это явно, используя подсказку .hint("broadcast").

df_receipts = spark.table("receipts")
df_milk_products = spark.table("products").filter(
    col("category_name").isin(["Молоко"])
)

# Spark оценивает размер правого датафрейма более чем 10 MB
# Будет произведен shuffle обоих датафреймов (SortMergeJoin)
df_receipts.join(df_milk_products, on="product_id")

# Подсказываем Spark выполнить BroadcastJoin правого датафрейма,
# даже если он займет больше 10 MB памяти. Таким образом
# избегаем shuffle левого (очень большого!) датафрейма
df_only_milk_receipts = (
    df_receipts
    .join(df_milk_products.hint("broadcast"), on="product_id")
)

Слева – граф вычислений обычного Join, сначала происходит Shuffle обеих таблиц, затем сортировка, а затем SortMergeJoin. Справа – граф вычислений для BroadcastJoin: для левой (большей по размеру) таблицы не требуется Shuffle, а вместо SortMergeJoin – теперь BroadcastHashJoin.

Слева — граф вычислений обычного Join, сначала происходит Shuffle обеих таблиц, затем сортировка, а затем SortMergeJoin. Справа — граф вычислений для BroadcastJoin: для левой (большей по размеру) таблицы не требуется Shuffle, а вместо SortMergeJoin — теперь BroadcastHashJoin.

Применение BroadcastJoin существенно уменьшает время выполнения, при этом нужно помнить о его особенностях:

  • Датафрейм для бродкаста должен быть действительно мал, чтобы поместиться в память каждого исполнителя.

  • Даже если фактически датафрейм очень мал, Spark может считать совершенно иначе в случаях, когда датафрейм не материализован (т.е. если не закеширован и не является таблицей), поэтому нужно явно указывать на применение BroadcastJoin, используя конструкцию .hint("broadcast").

  • BroadcastJoin неприменим для Full Outer Join.

  • BroadcastJoin неприменим для Left Join, если маленький датафрейм слева, и для Right Join, если маленький датафрейм справа.

Предварительное репартицирование

Как уже было сказано выше, операция Shuffle требуется не только для Join, но и для всех остальных широких трансформаций. Для примера рассмотрим следующий код с двумя последовательными операциями GroupBy, а также план выполнения запроса:

# Чеки в категории "Молоко"
df = df_only_milk_receipts

# Средний чек в категории "Молоко" в разрезе по неделям
stats = (
    df

    # Группировка №1
    .groupby("week", "receipt_id")
    .agg(sum("amount").alias("sum_amount"))

    # Группировка №2
    .groupby("week")
    .agg(mean("sum_amount").alias("avg_amount"))

    # Выполнение и получение статистики
    .collect()
)

51eafa16d176ca2dc516a8ebdb8c2f66.png

Давайте представим, что датафрейм df изначально партиционирован по полю "week". Это означало бы, что все чеки за одну неделю также находятся в одной и той же партиции, а значит и все строки, принадлежащие одному чеку, также находятся в одной партиции. Здравый смысл подсказывает, что при таком сценарии никаких перетасовок данных не потребуется.

Давайте проверим, умеет ли Spark исключать ненужные Shuffle: добавим предварительное репартицирование датафрейма df по полю "week" в начале запроса:

df = df_only_milk_receipts

stats = (
    df

    # Добавляем репартиртицирование по ключу, который является
    # подмножеством для обоих ключей дальнейших группировок
    .repartition("week")

    # Группировка №1
    .groupby("week", "receipt_id")
    .agg(sum("amount").alias("sum_amount"))

    # Группировка №2
    .groupby("week")
    .agg(mean("sum_amount").alias("avg_amount"))

    # Выполнение и получение статистики
    .collect()
)

52efd6d08299415bd85489b65dd82290.png

Действительно, ценой добавления одного предварительного репартицирования нам удалось избавиться от двух Shuffle, предшествующих двум группировкам.

Простое объяснение этому состоит в том, что для каждой операции Spark сравнивает два  партицирования:

  • Партицирование входного датафрейма. В примере выше df предварительно партицирован по набору полей {week}.

  • Требуемое партицирование для выполнения операции. В примере выше агрегация требует датафрейм, партицированный по набору полей  {week, receipt_id}. Если ключ входного партицирования является подмножеством требуемого, то Spark не добавляет операцию Shuffle. Так и произошло в нашем примере.

Иногда удаётся обнаружить длинные участки кода, которые можно оптимизировать добавлением одной строки .repartition(...). Для наглядности — пример из реального проекта:

keys = ["store_id", "customer_id"]
window_1 = Window.partitionBy(*keys, "receipt_id")
window_2 = Window.partitionBy(*keys).orderBy("time")

result = (
    df

    # Первый и единственный shuffle в плане выполнения
    .repartition(*keys)

    # Благодаря BroadcastJoin не репартицируем датафрейм df
    .join(df_stores.hint("broadcast"), on="store_id")
    .join(df_products.hint("broadcast"), on="product_id")

    # Ключ партиции оконной функции включает в себя поля, по
    # которым партицирован датафрейм df
    .withColumn("quantity_sum",
        F.sum("quantity").over(window_1)
    )
    .withColumn("rto_sum",
        F.sum("price").over(window_1)
    )
    .filter(...)

    # Ключ партиции оконной функции включает в себя поля, по
    # которым партицирован датафрейм df
    .withColumn("rank",
        F.rank().over(window_2)
    )
    .filter(...)

    # Ключ группировки включает в себя поля, по которым
    # партицирован датафрейм df
    .groupby(*keys, "receipt_id")
    .agg( # ...
    )
    .groupby(*keys)
    .agg( # ...
    )

    # Ключ джойна совпадает с набором полей, по которому
    # партицирован датафрейм df. Для датафрейма big_df будет
    # добавлен shuffle по полям ["store_id", "customer_id"].
    .join(big_df, on=keys)
)

# Датафрейм result по-прежнему партицирован по полям
# ["store_id", "customer_id"]
result

В данном примере количество уникальных пар ["store_id", "customer_id"] достаточно велико, сами группы достаточно малы, а значит можно не беспокоиться о том, что после .repartition(*keys) данные будут сильно перекошены.

Не стоит забывать, что датафрейм может быть партицирован и без ключа, вот несколько примеров:

  • df.repartition(200) распределит датафрейм равномерно на 200 партиций без ключа.

  • Даже если таблица сохранена в Hive в партицированном виде, spark.table("table") не унаследует партицирование. Подробнее об этом — в конце статьи в разделе про бакетирование.

  • df.union(df) размножит партиции и увеличит их количество в два раза, а значит нарушится правило «строки с одинаковым хешем лежат в одной партиции». В таком тривиальном случае union можно переписать на df.withColumn("n", explode(array(lit(1), lit(2)).drop("n"), сохранив количество партиций и ключ партицирования.

Кроме того, есть пара особенностей, связанных с countDistinct и join, из-за которых предварительное партицирование не сработает. О проблемах и вариантах их решения — ниже.

Проблема с двумя и более агрегациями countDistinct ()

Распределённый countDistinct имеет не самый очевидный алгоритм вычисления. Подробнее о нем можно прочитать, например, здесь. В этой статье рассмотрим случай, когда в одной агрегации встречаются два или более countDistinct:

(
    df
    .repartition("week") # Предварительное репартицирование
    .groupby("week", "receipt_id")
    .agg(
        countDistinct("product_id").alias("products"),
        countDistinct("brand_name").alias("brands")
    )
    .head()
)

ed5b334327ae9fa5a7d6064b71ec72b3.png

По аналогии с предыдущими кейсами мы ожидали всего один Shuffle, но получилось целых три. Как такое произошло?

Посмотрев на план выполнения, заметим, что:

  • Появился новый оператор Expand, который размножает данные. В нашем случае в 2 раза — по количеству функций countDistinct().

  • Информация о ключе партицирования датафрейма не сохранилась после применения оператора Expand (по аналогии с union). А значит любая последующая широкая трансформация неизбежно потребует новый Shuffle, что мы и видим в плане выполнения.

Для того, чтобы избежать лишних Shuffle, можно воспользоваться одним из хаков:

# 1. Замена countDistinct на collect_set + size
# Для очень больших датафреймов может вызвать ошибку
# java.lang.IllegalArgumentException: Cannot grow BufferHolder by size XXXX
# because the size after growing exceeds size limitation 2147483632
(
    df
    .repartition("week")
    .groupby("week", "receipt_id")
    .agg(
        size(collect_set("product_id")).alias("products"),
        size(collect_set("brand_name")).alias("brands")
    )
    .head()
)


# 2. С помощью оконных функций
# Требует две разных сортировки, что негативно сказывается
# на времени выполнения
from pyspark.sql import Window
window = Window.partitionBy("week", "receipt_id")
(
    df
    .repartition("week")
    .withColumn("product_id_dense_rank",
                dense_rank().over(window.orderBy("product_id")))
    .withColumn("brand_name_dense_rank",
                dense_rank().over(window.orderBy("brand_name")))
    .groupby("week", "receipt_id")
    .agg(
        max("product_id_dense_rank").alias("products"),
        max("brand_name_dense_rank").alias("brands")
    )
    .head()
)

Слева – план выполнения для замены countDistinct на size + collect_set, справа – для оконной функции dense_rank + max.

Слева — план выполнения для замены countDistinct на size + collect_set, справа — для оконной функции dense_rank + max.

Проблема с ключом Join

Мы выяснили, что если ключ партицирования является подмножеством ключа группировки, то GroupBy не требует дополнительного Shuffle. Мы могли бы ожидать такого поведения и от Join, но по какой-то причине это не так: в случае с Join необходимо, чтобы эти два ключа полностью совпадали. И это сильно усложняет нам жизнь, например:

left = spark.table("left").repartition("key")
right = spark.table("right").repartition("key")

# 1. Ключ джойна в точности совпадает с ключом партицирования
# обоих датафреймов. Дополнительный shuffle не требуется.
joined = (
    left.join(right, on="key")
    .head()
)

# 2. Ключ джойна является надмножеством ключа партицирования,
# но Spark все равно вставляет дополнительный shuffle
joined = (
    left.join(right, on=["key", "key_2"])
    .head()
)

План выполнения второго запроса выглядит следующим образом:

Два Shuffle подряд – это точно не то, что мы хотим

Два Shuffle подряд — это точно не то, что мы хотим

Для Inner Join существует известный хак: перенести часть ключа джойна в .filter. Для Outer Join простых способов избежать Shuffle не существует.

left = spark.table("left").repartition("key")
right = spark.table("right").repartition("key")

(
    left
    .join(right, on="key")
    .filter(
	  # Условие на равенство (left.key_2 == right.key_2) будет проброшено
	  # оптимизатором в ключи джойна, поэтому Spark нужно обмануть:
        (left.key_2 <= right.key_2)
        & (left.key_2 >= right.key_2)
    )
    .head()
)

Благодаря правилу PushPredicateThroughJoin оптимизатора Spark, условие из фильтра будет применяться прямо во время склеивания строк в SortMergeJoin

Благодаря правилу PushPredicateThroughJoin оптимизатора Spark, условие из фильтра будет применяться прямо во время склеивания строк в SortMergeJoin

Бакетирование таблиц

На примерах выше можно заметить, что операция Shuffle следует сразу после каждого чтения таблицы. Допустим, мы знаем, что к данным из таблицы всегда будет применяться одна и та же трансформация (например, GroupBy). Можем ли мы организовать хранение таблицы в партицированном виде так, чтобы партицирование сохранялось и при её чтении? Это позволило бы избавиться от одного бесполезного Shuffle.

Spark действительно умеет записывать партицированные таблицы:

# По умолчанию используется формат файлов parquet
df.write.partitionBy("store_id").saveAsTable("datamart.receipts")

При таком способе записи файлы будут разложены на поддиректории в следующем виде:

/user/hive/warehouse/datamart.db/receipts/
|-- store_id=1
|   `-- part-aaaaa-...-aaaaaaaaaaaa.c000.snappy.parquet
|-- store_id=2
|   `-- part-bbbbb-...-bbbbbbbbbbbb.c000.snappy.parquet
`-- store_id=3
    `-- part-ccccc-...-cccccccccccc.c000.snappy.parquet

Кажется логичным, чтобы Spark наследовал партицирование таблиц в процессе чтения, но это не так, и для этих целей в Spark предусмотрен другой способ записи таблиц, который называется бакетированием. Для этого при сохранении таблицы в Hive необходимо указать инструкцию .bucketBy:

N = df.rdd.getNumPartitions()
numBuckets = 200
df.write.bucketBy(numBuckets, "store_id").saveAsTable("datamart.receipts")

При таком способе записи таблица будет поделена на N ⨉ numBuckets файлов, где N — количество партиций в датафрейме df:

/user/hive/warehouse/datamart.db/receipts/
|-- part-11111-...-111111111111_00000.c000.snappy.parquet
|-- part-11111-...-111111111111_00001.c000.snappy.parquet
|-- ...
|-- part-11111-...-111111111111_00199.c000.snappy.parquet
|-- part-22222-...-222222222222_00000.c000.snappy.parquet
|-- ...
`-- part-NNNNN-...-NNNNNNNNNNNN_00199.c000.snappy.parquet

При чтении такой таблицы Spark сформирует датафрейм ровно с 200 партициями и будет знать, что датафрейм партицирован по полю "store_id". С некоторыми условностями можно сказать, что следующие два примера дадут идентичные датафреймы:

# 1. Бакетирование
df.write.bucketBy(200, "store_id").saveAsTable("datamart.receipts")
df = spark.table("datamart.receipts")

# 2. Репартиционирование
df = df.repartition(200, "store_id")

И теперь, применяя .groupBy, мы не видим предшествующий ему Shuffle:

# Таблица datamart.receipts бакетирована на 200 бакетов
# по полю "store_id", поэтому датафрейм df имеет 200 партиций
# с партицированием по полю "store_id"
df = spark.table("datamart.receipts")

# План выполнения следующего запроса не будет содержать
# ни одной операции shuffle
stats = (
    df
    .groupby("store_id", "receipt_id")
    .agg(sum("amount").alias("sum_amount"))
    .groupby("store_id")
    .agg(mean("sum_amount").alias("avg_amount"))
    .collect()
)

Благодаря бакетированию не осталось ни одного Shuffle

Благодаря бакетированию не осталось ни одного Shuffle

Бакетирование имеет свои недостатки:

  • Необходимо указывать количество бакетов (аргумент numBuckets). Если количество бакетов меньше, чем количество исполнителей, часть исполнителей вообще не получат данных и будут простаивать.

  • В худшем случае бакетирование таблицы приводит к созданию N ⨉ numBuckets файлов, где N — количество партиций в датафрейме df. Этого можно избежать путём репартицирования датафрейма по тем же колонкам перед записью: df.repartition(200, *keys).write.bucketBy(200, *keys).saveAsTable(...).

Заключение

Операция Shuffle является неотъемлемой частью любого Spark-приложения, но она отнимает время и создаёт большую нагрузку на сеть. Важно минимизировать количество Shuffle, чтобы сократить время выполнения задач. 

В этой статье мы рассмотрели следующие методы преобразования запросов:

  • BroadcastJoin: подсказка .hint("broadcast") сообщает Spark о том, что маленький датафрейм можно разослать на все исполнители. Это позволяет избежать Shuffle в операциях Join.

  • Предварительное репартицирование датафреймов: устранит лишние Shuffle для последовательных трансформаций с одним и тем же ключом. Для этого нужно добавить вызов .repartition. Имейте ввиду, что такое репартицирование может привести к перекосу данных.

  • Бакетирование таблиц: позволяет организовать данные так, чтобы избегать Shuffle после их чтения. Бакетирование особенно полезно в сценариях, когда к данным из таблицы всегда применяются одни и те же трансформации.

Также хочу поблагодарить Ilya Tkachev, Ilia Chernikov и Andrey Mazur за поддержку и вклад в создание этой статьи.

© Habrahabr.ru