Что внутри XGBoost, и при чем здесь Go

В мире машинного обучения одними из самых популярных типов моделей являются решающее дерево и ансамбли на их основе. Преимуществами деревьев являются: простота интерпретации, нет ограничений на вид исходной зависимости, мягкие требования к размеру выборку. Деревья имеют и крупный недостаток — склонность к переобучению. Поэтому почти всегда деревья объединяют в ансамбли: случайный лес, градиентный бустинг и др. Сложной теоретической и практической задачей является составление деревьев и объединение их в ансамбли.

В данной же статье будут рассмотрены процедура формирования предсказаний по уже обученной модели ансамбля деревьев, особенности реализаций в популярных библиотеках градиентного бустинга XGBoost и LightGBM. А так же читатель познакомится с библиотекой leaves для Go, которая позволяет делать предсказания для ансамблей деревьев, не используя при этом C API оригинальных библиотек.

Откуда растут деревья?


Рассмотрим сначала общие положения. Обычно работают с деревьями, где:

  1. разбиение в узле происходит по одному признаку
  2. дерево бинарно — у каждого узла есть левый и правый потомок
  3. в случае вещественного признака решающее правило состоит из сравнения значения признака с пороговым значением


Данную иллюстрацию я взял из документации XGBoost

iskkwtpwjobce6ix23itjf0l48i.png

В данном дереве имеем 2 узла, 2 решающих правил и 3 листа. Под кружочками указаны значения — результат применения дерева к какому-то объекту. Обычно, к результату вычисления дерева или ансамбля деревьев применяют функцию трансформации. Например, сигмоиду для задачи бинарной классификации.

Для получения предсказаний от ансамбля деревьев, полученного методом градиентного бустинга, нужно сложить результаты предсказаний всех деревьев:

double pred = 0.0;
for (auto& tree: trees)
    pred += tree->Predict(feature_values);


Здесь и далее будет код на C++, т.к. именно на этом языке написаны XGBoost и LightGBM. Я буду опускать несущественные детали и стараться приводить максимально лаконичный код.

Далее рассмотрим, что скрывается в Predict, и как устроена структура данных дерева.

Деревья XGBoost


В XGBoost есть несколько классов (в смысле ООП) деревьев. Будем говорить об RegTree (см. include/xgboost/tree_model.h), которая со слов документации является основной. Если оставить только детали, важные для предсказаний, то члены класса выглядят максимально просто:

class RegTree {
  // vector of nodes
  std::vector nodes_;
};


Решающее правило реализовано в функции GetNext. Код немного видоизменен, без влияния на результат вычислений:

// get next position of the tree given current pid
int RegTree::GetNext(int pid, float fvalue, bool is_unknown) const {
  const auto& node = nodes_[pid]
  float split_value = node.info_.split_cond;
  if (is_unknown) {
    return node.DefaultLeft() ? node.cleft_ : node.cright_;
  } else {
    if (fvalue < split_value) {
      return node.cleft_;
    } else {
      return node.cright_;
    }
  }
}


Отсюда следуют две вещи:

  1. RegTree работает только с вещественными признаками (тип float)
  2. поддерживаются пропущенные значения признаков


Центральным местом является класс Node. В нем содержатся локальная структура дерева, решающее правило и значение листа:

class Node {
public:
  // feature index of split condition
  unsigned SplitIndex() const {
    return sindex_ & ((1U << 31) - 1U);
  }
  // when feature is unknown, whether goes to left child
  bool DefaultLeft() const {
    return (sindex_ >> 31) != 0;
  }
  // whether current node is leaf node
  bool IsLeaf() const {
    return cleft_ == -1;
  }
private:
  // in leaf node, we have weights, in non-leaf nodes, we have split condition
  union Info {
    float leaf_value;
    float split_cond;
  } info_;
  // pointer to left, right
  int cleft_, cright_;
  // split feature index, left split or right split depends on the highest bit
  unsigned sindex_{0};
};


Можно выделить следующие особенности:

  1. листы представлены как узлы, у которых cleft_ = -1
  2. поле info_ представлено как union, т.е. два типа данных (в данном случае одинаковые) делят один участок памяти в зависимости от типа узла
  3. старший бит в sindex_ отвечает за то, куда спускается объект, у которого значение признака пропущено


Для того, чтобы можно было проследить путь от вызова метода RegTree::Predict до получения ответа, приведу недостающие две функции:

float RegTree::Predict(const RegTree::FVec& feat, unsigned root_id) const {
  int pid = this->GetLeafIndex(feat, root_id);
  return nodes_[pid].leaf_value;
}

int RegTree::GetLeafIndex(const RegTree::FVec& feat, unsigned root_id) const {
  auto pid = static_cast(root_id);
  while (!nodes_[pid].IsLeaf()) {
    unsigned split_index = nodes_[pid].SplitIndex();
    pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
  }
  return pid;
}


В функции GetLeafIndex мы в цикле спускаемся по узлам дерева, пока не попадем в лист.

Деревья LightGBM


В LightGBM нет структуры данных для узла. Вместо этого в структуре данных дерева Tree (файл include/LightGBM/tree.h) содержатся массивы значений, где в качестве индекса выступает номер узла. Значения в листьях также хранятся в отдельных массивах.

class Tree {
  // Number of current leaves
  int num_leaves_;
  // A non-leaf node's left child
  std::vector left_child_;
  // A non-leaf node's right child
  std::vector right_child_;
  // A non-leaf node's split feature, the original index
  std::vector split_feature_;
  //A non-leaf node's split threshold in feature value
  std::vector threshold_;
  std::vector cat_boundaries_;
  std::vector cat_threshold_;
  // Store the information for categorical feature handle and mising value handle.
  std::vector decision_type_;
  // Output of leaves
  std::vector leaf_value_;
};


