Окрашивание изображений

Привет, Хабр. Сегодня мы будем раскрашивать.

image-loader.svg

8 лет назад НЛО выдало мне инвайт за статью про окрашивание изображений. Сегодня мы вернемся к этой теме и посмотрим на одну из свежих работ в этой области: Color2Embed: Fast Exemplar-Based Image Colorization using Color Embeddings. Здесь будет мой вольный пересказ с реализацией и комментариями, ну и много картинок.

TL; DR:

В статье я расскажу как переносить цвет с одной картинки на другую с помощью смеси из U-Net и StyleGAN v2. Если эти слова вам ни о чем не говорят — пролистайте вниз, там красивые картинки. Мой код здесь, код авторов тут, еще я сделал Google Colab тетрадку, можно попробовать свои картинки.

Инструкция по запуску Google Colab

Google Colab — Это такой способ от Гугла запускать Jupyter Notebook в браузере на их мощностях с доступом к GPU и TPU и делиться ими.

  1. Открыть ссылку;

  2. Нажать справа сверху Подключиться (Connect) если нужно, если нет там будет информация об ОЗУ и диске;

  3. Нажать в меню сверху Среда выполнения → выполнить все (Runtime → run all);

  4. Прокрутить вниз до куска с кодом:

    target_image_uploaded = files.upload()
    ref_image_uploaded = files.upload()
    target_image_path = list(target_image_uploaded.keys())[0]
    ref_image_path = list(ref_image_uploaded.keys())[0]
    predict_colors(target_image_path, ref_image_path, model)

    Подождать, пока выполнится до него (клонирование репозитория, загрузка весов и т.д., займет минуту-две) и появится кнопка:

    Кнопка выбора картинкиКнопка выбора картинки
  5. Выбираем сперва картинку (по одной за раз, там две кнопки на каждую картинку, вторая появится после выбора первой), которую хотим покрасить, потом картинку, с которой хотим взять цвета;

  6. Работает меньше секунды (~300 мс) после загрузки картинок. Пример, как это выглядит:

    image-loader.svg

Обучение, вид сверху

Процесс обучения, картинка из оригинальной статьи:

image-loader.svg

Для обучения за раз нужна только одна цветная картинка, GT(ground truth). Из нее делается два входа: с помощью аугментаций (об этом ниже) получаем измененную трехканальную RGB картинку, на схеме это R_{rgb}; вторую получаем с помощью преобразования GTв пространство Lab — берем только один канал L (lightness), на схеме это T_L. Каналы aи bтоже сохраняем, они пригодятся для функции потерь в дальнейшем, обозначаются ониGT_{ab}.

Выходом являются два предсказанных цветных канала для заданной T_L — каналы aи b, обозначаются они P_{ab}, или, когда соединяем с T_L, — P_{Lab}.

Подавать одну картинку на вход очень удобно, так мы можем брать базу без разметки, не нужны ни классы, ни какие-то другие дополнительные характеристики.

Система состоит из трех частей: Кодировщика цвета (Color encoder, E_c), кодировщика контента (Content encoder, E_s) и непереводимой PFFN (progressive feature formalisation network). Кодировщик цвета сжимает цветную картинку до вектора с информацией о цвете, Кодировщик контента сжимает одноканальную картинку до вектора «смысла», а PFFN на основе всего этого генерирует цветовые каналы.

Функций потерь две:

  1. \mathcal{L}_{recon}reconstruction loss, он же SmoothL1Loss, сравнивает GT_{ab}и P_{ab}.

  2. \mathcal{L}_{perc}perceptual loss, тот же L1Loss, но на выходах VGG сети. На вход получает исходную картинку в RGB GT и предсказанную картинку в RGB P_{rgb}, полученную из P_{Lab}.

Общий лосс является взвешенной суммой: \mathcal{L}_{total}=\mathcal{L}_{recon} + 0.1*\mathcal{L}_{perc}

Для себя во время реализации я нарисовал такую схему — может, кому-то она будет понятней:

image-loader.svg

Аугментации цветного изображения

Авторы выдвигают логичный тезис: если просто подавать цветную картинку как есть и пытаться ее же раскрасить, то у сетки появится соблазн переобучиться, так как появится однозначное соответствие пикселей. Поэтому они предлагают сделать ряд аугментаций, от простых (повороты и шум), до интересной: давайте применим к картинке thin plate spline (TSP), что позволяет достаточно забавно искажать её:

image-loader.svg

В моем блокноте можно посмотреть пример вызова из кода (секция TSP Example), сам код в tsp.py

Кодировщик цвета

Кодировщик цвета E_cявляется самой простой частью системы — это обычная сеть для классификации. Авторы изобрели что-то простое свое, я не увидел смысла в этом костылестроении и взял классический быстрый и маленький resnet18.

Выходом является вектор, содержащий информацию о цвете, обозначен Z. Это очень похоже на вектор стиля в StyleGAN. Размерность его авторы предлагают взять 512, тут я с ними спорить не стал.

Кодировщик контента+ PFFN

