Kornia — библиотека компьютерного зрения

Приветствую, Хабр! Меня зовут Дмитрий. На момент написания этой статьи я являюсь магистрантом второго курса в ИТМО. Компьютерным зрением начал заниматься четыре года назад. В этой статье я хочу рассказать вам о Kornia.

Kornia это open source библиотека для решения задач компьютерного зрения. Она использует PyTorch в качестве основного бэкенда и состоит из набора дифференцируемых процедур и модулей. Создатели библиотеки вдохновлялись OpenCV, и поэтому Kornia является его аналогом, но при этом в некоторых моментах превосходит. Основным преимуществом Kornia по сравнению с тем же OpenCV, scikit-image или с Albumentations является возможность обрабатывать изображения батчами, а не по одному изображению и возможность обрабатывать данные на GPU. Об этом даже упоминал один из создателей библиотеки Albumentations в своей статье, вот ссылка на неё. Также из любопытного, например, команда из Sber AI для ускорения своего пайплайна в задаче faceswap добавляла обработку кадров батчами и переписывала часть функций на Kornia, вот ссылка на их статью.

Помимо Kornia есть и другие инструменты для решения задач, связанных с компьютерным зрением:

Создатели Kornia в своей статье приводят таблицу сравнения нескольких библиотек для задач компьютерного зрения.

Таблица сравнения библиотек компьютерного зрения

Таблица сравнения библиотек компьютерного зрения

Как видим, по всем заявленным в таблице пунктам с Kornia может сравниться только Tensorflow.image.

Кроме того, создатели Kornia отдельно проводили замеры скорости аугментаций производимых с помощью разных библиотек.

Libraries

TorchVision

Albumentations

Kornia (GPU)

Batch Size

1

1

1

32

128

RandomPerspective

4.88±1.82

4.68±3.60

4.74±2.84

0.37±2.67

0.20±27.00

ColorJiggle

4.40±2.88

3.58±3.66

4.14±3.85

0.90±24.68

0.83±12.96

RandomAffine

3.12±5.80

2.43±7.11

3.01±7.80

0.30±4.39

0.18±6.30

RandomVerticalFlip

0.32±0.08

0.34±0.16

0.35±0.82

0.02±0.13

0.01±0.35

RandomHorizontalFlip

0.32±0.08

0.34±0.18

0.31±0.59

0.01±0.26

0.01±0.37

RandomRotate

1.82±4.70

1.59±4.33

1.58±4.44

0.25±2.09

0.17±5.69

RandomCrop

4.09±3.41

4.03±4.94

3.84±3.07

0.16±1.17

0.08±9.42

RandomErasing

2.31±1.47

1.89±1.08

2.32±3.31

0.44±2.82

0.57±9.74

RandomGrayscale

0.41±0.18

0.43±0.60

0.45±1.20

0.03±0.11

0.03±7.10

RandomResizedCrop

4.23±2.86

3.80±3.61

4.07±2.67

0.23±5.27

0.13±8.04

CenterCrop

2.93±1.29

2.81±1.38

2.88±2.34

0.13±2.20

0.07±9.41

Таблица показывает, что Kornia делает аугментации быстрее только если обрабатывать изображения батчами. Если делать это по одной картинке, то большой разницы по скорости нет.

Рассмотрим какие модули входят в состав Kornia.

