Как работает DALL-E

ffqhnotdffvev2ewkwcekssyula.png


В Январе 2021 года Open AI показали DALL-E, генерирующую любые изображения по текстовому описанию, на хабре уже успели разобрать тему генерации изображений нейросетями, но мне захотелось разобраться в теме более детально и показать всё в коде.

Сейчас мы разберём, как работает Text2Image нейросеть DALL-E, посмотрим на хардкор математики, а также убедимся, что это сможет повторить каждый, написав реализацию DALL-E почти с нуля.

GPT


DALL-E состоит из двух нейросетей, одна из них — это GPT.

GPT пытается предсказывать последовательность токенов, на основе данной ей последовательности. Модель представляет собой архитектуру Transformers, состоящую только из Декодера. GPT учится фокусировать внимание на предыдущих словах, которые наиболее важны для предсказания следующего слова в предложении, используя механизм внимания.

▍ Attention


Механизм Attention является аналогом когнитивного внимания и позволяет находить связи между токенами, а также предсказывать продолжение этих токенов.

Для каждого токена создаётся набор векторов:

  1. Key — значимость токена в последовательности, если мы смотрим извне.
  2. Query — значимость токена в последовательности, если мы смотрим из этого токена.
  3. Value — репрезентация токена


Умножение вектора Query последнего токена на каждый вектор Key последовательности токенов с применением функции SoftMax даст коэффициенты значимости каждого токена.

Умножив Value каждого токена на его коэффициент, мы получим сумму Внимания всей последовательности.

Получившийся вектор мы умножаем на матрицу эмбендингов модели (всех возможных токенов словаря) и, таким образом, получаем коэффициент каждого токена словаря на соответствие входной последовательности токенов.

image

Она хорошо показала себя в генерации довольно натуральных текстов, которые порой нельзя отличить от написанных человеком.

image

VQ-GAN


image

VQ-GAN способна сжимать изображение в сетку векторов (токенов) и реконструировать обратно в изображение. Нейросеть состоит из энкодера, декодера и дискриминатора.

VQ-GAN является совмещением идей VQVAE и GAN.

▍ VQ-VAE (Vector Quantized Variational Autoencoder)


Идея VQ-VAE в том, чтобы научится эффективно сжимать изображение в более низкоразмерное скрытое пространство и разжимать в скрытое пространство изображения с наименьшими потерями.

Энкодер берёт изображение размером 512×512 и сжимает (чаще всего обычными свёрточными сетями) до 256×1, ищет ближайший вектор от получившегося в скрытом пространстве, а после декодер пытается такой маленький вектор разжать обратно в изображение 512×512.

▍ VQ-VAE + GAN = VQ-GAN


Энкодер всё так же принимает на вход изображение и кодирует его в сжатый вектор, после чего подменяет на ближайшего соседа в скрытном пространстве, а Декодер разжимает этот ближайший вектор в изображение.

Тут в игру входит частичка GAN, а именно дискриминатор. Дискриминатор в процессе обучения сравнивает, насколько исходное изображение похоже на реконструированное и возвращает градиент потерь реконструкции, это помогает в дальнейшем понять Энкодеру и Декодеру, как обмануть Дискриминатор и в разы повысить качество конечных изображений.

Как это работает внутри


Энкодер преобразует изображение X в сетку скрытых векторов

image

Каждый вектор в сетке подвергается квантованию с использованием функции q (ẑ)

image

Декодер реконструируют квантованную сетку

image

Для обучения VQ-GAN реконструкции картинки нам нужно как-то оценивать его работу и для этого у нас есть целых два Лосса (Функции потерь):

Эта функция потерь высчитывает, насколько хорошо модели удалось реконструировать оригинал.

image

Дискриминатор делает бррр…

Кстати, дискриминатор позволяет добиться лучшей реконструкции изображений, даже из очень сжатой сетки векторов.

image

Теперь с помощью этой функции мы можем обучить наш декодер:

image

GPT + VQ-GAN = DALL-E


«Мы умеем моделировать последовательность токенов текста с помощью GPT и умеем восстанавливать изображения из набора токенов автоэнкодера, почему бы не совместить?» — подумали разработчики из Open AI и сделали DALL-E.

Далее достаточно было обучить GPT в ответ на входной текст, генерировать токены подобные VQ-GAN соответствующего изображения. На удивление — это сработало, и она научилась «переводить» естественный язык в язык модели синтеза изображений.

Сначала мы собираем датасет в формате: Текст Ӏ Картинка.

Преобразуем картинки обучающего датасета, посредством энкодера VQGAN в матрицу токенов 32×32 и учим GPT сопоставлять 128 токенов текста с соответствующими 1024 токенами изображения.

После обучения, GPT принимая на вход только текст, сможет сгенерировать матрицу токенов, которую посредством декодера VQGAN мы сможем преобразовать в картинку, соответствующую входному тексту.

image

Что по коду?


И, конечно, после теоретической части хорошо бы перейти к практике.

