Как мы в МТС создали библиотеку для работы с графовыми нейронными сетями

Привет, Хабр! Меня зовут Диана Павликова, я работаю ML-инженером. Часто к нам приходят задачи, когда нужно повысить качество работы модели там, где обычными способами это сделать уже не получается. Мы решили применить что-то новое, поэтому обратились к теории графов и написали CoolGraph — open-source-библиотеку для работы с графовыми нейронными сетями. В этой статье я расскажу, как мы пришли к идее ее создания, как графы помогают улучшить результат, какую архитектуру мы выбрали и для каких задач подойдет этот инструмент. Все подробности — под катом.

Как родилась идея

5cb7ccfa1c11a1da5656787aace4e679.jpg

Как и некоторые из вас, мы сталкиваемся в своей практике с кредитным скорингом. С помощью ML эту задачу решили давно и уже не один раз, но все равно у бизнес-пользователей есть запрос на улучшение таких моделей разными способами.

В подобных задачах мы пытается предсказать риск дефолта клиента по кредитам в банке. Обычно условия получаются такие: мы предсказываем просрочку на k дней в пределах n месяцев. Нам надо было предсказать просрочку в 90 дней в пределах 12 месяцев.

Эта задача хорошо решается с помощью бустинга. Но нам нужно еще улучшить его результаты. Но как?

Часто при построении бустинга фичи собираются для объектов, которые рассматриваются отдельно друг от друга. Мы собираем все это в таблицу — и готово. Но что, если объекты могут быть связаны и даже зависеть друг от друга?

Здесь можно начать писать о теории графов, но на просторах интернета довольно много информации о ней. Например, на Хабре были хорошие статьи на эту тему «Теория графов. Термины и определения в картинках» и «Что такое графовые нейронные сети». На всякий случай дам краткое определение: граф (G) — это математическая абстракция, где объекты представляют собой вершины (V), а между ними есть связи (E). Все, теперь вы знаете, что граф — это не только титул.

43ba35868f607561a4cc362bc1623648.jpg

Наша команда в своей работе регулярно сталкивалась с этим понятием в разных задачах. И мы решили использовать наши навыки и знания в задачах кредитного скоринга.

Виды графов 

12186c4eeb6d77fd5b3118387dfca82c.jpg

Графы могут быть гомогенными или гетерогенными: в первом случае вершины одного типа, а во втором — разные. Например, гомогенным является граф общения в любой социальной сети между ее пользователями. Гетерогенный граф описывает взаимодействие между физическими и юридическими лицами.

С помощью графов можно показать связи между научными работами, когда авторы указывают ссылки на работы своих коллег. В этом случае, помимо разных типов вершин (ученые и их статьи), в графе будут разные типы связей: авторство статьи, цитирование, рецензия. В банковской сфере с помощью графов можно описать процесс перевода денег между клиентами разных банков: они будут представлять собой две связанные вершины, но с разным набором признаков.

Графы в ML

Все это мы решили использовать при построении скоринговой модели. Сама задача звучала так: есть обучающая выборка и есть четыре таргета. Мы предсказываем дефолт по нескольким продуктам: потребительскому кредиту, кредитной карте и так далее.

При наличии нескольких таргетов у нас есть два варианта действий. Мы можем использовать их как признаки при обучении модели или решать задачу мультитаргета. Мы выбрали второй вариант, так как с ним отлично справляются нейросети.

У нас есть размеченный датасет, который можно использовать как выборку для обучения. Каждый размеченный объект в датасете имеет связи на огромном графе. Например, возьмем для каждого объекта лишь 25 соседей. А что, если друг вашего друга — и ваш друг тоже? Поэтому берем еще 25 соседей ваших соседей. Таким образом, если обычный датасет состоит из N объектов, то при переходе к графам выборка растет экспоненциально. В примере выше для соседей 1 и 2 порядков мы получим 25 × 25 × N связей. Обычно берут глубину 2–3, то есть объект, его соседи, их соседи и, может, даже их соседи, что на несколько порядков увеличивает объем данных, участвующих в обучении, в сравнении с классическими моделями.

Мы получили очень много соседей для одного объекта. Что же дальше?

Все, что мы сделали, пока лишь привело к росту количества данных, при этом большинство объектов не имеет таргета. Кстати, именно поэтому GNN называют задачей semi-supervised learning.

35289f882d018386d8ef6f959ddcd1fa.png