Модули Kornia:

  • kornia.augmentation — модуль для аугментации данных.

  • kornia.color — модуль для преобразования цветового пространства (rgb to grayscale, rgb to bgr и т.д.).

  • kornia.contrib — экспериментальные модули. Есть модули для классификации изображений, сегментации, детекции. Например есть MobileViT, TinyViT.

  • kornia.enhance — инструменты для изменения яркости, контрастности, резкости и т.д. Есть выравнивание гистограммы, в том числе CLAHE.

  • kornia.feature — инструменты для детекции признаков, нахождения углов, линий и т.д. на изображениях.

  • kornia.filters — инструменты для размытия изображения и нахождения границ (Canny, Sobel и т.д.).

  • kornia.geometry — инструменты для геометрических преобразований: вращение, сдвиг, отражение и т.д.

  • kornia.sensors — экспериментальный модуль для работы с видео камерами.

  • kornia.io — модуль для загрузки и сохранения изображений. Для работы требует установки kornia_rs, который написан на Rust. Работает только на Linux.

  • kornia.image — инструменты модуля позволяют получать информацию об изображении: количество каналов, формат пикселей, позволяет клонировать изображение, загружать из файла, переводить в torch tensor из numpy array и обратно и т.д.

  • kornia.losses — модуль содержит имплементации различных функций ошибок.

  • kornia.metrics — имплементация различных метрик (accuracy, mean_iou и т.д.).

  • kornia.morphology — инструменты для морфологических операций:  сужение (erosion),  расширение (dilation),  открытие (opening),  закрытие (closing) и т.д…

  • kornia.tracking модуль для отслеживания целевого объекта по локальным признакам в последовательности кадров.

  • kornia.testing — модуль содержит утилиты для тестирования (сейчас их там всего три).

  • kornia.utils — инструменты для отрисовки линий, прямоугольников, выпуклых многоугольников, сохранения и загрузки облаков точек и т.д.

  • kornia.x — экспериментальное API для обучения моделей в Kornia. Позволяет учить модели для классификации изображений, семантической сегментации и детекции объектов.

Kornia обладает очень обширным функционалом и практически по каждому её модулю можно писать отдельную и очень большую статью. Рассказать обо всём в рамках одной статьи не представляется возможным. Поэтому рассмотрим примеры лишь некоторых возможностей Kornia.

Установка

Для установки нужно выполнить :

pip install kornia

Либо, если будете устанавливать через репозиторий гитхаба:

pip install git+https://github.com/kornia/kornia

Если будете использовать kornia.io для загрузки/сохранения изображений, то нужно будет установить kornia_rs, который реализован на языке Rust. Но работает kornia_rs только на Linux.

pip install kornia_rs

Кроме того, если планируете использовать экспериментальный Training API, то нужно ещё установить kornia[x].

pip install kornia[x]

Некоторые базовые возможности Kornia

Рассмотрим сначала более простые возможности Kornia, а потом уже перейдём к более сложным примерам. Так выглядит исходное изображение к которому мы будем применять некоторые базовые преобразования.

Исходное изображение

Исходное изображение

kornia.color

Сначала попробуем поменять цветовую схему изображения. Сделаем его серым.

# Функция для открытия изображений
def load_torch_image(file_name):
    image = K.image_to_tensor(cv2.imread(file_name), False).float() / 255
    image = K.color.bgr_to_rgb(image)
    return image


# Функция для совмещения двух изображений в одно (по горизонтали).
def concat_images_h(image1, image2):
    image1 = transforms.ToPILImage()(torch.squeeze(image1))
    image2 = transforms.ToPILImage()(torch.squeeze(image2))
    concatenated_images = Image.new('RGB', (image1.width + image2.width, image1.height))
    concatenated_images.paste(image1, (0, 0))
    concatenated_images.paste(image2, (image1.width, 0))
    return concatenated_images


# Загружаем изображение.
image_1 = load_torch_image("image_sea.jpg")

# Переведем изображение в серый цвет.
image_gray = K.color.rgb_to_grayscale(image_1)

# Объединим исходное изображение и его серый вариант в одну картинку.
concatenated_images_contrast = concat_images_h(image_1, image_gray)

После выполнения кода получаем следующее.

Применим к исходному изображению фильтр Sepia.

image_sepia = K.color.Sepia(rescale=False)(image_1)

concatenated_image_sepia = concat_images_h(image_1, image_sepia)

Получаем такой результат.

47bdd0d05baed30a61a43a62f5602581.jpg

kornia.enhance

Теперь попробуем изменить контраст.

# Изменяем контраст.
image_contrast = K.enhance.adjust_contrast(image_1, 1.5)

concatenated_images_contrast = concat_images_h(image_1, image_contrast)

После выполнения кода получается следующее изображение.

