[Перевод] Высокоэффективная генерация изображений на KerasCV с помощью Stable Diffusion

544297fa0f9ed573e5d2ad1c9f36b863.png

Сегодня покажем, как генерировать новые изображения по текстовому описанию при помощи KerasCV, stability.ai и Stable Diffusion. Материал подготовлен к старту нашего флагманского курса по Data Science.

Stable Diffusion — это мощная модель генерации изображений по текстовым описаниям с открытым исходным кодом. Решений с открытым кодом для генерации изображений по описаниям немало, но KerasCV выделяется из них рядом преимуществ, в том числе компилляцией XLA (ускоренной линейной алгебры) и поддержкой «смешанной точности». Вместе они позволяют достичь очень высокой скорости генерации.

Сегодня мы разберём реализацию Stable Diffusion от KerasCV, покажем, как использовать эти мощные средства повышения производительности, и изучим преимущества, которые они дают.

Сначала установим пакеты зависимости и разберёмся с модулями:

!pip install --upgrade keras-cv
import time
import keras_cv
from tensorflow import keras
import matplotlib.pyplot as plt


Введение

В отличие от других обучающих материалов, здесь вначале объясняется, как происходит внедрение. В случае генерации изображений по текстовым описаниям это проще показать, чем рассказать.

Посмотрим, насколько сильна keras_cv.models.StableDiffusion().

Сначала построим модель:

model = keras_cv.models.StableDiffusion(img_width=512, img_height=512)

Затем создадим текстовое описание. Например, «фото астронавта верхом на коне» (оригинальный запрос — «photograph of an astronaut riding a horse»):

images = model.text_to_image("photograph of an astronaut riding a horse", batch_size=3)


def plot_images(images):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.axis("off")


plot_images(images)
25/25 [==============================] - 19s 317ms/step

5511cb3ed9a41966baabd30bda565e38.png

Просто невероятно!

Но это ещё не все возможности модели. Попробуем ввести запрос посложнее. Например, «симпатичная волшебная летающая собака, фэнтези-арт», «золотого цвета, высокое качество, высокая детализация, изящная форма, чёткая фокусировка», «концептуальный дизайн, концепция персонажа, цифровая живопись, тайна, приключение»:

images = model.text_to_image(
    "cute magical flying dog, fantasy art, "
    "golden color, high quality, highly detailed, elegant, sharp focus, "
    "concept art, character concepts, digital painting, mystery, adventure",
    batch_size=3,
)
plot_images(images)
25/25 [==============================] - 8s 316ms/step

6f04b237bc2d8bb68e9ece3667772d4b.png

Возможностям просто нет предела (как минимум они раскрывают всё скрытое многообразие Stable Diffusion).


Стоп, как это вообще работает?

Что бы вы сейчас себе ни вообразили, ничего волшебного в Stable Diffusion нет. Это что-то вроде «латентной диффузионной модели». Давайте докопаемся до смысла этого термина. Возможно, вам знаком принцип сверхвысокого разрешения (super-resolution): можно обучить модель глубокого обучения удалению шума на исходном изображении. Тем самым мы превратим изображение в версию с высоким разрешением. Модель глубокого обучения не может по волшебству «вытащить» утраченную информацию из шума на фото с низким разрешением. Вместо этого она дорисовывает полученные при обучении данные так, чтобы создать иллюзию наиболее вероятных деталей. Подробнее о сверхвысоком разрешении — в следующих материалах Keras.io:

utt99qe4quhe13zcrc2vpmqecva.png

И, если вы выжали из этой идеи всё, то спросите себя:, а что, если запустить модель на «чистом шуме»? Тогда ей придётся «удалить шум у шума» и создать полностью новое изображение. Повторим это много раз, и маленький фрагмент шума будет превращаться в удивительно чёткое искусственное изображение с высоким разрешением.

Ключевая мысль латентной диффузии изложена в High-Resolution Image Synthesis with Latent Diffusion Models в 2020 году. Чтобы лучше понять принципы диффузии, можно посмотреть обучающий материал на Keras.io Неявные модели диффузии для удаления шума (Denoising Diffusion Implicit Models).

2c67e69399087ea9720df8d86f4eee5f.gif

Для перехода от латентной диффузии к системе создания изображения по текстовому описанию нужно добавить всего одно ключевое свойство: управление генерируемыми изображениями через ключевые слова текстового описания. В этом поможет «стабилизация» (conditioning) — классический метод, который состоит в привязке к фрагменту шума вектора, представляющего собой кусочек текста, а затем обучения модели на наборе данных пар изображений и описаний {image: caption}.

Это даёт начало архитектуре Stable Diffusion, которая состоит из трёх блоков:


  • кодер текста. Блок преобразует текстовое описание в латентный вектор (latent vector);
  • диффузионная модель. Она многократно удаляет шум с фрагмента изображения 64×64 в латентном состоянии;
  • декодер превращает фрагмент конечного изображения 64×64 в изображение с разрешением 512×512.