Важно учитывать, что у нас есть две группы вершин с разным набором признаков. Кроме этого, у нас есть фичи ребер, и их мы тоже хотим учитывать при обучении сети. Нам остается лишь собрать их. Для исходной задачи скоринга у нас вышло 300 фичей — по 150 на каждую группу, так как мы решаем задачу с гетерогенным графом.

По факту мы оказались с:

  • огромным количеством данных, где большая часть не имеет таргета;

  • разными фичами на вершинах;

  • фичами на ребрах;

  • мультитаргетом.

Не забываем, что GNN позволяет работать с огромным количеством графовых данных. Именно GNN эффективно собирают информацию о соседних вершинах и могут выявлять паттерны без хождения по всему графу. Это снижает вычислительные затраты по сравнению с более простыми методами, которые задействуют все связи и хождения по графу. GNN могут адаптироваться к графам разных размеров, при этом такие архитектуры, как свертки (далее речь пойдет о них), подразумевают под собой уменьшение размера графа за счет агрегирования данных, а это уменьшает вычислительную нагрузку и память, при этом нет таких вершин, которые бы «затерялись».

Так выглядит графовая свертка:

2c4468dab9820b37d47a78fcdd40097f.png

Выбор архитектуры GNN

Первый вопрос, который возник на данном этапе: какую именно архитектуру будем использовать? На Pytorch Geometric их множество — мы выбрали NNConv и GraphConv.

Обе архитектуры представляют из себя графовые свертки.

Расскажу об основном отличии — это зависимость от ребер. Не всегда есть данные на ребрах, но в тех случаях, когда они есть, рекомендуется использовать именно NNConv.

9ae8d42d184e2c5ce05476a8331fab85.png

Классическая реализация GraphConv подразумевает под собой фиксированные веса на ребрах. В этом случае основное внимание уделяется агрегации данных на соседних вершинах. Признаки на ребрах никак не учитываются.

1412fc36ecdfc3fc7d4cd525b9489cdf.png

Как я уже сказала выше, нам важны признаки на ребрах. Поэтому в своей задаче мы использовали именно NNConv. Архитектурно это выглядит вот так:

b3c2aa8cefc88a9bf7c23f1e084b29f1.png

Мы получаем полносвязный слой из наших признаков на разных вершинах, затем делаем агрегации. У нас это сумма, среднее и максимальное значения. К ним добавляем наши attention scalars и получаем слой эмбеддингов, который можем использовать в других задачах. Но мы пошли дальше и получили скоры на каждый таргет.

Мы поняли, что можем поделиться своими знаниями и наработками, и решили создать библиотеку CoolGraph. Дальше буду рассказывать о ней.

CoolGraph — начинаем строить GNN с нуля легко и просто

03816bed2cbd18398893fe58e1e01fdf.jpg

GNN редко используют в работе, даже при наличии нужных данных: это долго, дорого и непонятно. С помощью CoolGraph обучить свою графовую сетку вы можете с двух строк. Но об этом ниже. На данный момент в библиотеке представлены две архитектуры: NNConv и GraphConv, но мы работаем над добавлением новых. 

Основные особенности библиотеки:  

  • Предобработка, разделение на батчи, сэмплирование соседей происходит «под капотом» CoolGraph с помощью двух строчек кода.

  • Изначально подобранные архитектуры, которые не раз показали высокое качество. 

  • Подбор гиперпараметров «под капотом» с помощью Optuna.

  • Работа с гомогенными и гетерогенными графами.

  • Решение задач классификации: мультитаргета и мультилейбла.

  • Обучение не только с фичами вершин, но и ребер.

  • Поддержка категориальных фичей.

  • Если у вас нет фичей, то CoolGraph соберет их за вас. Фичи вершин: degree centrality (нормированная степень) и page rank. Фичи ребер: сумма степеней вершин, которые соединяет ребро и количество их общих соседей.

  • Можно отслеживать процесс обучение с помощью Mlflow и смотреть как меняется метрика.

То есть главным преимуществом CoolGraph можно назвать простоту запуска обучения: все работает автоматически. Optuna сама подберет лучшие параметры для обучения и лучшие параметры архитектур. 

Особенности библиотеки

Для описания каждого пункта и рассказа про все наши решения может понадобиться серия статей (пишите в комментарии, если ждете такую серию от нас). Сейчас же я кратко остановлюсь на нескольких основных моментах.

При старте работы обычно встает вопрос: как собрать соседей? Количество соседей лучше ограничивать, как сделали мы в своей задаче: брали по 25 соседей для вершин и по 25 соседей для каждого соседа. Но иногда могут понадобиться все соседи вершины, для этого стоит указать параметр num_neighbors = [-1, -1] при обучении.