7408719c7c809beed3848e6cfeea3767.jpg

Применим адаптивное выравнивание гистограммы с помощью фильтра CLAHE.

# Применяем фильтр CLAHE с настройками по умолчанию к изображению.
image_clahe = K.enhance.equalize_clahe(image_1)

concatenated_images_clahe = concat_images_h(image_1, image_clahe)

save_tensor_as_image('concatenated_images_clahe.jpg', concatenated_images_clahe)

Получаем следующий результат.

3c8504248e3e8355c26642271b2303d1.jpg

kornia.filters

Теперь найдем на изображении границы с помощью алгоритма Canny.

# Применяем алгоритм Canny к изображению.
image_canny = K.filters.canny(image_1)[0]

concatenated_images_canny = concat_images_h(image_1, image_canny)

Получаем следующее.

c4ca68ebe1b3dec44ed31cd24ef56478.jpg

А теперь попробуем размытие Гаусса.

# Фильтр Гаусса.
gaussian_blur = K.filters.GaussianBlur2d((5, 5), (4.5, 4.5))

# Применяем фильтр Гаусса к изображению.
image_gaussian_blur = gaussian_blur(image_1)

concatenated_images_gaussian_blur = concat_images_h(image_1, image_gaussian_blur)

Получаем размытое изображение.

f2ed84450ecef5051072139687043f9c.jpg

kornia.losses

Отдельно хочу заострить внимание на том, что в Kornia есть имплементация пары десятков функций ошибок. Например, для задачи семантической сегментации предложено шесть лосс-функций. Я их перечислю.

Конечно можно написать функцию ошибки самому, но если есть готовая реализация и она подходит, зачем изобретать велосипед?

kornia.utils

Kornia позволяет рисовать на изображениях линии и прямоугольники. Нарисуем прямоугольник.

image_with_rectangle = K.utils.draw_rectangle(image_1, torch.tensor([[[10, 10, 400, 400]]]),
                                              color=torch.tensor([1,0, 255]), fill=None)

14d04eea6162ce3dddf659fef62161fd.jpg

А теперь нарисуем линию.

image_with_line = K.utils.draw_line(torch.squeeze(image_1),
                                    torch.tensor([10, 10]),
                                    torch.tensor([500, 250]),
                                    torch.tensor([1,0,255]))

3cdfa3aaa4b4e97872af2f411994e4e2.jpg

kornia.morphology

Теперь применим к изображению морфологические операции: градиент, открытие, закрытие. Есть две базовых морфологических операции: эрозия и дилатация. Градиент, открытие и закрытие являются их линейными комбинациями.

kernel = torch.tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]])

# Операция градиента это разность результата эрозии и дилатации.
graded_image = K.morphology.gradient(image_1, kernel)

72aed7e1d9c14918452a16d8a49d5981.jpg

# Операция открытия это применение операции эрозии, а после неё операции дилатации.
opened_image = K.morphology.opening(image_1, kernel)

0235d14cbcdbfa7600d9bd31d555ae3d.jpg

# Операция закрытия это применение операции дилатации, а после неё операции эрозии.
closed_image = K.morphology.closing(image_1, kernel)

2607edc6d0eb77befe73d852f8e62848.jpg

kornia.augmentation

Применим к исходному изображению некоторые аугментации доступные в Kornia. В качестве примера возьмём ColorJitter, RandomAffine, RandomPerspective.

from kornia.augmentation import AugmentationSequential

# Сделаем набор аугментаций и применим их последовательно. 
aug_list = AugmentationSequential(
    K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
    K.augmentation.RandomAffine(360, [0.1, 0.1], [0.7, 1.2], [30.0, 50.0], p=1.0),
    K.augmentation.RandomPerspective(0.5, p=1.0),
    data_keys=["input"], same_on_batch=False)

augmented_image = aug_list(image_1)

После применения аугментаций получается такое изображение.

38ce02b0021a732e3c3434cc4bd46b5c.jpg

kornia.geometry

Рассмотрим пару примеров геометрических преобразований: зеркальное отражение и поворот.