LightGBM поддерживает категориальные признаки. Поддержка осуществляется с помощью битового поля, которое хранится в cat_threshold_ для всех узлов. В cat_boundaries_ хранит, к какому узлу какая часть битового поля соответствует. Поле threshold_ для категориального случая переводится в int и соответсвует индексу в cat_boundaries_ для поиска начала битового поля.

Рассмотрим решающее правило для категориального признака:

int CategoricalDecision(double fval, int node) const {
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  int int_fval = static_cast(fval);
  if (int_fval < 0) {
    return right_child_[node];;
  } else if (std::isnan(fval)) {
    // NaN is always in the right
    if (missing_type == 2) {
      return right_child_[node];
    }
    int_fval = 0;
  }
  int cat_idx = static_cast(threshold_[node]);
  if (FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
                  cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], int_fval)) {
    return left_child_[node];
  }
  return right_child_[node];
}


Видно, что в зависимости от missing_type значение NaN автоматически спускает решение по правой ветви дерева. Иначе NaN заменяется на 0. Поиск значения в битовом поле осуществляется достаточно просто:

bool FindInBitset(const uint32_t* bits, int n, int pos) {
  int i1 = pos / 32;
  if (i1 >= n) {
    return false;
  }
  int i2 = pos % 32;
  return (bits[i1] >> i2) & 1;
}


т.е., например, для категориального признака int_fval=42 проверяется, выставлен ли 41-ый (нумерация с 0) бит в массиве.

Этот подход имеет один существенный недостаток: если категориальный признак может принимать большие значения, например 100500, то для каждого решающего правила для этого признака будет создано битовое поле размером до 12564 байт!

Поэтому желательно перенумеровать значения категориальных признаков, чтобы они шли непрерывно от 0 до максимального значения.

Со своей стороны я внес поясняющие правки в LightGBM и их приняли.

Работа с вещественными признаками мало чем отличается от XGBoost, и я пропущу это для краткости.

leaves — библиотека для предсказаний в Go


XGBoost и LightGBM очень мощные библиотеки для построения моделей градиентного бустинга на решающих деревьях. Для их использования в backend сервисе, где необходимы алгоритмы машинного обучения, необходимо решить следующие задачи:

  1. Периодическое обучение моделей в оффлайн
  2. Доставка моделей в backend сервис
  3. Опрос моделей онлайн


Для написания нагруженного backend сервиса популярным языком является Go. Тащить XGBoost или LightGBM через C API и cgo является не самым простым решением — усложняется сборка программы, из-за неосторожного обращения можно словить SIGTERM, проблемы с количеством системных потоков (OpenMP внутри библиотек vs потоки go runtime).

Поэтому я решил написать библиотеку на чистом Go для предсказаний с помощью моделей, построенных в XGBoost или LightGBM. Она называется leaves.

leaves

Основные возможности библиотеки:

  • Для LightGBM моделей
    • Чтение моделей из стандартного формата (текстовый)
    • Поддержка вещественных и категориальных признаков
    • Поддержка пропущенных значений
    • Оптимизация работы с категориальными переменными
    • Оптимизация предсказаний за счет структур данных, рассчитанных только на предсказания

  • Для XGBoost моделей
    • Чтение моделей из стандартного формата (бинарный)
    • Поддержка пропущенных значений
    • Оптимизация предсказаний


Приведу здесь минимальную программу на Go, которая загружает модель с диска и выводит на экран предсказание:

package main

import (
        "bufio"
        "fmt"
        "os"
        "github.com/dmitryikh/leaves"
)

func main() {
        // 1. Открываем файл с моделью
        path := "lightgbm_model.txt"
        reader, err := os.Open(path)
        if err != nil {
                panic(err)
        }
        defer reader.Close()

        // 2. Читаем модель LightGBM
        model, err := leaves.LGEnsembleFromReader(bufio.NewReader(reader))
        if err != nil {
                panic(err)
        }

        // 3. Делаем предсказание!
        fvals := []float64{1.0, 2.0, 3.0}
        p := model.Predict(fvals, 0)
        fmt.Printf("Prediction for %v: %f\n", fvals, p)
}


API библиотеки минималистичен. Для использования модели XGBoost достаточно вызвать метод leaves.XGEnsembleFromReader, вместо приведенного выше. Предсказания можно делать пачками, вызывая методы PredictDense или model.PredictCSR. Больше сценариев использования можно найти в тестах к leaves.

Несмотря на то, что язык Go работает медленней C++ (в основном из-за более тяжелого runtime и проверок времени выполнения), благодаря ряду оптимизаций удалось достичь скорости предсказаний, сопоставимой с вызовом C API оригинальных библиотек.
xlrwbiikgtfpor6luj5_n0gm4c8.png

Более подробно о результатах и способе сравнений есть в репозитории на github.

Зри в корень


Надеюсь, данной статьей я приоткрыл дверцу в реализации деревьев в библиотеках XGBoost и LightGBM. Как видите, основные конструкции довольно просты, и я призываю читателей пользоваться преимуществом open source — изучать код, когда есть вопросы о том, как он работает.

Для тех же, кому интересна тема применения моделей градиентного бустинга в их сервисах на языке Go, рекомендую ознакомиться с библиотекой leaves. С помощью leaves можно довольно просто использовать leading edge решения в машинном обучении в вашей production среде, практически не проигрывая по скорости в сравнении с оригинальными реализациям на C++.

Успехов!

© Habrahabr.ru