Разворачиваем ML модель с использованием ONNX на Android в километре над землей

В свободное от работы время я летаю на параплане. Это такая штука, внешне похожая на парашют, но способная пролетать сотни километров маршрутов и висеть в воздухе часами. И это безо всякого мотора!

Парим в потоке на берегу Камы

Парим в потоке на берегу Камы

Полет происходит за счет поиска восходящих потоков и набора высоты в них. Вы наверное видели, как чайки или орлы какие‑нибудь кружат на одном месте, не маша крыльями, и поднимаются все выше. Вот так же и мы. Нашел поток, покружил, набрал высоту, полетел по маршруту до следующего потока, постепенно снижаясь.

Вот только проблема — потоков мы не видим, и их поиск — сложная и нетривиальная часть полетов, на обучение которой у людей уходят годы. При этом даже опытные пилоты часто не могут внятно сформулировать, почему они полетели искать поток именно вот туда. «Просто почувствовал что он вот там» на основании движения воздуха и крыла.

И пришла мне в голову мысль, что это хорошая задача чтобы попробовать использовать ML для помощи пилоту. Все это «просто почувствовал, не могу сформулировать как» — он прям просится, чтобы в этом месте попробовать заменить естественный интеллект на искусственный. По сути это задача классификации: по возмущениям крыла сказать, есть поток рядом или нет.

Пара слов про задачу

Более подробно про саму ML модель я расскажу как‑нибудь позже. Пока еще не понятно, будет ли вообще оно работать. Но вкратце идея такова:

  • Вокруг восходящего потока всегда есть определенная турбулентность. Есть нисходящие потоки, есть горизонтальные течения воздуха. Именно на это все и ориентируются опытные пилоты в своем поиске. Летишь‑летишь ровненько в спокойном воздухе, вдруг опа — чувствуешь, что крыло начало покачивать, оно начало само ускоряться — значит, поток рядом. Эти покачивания вполне видны на GPS треке полета

  • Окей, берем трек за N последних секунд полета и пытаемся по его форме понять, что поток рядом есть, в идеале еще бы понять где он и как далеко. Скармливаем его ML модели, которая нам это все и сообщает

  • Обучаем модель на базах треков полетов. Есть куча сайтов, куда пилоты заливают свои треки, так что накачать тысячи километров и сотни часов — не проблема.

Вроде все ясно, берем в руки Jupyter Notebook, всякие фреймворки для ML (я использовал сделанный у нас же в ИТМО FEDOT, который перебирает варианты классификаторов и их параметры в поисках лучшего для данного набора данных) и получаем какую‑то модель. Точнее три модели, я решил разбить отдельно на модель для получения наличия потока рядом, для расстояния до него и направления.

Окей, в блокноте у нас все работает, а дальше‑то что? Мне же надо не просто абстрактные циферки F1 и всяких пресижн‑реколлов, мне надо мою модель как‑то в полете использовать, а ноутбук с собой не утащишь (не, у нас был парень-трубач, который умудрялся со здоровенной валторной летать и трубить в воздухе, но я не такой акробат-экстремал).

Выглядит кстати оно примерно вот так:

Первая версия приложения в условиях тестового полета в Татарстане сообщает, что поток в 50-100 метрах впереди

Первая версия приложения в условиях тестового полета в Татарстане сообщает, что поток в 50–100 метрах впереди

Надо получить доступ к нашей модели, изначально созданной в Python фреймворке, с Android смартфона, который есть у каждого пилота и используется для всякого полетного софта. Не вопрос, подумал я. И быстро накидал на FastAPI веб‑сервис, внутрь которого засунул свою модель, выставил наружу API и сваял простой Android‑фронтенд к нему.

Но вышел облом. Потому что в воздухе интернета внезапно нету. Можно закладывать виражи прямо над вышкой сотовой связи (над ними, кстати, обычно хорошие потоки стоят), но самой связи‑то и не будет. Как я понимаю — потому, что антенны направлены в землю, а не в небо. В итоге на поле под антенной ловит отличный 4G и можно смотреть ютубчик в FullHD, а как поднимешься на 400 метров и выше — связи уже вообще нет. А мы летаем и на километре, и на двух, и даже бывает выше.

В общем, идея с доступом к Python модели по сети оказалась неудачной. Значит, надо завести модель на смартфоне. А как? Python на Android вроде особо не живет.

