Интеграция TFLite во Flutter: внедряем модели машинного обучения в мобильное приложение

Привет! Меня зовут Никита Грибков, я Flutter-разработчик в AGIMA. В этой статье расскажу про фреймворк TensorFlow Lite, который позволяет интегрировать в мобильное приложение модели машинного обучения. Это полезная штука, если нужно реализовать фичи, связанные с распознаванием речи или с классификацией изображений. Покажу, как обучать модели и как затем с ними работать.

0ef22decdcba2f32f313a625ef73cfae.png

Технология позволяет создавать персонализированные и интеллектуальные решения для пользователей, поэтому пользуется высоким спросом. Если наша цель — сделать приложение более удобным и инклюзивным, то, скорее всего, придется использовать ML.

Вот несколько примеров задач, для которых технология 100% подходит:

  • классификация изображений: чтобы приложение могло распознавать объекты на фотографиях или видео (например, Google Lens);

  • обработка естественного языка (NLP): в приложениях с голосовыми ассистентами или чат-ботами ML обрабатывает речь и тексты (например, Siri или Google Assistant);

  • персонализация: алгоритмы ML анализируют поведение пользователей и предлагают персонализированный контент или рекомендации;

  • распознавание голоса: используется в приложениях для конвертации речи в текст и команд.

Существует несколько способов, как интегрировать модели машинного обучения в приложение. Можно воспользоваться ML Kit от Firebase или библиотеки на Dart. Но лично я пробовал работать с TensorFlow Lite (TFLite). Этот фреймворк можно считать самым распространенным решением в данном случае.

Его главное (но не единственное) преимущество — что он может работать в офлайне, когда устройство не подключено к интернету. Также мне нравится, что TFLite оптимизирован для работы на устройствах с ограниченными ресурсами, это удобно. Разберем, как фреймворк работает.

Подготовка модели для использования с TFLite

Прежде чем интегрировать TFLite во Flutter-приложение, необходимо подготовить модель. Это предполагает её обучение в TensorFlow и конвертацию в формат .tflite.

Шаг 1. Создание и обучение модели в TensorFlow

Для работы с машинным обучением вы можете обучить модель с помощью TensorFlow. Вот простой пример создания и обучения модели на Python:

import tensorflow as tf
from tensorflow.keras import layers

