Портируем ML модели на Java с помощью ONNX

Всем привет, меня зовут Евгений Мунин. Я Senior ML Engineer в Ad Tech в платформе ставок для Web рекламы и автор ТГ канала ML Advertising. Сегодня расскажу, как мы применяем ML модели в рекламных платформах с бэк‑ендом на JVM.

Как работают платформы ставок?

Цепочка доставки рекламы запросом от рекламодателей через посредники платформы ставок до пользователя

Цепочка доставки рекламы запросом от рекламодателей через посредники платформы ставок до пользователя

Все начинается с момента когда пользователь заходит на сайт, на котором содержатся рекламные слоты. В этот момент издатель, который владеет сайтом логирует вход пользователя и отправляет запрос в аггрегатор т.н. Prebid. Он в свою очередь контактирует с Supply Side платформой, которая проводит аукцион на продажу рекламного слота. Участники аукциона оценивают запрос, чтобы понять, соответствует ли он критериям рекламных кампаний клиентов, и если да, то какова его ценность для рекламодателя. После этого участники отправляют ответ с ставкой обратно. Победитель аукциона выставляет креатив своего клиента на купленное рекламное мести для показа.

Платформы ставок обрабатывают запросы пользователя, фильтруют их и делают рекомендации близко к реальному времени. Зачастую, весь процесс от запроса до размещения креатива на сайте должен занимать не более 150 мс, т.е. то время за которое у пользователя с хорошим интернетом загружается web-страница. Чтобы отвечать этому требованию бэк платформ, как правило, пишется на языках под JVM (встречал на своей практике на Java, Scala) или еще на Rust.

Так, а в чем же здесь проблема, спросите вы меня. А проблема кроется в том, платформа, обрабатывая входящий запрос от пользователя должна применять на нем ML модели, которые, например, фильтруют fraud, отсекает низкокачественный трафик или подкручивают монетизацию. Эти ML модели, которые предполагается использовать на платформах ставок, в большинстве случаев пишутся на Python фреймворках: Sklearn, PyTorch, etc. (кроме разве-что SparkML, у которого есть обертки и на Python, и на Scala). Библиотеки не особо часто запариваются над вопросами совместимости с другими языками.

Чтобы решить этот вопрос был разработан единый формат Open Neural Network Exchange (сокращенно ONNX), в который можно записать ML модели с разных библиотек и сделать их доступными для использования на платформах, в том числе под JVM.

Сегодня мы рассмотрим на примере простой модели логистической регрессии на Sklearn, которая предсказывает вероятность наличия ставки на рекламной платформе. Для того, чтобы обернуть ее в ONNX формат мы воспользуемся библиотекой ONNXRuntime и ее Java байндингом, чтобы запустить модель на JVM.

Напишем модель логистической регрессии

Для начала, нам нужно прописать сам пайплайн модели. Здесь мы воспользуемся sklearn.pipeline, который будет включать два этапа:

  • FeatureHasher, чтобы хешировать входящие признаки следующим образом: hash_feat = MurMurHash3(feat) % hash_size

  • LogisticRegression модель, которая предсказывает вероятность ставки на рекламный запрос

class GammaPipeline(Pipeline):
  def __init__(self):
    super().__init__([
        ("hasher", FeatureHasher(n_features=2**5, input_type="string", dtype=np.float32)),
        ("logreg",
            LogisticRegression(
                max_iter=10,
                random_state=42,
                penalty="l2",
                solver="lbfgs",
                C=1.0,
                tol=1e-4,
            ),
        ),
    ], verbose=True)

Когда мы задали пайплайн, создадим выборку каких-нибудь данных.

data = {
    "timestamp": [datetime.datetime.now() for _ in range(3)],
    "feature1": ["val10", "val11", "val12"],
    "feature2": ["val20", "val21", "val22"],
    "target": [True, False, False],
}

df = pd.DataFrame(data)

x = df.drop(columns=["target"])
y = df["target"]

Зафитим пайплайн и отобразим его структуру

pipeline = GammaPipeline()
pipeline.fit(x.values, y)

Структура пайплайна

Структура пайплайна

Сериализуем ONNX предиктор

Теперь, когда пайплайн обучен, будем кастовать его в ONNX формат. Здесь нам понадобится определить initial_type фичей и их количество в выборке. Также стоит отключить zipmap для меток классов логистической регрессии, поскольку мы ожидаем выход модель, ни как список словарей с метками классов, а просто, как вектор вероятностнй.

initial_type = [
    ('input', StringTensorType([None, len(x[0])])),
]