На помощь пришел ONNX — Open Neural Network Exchange (https://onnx.ai). Это открытый формат для сохранения и обмена ML моделями между различными фреймворками. Кроме самого формата файлов еще есть ONNX‑Runtime: набор библиотек, позволяющих выполнять инференс (то есть использование уже готовой обученной модели) на тех устройствах, под которые этот рантайм есть. Под Android, к счастью, есть.

Процесс работы выглядит как:

  1. Конвертируем нашу модель в onnx формат

  2. Подключаем рантайм к нашему Android приложению

  3. Загружаем модель

  4. Делаем предсказания

Конвертация модели

Модели свои я в итоге сделал в Scikit-Learn, для этого фреймворка есть готовые скрипты (модуль называется skl2onnx), которые умеют конвертировать все стандартные классификаторы. Документация по ним доступна тут https://onnx.ai/sklearn-onnx/introduction.html

Пример кода:

clf = RandomForestClassifier()
# .... Обучение классификатора, X - наш датафрейм с признаками

from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

initial_type = [('float_input', FloatTensorType([None, X.shape[1]]))]
onnx_model = convert_sklearn(clf, initial_types=initial_type)

onnx_model_file = 'model.onnx'
with open(onnx_model_file, 'wb') as f:
    f.write(onnx_model.SerializeToString())

Подключение в Android и загрузка модели

Само подключение делается тривиально. Добавляем зависимость:

   implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.17.1'

Создаем сессию:

val ortEnvironment = OrtEnvironment.getEnvironment()
val modelBytes: ByteArray = ... // читаем модель откуда-нибудь в ByteArray
val session = ortEnvironment.createSession( modelBytes )

И тут вляпываемся в проблему. Мои модели оказались довольно крупными (.onnx файл получался размером 300–400 Мб). В сэмплах код загрузки обычно читает файл модели в массив байт и затем скармливает его методу открытия сессии. Вот только в Android довольно ограниченный размер памяти у джава-приложений, в итоге выделить 300 Мб буфер оно мне не дало.

У onnx-runtime есть метод для загрузки модели напрямую по имени файла. Сам рантайм является оберткой вокруг нативного кода, и если передать имя файла ему, то открывается он в памяти за пределами JVM. Но тут опять начинаются любимые приколы Android с доступом к файлам. Нельзя так просто взять и прочитать файл из Downloads. В итоге пришлось создавать папку приложения из кода и кидать в нее через Android Studio (потому что ни в каком файловом браузере она у меня не отображалась) файлы моделей.

val onnxDir = context.getDir("onnx", Context.MODE_PRIVATE)
val session = ortEnvironment.createSession( File(onnxDir, "model.onnx").absoluteFile.absolutePath )

Правильным подходом в таком случае наверное будет скачивание актуальной версии модели откуда-то по сети из самого приложения и сохранения ее во внутреннюю память, но для прототипа это все лишний код и возня, которой очень не хотелось заниматься

Использование модели

После того, как модель загружена, можем использовать. Для использования нам надо создать объект класса Tensor. В моем случае у модели 160 входов, представляющих всякие вещественные числа (координаты, высоты за последние 30 секунд). Я сперва складываю их в список, затем создаю из них Tensor.

Код подготовки данных выглядит примерно так:

val columnValues = mutableListOf()
columnValues.add(lastPoint.point.lat.toFloat())
columnValues.add(lastPoint.point.lon.toFloat())
columnValues.add(lastPoint.x.toFloat())
// Тут еще 100500 значений

// Делаем буфер
val bufferInputs = FloatBuffer.wrap( columnValues.toTypedArray().toFloatArray() )
// Превращаем в тензор
val tensor = OnnxTensor.createTensor( ortEnvironment , bufferInputs , longArrayOf( 1, columnValues.size.toLong() ) )

И теперь мы можем делать предсказания. У меня в приложении используется два типа моделей. Одна классифицирует по бинарному признаку (есть или нет поток), в этом случае выходным узлом ONNX модели является тип Long, который возвращает 0 или 1:

fun makePrediction(ortSession: OrtSession, tensor: OnnxTensor): Boolean {
    // Получаем имя входного узла (его мы указывали в Python экспорте модели в файл)
    val inputName = ortSession.inputNames?.iterator()?.next()
    // Выполняем классификацию
    val results = ortSession.run( mapOf( inputName to tensor ) )
    // Получаем результат
    val output = results[0].value as LongArray
    return output[0] != 0
}

В случае с мультиклассовой классификацией ONNX модель нам вернет строку с именем класса:

fun makeStringPrediction(ortSession: OrtSession, tensor: OnnxTensor): String {
    val inputName = ortSession.inputNames?.iterator()?.next()
    val results = ortSession.run( mapOf( inputName to tensor ) )
    val output = results[0].value as Array
    return output[0]
}

Вуаля, мы запустили нашу ML модель на Android. При этом еще и получилось быстрее: сетевой запрос к моему Python FastAPI сервису (когда все-таки удавалось его выполнить в полете) занимал 2–3 секунды, а локально модель отрабатывает за 100–200 мс.

Проверка совпадения результатов

В ходе тестирования модели на земле у меня возникло ощущение, что оно работает как-то не так. Я ходил по земле пешочком и ожидал, что хотя бы иногда оно мне выдаст False Positive и нарисует поток на экране, но оно всегда выдавало 0 из первого классификатора. Я решил проверить, что моя ONNX модель работает точно как надо и совпадает с sklearn-моделью в питоне. А не всегда тупо возвращает 0.

ONNX Runtime есть и под Python. Так что можно загрузить свою модель и сравнить результаты исходной и сконвертированной. Этот код мне писал ChatGPT, за что ему большое спасибо (я ж не настоящий сварщикML-щик, я эту маску на стройке нашел).

Тут на входе у нас два файла моделей (питоновский pkl и onnx)

# Метод для загрузки оригинальной модели из .pkl файла
def load_pickle_model(model_path):
    with open(model_path, 'rb') as f:
        model = joblib.load(f)
    return model

# Метод для загрузки onnx модели
def load_onnx_model(model_path):
    session = ort.InferenceSession(model_path)
    return session

# Код предсказания
def predict(model, input_data):
    if isinstance(model, ort.InferenceSession):
        input_name = model.get_inputs()[0].name
        output_name = model.get_outputs()[0].name
        return model.run([output_name], {input_name: input_data.astype(np.float32)})[0]
    else:
        return model.predict(input_data)

pickle_model_path = 'model.pkl'
onnx_model_path = 'model.onnx'

# Загружаем две модели из разных файлов

pickle_model = load_pickle_model(pickle_model_path)
onnx_model = load_onnx_model(onnx_model_path)

# ... тут еще грузим input_data из файлов с треками

# Предсказываем
onnx_predictions = predict(onnx_model, input_data)
pickle_predictions = predict(pickle_model, input_data)

# Если хоть один элемент массивов не совпадет - выдастся предупреждение
if np.allclose(pickle_predictions, onnx_predictions):
    print("Результаты предсказаний моделей совпадают.")
else:
    print("Результаты предсказаний моделей не совпадают.")

У меня результаты совпали. А как я позже понял, моя хитрая модель просто раскусила, что ниже 200–300 метров значимые восходящие потоки никогда не образуются, поэтому и выдавала мне всегда 0, когда я ходил по земле. Стоило вручную подхачить высоту, и она начала иногда выдавать ожидаемые FP результаты. Вот за такие моменты я люблю ML — когда сперва думаешь что ничего не работает, а оказывается это твоя модель просто стала умнее тебя.

Заключение

Питон это хорошо, но ML модели могут работать в самых разных окружениях, в том числе и там где его нет. ONNX оказался неплохим выбором — рантаймы под него есть под разные платформы, в том числе Android.

Проблемы могут возникнуть со сложными моделями, так как скрипты конвертации или рантайм могут их не поддерживать. Но в некоторых ситуациях выбора не остается. Как когда вы висите в воздухе на большой высоте, у вас только смартфон, и интернета нет.

Полеты в предгорьях Гималаев. И как бы тут Jupyter Notebook запустить?

Полеты в предгорьях Гималаев. И как бы тут Jupyter Notebook запустить?

Скоро в Питере начнется сезон, испытаю свои модели в реальном полете. Хитрый план — чтобы моя чудо-технология помогла мне на соревнованиях в мае.

Напишу потом статью, если оно сработает. Пока посчитанные метрики показывают, что оно хотя бы в теории работоспособно, и какие-то признаки потоков из трека вытащить можно, это не полный рандом. Осталось проверить, будет ли оно практически полезно.

© Habrahabr.ru