Портируем ML модели на Java с помощью ONNX
Всем привет, меня зовут Евгений Мунин. Я Senior ML Engineer в Ad Tech в платформе ставок для Web рекламы и автор ТГ канала ML Advertising. Сегодня расскажу, как мы применяем ML модели в рекламных платформах с бэк‑ендом на JVM.
Как работают платформы ставок?
Цепочка доставки рекламы запросом от рекламодателей через посредники платформы ставок до пользователя
Все начинается с момента когда пользователь заходит на сайт, на котором содержатся рекламные слоты. В этот момент издатель, который владеет сайтом логирует вход пользователя и отправляет запрос в аггрегатор т.н. Prebid. Он в свою очередь контактирует с Supply Side платформой, которая проводит аукцион на продажу рекламного слота. Участники аукциона оценивают запрос, чтобы понять, соответствует ли он критериям рекламных кампаний клиентов, и если да, то какова его ценность для рекламодателя. После этого участники отправляют ответ с ставкой обратно. Победитель аукциона выставляет креатив своего клиента на купленное рекламное мести для показа.
Платформы ставок обрабатывают запросы пользователя, фильтруют их и делают рекомендации близко к реальному времени. Зачастую, весь процесс от запроса до размещения креатива на сайте должен занимать не более 150 мс, т.е. то время за которое у пользователя с хорошим интернетом загружается web-страница. Чтобы отвечать этому требованию бэк платформ, как правило, пишется на языках под JVM (встречал на своей практике на Java, Scala) или еще на Rust.
Так, а в чем же здесь проблема, спросите вы меня. А проблема кроется в том, платформа, обрабатывая входящий запрос от пользователя должна применять на нем ML модели, которые, например, фильтруют fraud, отсекает низкокачественный трафик или подкручивают монетизацию. Эти ML модели, которые предполагается использовать на платформах ставок, в большинстве случаев пишутся на Python фреймворках: Sklearn, PyTorch, etc. (кроме разве-что SparkML, у которого есть обертки и на Python, и на Scala). Библиотеки не особо часто запариваются над вопросами совместимости с другими языками.
Чтобы решить этот вопрос был разработан единый формат Open Neural Network Exchange (сокращенно ONNX), в который можно записать ML модели с разных библиотек и сделать их доступными для использования на платформах, в том числе под JVM.
Сегодня мы рассмотрим на примере простой модели логистической регрессии на Sklearn, которая предсказывает вероятность наличия ставки на рекламной платформе. Для того, чтобы обернуть ее в ONNX формат мы воспользуемся библиотекой ONNXRuntime и ее Java байндингом, чтобы запустить модель на JVM.
Напишем модель логистической регрессии
Для начала, нам нужно прописать сам пайплайн модели. Здесь мы воспользуемся sklearn.pipeline, который будет включать два этапа:
FeatureHasher
, чтобы хешировать входящие признаки следующим образом:hash_feat = MurMurHash3(feat) % hash_size
LogisticRegression
модель, которая предсказывает вероятность ставки на рекламный запрос
class GammaPipeline(Pipeline):
def __init__(self):
super().__init__([
("hasher", FeatureHasher(n_features=2**5, input_type="string", dtype=np.float32)),
("logreg",
LogisticRegression(
max_iter=10,
random_state=42,
penalty="l2",
solver="lbfgs",
C=1.0,
tol=1e-4,
),
),
], verbose=True)
Когда мы задали пайплайн, создадим выборку каких-нибудь данных.
data = {
"timestamp": [datetime.datetime.now() for _ in range(3)],
"feature1": ["val10", "val11", "val12"],
"feature2": ["val20", "val21", "val22"],
"target": [True, False, False],
}
df = pd.DataFrame(data)
x = df.drop(columns=["target"])
y = df["target"]
Зафитим пайплайн и отобразим его структуру
pipeline = GammaPipeline()
pipeline.fit(x.values, y)
Структура пайплайна
Сериализуем ONNX предиктор
Теперь, когда пайплайн обучен, будем кастовать его в ONNX формат. Здесь нам понадобится определить initial_type
фичей и их количество в выборке. Также стоит отключить zipmap для меток классов логистической регрессии, поскольку мы ожидаем выход модель, ни как список словарей с метками классов, а просто, как вектор вероятностнй.
initial_type = [
('input', StringTensorType([None, len(x[0])])),
]
options = {LogisticRegression: {'zipmap': False}}
onnx_model = convert_sklearn(
pipeline,
initial_types=initial_type,
options=options
)
with open(path_data + "models/onnx_log_reg.onnx", "wb") as f:
f.write(onnx_model.SerializeToString())
После того, как модель записана в ONNX, можем ее визуализировать с помощью netron.app.
Как подкапотно выглядит пайплайн. Все что слева от LinearClassifier’а — это этап хеширования фичей
Читаем ONNX предиктор и запускаем на Java
Теперь, когда модель записана в формат ONNX, можем ее десериализовать на JVM. В данном примере, я использовал Java 17.
Для начала создадим OrtEnvironment
, откроем сессию для предсказания с помощью OrtSession
класса и передадим путь до .onnx предиктора. Когда предиктор прочитан, запускаем инференс с помощью session.run
. Поскольку модель ожидает на входе структуру в виде Map
, ее ключи колонок соответствовать названием входных колонок в пайплайн, а сами колонки должны иметь тип OnnxTensor
.
public class OnnxModelRunner {
private OrtSession session;
private OrtEnvironment environment;
public OnnxModelRunner(String modelPath) throws OrtException {
environment = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
session = environment.createSession(modelPath, options);
}
public OrtSession.Result runModel(String[][] bidders) throws OrtException {
OnnxTensor inputTensor = OnnxTensor.createTensor(OrtEnvironment.getEnvironment(), bidders);
return session.run(Collections.singletonMap("input", inputTensor));
}
}
После того, как мы определили класс OnnxModelRunner, можем запускать предсказание. Подготовим фичи и создадим объект modelRunner
и запустим инференс. Также здесь же пропишем извлечение вероятностей на выходе модели.
public class ApplicationOnnx {
public static void main(String[] args) throws OrtException {
String[][] bidders = {
{"val11", "val21"},
{"val12", "val22"},
{"val13", "val23"}
};
OnnxModelRunner modelRunner = new OnnxModelRunner("onnx_log_reg_v1_2.onnx");
OrtSession.Result results = modelRunner.runModel(bidders);
StreamSupport.stream(results.spliterator(), false)
.filter(onnxItem -> Objects.equals(onnxItem.getKey(), "probabilities"))
.forEach(onnxItem -> {
OnnxValue onnxValue = onnxItem.getValue();
OnnxTensor tensor = (OnnxTensor) onnxValue;
try {
float[][] probas = (float[][]) tensor.getValue();
System.out.println(
" tensor.getValue(): " + tensor.getValue() +
"\n probas: " + Arrays.deepToString(probas)
);
} catch (OrtException e) {
throw new RuntimeException(e);
}
});
}
}
В OrtSession.Result у нас Map из двух элементов:
для лейблов |
|
для вероятностей |
|
Отфильтруем мапу по ключу probabilities
, выбрав только вероятности. Значения имеет тип OnnxValue
, который мы сначана скастуем в OnnxTensor
, а потом в массив чисел float[][]
. Расперсенный выход модели выглядит следующим образом:
Output: probabilities: OnnxTensor(info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[3, 2]))
tensor.getValue(): [[F@6438a396
probas: [[0.29334846, 0.70665157], [0.853327, 0.14667302], [0.853327, 0.14667302]]
В заключение
Мы рассмотрели простой пример, как портировать Sklearn пайплайн в ONNX, потом прочитать его и запустить предсказания на Java. Естественно, в зависимости от усложнения задач будут подниматься новые вопросы, например:
Как записывать в ONNX пайплайн с кастомными преобразованиями, и как задавать
shape_calculator
иconverter
Что делать, если требуется сохранить тип разных входных колонок, вместо то, чтобы их переводить все в string, float?
Как сделать, чтобы ONNX модель могла запускаться на предсказании на GPU
Но данного примера уже достаточно, чтобы решить проблему совместимости ML фреймворков и платформ.
Если вас интересует тема ML моделей в рекламных платформах, и как катить модели в прод, то заходите на мой ТГ канал ML Advertising!
Спасибо что дочитали о конца!