options = {LogisticRegression: {'zipmap': False}}

onnx_model = convert_sklearn(
    pipeline,
    initial_types=initial_type,
    options=options
)

with open(path_data + "models/onnx_log_reg.onnx", "wb") as f:
  f.write(onnx_model.SerializeToString())

После того, как модель записана в ONNX, можем ее визуализировать с помощью netron.app.

Как подкапотно выглядит пайплайн. Все что слева от LinearClassifier'а - это этап хеширования фичей

Как подкапотно выглядит пайплайн. Все что слева от LinearClassifier’а — это этап хеширования фичей

Читаем ONNX предиктор и запускаем на Java

Теперь, когда модель записана в формат ONNX, можем ее десериализовать на JVM. В данном примере, я использовал Java 17.

Для начала создадим OrtEnvironment, откроем сессию для предсказания с помощью OrtSession класса и передадим путь до .onnx предиктора. Когда предиктор прочитан, запускаем инференс с помощью session.run. Поскольку модель ожидает на входе структуру в виде Map, ее ключи колонок соответствовать названием входных колонок в пайплайн, а сами колонки должны иметь тип OnnxTensor .

public class OnnxModelRunner {
    private OrtSession session;
    private OrtEnvironment environment;

    public OnnxModelRunner(String modelPath) throws OrtException {
        environment = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions options = new OrtSession.SessionOptions();
        session = environment.createSession(modelPath, options);
    }

    public OrtSession.Result runModel(String[][] bidders) throws OrtException {
        OnnxTensor inputTensor = OnnxTensor.createTensor(OrtEnvironment.getEnvironment(), bidders);
        return session.run(Collections.singletonMap("input", inputTensor));
    }
}

После того, как мы определили класс OnnxModelRunner, можем запускать предсказание. Подготовим фичи и создадим объект modelRunner и запустим инференс. Также здесь же пропишем извлечение вероятностей на выходе модели.

public class ApplicationOnnx {

    public static void main(String[] args) throws OrtException {

        String[][] bidders = {
            {"val11", "val21"},
            {"val12", "val22"},
            {"val13", "val23"}
        };

        OnnxModelRunner modelRunner = new OnnxModelRunner("onnx_log_reg_v1_2.onnx");

        OrtSession.Result results = modelRunner.runModel(bidders);

        StreamSupport.stream(results.spliterator(), false)
                .filter(onnxItem -> Objects.equals(onnxItem.getKey(), "probabilities"))
                .forEach(onnxItem -> {
                    OnnxValue onnxValue = onnxItem.getValue();
                    OnnxTensor tensor = (OnnxTensor) onnxValue;
                    try {
                        float[][] probas = (float[][]) tensor.getValue();
                        System.out.println(
                                "    tensor.getValue(): " + tensor.getValue() +
                                        "\n    probas: " + Arrays.deepToString(probas)
                        );
                    } catch (OrtException e) {
                        throw new RuntimeException(e);
                    }
                });
    }
}

В OrtSession.Result у нас Map из двух элементов:

для лейблов

OnnxTensor(info=TensorInfo(javaType=INT64,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,shape=[3]))

для вероятностей

OnnxTensor(info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[3, 2]))

Отфильтруем мапу по ключу probabilities, выбрав только вероятности. Значения имеет тип OnnxValue, который мы сначана скастуем в OnnxTensor, а потом в массив чисел float[][]. Расперсенный выход модели выглядит следующим образом:

Output: probabilities: OnnxTensor(info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[3, 2]))
    tensor.getValue(): [[F@6438a396
    probas: [[0.29334846, 0.70665157], [0.853327, 0.14667302], [0.853327, 0.14667302]]

В заключение

Мы рассмотрели простой пример, как портировать Sklearn пайплайн в ONNX, потом прочитать его и запустить предсказания на Java. Естественно, в зависимости от усложнения задач будут подниматься новые вопросы, например:

  • Как записывать в ONNX пайплайн с кастомными преобразованиями, и как задавать shape_calculator и converter

  • Что делать, если требуется сохранить тип разных входных колонок, вместо то, чтобы их переводить все в string, float?

  • Как сделать, чтобы ONNX модель могла запускаться на предсказании на GPU

Но данного примера уже достаточно, чтобы решить проблему совместимости ML фреймворков и платформ.

Если вас интересует тема ML моделей в рекламных платформах, и как катить модели в прод, то заходите на мой ТГ канал ML Advertising!

Спасибо что дочитали о конца!

© Habrahabr.ru