# Создание простой модели для классификации изображений
model = tf.keras.Sequential([
    layers.Flatten(input_shape=(28, 28)),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

Компиляция модели:

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

Обучение модели на данных MNIST:

model.fit(train_images, train_labels, epochs=5)

Сохранение модели:

model.save("model.h5")

Сеть состоит из одного слоя для преобразования 28×28 пикселей в одномерный вектор, скрытого слоя с 128 нейронами и выходного слоя с 10 нейронами для 10 классов.

model = tf.keras.Sequential([
    layers.Flatten(input_shape=(28, 28)),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

Модель компилируется с использованием оптимизатора Adam и функции потерь Sparse Categorical Crossentropy.

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

Затем обучается на данных MNIST в течение 5 эпох и сохраняется в файл «model.h5».

model.fit(train_images, train_labels, epochs=5)

Шаг 2: Конвертация модели в формат TFLite

После обучения модели ее нужно преобразовать в формат .tflite с помощью TFLite-конвертера.

Пример кода для конвертации модели:

# Загрузка модели
model = tf.keras.models.load_model('model.h5')

# Конвертация модели в формат TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Сохранение модели в формате .tflite
with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

Теперь у вас есть модель в формате .tflite, которую можно интегрировать в приложение на Flutter.

Интеграция TFLite в Flutter-приложение

Для работы с TFLite в Flutter нужно использовать плагин tflite_flutter. Этот репозиторий — управляемый TensorFlow форк проекта — управляемый TensorFlow форк проекта [tflite_flutter_plugin]

Шаг 1. Установка необходимых зависимостей

Откройте файл pubspec.yaml вашего Flutterпроекта и добавьте зависимости:

dependencies:
  flutter:
    sdk: flutter
  tflite_flutter: ^0.11.0
  tflite_flutter_helper_plus: ^0.0.2

Шаг 2. Подготовка модели

Скопируйте файл вашей модели model.tflite в папку проекта assets. Затем в файле pubspec.yaml укажите путь к модели в разделе assets:

flutter:
  assets:
    - assets/model.tflite
    - assets/labels.txt # если у вас есть файл с метками

Шаг 3. Загрузка и использование модели в коде Flutter

Теперь создадим код для загрузки модели и выполнения предсказаний на ее основе на стороне Flutter.

Импорт пакетов:

import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tflite_flutter_helper_plus/tflite_flutter_helper_plus.dart';

Загрузка модели:

late Interpreter interpreter;
Future loadModel() async {
  try {
    // Загружаем модель из assets
    interpreter = await Interpreter.fromAsset('model.tflite');
    print('Модель загружена успешно');
  } catch (e) {
    print('Ошибка загрузки модели: $e');
  }
}

Этот код преобразует изображение в массив Float32List. Он берет каждый пиксель изображения, извлекает значения красного, зеленого и синего каналов, нормализует их с помощью заданных mean и std, а затем заполняет массив.

Float32List imageToByteListFloat32(
      img.Image image, int inputSize, double mean, double std) {
    var convertedBytes = Float32List(1 * inputSize * inputSize * 3);
    var buffer = Float32List.view(convertedBytes.buffer);
    int pixelIndex = 0;
    for (var i = 0; i < inputSize; i++) {
      for (var j = 0; j < inputSize; j++) {
        var pixel = image.getPixel(j, i);
        buffer[pixelIndex++] = ((img.getRed(pixel) - mean) / std);
        buffer[pixelIndex++] = ((img.getGreen(pixel) - mean) / std);
        buffer[pixelIndex++] = ((img.getBlue(pixel) - mean) / std);
      }
    }
    return convertedBytes;
  }

Выполнение предсказаний

Для выполнения предсказаний нужно преобразовать входные данные в подходящий формат, например, изображение в тензор (массив данных).

 Future classifyImage(File image) async {
    // Преобразуем изображение в тензор
    final img.Image imageInput = img.decodeImage(image.readAsBytesSync())!;
    var inputImage = img.copyResize(imageInput, width: 28, height: 28);
    var input = imageToByteListFloat32(inputImage, 28, 127.5, 127.5);

    // Подготовка выходного тензора
    var output = List.filled(10, 0).reshape([1, 10]);

    // Выполнение предсказания
    _interpreter.run(input, output);

    setState(() {
      _result = 'Предсказание: ${output.toString()}';
    });
  }

Flutter с использованием TFLite

import 'dart:typed_data';
import 'package:flutter/material.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:image_picker/image_picker.dart';
import 'dart:io';
import 'package:image/image.dart' as img;

class MyHomePage extends StatefulWidget {
  @override
  _MyHomePageState createState() => _MyHomePageState();
}

class _MyHomePageState extends State {
  late Interpreter _interpreter;
  File? _image;
  final picker = ImagePicker();
  String _result = 'Нет предсказаний';

  @override
  void initState() {
    super.initState();
    loadModel();
  }

  Future loadModel() async {
    try {
      _interpreter = await Interpreter.fromAsset('model.tflite');
      print('Модель загружена');
    } catch (e) {
      print('Ошибка загрузки модели: $e');
    }
  }

  Future pickImage() async {
    final pickedFile = await picker.pickImage(source: ImageSource.gallery);

    setState(() {
      _image = File(pickedFile!.path);
    });

    if (_image != null) {
      classifyImage(_image!);
    }
  }

  Future classifyImage(File image) async {
    // Преобразуем изображение в тензор
    final img.Image imageInput = img.decodeImage(image.readAsBytesSync())!;
    var inputImage = img.copyResize(imageInput, width: 28, height: 28);
    var input = imageToByteListFloat32(inputImage, 28, 127.5, 127.5);

    // Подготовка выходного тензора
    var output = List.filled(10, 0).reshape([1, 10]);

    // Выполнение предсказания
    _interpreter.run(input, output);

    setState(() {
      _result = 'Предсказание: ${output.toString()}';
    });
  }

  Float32List imageToByteListFloat32(
      img.Image image, int inputSize, double mean, double std) {
    var convertedBytes = Float32List(1 * inputSize * inputSize * 3);
    var buffer = Float32List.view(convertedBytes.buffer);
    int pixelIndex = 0;
    for (var i = 0; i < inputSize; i++) {
      for (var j = 0; j < inputSize; j++) {
        var pixel = image.getPixel(j, i);
        buffer[pixelIndex++] = ((img.getRed(pixel) - mean) / std);
        buffer[pixelIndex++] = ((img.getGreen(pixel) - mean) / std);
        buffer[pixelIndex++] = ((img.getBlue(pixel) - mean) / std);
      }
    }
    return convertedBytes;
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(title: Text('TFLite Classifier')),
      body: Column(
        children: [
          _image == null ? Text('Выберите изображение') : Image.file(_image!),
          ElevatedButton(
            onPressed: pickImage,
            child: Text('Загрузить изображение'),
          ),
          Text(_result),
        ],
      ),
    );
  }
}

Оптимизация модели для мобильных устройств

Чтобы повысить производительность на мобильных устройствах, можно использовать такие подходы:

  • Квантизация модели. Она уменьшает размер модели и ускоряет работу за счет уменьшения точности числовых представлений.

  • Параллельное выполнение. Использование многоядерных процессоров для ускорения предсказаний.

Пример кода для квантизации модели:

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

with open('quantized_model.tflite', 'wb') as f:
    f.write(tflite_model)

Что в итоге

В итоге мы получаем приложение с уже работающими моделями ML. Самый долгий этап связан с обучением моделей, всё остальное — вопрос техники. Думаю, примеры выше помогут провести интеграцию быстро.

Если у вас остались вопросы — задавайте в комментариях, я отвечу. А вообще подписывайтесь на канал нашего коллеги Саши Ворожищева — он много пишет про Flutter и про мобильную разработку в целом.

Что еще почитать

© Habrahabr.ru