Сначала текстовое описание проектируется в предварительно изученное пространство собственных векторов, языковую модель с «замороженными весами». Затем такой собственный вектор связывается со случайно генерируемым фрагментом шума, который проходит итерации «удаления шума» в декодере. По умолчанию их 50. И чем больше, тем красивее и чётче конечное изображение. По умолчанию фрагмент проходит 50 итераций.

Наконец, скрытое изображение 64×64 проходит через декодер, что даёт корректную визуализацию высокого разрешения.

Архитектура Stable Diffusion

Архитектура Stable Diffusion

Это очень простая система: для реализации Keras достаточно четырёх файлов на 500 строк в общей сложности:

Однако после изучения миллиардов изображений и их текстовых описаний работа этой простой системы кажется волшебством. Фейнман о Вселенной сказал: «Она не сложна, её просто много!»


«Плюшки» KerasCV

Почему из нескольких общедоступных реализаций Stable Diffusion стоит использовать именно keras_cv.models.StableDiffusion?

Помимо простого API модель Stable Diffusion от KerasCV даёт важные преимущества:


  • реализацию в графовом режиме;
  • компиляцию XLA при помощи jit_compile=True;
  • поддержку вычислений со смешанной точностью.

Объединяя все эти преимущества, модель Stable Diffusion от KerasCV работает на порядки быстрее наивных реализаций. В этом разделе описываются активация всех этих функций и повышение эффективности при их использовании.

Ради интереса мы сравнили по времени реализации Stable Diffusion для «диффузоров» от HuggingFace и от KerasCV. В обоих случаях поставили задачу сгенерировать 3 изображения с 50 итерациями для каждого. Тестирование проводилось на Tesla T4.

Исходный код сравнительных тестов находится в открытом доступе на GitHub, его можно перезапускать на Colab и воспроизвести результаты. Вот результаты тестов на время генерации:

На Tesla T4 генерация ускоряется на 30%! Хотя улучшение результата на V100 не столь впечатляет, в целом мы ожидаем, что результаты таких тестов на всех графических процессорах NVIDIA всегда будут в пользу KerasCV.

Ради полноты картины мы привели время генерации при холодном и при прогретом старте. Время холодного старта включает единовременные затраты на создание и компиляцию модели, поэтому в производственной среде, где вы будете много раз использовать один и тот же экземпляр модели, оно ничтожно мало. Тем не менее приводим цифры для холодного старта:

Хотя результаты выполнения задач из этого руководства могут быть разными, у нас в тесте реализация Stable Diffusion на KerasCV оказалась гораздо быстрее, чем на PyTorch. В значительной степени это может быть связано с компиляцией XLA.


Эффективность улучшается при каждой оптимизации и может заметно меняться от одной конфигурации «железа» к другой.

Для начала давайте испытаем нашу неоптимизированную модель с текстовым описанием «Симпатичная выдра держит ракушки в радужном водовороте. Акварель»:

benchmark_result = []
start = time.time()
images = model.text_to_image(
    "A cute otter in a rainbow whirlpool holding shells, watercolor",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["Standard", end - start])
plot_images(images)

print(f"Standard model: {(end - start):.2f} seconds")
keras.backend.clear_session()  # Clear session to preserve memory.
25/25 [==============================] - 8s 316ms/step
Standard model: 8.17 seconds

3_qycibe1h8emp-ybpthhlhoqdw.png


Смешанная точность

«Смешанная точность» использует вычисления с точностью float16, храня при этом веса в формате float32. Благодаря этому операции float16 поддерживаются гораздо более быстрыми ядрами, чем аналоги для float32 на современных графических процессорах NVIDIA.

Использовать смешанную точность в Keras (в том числе для keras_cv.models.StableDiffusion) просто:

keras.mixed_precision.set_global_policy("mixed_float16")
INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA A100-SXM4-40GB, compute capability 8.0

Просто работает.

odel = keras_cv.models.StableDiffusion()

print("Compute dtype:", model.diffusion_model.compute_dtype)
print(
    "Variable dtype:",
    model.diffusion_model.variable_dtype,
)
Compute dtype: float16
Variable dtype: float32

Как видите, собранная выше модель использует расчёты со смешанной точностью. А мы используем скорость операции float16 при хранении переменных с точностью float32.

# Warm up model to run graph tracing before benchmarking.
model.text_to_image("warming up the model", batch_size=3)

start = time.time()
images = model.text_to_image(
    "a cute magical flying dog, fantasy art, "
    "golden color, high quality, highly detailed, elegant, sharp focus, "
    "concept art, character concepts, digital painting, mystery, adventure",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["Mixed Precision", end - start])