Одна из сложностей при работе с данными — это определение значения batch_size. Наша библиотека может подобрать его сама. Нужно указать лишь batch_size = «auto». Он подбирается автоматически, исходя из доступных ресурсов. 

Мы постарались найти такие параметры архитектуры, при которых запуск будет наиболее удачным. Но, конечно же, вы можете менять их под конкретную задачу. Все параметры можно подобрать с помощью Optuna. 

Решаем задачу мультитаргета с CoolGraph

Напоминаю, у нас есть задача кредитного скоринга (мультитаргет), с двумями видами вершин. Для примера пришлось сильно урезать данные, но обычно это около 15–20 млн вершин на 1 месяц. 

Наши данные выглядят так:  

Data(x=[569945, 300], edge_index=[2, 652147], edge_attr=[652147, 26], group_mask=[569945], label_mask=[569945], index=[569945], prod_1_target=[569945], prod_2_target=[569945], prod_3_target=[569945], prod_4_target=[569945])

Рассмотрим подробнее вершины:

print(data.x)
 #РЕЗУЛЬТАТ ПРИНТА 
tensor([[-0.1452, -2.4077, -0.4667,  ...,     nan,     nan,  1.0000],
        [-1.3054, -2.4077, -1.4242,  ...,     nan,     nan,  1.0000],
        [-1.3054, -2.4077, -1.4242,  ...,     nan,     nan,  1.0000],
        ...,
        [    nan,     nan,     nan,  ...,  0.0000,  0.0000,  0.0000],
        [    nan,     nan,     nan,  ...,  0.0000,  0.0000,  0.0000],
        [    nan,     nan,     nan,  ...,  0.0000,  0.0000,  0.0000]])

Почему присутствуют nan в признаках? Как уже было сказано, у нас две группы вершин с разным набором признаков. Мы их никак не изменяем. В конфигах, которые мы используем для обучения, нужно указывать признаки на каждую группу. Как работать с конфигами, мы описали в туториалах — это важная часть библиотеки.

print(data.x.shape)

#РЕЗУЛЬТАТ ПРИНТА 
torch.Size([569945, 300])

Всего у нас 569945 вершин и 300 признаков на них. 

Рассмотрим подробнее ребра

print(data.edge_index)
#РЕЗУЛЬТАТ ПРИНТА
tensor([[126334,  99288,  65081,  ..., 517312, 306248, 310050],
        [149463, 215336,   3520,  ..., 568729, 568729, 143931]])

Этот тензор представляют из себя индексы вершин, которые между собой связаны попарно. Как уже сказала выше, у нашего графа есть также фичи на ребрах, вот они:  

print(data.edge_attr)
#РЕЗУЛЬТАТ ПРИНТА
tensor([[-1.2995,  0.3142,  0.1702,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.4447,  1.9529, -1.5087,  ...,  1.1576,  0.5353,  0.5239],
        [ 0.0208, -0.1040, -0.5799,  ...,  0.0000,  0.0278,  0.0309],
        ...,
        [ 0.2748,  0.6105, -0.7103,  ...,  0.3561,  0.0000,  0.0692],
        [-0.9308, -1.3162,  0.7105,  ...,  0.0000,  0.0000,  0.8563],
        [ 0.4504,  0.8123, -0.7103,  ...,  0.0000,  0.0000,  0.0000]])

У нас есть две группы вершин (так и назовем их, group_1 и group_2), их мы маскируем:  

print(data.group_mask)
tensor([1, 1, 1,  ..., 0, 0, 0], dtype=torch.int32)

#РЕЗУЛЬТАТ ПРИНТА
tensor(0.2539)

75% относится к группе 1, остальные — к группе 2.

Как выглядят таргеты?

Напомню, помимо обучающей выборки (вершины), где у нас известен таргет, у нас есть и соседи этих вершин, у которых нет таргета. Их мы тоже маскируем. 

print(data.label_mask)
#РЕЗУЛЬТАТ ПРИНТА
tensor([False, False, False,  ..., False, False, False])

Баланс классов выглядит так:  

print(data.label_mask.float().mean())
#РЕЗУЛЬТАТ ПРИНТА
tensor(0.0029) #размеченные вершины

Да, именно у такого % вершин есть таргет, выглядит страшно, да?  

Напоминаю, у нас мультитаргет, все четыре таргета представляют из себя банковский продукт, по которому мы должны предсказать наступление дефолта. Поэтому сами таргеты выглядят так:

print(data.prod_1_target, data.prod_2_target, data.prod_3_target, data.prod_4_target)
#РЕЗУЛЬТАТ ПРИНТА
(tensor([-100, -100, -100,  ..., -100, -100, -100]),
 tensor([-100, -100, -100,  ..., -100, -100, -100]),
 tensor([-100, -100, -100,  ..., -100, -100, -100]),
 tensor([-100, -100, -100,  ..., -100, -100, -100])) # -100 — это отсутствие таргета у вершины

В нашем случае у нас есть готовое разделение на train/test от нашего заказчика. 

С помощью NeighborLoader сэмплируем соседей для каждой вершины. Такой подход описан в статье Inductive Representation Learning on Large Graphs. 

from torch_geometric.loader import NeighborLoader

loader_train = NeighborLoader(data, 
                             num_neighbors=[25, 25],
                              batch_size=250,
                              shuffle=True, 
                              input_nodes=input_nodes_train # список индексов вершин для обучения 
) 

 loader_test = NeighborLoader(data,
                              num_neighbors=[25, 250],
                              batch_size=250,
                              shuffle=True, 
                              input_nodes=input_nodes_test # список индексов вершин для теста
)

Loader_train и loader_test представляют из себя разделенный на батчи граф. 

Для этого мы итерируемся по нашим loaders (каждому) и получаем списки с батчами:  

print(train_list_loader)
#РЕЗУЛЬТАТ ПРИНТА
[Data(edge_index=[2, 101368], edge_attr=[101368, 26], group_mask=[89244], label_mask=[89244], index=[89244], prod_1_target=[89244], prod_2_target=[89244], prod_3_target=[89244], prod_4_target=[89244], input_id=[250], batch_size=250, group_2=[66428, 150], group_1=[22816, 150]),
 Data(edge_index=[2, 98442], edge_attr=[98442, 26], group_mask=[86739], label_mask=[86739], index=[86739], prod_1_target=[86739], prod_2_target=[86739], prod_3_target=[86739], prod_4_target=[86739], input_id=[250], batch_size=250, group_2=[64905, 150], group_1=[21834, 150]),
 Data(edge_index=[2, 98212], edge_attr=[98212, 26], group_mask=[86236], label_mask=[86236], index=[86236], prod_1_target=[86236], prod_2_target=[86236], prod_3_target=[86236], prod_4_target=[86236], input_id=[250], batch_size=250, group_2=[63976, 150], group_1=[22260, 150]),
 Data(edge_index=[2, 100990], edge_attr=[100990, 26], group_mask=[88026], label_mask=[88026], index=[88026], prod_1_target=[88026], prod_2_target=[88026], prod_3_target=[88026], prod_4_target=[88026], input_id=[250], batch_size=250, group_2=[64934, 150], group_1=[23092, 150]), …] 

Так как у нас маленький датасет, длина списка с train_loaders будет 15, а с test_loaders всего 5. Для примера этого достаточно. 

Данные готовы, осталось обучить GNN с помощью CoolGraph:  

from cool_graph.runners import MultiRunner # импортируем библиотеку

runner = MultiRunner(train_loader=train_list_loader,
test_loader=test_list_loader)
res = runner.run()

Теперь точно все!

Посмотрим метрики:

print(res["best_loss"])
 #РЕЗУЛЬТАТ ПРИНТА
{'prod_1_target__roc_auc__group_2’: 0.623,
 'prod_1_target__roc_auc__group_1’: 0.688,
 'prod_2_target__roc_auc__group_2': 0.613,
 'prod_2_target__roc_auc__group_1': 0.677,
 'prod_3_target__roc_auc__group_2': 0.62,
 'prod_3_target__roc_auc__group_1': 0.632,
 'prod_4_target__roc_auc__group_2': 0.689,
 'prod_4_target__roc_auc__group_1': 0.693,
 'calc_time': 0.033,
 'epoch': 0}

Обучали только одну эпоху, чтоб показать для примера. Итак, мы с двух строчек обучили сеть там, где обычно нужно самому писать архитектуру, выбирать свертки и так далее.

CoolGraph на данных Amazon Computers 

Не все данные такие сложные. Данные в нашей исходной задаче выглядят сложно: и группы вершин, и фичи на ребрах, и огромное количество данных. Хочу показать на более простом примере, а именно на открытом датасете Amazon Computers https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.Amazon.html:

Описание данных:  
nodes: 13752
edges: 491722
features: 767 (в данном датасете фичи есть только на вершинах)
classes: 10 

