Как проверить свои модели ONNX на Python: кратко
Привет, Хабр!
В этой статье разберем, что такое ONNX, как экспортировать модели в этот универсальный формат и, что самое главное, как протестировать их с помощью Python.
Экспорт моделей в формат ONNX
Если вы разрабатываете модели на PyTorch, экспортировать их в ONNX — это проще простого.
Перед экспортом нужно, чтобы модель и входные данные готовы. Обычно достаточно просто создать экземпляр модели и подготовить данные. Например, если вы работаете с простой нейронной сетью, то это может выглядеть так:
import torch
import torchvision.models as models
# Загружаем предобученную модель ResNet
model = models.resnet18(pretrained=True)
model.eval() # Переключаем модель в режим оценки
# Создаем случайный тензор для входных данных
dummy_input = torch.randn(1, 3, 224, 224) # Формат: [batch_size, channels, height, width]
Теперь, когда модель готова, пришло время экспортировать ее. Для этого используем метод torch.onnx.export()
:
torch.onnx.export(model,
dummy_input,
"model.onnx",
export_params=True,
opset_version=11, # Версия ONNX
do_constant_folding=True, # Оптимизация
input_names=['input'], # Названия входных слоев
output_names=['output']) # Названия выходных слоев
print("Модель успешно экспортирована в формат ONNX!")
Пару советов
Модель должна находится в режиме оценки model.eval()
, чтобы отключить такие слои, как Dropout
и BatchNorm
, которые могут привести к нестабильности при тестировании.
Если модель использует кастомные операции, нужно чтобы они поддерживались в ONNX, иначе может возникнуть ошибка при экспорте.
Теперь посмотрим на процесс экспорта из TensorFlow. Начнем с простого примера с использованием Keras.
Создадим простую модель с использованием Keras:
import tensorflow as tf
# Создаем простую модель
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# Создаем случайный тензор для входных данных
dummy_input = tf.random.normal([1, 784]) # Формат: [batch_size, features]
Для экспорта модели из TensorFlow в ONNX, используем библиотеку tf2onnx
, которая позволяет конвертировать модели TensorFlow в формат ONNX:
pip install tf2onnx
Теперь можем экспортировать модель:
import tf2onnx
# Экспортируем модель
onnx_model = tf2onnx.convert.from_keras(model, output_path="model.onnx")
print("Модель успешно экспортирована в формат ONNX!")
Тестирование моделей ONNX
Прежде чем запускать тесты, необходимо определиться с тестовыми данными. Здесь важно учитывать несколько факторов:
Тип данных: Если вы работаете с изображениями, используйте наборы данных, которые были использованы для обучения модели. Для текстовых данных — нужно разнообразие, чтобы модель могла обрабатывать различные случаи.
Размер тестового набора: Тестовый набор должен быть велик, чтобы покрыть все возможные сценарии. Чем больше данных, тем выше вероятность выявить проблемы.
Критерии оценки должны включать метрики производительности, такие как:
Точность: Сравните предсказания вашей модели с реальными значениями. Используйте метрики: точность, полнота и F1-мера.
Время выполнения: Измерьте время, необходимое для выполнения предсказаний на тестовых данных. Можно использовать встроенные функции времени в Python, например как
time.time()
.Использование памяти: Проверьте, насколько хорошо модель использует ресурсы. Можно использовать библиотеку
memory-profiler
, для отслеживания использования памяти вашей модели во время выполнения.
На рынке существует несколько инструментов для тестирования моделей ONNX, и мы сосредоточимся на двух основных:
ONNX Runtime: Это высокопроизводительное исполнение моделей ONNX, которое обеспечивает быструю и эффективную работу.
NumPy: Для анализа и обработки данных, а также для сравнения предсказаний с реальными значениями.
Теперь перейдем к практике.
pip install onnxruntime numpy
Начнем с загрузки нашей модели и подготовки тестовых данных. Допустим, есть модель, которая принимает на вход изображения размером 224×224 пикселя:
import onnx
import onnxruntime as ort
import numpy as np
from PIL import Image
# Загрузка модели
onnx_model = onnx.load("model.onnx")
ort_session = ort.InferenceSession("model.onnx")
# Функция для подготовки изображения
def preprocess_image(image_path):
image = Image.open(image_path)
image = image.resize((224, 224))
image = np.array(image).astype(np.float32) # Преобразуем в float32
image = image.transpose(2, 0, 1) # Изменяем порядок осей на [C, H, W]
image = np.expand_dims(image, axis=0) # Добавляем размерность [1, C, H, W]
return image
# Пример загрузки изображения
test_image = preprocess_image("test_image.jpg")
Теперь можно сделать предсказание на основе подготовленного изображения:
# Выполняем предсказание
outputs = ort_session.run(None, {onnx_model.graph.input[0].name: test_image})
# Вывод предсказания
predicted_class = np.argmax(outputs[0]) # Получаем класс с максимальной вероятностью
print(f"Предсказанный класс: {predicted_class}")
Теперь протестируем, как быстро работает наша модель на нескольких изображениях. Создадим цикл для тестирования:
import time
test_images = ["image1.jpg", "image2.jpg", "image3.jpg"] # Список изображений для тестирования
total_time = 0
for image_path in test_images:
image = preprocess_image(image_path)
start_time = time.time()
outputs = ort_session.run(None, {onnx_model.graph.input[0].name: image})
end_time = time.time()
predicted_class = np.argmax(outputs[0])
total_time += end_time - start_time
print(f"Изображение: {image_path}, Предсказанный класс: {predicted_class}, Время: {end_time - start_time:.4f} секунд")
print(f"Среднее время предсказания: {total_time / len(test_images):.4f} секунд")
Заключение
Тестирование моделей в формате ONNX позволяет уверенно запускать ваши решения в продакшн. Не забывайте, что качественное тестирование — это залог успеха вашего проекта.
Чтобы получить актуальные знания по ML, приходите на ближайшие открытые уроки:
7 октября: «Word2Vec — классика векторных представлений слов для решения задач текстовой обработки». Записаться
8 октября: «Spark ML: обзор, разработка модели на Spark ML, вывод модели в промышленное использование». Записаться
10 октября: «Тестирование торговых стратегий с помощью инструмента «Backtrading». Записаться