# Сделаем зеркальное изображение.
fliped_image = K.geometry.transform.Vflip()(image_1)

# Объединим исходное и зеркальное изображение для наглядности.
concatenated_fliped_image = concat_images_h(image_1, fliped_image)

После выполнения кода получаем следующее.

12ba6e69bce8f5c248e15ccb40f5c8c8.jpg

# Повернём изображение на 45 градусов.
rotated_image = K.geometry.transform.Rotate(torch.tensor([45.]))(image_1)

# Объединим исходное и повёрнутое изображение для наглядности.
concatenated_rotated_image = concat_images_h(image_1, rotated_image)

После выполнения кода получаем вот это.

d7293a4af4349d39f2a5364b7aacd430.jpg

Более сложные примеры

Теперь рассмотрим более сложные примеры возможностей Kornia.

Face Detection

В Kornia встроен детектор лиц, использующий модель YuNet. Давайте посмотрим, как он работает. Допустим, у нас есть такое изображение.

Ищем здесь лицо

Ищем здесь лицо

Попробуем найти на нём лицо.

import kornia as K
import torch
import cv2
from kornia.contrib import FaceDetector, FaceDetectorResult, FaceKeypoint
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image


# Функция для отрисовки ключевых точек лица на фото.
def draw_keypoint(img: np.ndarray, det: FaceDetectorResult, kpt_type: FaceKeypoint) -> np.ndarray:
    
    kpt = det.get_keypoint(kpt_type).int().tolist()
    
    return cv2.circle(img, kpt, 2, (255, 0, 0), 2)


# Функция для нахождения лица на фото.
def detect(img_raw):

    if img_raw is not None:
        img = K.image_to_tensor(img_raw, keepdim=False)
        img = K.color.bgr_to_rgb(img.float())

        # Применим детектор лиц к изображению.
        face_detection = FaceDetector()
        with torch.no_grad():
            dets = face_detection(img)
        dets = [FaceDetectorResult(o) for o in dets[0]]

        img_vis = img_raw.copy()

        vis_threshold = 0.95

        # Оставляем только те рамки, уверенность модели в которых превышает 0.95.
        for b in dets:
            if b.score < vis_threshold:
                continue
            
            # Нарисуем bounding box по найденным координатам
            img_vis = cv2.rectangle(img_vis, b.top_left.int().tolist(), b.bottom_right.int().tolist(), (255, 0, 0), 4)
            
            # Нарисуем найденные на лице ключевые точки
            img_vis = draw_keypoint(img_vis, b, FaceKeypoint.EYE_LEFT)
            img_vis = draw_keypoint(img_vis, b, FaceKeypoint.EYE_RIGHT)
            img_vis = draw_keypoint(img_vis, b, FaceKeypoint.NOSE)
            img_vis = draw_keypoint(img_vis, b, FaceKeypoint.MOUTH_LEFT)
            img_vis = draw_keypoint(img_vis, b, FaceKeypoint.MOUTH_RIGHT)
        

        return img_vis

# Функция для вывода изображения.
def image_show(image):
    plt.imshow(image)
    plt.axis('off')
    plt.show()



image = np.asarray(Image.open("image1.png"))

img_vis = detect(image)

image_show(img_vis)

После выполнения кода получаем следующее.

Лицо найдено

Лицо найдено

Видим, что не только Терминатор смог что-то обнаружить на фото)).

Image matching

Теперь попробуем найти похожие точки на двух следующих изображениях с помощью метода LoFTR: Detector-Free Local Feature Matching with Transformers доступного в Kornia.

Первое изображение

Первое изображение

Второе изображение

Второе изображение

# Загрузим наши изображения
image_1 = transforms.PILToTensor()(Image.open("c1.jpg")).float().unsqueeze(dim=0)/255
image_2 = transforms.PILToTensor()(Image.open("c2.jpg")).float().unsqueeze(dim=0)/255


# Применим Detector-Free Local Feature Matching with Transformers
matcher = KF.LoFTR(pretrained="outdoor")