# импортируем датасет с torch geometric
from torch_geometric import datasets
# importing Runner
from cool_graph.runners import Runner
dataset = datasets.Amazon(root='./data/Amazon', name='Computers')
data = dataset.data

print(data)
#РЕЗУЛЬТАТ ПРИНТА
Data(x=[13752, 767], edge_index=[2, 491722], y=[13752])

Начинаем обучение

# initializing Runner
runner = Runner(data, conv_type=”GraphConv”) # здесь можно использовать эту архитектуру
# running
result = runner.run()

вот так выглядит

Sample data: 100%|██████████| 42/42 [00:04<00:00,  9.78it/s]
Sample data: 100%|██████████| 14/14 [00:01<00:00,  9.95it/s]
//2024-08-10 11:36:07 - epoch 0 test:
 //{'accuracy': 0.671, 'cross_entropy': 1.035, 'f1_weighted': 0.614, 'calc_time': 0.006, 'main_metric': 0.671}
2024-08-10 11:36:08 - epoch 0 train:
 {'accuracy': 0.661, 'cross_entropy': 1.039, 'f1_weighted': 0.604, 'calc_time': 0.017, 'main_metric': 0.661}
2024-08-10 11:36:17 - epoch 5 test:
 {'accuracy': 0.911, 'cross_entropy': 0.299, 'f1_weighted': 0.91, 'calc_time': 0.006, 'main_metric': 0.911}
2024-08-10 11:36:18 - epoch 5 train:
 {'accuracy': 0.92, 'cross_entropy': 0.246, 'f1_weighted': 0.918, 'calc_time': 0.017, 'main_metric': 0.92}
2024-08-10 11:36:27 - epoch 10 test:
 {'accuracy': 0.929, 'cross_entropy': 0.265, 'f1_weighted': 0.929, 'calc_time': 0.006, 'main_metric': 0.929}
2024-08-10 11:36:28 - epoch 10 train:
 {'accuracy': 0.956, 'cross_entropy': 0.143, 'f1_weighted': 0.956, 'calc_time': 0.017, 'main_metric': 0.956}
2024-08-10 11:36:37 - epoch 15 test:
 {'accuracy': 0.923, 'cross_entropy': 0.31, 'f1_weighted': 0.922, 'calc_time': 0.006, 'main_metric': 0.923}
2024-08-10 11:36:38 - epoch 15 train:
 {'accuracy': 0.96, 'cross_entropy': 0.125, 'f1_weighted': 0.959, 'calc_time': 0.017, 'main_metric': 0.96}
2024-08-10 11:36:47 - epoch 20 test:
 {'accuracy': 0.924, 'cross_entropy': 0.317, 'f1_weighted': 0.923, 'calc_time': 0.006, 'main_metric': 0.924}
2024-08-10 11:36:49 - epoch 20 train:
 {'accuracy': 0.968, 'cross_entropy': 0.097, 'f1_weighted': 0.967, 'calc_time': 0.019, 'main_metric': 0.968}

Смотрим результат:

print(result["best_loss”])]
#РЕЗУЛЬТАТ ПРИНТА
{'accuracy': 0.929,
 'cross_entropy': 0.265,
 'f1_weighted': 0.929,
 'calc_time': 0.006,
 'main_metric': 0.929,
 'tasks': {'y': {'accuracy': 0.928737638161722,
   'cross_entropy': 0.26483264565467834,
   'f1_weighted': 0.9286287704372129}},
 'epoch': 10}

Заключение

Есть и другие примеры использования CoolGraph. В этой статье я рассказала, как мы пришли к идее такой библиотеки и какие задачи она может решать. 

Сама библиотека распространяется по лицензии MIT. У нас есть репозиторий https://github.com/MobileTeleSystems/CoolGraph, куда вы можете репортить или задавать вопросы. Там несколько примеров, туториалы по разным частям библиотеки. В readme можно найти таблицу, где мы сравниваем нашу библиотеку с SOTA. Мы не завершаем разработку и готовы расти и развиваться.

Я с радостью помогу пользователям с вопросами, которые возникнут при ее использовании. Если есть вопросы, задавайте их в комментариях. Всем пока!

P.S. Огромное спасибо моим коллегам: Сергею Кулиеву, Владу Кузнецову, Алексею Пристайко, Игорю Иноземцеву за помощь с разработкой библиотеки, Никите Зелинскому за консультирование и поддержку и Александру Шойко за ревью статьи.

© Habrahabr.ru