Сейчас мы воспроизведём все эти математические фокусы в код. Это легче, чем вы думаете!

(Для воспроизведения кода нужно использовать Google Colab)

▍ Установим доступную реализацию GPT

!pip install datasets &>> install.log
!git clone https://github.com/karpathy/minGPT &>> install.log



!cd minGPT; git checkout 3ed14b2cec0dfdad3f4b2831f2b4a86d11aef150
!git clone https://github.com/karpathy/minGPT &>> install.log
import sys
sys.path.append('minGPT')
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from torch.utils.data import Dataset
from mingpt.model import GPT, GPTConfig
from mingpt.trainer import Trainer, TrainerConfig
from mingpt.utils import sample


import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)


▍ Установка VQ-GAN

# Это немного запутанно, но цель здесь
# - это просто для того, чтобы подготовить всё к предстоящей демонстрации.

# Загрузка предварительно обученной модели VQGAN
print('Downloading VQGAN model weights')
!curl -L 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1' > vqgan_im1024.ckpt
!curl -L 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' > vqgan_im1024.yaml
#Установите требования к VQGAN
print('Installing requirements')
!git clone https://github.com/CompVis/taming-transformers &> /dev/null
!pip install einops omegaconf pytorch_lightning &> /dev/null

# Настройка VQGAN
import sys, einops, torch
sys.path.append('./taming-transformers')
from omegaconf import OmegaConf
from taming.models import cond_transformer, vqgan
from PIL import Image
from matplotlib import pyplot as plt