# LofTR не работает с цветными изображениями. Поэтому сделаем их серыми.
input_dict = {
    "image0": K.color.rgb_to_grayscale(image_1),  
    "image1": K.color.rgb_to_grayscale(image_2),
}

# Для тех объектов первого и второго изображения, которые
# модель посчитала похожими получим координаты.
with torch.inference_mode():
    correspondences = matcher(input_dict)


# Отфильтруем координаты так, чтобы остались только те точки,
# в схожести которых модель уверена более чем на 0.8
mask = correspondences['confidence'] > 0.8
indices = torch.nonzero(mask)
correspondences['confidence'] = correspondences['confidence'][indices]
correspondences['keypoints0'] = correspondences['keypoints0'][indices]
correspondences['keypoints1'] = correspondences['keypoints1'][indices]
correspondences['batch_indexes'] = correspondences['batch_indexes'][indices]


mkpts0 = correspondences["keypoints0"].cpu().numpy()
mkpts1 = correspondences["keypoints1"].cpu().numpy()

# Применим дополнительный фильтр findFundamentalMat() чтобы отсеять неподходящие точки.
# Лишние точки на результирующем изображении будут отмечены синим цветом.
Fm, inliers = cv2.findFundamentalMat(mkpts0, mkpts1,
                                     cv2.USAC_MAGSAC, 0.5, 0.999, 100000)
inliers = inliers > 0


# Теперь выведем оба изображения и точки, которые модель посчитала похожими,
# соединив их линиями.
draw_LAF_matches(
    KF.laf_from_center_scale_ori(
        torch.from_numpy(mkpts0).view(1, -1, 2),
        torch.ones(mkpts0.shape[0]).view(1, -1, 1, 1),
        torch.ones(mkpts0.shape[0]).view(1, -1, 1),
    ),
    KF.laf_from_center_scale_ori(
        torch.from_numpy(mkpts1).view(1, -1, 2),
        torch.ones(mkpts1.shape[0]).view(1, -1, 1, 1),
        torch.ones(mkpts1.shape[0]).view(1, -1, 1),
    ),
    torch.arange(mkpts0.shape[0]).view(-1, 1).repeat(1, 2),
    K.tensor_to_image(image_1),
    K.tensor_to_image(image_2),
    inliers,
    draw_dict={"inlier_color": (0.2, 1, 0.2), "tentative_color": None,
               "feature_color": (0.2, 0.5, 1), "vertical": False} )

После выполнения кода получается такое изображение.

2d3610a381100cf41db8f53f0d6215c3.png

Точки на изображениях, которые алгоритм посчитал схожими соединены линиями. Синие точки это то, что дополнительно фильтровал метод findFundamentalMat.

Image Stitching

Kornia также позволяет объединить несколько изображений имеющих пересекающиеся области чтобы получить одно общее изображение.

На сайте НАСА я нашёл фото сделанные марсоходом Perseverance около кратера Езеро (Jezero). Давайте потренируемся на них.

Окрестности Марса около кратера Езеро.

Окрестности Марса около кратера Езеро.

И снова окрестности Марса.

И снова окрестности Марса.

# Загрузим наши изображения
image_1 = transforms.PILToTensor()(Image.open("mars1.jpg")).float().unsqueeze(dim=0)/255
image_2 = transforms.PILToTensor()(Image.open("mars2.jpg")).float().unsqueeze(dim=0)/255

# Добавим оба изображения в список.
images = [image_1, image_2]

# Сшиватель изображений.
IS = ImageStitcher(KF.LoFTR(pretrained="outdoor"), estimator="ransac")

# Применим сшиватель изображений к списку наших фото.
with torch.no_grad():
    out = IS(*images)

# Выведем получившееся фото на экран.
plt.imshow(K.tensor_to_image(out))
plt.show()

.

Вот, что мы получим после выполнения кода.

Объединение изображений

Объединение изображений

Похоже, что каким-то подобным образом снимки, сделанные марсоходами, объединяют чтобы получить панорамные изображения Марса. Теперь мы умеем делать что-то наподобие).

Visual Prompting

