[Перевод] Делаем печатные ссылки кликабельными с помощью TensorFlow 2 Object Detection API
TL; DR
В этой статье мы начнем решать проблему того, как сделать печатные ссылки в книгах или журналах кликабельными используя камеру смартфона.
С помощью TensorFlow 2 Object Detection API мы научим TensorFlow модель находить позиции и габариты строк https://
в изображениях (например в каждом кадре видео из камеры смартфона).
Текст каждой ссылки, расположенный по правую сторону от https://
, будет распознан с помощью библиотеки Tesseract. Работа с библиотекой Tesseract не является предметом этой статьи, но вы можете найти полный исходный код приложения в репозитории links-detector repository на GitHub.
Запустить Links Detector со смартфона, чтобы увидеть конечный результат.Открыть репозиторий links-detector на GitHub с полным исходным кодом приложения.
Вот так в итоге будет выглядеть процесс распознавания печатных ссылок:
На данный момент приложение находится в экспериментальной стадии и имеет множество недоработок и ограничений. Поэтому, до тех пор, пока вышеуказанные недоработки не будут ликвидированы, не ожидайте от приложения слишком многого. Также стоит отметить, что целью данной статьи является экспериментирование с TensorFlow 2 Object Detection API, а не создание production-ready приложения.В случае, если блоки с исходным кодом в этой статье будут отображаться без подсветки кода вы можете перейти на GitHub версию этой статьи
Проблема
Я работаю программистом, и в свободное от работы время учу Machine Learning в качестве хобби. Но проблема не в этом.
Я купил книгу по машинному обучению и, читая первые главы, столкнулся с множеством печатных ссылок на подобии https://tensorflow.org/
или https://some-url.com/which/may/be/even/longer?and_with_params=true
.
К сожалению, кликать по печатным ссылкам не представлялось возможным (спасибо, Кэп!). Чтобы открыть ссылки в браузере мне приходилось набирать их посимвольно в адресной строке, что было довольно медленно. К тому же опечатки никто не отменял.
Возможное решение
Я подумал, а что если, по аналогии с распознавателем QR кодов, мы «научим» смартфон (1) определять местоположение и (2) распознавать печатные гипер-ссылки и делать их кликабельными? В таком случае читатель делал бы всего один клик вместо посимвольного ввода с множеством нажатий на клавиши. Операционная сложность всей этой операции уменьшилась бы с O(N)
до O(1)
.
Вот так бы этот процесс выглядел:
Требования к решению
Как я уже упомянул выше, я не эксперт в машинном обучении. Для меня это больше как хобби. Поэтому и цель этой статьи заключается больше в экспериментировании и обучении работе с TensorFlow 2 Object Detection API, чем в попытке создания production-ready приложения.
С учетом вышесказанного, я упростил требования к финальному решению и свел их к следующим пунктам:
- Производительность процесса обнаружения и распознавания должна быть близка к реальному времени (например,
0.5-1
кадров в секунду на устройстве схожем по производительности с iPhone X). Это будет означать, что весь процесс обнаружения + распознавания должен происходить не более чем за2
секунды. - Должны поддерживаться только ссылки на английском языке.
- Должны поддерживаться только ссылки черного (темно-серого) цвета на белом (светло-сером) фоне.
- Должны поддерживаться только
https://
ссылки (допускается, чтоhttp://
,ftp://
,tcp://
и прочие ссылки не будут распознаны).
Находим решение
Общий подход
Вариант №1: Модель на стороне сервера
Алгоритм действий:
- Получаем видео-поток (кадр за кадром) на стороне клиента.
- Отправляем каждый кадр на сервер.
- Осуществляем обнаружение и распознавание ссылок на сервере и отправляем результат клиенту.
- Отображаем распознанные ссылки ни стороне клиента и делаем их кликабельными.
Преимущества:
- ✓ Скорость обнаружения и распознавания ссылок не ограничена производительностью клиентского устройства. При желании мы можем ускорить скорость обнаружения ссылок масштабируя наши сервера горизонтально (больше серверов) или вертикально (больше ядер и GPUs).
- ✓ Модель может иметь больший размер (и, возможно, большую точность), поскольку отсутствует необходимость ее загрузки на сторону клиента. Загрузить модель размером
~10Mb
на сторону клиента выглядит реалистичным, но все-же загрузить модель размером~100Mb
может быть довольно проблематичным с точки зрения пользовательского UX (user experience). - ✓ У нас появляется возможность контролировать доступ к модели. Поскольку модель «спрятана» за публичным API, мы можем контролировать каким клиентам она будет доступна.
Недостатки:
- ✗ Сложность системы растет. Вместо использования одного лишь
JavaScript
на стороне клиента нам необходимо будет так же создать, например,Python
инфраструктуру на стороне сервера. Нам так же будет необходимо позаботиться об автоматическом масштабировании сервиса. - ✗ Работа приложения в режиме оффлайн невозможна поскольку для работы приложения требуется доступ к интернету.
- ✗ Множество HTTP запросов к сервису со стороны клиента может стать слабым местом системы с точки зрения производительности. Предположим, мы хотим улучшить производительность обнаружения и распознавания ссылок с
1
до10+
кадров в секунду. В таком случае каждый клиент будет слать10+
запросов в секунду на сервер. Для10
клиентов, работающих одновременно, это уже будет означать100+
запросов в секунду. На помощь могут прийти двусторонний стримингHTTP/2
иgRPC
, но мы снова возвращаемся к первому пункту, связанному с растущей сложностью системы. - ✗ Стоимость системы растет. В основном это связано с оплатой за аренду серверов.
Вариант №2: Модель на стороне клиента
Алгоритм действий:
- Получаем видео-поток (кадр за кадром) на стороне клиента.
- Осуществляем обнаружение и распознавание ссылок на стороне клиента (без отправки на сервер).
- Отображаем распознанные ссылки ни стороне клиента и делаем их кликабельными.
Преимущества:
- ✓ Менее сложная система. Нет необходимости в разработке серверной части приложения и создания API.
- ✓ Приложение может работать в режиме оффлайн. Модель загружена на сторону клиента и нет необходимости в доступе к интернету (см. Progressive Web Application)
- ✓ Система «почти» автоматически масштабируема. Каждый новый клиент приложения «приходит» со своим процессором и видеокартой. Это конечно же неполноценное масштабирование (мы затронем причины ниже).
- ✓ Система гораздо дешевле. Нам необходимо заплатить только за сервер со статическими данными (
HTML
,JS
,CSS
, файлы модели и пр.). В случае с GitHub, такой сервер может быть предоставлен бесплатно. - ✓ Отсутствует (так же как и серверы) проблема большого количества HTTP запросов в секунду к серверам.
Недостатки:
- ✗ Возможно только горизонтальное масштабирование, когда каждый клиент автоматически имеет свои собственные процессоры и графическую карту. Вертикальное масштабирование невозможно поскольку мы не можем повлиять на производительность клиентского устройства. В результате мы не можем гарантировать быстрого обнаружения и распознавания ссылок для медленных устройств.
- ✗ Невозможно контролировать использование модели клиентами. Каждый может загрузить к себе модель и использовать ее где и как угодно.
- ✗ Скорость расхода батареи клиентского устройства может стать проблемой. Модель при работе потребляет вычислительные ресурсы. Пользователи приложения могут быть недовольны тем, что их iPhone становится все теплее и теплее во время работы.
Выбираем общий подход
Поскольку целю этой статьи и проекта в целом является обучение, а не создание приложения коммерческого уровня мы можем выбрать второй вариант и хранить модель на стороне клиента. Это сделает весь проект менее затратным и у нас будет возможность больше сфокусироваться на машинном обучении, а не на создании автоматически масштабируемой серверной инфраструктуры.
Углубляемся в детали
Итак, мы выбрали вариант приложения без серверной части. Предположим теперь, что у нас на входе есть изображение (кадр) из видео-потока камеры, который выглядит так:
Нам необходимо решить две подзадачи:
- Обнаружение ссылок (найти позицию и габариты ссылок на странице)
- Распознавание ссылок (распознать текст ссылок)
Вариант №1: Решение на основе библиотеки Tesseract
Первым и наиболее очевидным вариантом решением задачи оптического распознавания символов (OCR) может быть распознавания текста всего изображения с помощью, например, библиотеки Tesseract.js. Она принимает изображение на вход и выдает распознанные параграфы, текстовые строки, блоки текста и слова и вместе с габаритами и координатами.
Далее мы можем попытаться найти ссылки в распознанном тексте с помощью регулярного выражения похожего на это (пример на TypeScript):
const URL_REG_EXP = /https?:\/\/(www\.)?[-a-zA-Z0-9@:%._+~#=]{2,256}\.[a-z]{2,4}\b([-a-zA-Z0-9@:%_+.~#?&/=]*)/gi;
const extractLinkFromText = (text: string): string | null => {
const urls: string[] | null = text.match(URL_REG_EXP);
if (!urls || !urls.length) {
return null;
}
return urls[0];
};
✓ Похоже, что задача решена довольно прямолинейным и простым способом:
- Мы знаем габариты и координаты ссылок.
- Мы так же знаем текст ссылок и можем сделать их кликабельными.
✗ Проблема в том, что время обнаружения + распознавания может варьироваться от 2
до 20+
секунд в зависимости от размера изображения, его качества и «похожих на текст» объектов в изображении. В итоге будет очень сложно достичь той близкой к реальному времени производительности в 0.5-1
кадров в секунду.
✗ Также, если подумать, то мы просим библиотеку распознать весь текст на картинке, даже если в тексте совсем нет ссылок или если в тексте есть одна-две ссылки, которые составляют, пускай, ~10% от всего объема текста. Это звучит как неэффективная трата вычислительных ресурсов.
Вариант №2: Решение на основе библиотек Tesseract и TensorFlow (+1 модель)
Мы могли бы заставить Tesseract работать быстрее используя еще один дополнительный «алгоритм-советчик» перед тем, как приступить к распознаванию ссылок. Этот «алгоритм-советчик» должен обнаруживать (но не распознавать) начало ссылок (координаты самой левой границы ссылки) для каждой ссылки в изображении. Это позволит нам ускорить задачу распознавания текста ссылок, если мы будем следовать следующим правилам:
- Если изображение не содержит ни одной ссылки мы должны полностью избежать распознавания текста библиотекой Tesseract.
- Если изображение содержит ссылки, то мы должны «попросить» Tesseract распознать только те части изображения, которые содержат текст ссылок. Мы хотим тратить время на распознавание «полезного» для нашей задачи текста.
Этот «алгоритм-советчик», который будет срабатывать перед вызовом Tesseract должен выполняться каждый раз за одно и то же время, независимо от качества и содержимого изображения. Он также должен быть достаточно быстрым и должен определять наличие и позиции ссылок быстрее чем за 1
секунду (например, на iPhone X). В таком случае мы сможем попытаться заставить наше приложение работать в режиме близком к реальному времени (определения «близости» мы дали выше).
Итак, что если мы воспользуемся еще одним алгоритмом (еще одной моделью) обнаружения объектов, который поможет нам найти строкиhttps://
в изображении (каждая защищенная ссылка начинается сhttps://
, не так ли?). Тогда, зная расположение и габариты префиксовhttps://
в изображении, мы сможем отправить на распознавание текста с помощью библиотеки Tesseract только те части изображения, которые находятся по правую сторону от префиксовhttps://
и являются их продолжением.
Обратите внимание на изображение ниже:
На этом изображении можно заметить, что Tesseract будет выполнять гораздо меньше работы по распознаванию текста, если мы подскажем ему, где в тексте могут находиться ссылки (обратите внимание на количество голубых прямоугольников, чем не доказательство).
Итак, вопрос, на который нам необходимо ответить теперь, какую же модель обнаружения объектов нам выбрать и как «научить» ее находить на изображении префиксы https://
.
Наконец-то мы подобрались ближе к TensorFlow
Выбираем подходящую модель обнаружения объектов
Тренировка новой модели обнаружения объектов с нуля не является хорошим вариантом в нашем случае по следующим причинам:
- ✗ Тренировка может занять дни/недели и стоить много денег (за аренду тех-же серверов с GPU).
- ✗ У нас скорее всего не получится собрать набор данных, состоящий из сотен тысяч фотографий книг и журналов со ссылками. Тем-более, что нам нужны не только изображения, но еще и координаты префиксов
https://
для каждого из них. С другой стороны мы можем попытаться сгенерировать такой набор данных, но об этом ниже.
Итак, вместо создания новой модели обнаружения объектов, мы будем обучать уже существующую и натренированную модель обнаруживать новый для нее класс объектов (см. transfer learning). В нашем случае под «новым классом» объектов мы имеем в виду изображения префикса https://
. Такой подход имеет следующие преимущества:
- ✓ Набор данных может быть гораздо меньшим. Нет необходимости собирать сотни тысяч изображений с локализациями (координатами объектов в изображении). Вместо этого мы можем обойтись сотней изображений и сделать локализацию объектов вручную. Это возможно по той причине, что модель уже натренированна на общем наборе данных типа COCO и уже умеет извлекать основные характеристики изображения (научить «первокурсника» линейной алгебре, как правило, легче, чем «первоклассника»).
- ✓ Время тренировки так же будет гораздо меньшим (на GPU получим минуты/часы вместо дней/недель). Время сокращается за счет меньшего объема данных (меньших партий данных во время тренировки) и меньшего количества тренируемых параметров модели.
Мы можем выбрать существующую модель из «зоопарка» моделей TensorFlow 2, который представляет собой коллекцию моделей натренированных на наборе данных COCO 2017. На данный момент эта коллекция включает в себя ~40
разных вариаций моделей.
Для того, чтобы «научить» модель обнаруживать новые, ранее неизвестные ей объекты, мы можем воспользоваться TensorFlow 2 Object Detection API. TensorFlow Object Detection API — это фреймворк на основе TensorFlow, который позволяет конструировать и тренировать модели обнаружения объектов.
Если вы перейдете по ссылке на «зоопарк» моделей вы увидите, что для каждой модели там указана скорость и точность обнаружения объектов.
_Изображение взято с репозитория TensorFlow Model Zoo_
Конечно же, для того, чтобы выбрать подходящую модель, нам важно найти правильный баланс между скоростью и точностью обнаружения. Но что еще важнее в нашем случае, это размер модели, поскольку мы планируем загружать ее на сторону клиента.
Размер архива с моделью может варьироваться от ~20Mb
до ~1Gb
. Вот несколько примеров:
1386 (Mb)
centernet_hg104_1024x1024_kpts_coco17_tpu-32
330 (Mb)
centernet_resnet101_v1_fpn_512x512_coco17_tpu-8
195 (Mb)
centernet_resnet50_v1_fpn_512x512_coco17_tpu-8
198 (Mb)
centernet_resnet50_v1_fpn_512x512_kpts_coco17_tpu-8
227 (Mb)
centernet_resnet50_v2_512x512_coco17_tpu-8
230 (Mb)
centernet_resnet50_v2_512x512_kpts_coco17_tpu-8
29 (Mb)
efficientdet_d0_coco17_tpu-32
49 (Mb)
efficientdet_d1_coco17_tpu-32
60 (Mb)
efficientdet_d2_coco17_tpu-32
89 (Mb)
efficientdet_d3_coco17_tpu-32
151 (Mb)
efficientdet_d4_coco17_tpu-32
244 (Mb)
efficientdet_d5_coco17_tpu-32
376 (Mb)
efficientdet_d6_coco17_tpu-32
376 (Mb)
efficientdet_d7_coco17_tpu-32
665 (Mb)
extremenet
427 (Mb)
faster_rcnn_inception_resnet_v2_1024x1024_coco17_tpu-8
424 (Mb)
faster_rcnn_inception_resnet_v2_640x640_coco17_tpu-8
337 (Mb)
faster_rcnn_resnet101_v1_1024x1024_coco17_tpu-8
337 (Mb)
faster_rcnn_resnet101_v1_640x640_coco17_tpu-8
343 (Mb)
faster_rcnn_resnet101_v1_800x1333_coco17_gpu-8
449 (Mb)
faster_rcnn_resnet152_v1_1024x1024_coco17_tpu-8
449 (Mb)
faster_rcnn_resnet152_v1_640x640_coco17_tpu-8
454 (Mb)
faster_rcnn_resnet152_v1_800x1333_coco17_gpu-8
202 (Mb)
faster_rcnn_resnet50_v1_1024x1024_coco17_tpu-8
202 (Mb)
faster_rcnn_resnet50_v1_640x640_coco17_tpu-8
207 (Mb)
faster_rcnn_resnet50_v1_800x1333_coco17_gpu-8
462 (Mb)
mask_rcnn_inception_resnet_v2_1024x1024_coco17_gpu-8
86 (Mb)
ssd_mobilenet_v1_fpn_640x640_coco17_tpu-8
44 (Mb)
ssd_mobilenet_v2_320x320_coco17_tpu-8
20 (Mb)
ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8
20 (Mb)
ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8
369 (Mb)
ssd_resnet101_v1_fpn_1024x1024_coco17_tpu-8
369 (Mb)
ssd_resnet101_v1_fpn_640x640_coco17_tpu-8
481 (Mb)
ssd_resnet152_v1_fpn_1024x1024_coco17_tpu-8
480 (Mb)
ssd_resnet152_v1_fpn_640x640_coco17_tpu-8
233 (Mb)
ssd_resnet50_v1_fpn_1024x1024_coco17_tpu-8
233 (Mb)
ssd_resnet50_v1_fpn_640x640_coco17_tpu-8
Модель ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8
выглядит наиболее подходящей в нашем случае:
- ✓ Она относительно небольшая —
20Mb
в архиве. - ✓ Она достаточно быстрая —
39ms
на одно обнаружение. - ✓ Она использует сеть MobileNet v2 в качестве экстрактора свойств изображения (feature extractor), которая в свою очередь оптимизирована под работу на мобильных устройствах и обеспечивает меньший расход батареи.
- ✓ Она производит обнаружение всех известных ей объектов в изображении за один проход независимо от содержимого изображения (отсутствует шаг regions proposal, что делает работу сети быстрее).
- ✗ В то же время это не самая точная модель (все является компромиссом ️)
Название модели включает в себя ее несколько важных характеристик, с которыми вы при желании можете ознакомиться детальнее:
- Ожидаемый размер изображения на входе —
640x640px
. - Модель построена на основе Single Shot MultiBox Detector (SSD) и Feature Pyramid Network (FPN).
- Сверточная нейронная сеть (CNN) MobileNet v2 используется в качестве экстрактора свойств изображения (feature extractor).
- Модель была обучена на наборе данных COCO
Устанавливаем Object Detection API
В этой статье мы будем устанавливать Tensorflow 2 Object Detection API в виде пакета Python. Это достаточно удобно, в случае если вы экспериментируете в Google Colab (предпочтительно) или в Jupyter. В обоих случаях вы можете избежать локальной инсталляции пакетов и проводить эксперименты непосредственно в браузере.
Также есть возможность установки Object Detection API используя Docker, о котором вы можете прочитать в документации.
Если у вас возникнут трудности во время установки API или во время создания набора данных (следующие разделы), вы можете обратиться к статье TensorFlow 2 Object Detection API tutorial, в которой есть много полезных деталей и советов.
Для начала давайте клонируем репозиторий с API:
git clone --depth 1 https://github.com/tensorflow/models
output →
Cloning into 'models'...
remote: Enumerating objects: 2301, done.
remote: Counting objects: 100% (2301/2301), done.
remote: Compressing objects: 100% (2000/2000), done.
remote: Total 2301 (delta 561), reused 922 (delta 278), pack-reused 0
Receiving objects: 100% (2301/2301), 30.60 MiB | 13.90 MiB/s, done.
Resolving deltas: 100% (561/561), done.
Теперь можем скомпилировать файлы-прототипы API в Python формат, используя protoc:
cd ./models/research
protoc object_detection/protos/*.proto --python_out=.
Следующим шагом будет установка API для версии TensorFlow 2 используя pip
и файл setup.py`:
cp ./object_detection/packages/tf2/setup.py .
pip install . --quiet
Если на этом шаге вы обнаружите ошибки, связанные установкой зависимых пакетов, попробуйте запустить pip install . --quiet
во второй раз.
Проверить успешность установки вы можете запустив тест:
python object_detection/builders/model_builder_tf2_test.py
В итоге вы должны будете увидеть в консоли, что-то вроде этого:
[ OK ] ModelBuilderTF2Test.test_unknown_ssd_feature_extractor
----------------------------------------------------------------------
Ran 20 tests in 45.072s
OK (skipped=1)
TensorFlow Object Detection API установлена! Теперь мы можем использовать скрипты, предоставляемы этой API, для обнаружения объектов в изображениях, тренировки или доработки моделей.
Загружаем заранее обученную модель
Давайте загрузим ранее выбранную нами модель ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8
из коллекции моделей TensorFlow и посмотрим, как мы можем использовать ее для обнаружения общих объектов, таких как «кот», «собака», «машина» и пр. (объектов с классами, поддерживаемыми набором данных COCO).
Мы воспользуемся утилитой TensorFlow get_file () для загрузки архивированной модели по URL и для дальнейшей ее распаковки.
import tensorflow as tf
import pathlib
MODEL_NAME = 'ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8'
TF_MODELS_BASE_PATH = 'http://download.tensorflow.org/models/object_detection/tf2/20200711/'
CACHE_FOLDER = './cache'
def download_tf_model(model_name, cache_folder):
model_url = TF_MODELS_BASE_PATH + model_name + '.tar.gz'
model_dir = tf.keras.utils.get_file(
fname=model_name,
origin=model_url,
untar=True,
cache_dir=pathlib.Path(cache_folder).absolute()
)
return model_dir
# Start the model download.
model_dir = download_tf_model(MODEL_NAME, CACHE_FOLDER)
print(model_dir)
output →
/content/cache/datasets/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8
Вот как на данный момент выглядит структура папок:
Папка checkpoint
содержит «слепок» параметров обученной модели.
Файл pipeline.config
содержит настройки обнаружения. Мы еще вернемся к этому файлу ниже, когда будем обучать нашу модель.
Обнаружение объектов с помощью загруженной модели
На данный момент модель способна обнаруживать объекты классов, поддерживаемых набором данных COCO (их всего 90), таких, как car
, bird
, hot dog
и пр. Эти классы еще могут называть ярлыками (labels).
Источник изображения: сайт COCO
Попробуем, обнаружит ли модель объекты этих классов.
Загружаем ярлыки COCO
Object Detection API уже содержит файл с полным набор классов (ярлыков) COCO для нашего удобства.
import os
# Import Object Detection API helpers.
from object_detection.utils import label_map_util
# Loads the COCO labels data (class names and indices relations).
def load_coco_labels():
# Object Detection API already has a complete set of COCO classes defined for us.
label_map_path = os.path.join(
'models/research/object_detection/data',
'mscoco_complete_label_map.pbtxt'
)
label_map = label_map_util.load_labelmap(label_map_path)
# Class ID to Class Name mapping.
categories = label_map_util.convert_label_map_to_categories(
label_map,
max_num_classes=label_map_util.get_max_label_map_index(label_map),
use_display_name=True
)
category_index = label_map_util.create_category_index(categories)
# Class Name to Class ID mapping.
label_map_dict = label_map_util.get_label_map_dict(label_map, use_display_name=True)
return category_index, label_map_dict
# Load COCO labels.
coco_category_index, coco_label_map_dict = load_coco_labels()
print('coco_category_index:', coco_category_index)
print('coco_label_map_dict:', coco_label_map_dict)
output →
coco_category_index:
{
1: {'id': 1, 'name': 'person'},
2: {'id': 2, 'name': 'bicycle'},
...
90: {'id': 90, 'name': 'toothbrush'},
}
coco_label_map_dict:
{
'background': 0,
'person': 1,
'bicycle': 2,
'car': 3,
...
'toothbrush': 90,
}
Создаем функцию обнаружения
В этом разделе мы создадим так называемую функцию обнаружения, которая будет использовать загруженную нами ранее модель, собственно, для обнаружения объектов в изображении.
import tensorflow as tf
# Import Object Detection API helpers.
from object_detection.utils import config_util
from object_detection.builders import model_builder
# Generates the detection function for specific model and specific model's checkpoint
def detection_fn_from_checkpoint(config_path, checkpoint_path):
# Build the model.
pipeline_config = config_util.get_configs_from_pipeline_file(config_path)
model_config = pipeline_config['model']
model = model_builder.build(
model_config=model_config,
is_training=False,
)
# Restore checkpoints.
ckpt = tf.compat.v2.train.Checkpoint(model=model)
ckpt.restore(checkpoint_path).expect_partial()
# This is a function that will do the detection.
@tf.function
def detect_fn(image):
image, shapes = model.preprocess(image)
prediction_dict = model.predict(image, shapes)
detections = model.postprocess(prediction_dict, shapes)
return detections, prediction_dict, tf.reshape(shapes, [-1])
return detect_fn
inference_detect_fn = detection_fn_from_checkpoint(
config_path=os.path.join('cache', 'datasets', MODEL_NAME, 'pipeline.config'),
checkpoint_path=os.path.join('cache', 'datasets', MODEL_NAME, 'checkpoint', 'ckpt-0'),
)
Функция inference_detect_fn
принимает на входе изображение и возвращает информацию об обнаруженных в нем объектах.
Загружаем тестовые изображения
Давайте попробуем найти объекты на следующем изображении:
Для этого сохраним это изображение в папку inference/test/
нашего проекта. Если вы используете Google Colab, вы можете создать эту папку и произвести загрузку файла вручную.
Вот как структура папок должна выглядеть на данный момент:
import matplotlib.pyplot as plt
%matplotlib inline
# Creating a TensorFlow dataset of just one image.
inference_ds = tf.keras.preprocessing.image_dataset_from_directory(
directory='inference',
image_size=(640, 640),
batch_size=1,
shuffle=False,
label_mode=None
)
# Numpy version of the dataset.
inference_ds_numpy = list(inference_ds.as_numpy_iterator())
# You may preview the images in dataset like this.
plt.figure(figsize=(14, 14))
for i, image in enumerate(inference_ds_numpy):
plt.subplot(2, 2, i + 1)
plt.imshow(image[0].astype("uint8"))
plt.axis("off")
plt.show()
Запускаем обнаружение для тестового изображения
На данном этапе мы готовы запустить обнаружение. Первый элемент массива inference_ds_numpy[0]
содержит наше первое тестовое изображение в формате массива Numpy
.
detections, predictions_dict, shapes = inference_detect_fn(
inference_ds_numpy[0]
)
Проверим размерность массивов, которые нам вернула функция:
boxes = detections['detection_boxes'].numpy()
scores = detections['detection_scores'].numpy()
classes = detections['detection_classes'].numpy()
num_detections = detections['num_detections'].numpy()[0]
print('boxes.shape: ', boxes.shape)
print('scores.shape: ', scores.shape)
print('classes.shape: ', classes.shape)
print('num_detections:', num_detections)
output →
boxes.shape: (1, 100, 4)
scores.shape: (1, 100)
classes.shape: (1, 100)
num_detections: 100.0
Модель вернула нам массив со 100
«обнаружениями». Это не означает, что модель нашла 100
объектов в изображении. Это скорее говорит нам, что модель имеет 100
ячеек и поддерживает обнаружение максимум 100
объектов одновременно в одном изображении. Каждое «обнаружение» имеет соответствующий рейтинг (вероятность, score), который говорит об уверенности модели в том, что обнаружен именно этот объект. Габариты каждого найденного объекта хранятся в массиве boxes
. Рейтинг каждого обнаружения хранится в массиве scores
. Массив classes
хранит ярлыки для каждого «обнаружения».
Давайте проверим первые 5 таких «обнаружений»:
print('First 5 boxes:')
print(boxes[0,:5])
print('First 5 scores:')
print(scores[0,:5])
print('First 5 classes:')
print(classes[0,:5])
class_names = [coco_category_index[idx + 1]['name'] for idx in classes[0]]
print('First 5 class names:')
print(class_names[:5])
output →
First 5 boxes:
[[0.17576033 0.84654826 0.25642633 0.88327974]
[0.5187813 0.12410264 0.6344235 0.34545377]
[0.5220358 0.5181462 0.6329132 0.7669856 ]
[0.50933677 0.7045719 0.5619138 0.7446198 ]
[0.44761637 0.51942706 0.61237675 0.75963426]]
First 5 scores:
[0.6950246 0.6343004 0.591157 0.5827219 0.5415643]
First 5 classes:
[9. 8. 8. 0. 8.]
First 5 class names:
['traffic light', 'boat', 'boat', 'person', 'boat']
Модель видит светофор (traffic light
), три лодки (boats
) и человека (person
). И мы можем подтвердить, что эти объекты действительно существуют в изображении.
В массиве scores
мы видим, что модель наиболее уверенна (с 70% вероятностью) в найденном объекте класса traffic light
.
Каждый элемент массива boxes
представляет собой координаты [y1, x1, y2, x2]
, где (x1, y1)
и (x2, y2)
соответственно координаты левого верхнего и правого нижнего углов габаритного прямоугольника.
Попробуем визуализировать габаритные прямоугольники:
# Importing Object Detection API helpers.
from object_detection.utils import visualization_utils
# Visualizes the bounding boxes on top of the image.
def visualize_detections(image_np, detections, category_index):
label_id_offset = 1
image_np_with_detections = image_np.copy()
visualization_utils.visualize_boxes_and_labels_on_image_array(
image_np_with_detections,
detections['detection_boxes'][0].numpy(),
(detections['detection_classes'][0].numpy() + label_id_offset).astype(int),
detections['detection_scores'][0].numpy(),
category_index,
use_normalized_coordinates=True,
max_boxes_to_draw=200,
min_score_thresh=.4,
agnostic_mode=False,
)
plt.figure(figsize=(12, 16))
plt.imshow(image_np_with_detections)
plt.show()
# Visualizing the detections.
visualize_detections(
image_np=tf.cast(inference_ds_numpy[0][0], dtype=tf.uint32).numpy(),
detections=detections,
category_index=coco_category_index,
)
В итоге мы увидим:
В то же время, если мы попробуем обнаружить объекты на текстовом изображении мы увидим следующее:
Модель не смогла найти ничего в этом изображении. Это как-раз то, что мы собираемся исправить и чему хотим научить нашу модель — видеть приставки https://
в текстовых изображениях.
Подготавливаем набор данных для тренировки
Для того, чтобы научить модель ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8
обнаруживать объекты, которые не были описаны в наборе данных COCO нам необходимо подготовить свой набор данных и «доучить» модель на нем.
Наборы данных для задачи обнаружения объектов состоят из двух компонентов:
- Собственно само изображение (например, изображение печатной странички книги или журнала)
- Габаритные прямоугольники, которые показывают где именно в изображении расположены объекты.
В примере выше координаты левого верхнего
и правого нижнего
углов имеют абсолютные значения (в пикселях). Также существуют альтернативные способы записи параметров таких габаритных прямоугольников. Например, мы можем описать прямоугольник с помощью его координат центра
, а так же ширины
и высоты
. Мы также можем использовать относительные значения координат (процент от ширины или высоты изображения). Но в целом, думаю идея понятна: модель должна знать где именно в изображении находится тот или иной объект.
Вопрос в том, где же нам взять такие данные для тренировки. У нас есть три варианта:
- Воспользоваться имеющимся набором данных.
- Сгенерировать новый искусственный набор данных.
- Создать набор данных вручную путем фотографирования или загрузки реальных изображений с текстом и
https://
ссылками и дальнейшей аннотацией (указанием позиций объектов) каждого изображения вручную.
Вариант №1: Использование существующих наборов данных
Есть множество общедоступных наборов данных. Мы можем воспользоваться следующими ресурсами для поиска подходящего набора:
✓ Если у вас получится найти подходящий набор данных с лицензией, позволяющей его использовать, то это, пожалуй, наиболее быстрый способ начать тренировку модели.
✗ Но проблема в том, что мне не удалось найти набор данных, содержащий изображения книг со ссылками и их координатами.
Этот вариант нам прийдется пропустить.
Вариант №2: Генерирование искусственного набора данных
Существуют библиотеки (например keras_ocr), которые могли бы нам помочь сгенерировать случайный текст, поместить в него ссылку и отрисовать текст на различных фонах и с различными искажениями.
✓ Преимущество данного подхода заключается в том, что он дает нам возможность сгенерировать экземпляры данных с разными шрифтами, лигатурами, цветами текста и фона. Это помогло бы нам избежать проблемы переученности модели. Модель могла-бы легко обобщать свои «знания» в случае с изображениями, которые она не видела ранее.
✓ Этот подход дает нам возможность сгенерировать разные типы ссылок, таких как: http://
, http://
, ftp://
, tcp://
и пр. Ведь найти множество реальных изображений с разными типами ссылок могло бы стать проблемой.
✓ Еще одним преимуществом этого подхода является то, что мы можем сгенерировать столько изображений сколько хотим. Мы не ограничены количеством страниц со ссылками в книге, которую нам удалось найти. Увеличение набора данных может в итоге улучшить точность модели.
✗ С другой стороны, существует возможность неправильного использования такого генератора, что в итоге может привести к набору
данных, который будет существенно отличаться от реальных изображений. Например, мы можем ошибочно применить неправдоподобные изгибы страниц (волна вместо дуги) или неправдоподобные фоны. Модель в таком может не обобщить свои «знания» на изображения из реального мира.
Этот подход мне кажется очень многообещающим. Он может помочь нам преодолеть множество недостатков модели (о них мы упомянем ниже в статье). Я пока еще не пробовал применить этот подход, но, возможно, это будет предметом отдельной статьи.
Вариант №3: Создание набора данных вручную
Наиболее прямолинейный способ — это взять книгу (или книги), сфотографировать странички, содержащие ссылки и обозначить локации префиксов https://
для каждой странички вручную.
Хорошая новость в том, что набор данных, который нам нужен, может быть достаточно небольшим (сотни изображений будет достаточно). Это обусловлено тем, что мы не собираемся тренировать модель с нуля. Вместо этого мы будем «доучивать» уже обученную модель (см. transfer learning и few-shot learning).
✓ В данном случае набор данных будет максимально приближен к реальному миру. Мы в буквальном смысле возьмем книгу, сфотографируем странички с реальными шрифтами, изгибами, тенями и цветами.
✗ С другой стороны, даже с учетом того, что нам нужны всего сотни страничек, работа по сбору таких страничек и их дальнейшей аннотации может занять достаточно много времени.
✗ Тяжело найти разные книги и журналы с разными шрифтами, типами ссылок, с разными фонами и лигатурами. В итоге набора данных будет достаточно узконаправленным (у пользователей должны будут быть книги со шрифтами и фонами похожими на ваши).
Поскольку целью этой статьи, как было упомянуто выше, не является создание модели, которая должна выиграть соревнование по обнаружению объектов, мы можем пойти по пути создан