def load_vqgan_model(config_path, checkpoint_path):
    config = OmegaConf.load(config_path)
    if config.model.target == 'taming.models.vqgan.VQModel':
        model = vqgan.VQModel(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
    elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
        parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
        parent_model.eval().requires_grad_(False)
        parent_model.init_from_ckpt(checkpoint_path)
        model = parent_model.first_stage_model
    else:
        raise ValueError(f'unknown model type: {config.model.target}')
    del model.loss
    return model

print('Loading VQGAN model')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
vqgan_model = load_vqgan_model('vqgan_im1024.yaml', 'vqgan_im1024.ckpt').to(device)  

print('Dataset Prep')
!pip install openimages
!wget https://storage.googleapis.com/localized-narratives/annotations/open_images_validation_captions.jsonl

import json

with open('/content/open_images_validation_captions.jsonl', 'r') as json_file:
    json_list = list(json_file)

for json_str in json_list[:5]:
    result = json.loads(json_str)

import urllib3
import boto3
import botocore
import concurrent.futures
import os
from tqdm.notebook import tqdm

def _download_single_image(arguments):
    if os.path.exists(arguments["dest_file_path"]):
        return

    try:
        with open(arguments["dest_file_path"], "wb") as dest_file:
            arguments["s3_client"].download_fileobj(
                "open-images-dataset",
                arguments["image_file_object_path"],
                dest_file,
            )

    except urllib3.exceptions.ProtocolError as error:
        _logger.warning(
            f"Unable to download image {arguments['image_file_object_path']} -- skipping",
            error,
        )


def download_images_by_id(image_ids,section, images_directory):

    # мы загрузим изображения из AWS S3, поэтому нам понадобится клиент boto S3
    s3_client = boto3.client(
        's3',
        config=botocore.config.Config(signature_version=botocore.UNSIGNED),
    )

   # создайте повторяющийся список аргументов функции
   # который мы сопоставим с функцией загрузки
    download_args_list = []
    for image_id in image_ids:
        image_file_name = image_id + ".jpg"
        download_args = {
            "s3_client": s3_client,
            "image_file_object_path": section + "/" + image_file_name,
            "dest_file_path": os.path.join(images_directory, image_file_name),
        }
        download_args_list.append(download_args)
        
    # используйте ThreadPoolExecutor для параллельной загрузки изображений
    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:

        # используйте исполнитель для сопоставления функции загрузки с итерируемым набором аргументов
        list(tqdm(executor.map(_download_single_image, download_args_list),
                  total=len(download_args_list), desc="Downloading images"))
!mkdir -p ims
def get_openimages(n_images=2000):
  print('Downloading images')
  download_images_by_id([json.loads(s)['image_id'] for s in json_list[:n_images]], 'validation', 'ims/')
  data = [json.loads(s) for s in json_list[:n_images]]
  return data
type(vqgan_model) # Наша модель vqgan загружена


▍ Настраиваем датасет

# @title изменённый набор данных теперь включает текстовые маркеры перед графическими
from torch.utils.data import Dataset
import numpy as np

max_text_length = 128
vocab = 'abcdefghijklmnopqrstuvwxzy '

def encode_char(c):
  if c in vocab:
    return vocab.index(c)
  return 50 # 'special character'

class PatchDataset(Dataset):

    def __init__(self, image_fns, labels, block_size=255, max_text_length=128):
        # Перебирайте изображения, получая токены VQGAN из версий размером 256 пикселей и встраиваний эмбендингов CLIP
        self.ims = [] # # Изображения, закодированные VQGAN
        self.conds = []
        
        for fn, caption in tqdm(zip(image_fns, labels)):
            self.ims.append(fn)
            self.conds.append(caption)

        # 1024 возможных кода VQGAN + наша кодировка текста
        chars = range(1024+53)
        data_size, vocab_size = len(image_fns), len(chars)
        print('data has %d characters, %d unique.' % (data_size, vocab_size))
        self.stoi = { ch:i for i,ch in enumerate(chars) }
        self.itos = { i:ch for i,ch in enumerate(chars) }
        self.block_size = block_size # + max_text_length << TODO?
        self.vocab_size = vocab_size
    
    def __len__(self):
        return len(self.ims)# was len(self.data) - self.block_size
        
    def __getitem__(self, idx):
        
        fn = self.ims[idx]
        caption = self.conds[idx]
        
        # Закодируйте изображение с помощью vegan
        pil_im = Image.open(fn).convert('RGB').resize((256, 256))
        im_tensor = torch.tensor(np.array(pil_im)).permute(2, 0, 1) / 255
        with torch.no_grad():
            z, a, b = vqgan_model.encode(im_tensor.to(device).unsqueeze(0) * 2 - 1)
        im_idxs = b[-1] # 16*16
        
        # Закодируйте текст:
        char_idxs = [encode_char(c) for c in caption.lower()[:max_text_length]]
        while len(char_idxs) < max_text_length:
            char_idxs += [51]
        # 52 - это конец текстовых токенов.
        char_idxs += [52]
        # На данный момент они будут конфликтовать с токенами от vegan, поэтому мы добавляем 1024
        char_idxs = [c+1024 for c in char_idxs]
        
        # Комбинируем
        dix = [self.stoi[int(s)] for s in char_idxs]
        dix += [self.stoi[int(s)] for s in im_idxs]
        
        # Разделить на x и y
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        return x, y
data = get_openimages(n_images=6000) # 2 тысячи на быструю демонстрацию, 20 тысяч на лучшую попытку обучения
# Наш новый набор данных
image_fns = ['ims/'+d['image_id']+'.jpg' for d in data]
labels = [d['caption'] for d in data]
dset = PatchDataset(image_fns, labels, max_text_length=max_text_length)
x, y = dset[0]
x.shape, y.shape, x[-3:], y[-3:] # Y - это x, смещённый на 1.


▍ Обучение модели делает брррр…


Можете попробовать разные параметры обучения, чем больше max_epochs и n_images (на этом датасете до 20000) — тем лучше, но дольше обучение.

Так-же советую в будущем попробовать использовать более большие датасеты для лучших результатов.

block_size=255+max_text_length+1 # Чтобы также установить кондиционирование
mconf = GPTConfig(dset.vocab_size, block_size,
                  n_layer=8, n_head=8, n_embd=512)
model = GPT(mconf)

# Обучение
tconf = TrainerConfig(max_epochs=10, batch_size=32, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(dset)*block_size,
                      num_workers=0) # num_workers=0, чтобы избежать некоторых ошибок многопроцессорной обработки
trainer = Trainer(model, dset, None, tconf)
trainer.train()


▍ Наконец генерируем картинки

prompt = 'Green cat' #@param
 
# Закодируйте промт так, как мы это делаем в наборе данных
char_idxs = [encode_char(c) for c in prompt.lower()[:max_text_length]]
while len(char_idxs) < max_text_length:
    char_idxs += [51]
char_idxs += [52]
char_idxs = [c+1024 for c in char_idxs]

# Брррр...
fig, axs = plt.subplots(3, 3, figsize=(12, 12))
for i in tqdm(range(9)):
    x = torch.tensor([dset.stoi[s] for s in char_idxs], dtype=torch.long)[None,...].to(device)
    y = sample(model, x, 256, temperature=1., sample=True, top_k=200)[0]
    completion = [dset.itos[int(i)] for i in y]
    ccc = completion[-256:]
    ccc = [min(c, 1023) for c in ccc]
    with torch.no_grad():
        zs = vqgan_model.quantize.get_codebook_entry(torch.tensor(ccc).to(device), (1, 16, 16, 256))
        axs[i%3, i//3].imshow(vqgan_model.decode(zs).add(1).div(2).cpu().squeeze().permute(1, 2, 0).clip(0, 1))


Выводы


Работа Open AI над первой версией DALL-E дала толчок к бурному развитию Text-to-image, и позволило нам уже сейчас увидеть реально эффективные модели с разными подходами, приспособленные к коммерческому использованию:

DALL-E 2, Midjourney, Imagen, Stable Diffusion и т.д.

Сегодня мы рассмотрели эту модель со всех сторон и получили чуть больше ответов на ваши вопросы, которые вы не могли найти.

Потыкать код и обучить модельку можно в этом колабе.

image
Результат из очень большой модели RuDALL-E от Сбера

0×0A-0×5B=?

sz7jpfj8i1pa6ocj-eia09dev4q.png

© Habrahabr.ru