plot_images(images)

print(f"Mixed precision model: {(end - start):.2f} seconds")
keras.backend.clear_session()
25/25 [==============================] - 15s 226ms/step
25/25 [==============================] - 6s 226ms/step
Mixed precision model: 6.02 seconds

yoqdp-e27-p6z1bjyiwls8qojls.png


Компиляция XLA

TensorFlow включает встроенный компиллятор XLA XLA: Accelerated Linear Algebra — ускоренная линейная алгебра. keras_cv.models.StableDiffusion поддерживает работу аргумента jit_compile «из коробки». Значение True активирует компиляцию XLA для этого аргумента, что приводит к значительному ускорению.

Воспользуемся этим в примере с «креслом цвета авокадо»:

# Set back to the default for benchmarking purposes.
keras.mixed_precision.set_global_policy("float32")

model = keras_cv.models.StableDiffusion(jit_compile=True)
# Before we benchmark the model, we run inference once to make sure the TensorFlow
# graph has already been traced.
images = model.text_to_image("An avocado armchair", batch_size=3)
plot_images(images)
25/25 [==============================] - 36s 245ms/step


image

Оценим эффективность нашей модели XLA:

start = time.time()
images = model.text_to_image(
    "A cute otter in a rainbow whirlpool holding shells, watercolor",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["XLA", end - start])
plot_images(images)

print(f"With XLA: {(end - start):.2f} seconds")
keras.backend.clear_session()
25/25 [==============================] - 6s 245ms/step
With XLA: 6.27 seconds

dukfkvv4aradbf85jbei_izq-zq.png

На графическом процессоре A100 получилось примерно двукратное ускорение. Чудеса!


Соберём всё воедино

И как же собрать самый производительный в мире (по состоянию на сентябрь 2022 года) стабильный конвейер диффузионного вывода?

С помощью двух строк кода:

keras.mixed_precision.set_global_policy("mixed_float16")
model = keras_cv.models.StableDiffusion(jit_compile=True)

и текстового описания «Плюшевые мишки ведут исследования в области машинного обучения»:

# Let's make sure to warm up the model
images = model.text_to_image(
    "Teddy bears conducting machine learning research",
    batch_size=3,
)
plot_images(images)
25/25 [==============================] - 39s 157ms/step

4d6601216059384e42ab7fb2717d04a9.png

Насколько быстро это работает? Сейчас разберёмся! Пусть теперь у нас «Таинственный тёмный незнакомец посещает египетские пирамиды», «высокое качество, высокая детализация, изящная форма, чёткая фокусировка», «концептуальный дизайн, концепция персонажа, цифровая живопись»:

start = time.time()
images = model.text_to_image(
    "A mysterious dark stranger visits the great pyramids of egypt, "
    "high quality, highly detailed, elegant, sharp focus, "
    "concept art, character concepts, digital painting",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["XLA + Mixed Precision", end - start])
plot_images(images)

print(f"XLA + mixed precision: {(end - start):.2f} seconds")
25/25 [==============================] - 4s 158ms/step
XLA + mixed precision: 4.25 seconds

544297fa0f9ed573e5d2ad1c9f36b863.png

Оценим результаты:

print("{:<20} {:<20}".format("Model", "Runtime"))
for result in benchmark_result:
    name, runtime = result
    print("{:<20} {:<20}".format(name, runtime))
Model                 Runtime             
Standard              8.17177152633667    
Mixed Precision       6.022329568862915   
XLA                   6.265935659408569   
XLA + Mixed Precision 4.252242088317871   

Полностью оптимизированной модели хватило четырёх секунд, чтобы сгенерировать три новых изображения из текстового запроса на графическом процессоре A100.


Заключение

Благодаря XLA KerasCV позволяет создать Stable Diffusion нового поколения. А благодаря смешанной точности и XLA мы получаем самый быстрый конвейер Stable Diffusion на сентябрь 2022 года.

В конце руководств по keras.io мы обычно рекомендуем несколько тем для дальнейшего изучения. На этот раз мы лишь ограничимся одним призывом:

Прогоните свои описания через эту модель! Это просто бомба!

Если у вас графический процессор от NVIDIA GPU или же M1 MacBookPro, можно запустить генерацию на своей машине. (Отметим, что при старте на M1 MacBookPro активировать смешанную точность не нужно: эппловский Metal пока не очень хорошо поддерживает её).


Научим разрабатывать генеративные сети, работать с данными, чтобы вы прокачали карьеру или стали востребованным IT-специалистом:
Чтобы посмотреть все курсы, кликните по баннеру:

cqna880todtt287i6ffb12uzzwk.png

Краткий каталог курсов
Data Science и Machine Learning
Python, веб-разработка
Мобильная разработка
Java и C#
От основ — в глубину
А также

© Habrahabr.ru