Кодировщик контента и PFFN вместе представляют собой похожую на U-Net сеть с одним большим отличием: в части «развертки», кроме информации с предыдущих этапов, нам добавляется информация о цвете с кодировщика цвета. Авторы называют это PFFN, идея взята из StyleGANv2, и сердцем ее является нечто под название Modulated Convolution. Modulated Convolution — это такая свертка, с весами которой мы немного поиграем.Остановимся на этом подробнее, схема из статьи:

image-loader.svgКод на PyTorch

class ModulatedConv2d(nn.Module):
    def __init__(
            self,
            in_channel,
            out_channel,
            kernel_size,
            style_dim
    ):
        # Part from https://github.com/rosinality/stylegan2-pytorch/blob/a2f38914bb5049894c37f2d7a9854bc130cf8a27/model.py
        super().__init__()

        self.eps = 1e-8
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel

        fan_in = in_channel * kernel_size ** 2
        self.scale = 1 / math.sqrt(fan_in)
        self.padding = kernel_size // 2

        self.weight = nn.Parameter(
            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
        )
        self.weight_linear = nn.Parameter(torch.randn(in_channel, style_dim))
        self.bias_linear = nn.Parameter(torch.zeros(in_channel).fill_(1))

    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size})"
        )

    def forward(self, input_, style):
        batch, in_channel, height, width = input_.shape

        # Linear
        style = F.linear(style, self.weight_linear * self.scale, bias=self.bias_linear)

        # Dot
        weight = self.scale * self.weight * style.view(batch, 1, in_channel, 1, 1) 

        # Norm
        Fnorm = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
        weight = weight * Fnorm.view(batch, self.out_channel, 1, 1, 1)

        # Convolve
        weight = weight.view(batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size)
        input_ = input_.view(1, batch * in_channel, height, width)
        out = F.conv2d(input_, weight, padding=self.padding,  groups=batch)

        _, _, height, width = out.shape
        out = out.view(batch, self.out_channel, height, width)
        return out

Входа у нас два: вектор цвета Zи выход предыдущего слоя-свертки d. Кроме того у нас есть набор весов: веса свертки и веса линейного слоя. Линейный слой нам нужен, чтобы изменить размерность вектора Z, так как количество каналов в свертке зависит от того, на каком мы этапе (чем ближе к выходу, тем меньше каналов). То есть, после него будет вектор цвета, но уже размерностью не 512, а C.

Второе наше действие ключевое: мы умножаем веса свертки wна полученный вектор цвета (dot product), получаем w'. Кроме того, после умножения мы нормируем полученные веса, получаем w''.

Нам осталось только одно действие — непосредственно свертка, полученные фичи gидут в следующий этап.

Подробный разбор этого приема описан в статье, там он был придуман для удаления артефактов при генерации изображений.

Modulated Convolution вставляется после каждого блока «развертки» U-Net, и только этим PFFN отличается от классического варианта.

Предсказание

На этапе обучения для получения цвета мы берем ту же картинку, что и хотим окрасить. Очевидно, на предсказании так сделать не получится. Но наши аугментации оказали нам услугу: если взять похожую цветную картинку, то все хорошо работает, при этом совсем не обязательно брать полутоновое изображение для окраски (хотя информация о цвете не будет никак учитываться).

Посмотрим, как воспроизвелись результаты из оригинальной статьи:

Хорошие результатыХорошие результаты

К чести авторов, моя реализация воспроизводит результаты из статьи практически один в один. Правда, есть ложка дёгтя: в статье приведены хорошие примеры, а если взять примеры из другой статьи, которую авторы улучшают и с которой сравниваются, то проблемы становятся видны:

Не очень хорошие результатыНе очень хорошие результаты

Также можно перекрасить одну картинку разными:

Когда я пробовал другие картинки, заметил, что найти удачные примеры сильно сложнее, чем неудачные. Многие картинки остаются бесцветными, или скатываются в сепию, например, как этот кот:

Бонус: картинка, которую я красил 8 лет назад все еще красится хорошо:

image-loader.svg

Напомню, что в colab можно попробовать свои картинки.

Критика

  1. Результаты, показанные в статье, можно воспроизвести — это очень здорово и не часто встречается в наше время.

  2. Использование той же самой картинки в обучении — это очень удобно, но, я думаю, это является причиной узкого диапазона работ: картинки должны быть очень похожи для хорошего результата. Предыдущая идея — брать картинку из того же класса — кажется более робастной.

  3. Скорость является киллер-фичей работы, она в разы выше, чем у конкурентов, составляющие части маленькие и быстрые.

  4. Мы используем только одну картинку, с которой будем брать цвет. Идея со статьи предшественника — использовать несколько картинок-источников цвета — кажется, может достаточно легко улучшить качество, технических трудностей для реализации немного.

Полезные ссылки:

  1. Сама статья: https://arxiv.org/abs/2106.08017

  2. Статья-ориентир: https://arxiv.org/abs/1807.06587

  3. StyleGANv2 (Modulated Convolutional): https://arxiv.org/abs/1912.04958

  4. Код авторов (PyTorch): https://github.com/zhaohengyuan1/Color2Embed

  5. Код мой (PyTorch): https://github.com/Kwentar/Color2Embed_pytorch

  6. Colab: https://colab.research.google.com/drive/1Xyq-kuTWzvoQH4r7d5C7YN7sVe19pUv0#scrollTo=6Nm25AlJzmyn

© Habrahabr.ru