Интеграция TFLite во Flutter: внедряем модели машинного обучения в мобильное приложение
Привет! Меня зовут Никита Грибков, я Flutter-разработчик в AGIMA. В этой статье расскажу про фреймворк TensorFlow Lite, который позволяет интегрировать в мобильное приложение модели машинного обучения. Это полезная штука, если нужно реализовать фичи, связанные с распознаванием речи или с классификацией изображений. Покажу, как обучать модели и как затем с ними работать.
Технология позволяет создавать персонализированные и интеллектуальные решения для пользователей, поэтому пользуется высоким спросом. Если наша цель — сделать приложение более удобным и инклюзивным, то, скорее всего, придется использовать ML.
Вот несколько примеров задач, для которых технология 100% подходит:
классификация изображений: чтобы приложение могло распознавать объекты на фотографиях или видео (например, Google Lens);
обработка естественного языка (NLP): в приложениях с голосовыми ассистентами или чат-ботами ML обрабатывает речь и тексты (например, Siri или Google Assistant);
персонализация: алгоритмы ML анализируют поведение пользователей и предлагают персонализированный контент или рекомендации;
распознавание голоса: используется в приложениях для конвертации речи в текст и команд.
Существует несколько способов, как интегрировать модели машинного обучения в приложение. Можно воспользоваться ML Kit от Firebase или библиотеки на Dart. Но лично я пробовал работать с TensorFlow Lite (TFLite). Этот фреймворк можно считать самым распространенным решением в данном случае.
Его главное (но не единственное) преимущество — что он может работать в офлайне, когда устройство не подключено к интернету. Также мне нравится, что TFLite оптимизирован для работы на устройствах с ограниченными ресурсами, это удобно. Разберем, как фреймворк работает.
Подготовка модели для использования с TFLite
Прежде чем интегрировать TFLite во Flutter-приложение, необходимо подготовить модель. Это предполагает её обучение в TensorFlow и конвертацию в формат .tflite.
Шаг 1. Создание и обучение модели в TensorFlow
Для работы с машинным обучением вы можете обучить модель с помощью TensorFlow. Вот простой пример создания и обучения модели на Python:
import tensorflow as tf
from tensorflow.keras import layers
# Создание простой модели для классификации изображений
model = tf.keras.Sequential([
layers.Flatten(input_shape=(28, 28)),
layers.Dense(128, activation='relu'),
layers.Dense(10, activation='softmax')
])
Компиляция модели:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
Обучение модели на данных MNIST:
model.fit(train_images, train_labels, epochs=5)
Сохранение модели:
model.save("model.h5")
Сеть состоит из одного слоя для преобразования 28×28 пикселей в одномерный вектор, скрытого слоя с 128 нейронами и выходного слоя с 10 нейронами для 10 классов.
model = tf.keras.Sequential([
layers.Flatten(input_shape=(28, 28)),
layers.Dense(128, activation='relu'),
layers.Dense(10, activation='softmax')
])
Модель компилируется с использованием оптимизатора Adam и функции потерь Sparse Categorical Crossentropy.
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
Затем обучается на данных MNIST в течение 5 эпох и сохраняется в файл «model.h5».
model.fit(train_images, train_labels, epochs=5)
Шаг 2: Конвертация модели в формат TFLite
После обучения модели ее нужно преобразовать в формат .tflite с помощью TFLite-конвертера.
Пример кода для конвертации модели:
# Загрузка модели
model = tf.keras.models.load_model('model.h5')
# Конвертация модели в формат TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# Сохранение модели в формате .tflite
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
Теперь у вас есть модель в формате .tflite, которую можно интегрировать в приложение на Flutter.
Интеграция TFLite в Flutter-приложение
Для работы с TFLite в Flutter нужно использовать плагин tflite_flutter
. Этот репозиторий — управляемый TensorFlow форк проекта — управляемый TensorFlow форк проекта [tflite_flutter_plugin]
Шаг 1. Установка необходимых зависимостей
Откройте файл pubspec.yaml вашего Flutterпроекта и добавьте зависимости:
dependencies:
flutter:
sdk: flutter
tflite_flutter: ^0.11.0
tflite_flutter_helper_plus: ^0.0.2
Шаг 2. Подготовка модели
Скопируйте файл вашей модели model.tflite в папку проекта assets. Затем в файле pubspec.yaml укажите путь к модели в разделе assets:
flutter:
assets:
- assets/model.tflite
- assets/labels.txt # если у вас есть файл с метками
Шаг 3. Загрузка и использование модели в коде Flutter
Теперь создадим код для загрузки модели и выполнения предсказаний на ее основе на стороне Flutter.
Импорт пакетов:
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tflite_flutter_helper_plus/tflite_flutter_helper_plus.dart';
Загрузка модели:
late Interpreter interpreter;
Future loadModel() async {
try {
// Загружаем модель из assets
interpreter = await Interpreter.fromAsset('model.tflite');
print('Модель загружена успешно');
} catch (e) {
print('Ошибка загрузки модели: $e');
}
}
Этот код преобразует изображение в массив Float32List. Он берет каждый пиксель изображения, извлекает значения красного, зеленого и синего каналов, нормализует их с помощью заданных mean и std, а затем заполняет массив.
Float32List imageToByteListFloat32(
img.Image image, int inputSize, double mean, double std) {
var convertedBytes = Float32List(1 * inputSize * inputSize * 3);
var buffer = Float32List.view(convertedBytes.buffer);
int pixelIndex = 0;
for (var i = 0; i < inputSize; i++) {
for (var j = 0; j < inputSize; j++) {
var pixel = image.getPixel(j, i);
buffer[pixelIndex++] = ((img.getRed(pixel) - mean) / std);
buffer[pixelIndex++] = ((img.getGreen(pixel) - mean) / std);
buffer[pixelIndex++] = ((img.getBlue(pixel) - mean) / std);
}
}
return convertedBytes;
}
Выполнение предсказаний
Для выполнения предсказаний нужно преобразовать входные данные в подходящий формат, например, изображение в тензор (массив данных).
Future classifyImage(File image) async {
// Преобразуем изображение в тензор
final img.Image imageInput = img.decodeImage(image.readAsBytesSync())!;
var inputImage = img.copyResize(imageInput, width: 28, height: 28);
var input = imageToByteListFloat32(inputImage, 28, 127.5, 127.5);
// Подготовка выходного тензора
var output = List.filled(10, 0).reshape([1, 10]);
// Выполнение предсказания
_interpreter.run(input, output);
setState(() {
_result = 'Предсказание: ${output.toString()}';
});
}
Flutter с использованием TFLite
import 'dart:typed_data';
import 'package:flutter/material.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:image_picker/image_picker.dart';
import 'dart:io';
import 'package:image/image.dart' as img;
class MyHomePage extends StatefulWidget {
@override
_MyHomePageState createState() => _MyHomePageState();
}
class _MyHomePageState extends State {
late Interpreter _interpreter;
File? _image;
final picker = ImagePicker();
String _result = 'Нет предсказаний';
@override
void initState() {
super.initState();
loadModel();
}
Future loadModel() async {
try {
_interpreter = await Interpreter.fromAsset('model.tflite');
print('Модель загружена');
} catch (e) {
print('Ошибка загрузки модели: $e');
}
}
Future pickImage() async {
final pickedFile = await picker.pickImage(source: ImageSource.gallery);
setState(() {
_image = File(pickedFile!.path);
});
if (_image != null) {
classifyImage(_image!);
}
}
Future classifyImage(File image) async {
// Преобразуем изображение в тензор
final img.Image imageInput = img.decodeImage(image.readAsBytesSync())!;
var inputImage = img.copyResize(imageInput, width: 28, height: 28);
var input = imageToByteListFloat32(inputImage, 28, 127.5, 127.5);
// Подготовка выходного тензора
var output = List.filled(10, 0).reshape([1, 10]);
// Выполнение предсказания
_interpreter.run(input, output);
setState(() {
_result = 'Предсказание: ${output.toString()}';
});
}
Float32List imageToByteListFloat32(
img.Image image, int inputSize, double mean, double std) {
var convertedBytes = Float32List(1 * inputSize * inputSize * 3);
var buffer = Float32List.view(convertedBytes.buffer);
int pixelIndex = 0;
for (var i = 0; i < inputSize; i++) {
for (var j = 0; j < inputSize; j++) {
var pixel = image.getPixel(j, i);
buffer[pixelIndex++] = ((img.getRed(pixel) - mean) / std);
buffer[pixelIndex++] = ((img.getGreen(pixel) - mean) / std);
buffer[pixelIndex++] = ((img.getBlue(pixel) - mean) / std);
}
}
return convertedBytes;
}
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(title: Text('TFLite Classifier')),
body: Column(
children: [
_image == null ? Text('Выберите изображение') : Image.file(_image!),
ElevatedButton(
onPressed: pickImage,
child: Text('Загрузить изображение'),
),
Text(_result),
],
),
);
}
}
Оптимизация модели для мобильных устройств
Чтобы повысить производительность на мобильных устройствах, можно использовать такие подходы:
Квантизация модели. Она уменьшает размер модели и ускоряет работу за счет уменьшения точности числовых представлений.
Параллельное выполнение. Использование многоядерных процессоров для ускорения предсказаний.
Пример кода для квантизации модели:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
with open('quantized_model.tflite', 'wb') as f:
f.write(tflite_model)
Что в итоге
В итоге мы получаем приложение с уже работающими моделями ML. Самый долгий этап связан с обучением моделей, всё остальное — вопрос техники. Думаю, примеры выше помогут провести интеграцию быстро.
Если у вас остались вопросы — задавайте в комментариях, я отвечу. А вообще подписывайтесь на канал нашего коллеги Саши Ворожищева — он много пишет про Flutter и про мобильную разработку в целом.