Visual Prompting представляет собой использование так называемых визуальных подсказок — ключевых точек или ограничивающих рамок, которые вместе с искомым изображением передаются в предобученную модель, например в визуальный трансформер. И модель пытается по этим подсказкам произвести, например, сегментацию области, точки которой похожи на те ключевые точки, что мы передали модели в качестве примера. В качестве модели будем использовать предобученный визуальный трансформер ViT-H SAM. Дальше будет много кода, но большая его часть это вспомогательные функции для отрисовки ключевых точек и вывода изображения на экран.

import kornia as K
import torch
import cv2
import matplotlib.pyplot as plt
from kornia.contrib.visual_prompter import VisualPrompter
from kornia.geometry.keypoints import Keypoints
from kornia.geometry.boxes import Boxes
from kornia.contrib.models import SegmentationResults
from kornia.utils import tensor_to_image


# Определяем девайс: видеокарту или процессор.
if torch.cuda.is_available():
   print("Training on GPU")
   device = torch.device("cuda:0")
else:
     print("Training on CPU")
     device = torch.device("cpu")


# Функция для открытия изображения.
def load_torch_image(file_name):
    image = K.image_to_tensor(cv2.imread(file_name), False).float() / 255
    image = K.color.bgr_to_rgb(image)
    return image


# Функция для вывода изображения на экран.
def show_image(tensor):
    image = K.utils.tensor_to_image(tensor)
    plt.imshow(image)
    plt.axis('off')
    plt.show()


# Загрузим наши изображения
image_1 = load_torch_image("image_sea.jpg")

# Посмотрим на открытое изображение.
show_image(image_1)

Сначала загружаем VisualPrompter. По умолчанию это будет ViT-H SAM, который весит почти 2,5 гигабайта. Но можно загрузить ViT-L SAM он весит в два раза меньше, либо ViT-B SAM -он весит 348 мегабайт. Потом передаём в загруженную модель наше изображение.

prompter = VisualPrompter()
prompter.set_image(torch.squeeze(image_1))

Теперь создаем визуальные подсказки — ключевые точки по которым модель будет пытаться сегментировать области. Первые четыре ключевых точки это точки относящиеся только к воде, и на финальном изображении они будут отмечены зелёными крестиками. Остальные четыре точки относятся к небу, облакам, солнцу и на финальном изображении они будут отмечены красными крестиками. Дальше создаём метки 0 и 1. За единицу возьмём точки относящиеся к воде, остальные будут нулём.

keypoints = Keypoints(torch.tensor([[[100.0, 100.0], [225.0, 475.0], [485.0, 450.0], [547.0, 384.0],
                                     [10.0, 600.0], [100.0, 590.0], [200.0, 580.0], [485.0, 566.0]]],
                                    device=device, dtype=torch.float32))

labels = torch.zeros(keypoints.shape[:2], device=device, dtype=torch.float32)

labels_to_query = labels.clone()
labels_to_query[..., 4:] = 1

Теперь запускаем модель, передав ей ключевые точки и метки.

prediction = prompter.predict(
   keypoints=keypoints,
   keypoints_labels=labels_to_query,
   multimask_output=False, output_original_size=True)

Дальше идут вспомогательные функции для отображения маски, ключевых точек и вывода результата на экран.

def colorize_masks(binary_masks: torch.Tensor, merge: bool = True, alpha: None | float = None) -> list[torch.Tensor]:
    """Convert binary masks (B, C, H, W), boolean tensors, into masks with colors (B, (3, 4) , H, W) - RGB or RGBA. Where C refers to the number of masks.
    Args:
        binary_masks: a batched boolean tensor (B, C, H, W)
        merge: If true, will join the batch dimension into a unique mask.
        alpha: alpha channel value. If None, will generate RGB images

    Returns:
        A list of `C` colored masks.
    """
    B, C, H, W = binary_masks.shape
    OUT_C = 4 if alpha else 3

    output_masks = []

    for idx in range(C):
        _out = torch.zeros(B, OUT_C, H, W, device=binary_masks.device, dtype=torch.float32)
        for b in range(B):
            color = torch.rand(1, 3, 1, 1, device=binary_masks.device, dtype=torch.float32)
            if alpha:
                color = torch.cat([color, torch.tensor([[[[alpha]]]], device=binary_masks.device, dtype=torch.float32)], dim=1)

            to_colorize = binary_masks[b, idx, ...].view(1, 1, H, W).repeat(1, OUT_C, 1, 1)
            _out[b, ...] = torch.where(to_colorize, color, _out[b, ...])
        output_masks.append(_out)

    if merge:
        output_masks = [c.max(dim=0)[0] for c in output_masks]

    return output_masks


def show_binary_masks(binary_masks: torch.Tensor, axes) -> None:
    """plot binary masks, with shape (B, C, H, W), where C refers to the number of masks.

    will merge the `B` channel into a unique mask.
    Args:
        binary_masks: a batched boolean tensor (B, C, H, W)
        ax: a list of matplotlib axes with lenght of C
    """
    colored_masks = colorize_masks(binary_masks, True, 1.0)

    for ax, mask in zip(axes, colored_masks):
        ax.imshow(tensor_to_image(mask))


def show_boxes(boxes: Boxes, ax) -> None:
    boxes_tensor = boxes.to_tensor(mode="xywh").detach().cpu().numpy()
    for batched_boxes in boxes_tensor:
        for box in batched_boxes:
            x0, y0, w, h = box
            ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="orange", facecolor=(0, 0, 0, 0), lw=2))


def show_points(points: tuple[Keypoints, torch.Tensor], ax, marker_size=200):
    coords, labels = points
    pos_points = coords[labels == 1].to_tensor().detach().cpu().numpy()
    neg_points = coords[labels == 0].to_tensor().detach().cpu().numpy()

    ax.scatter(pos_points[:, 0], pos_points[:, 1], color="green", marker="+", s=marker_size, linewidth=2)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color="red", marker="x", s=marker_size, linewidth=2)



def show_predictions(
    image: torch.Tensor,
    predictions: SegmentationResults,
    points: tuple[Keypoints, torch.Tensor] | None = None,
    boxes: Boxes | None = None,
) -> None:
    n_masks = predictions.logits.shape[1]

    fig, axes = plt.subplots(1, n_masks, figsize=(21, 16))
    axes = [axes] if n_masks == 1 else axes

    for idx, ax in enumerate(axes):
        score = predictions.scores[:, idx, ...].mean()
        ax.imshow(tensor_to_image(image))
        ax.set_title(f"Mask {idx+1}, Score: {score:.3f}", fontsize=18)

        if points:
            show_points(points, ax)

        if boxes:
            show_boxes(boxes, ax)

        ax.axis("off")

    show_binary_masks(predictions.binary_masks, axes)
    plt.show()

Код вывода результата на экран.

show_predictions(image_1, prediction, points=(keypoints, labels_to_query))

Получаем изображение с найденной маской. По четырём ключевым точкам, относящимся к воде модель сегментировала всю воду на исходной фотографии. Дополнительно на фото отрисованы использованные ключевые точками (зелёные и красные крестики).

21474aa9499fe4665e8b0c2a09cf6e41.png

Посмотреть другие примеры использования визуальных подсказок можно по этой ссылке.

Заключение

В этой статье мы рассмотрели далеко не весь перечень возможностей Kornia. Их гораздо больше. Создатели Kornia обещали в будущем добавить алгоритмы для следующих областей: Super Resolution, Deep Edge detection, Stereo and Optical flow and camera calibration, Neural Rendering. Что же, ждём. А пока, скажем им спасибо за то, что уже есть.

Получить более подробную информацию о Kornia можно по следующим ссылкам.

Ссылки:

  1. Репозиторий Kornia

  2. Официальная документация по Kornia

  3. Туториал

Код и изображения из этой статьи доступны в ноутбуках по ссылке.

Большое спасибо, что прочитали! Если вы уже знакомы с Kornia, поделитесь пожалуйста своими впечатлениями или опытом использования в комментариях.

© Habrahabr.ru