Обучение и тестирование нейронных сетей на PyTorch с помощью Ignite

Привет, Хабр, в этой статье я расскажу про библиотеку ignite, с помощью которой можно легко обучать и тестировать нейронные сети, используя фреймворк PyTorch.

С помощью ignite можно писать циклы для обучения сети буквально в несколько строк, добавлять из коробки расчет стандартных метрик, сохранять модель и т.д. Ну, а для тех кто переехал с TF на PyTorch, можно сказать, что библиотека ignite — Keras для PyTorch.

В статье будет детально разобран пример обучения нейронной сети для задачи классификации, используя ignite

35arafc8y9aicrbpz5unazs-y-a.png


Добавим еще больше огня в PyTorch

Не буду тратить время, рассказывая о том, насколько крутой фреймворк PyTorch. Тот, кто им уже пользовался, понимает, о чём я пишу. Но, при всех его достоинствах, он все же является низкоуровневым в плане написания циклов для обучения, проверки, тестирования нейронных сетей.

Если мы посмотрим официальные примеры использования фреймворка PyTorch, то увидим в коде обучения сетки как минимум два цикла итераций по эпохам и по батчам обучающей выборки:

for epoch in range(1, epochs + 1):
    for batch_idx, (data, target) in enumerate(train_loader):
        # ...

Основная идея библиотеки ignite заключается в том, чтобы факторизовать эти циклы в единый класс, при этом позволив пользователю взаимодействовать с этими циклами с помощью обработчиков событий.

В итоге, в случае стандартных задач глубокого обучения мы можем неплохо сэкономить на количестве строк кода. Меньше строк — меньше ошибок!

К примеру, для сравнения, слева код для обучения и валидации модели, используя ignite, а справа — на чистом PyTorch:
image

Итак, ещё раз, чем же хорош ignite?


  • больше не нужно писать для каждой задачи циклы for epoch in range(n_epochs) и for batch in data_loader.
  • позволяет лучше факторизовать код
  • позволяет вычислять базовые метрики из коробки
  • предоставляет «плюшки» типа
    • сохранение последней и лучших моделей (также оптимизатора и learning rate scheduler) во время обучения,
    • ранняя остановка обучения
    • итд
  • легко интегрируется с инструментами визуализации: tensorboardX, visdom, …

В каком-то смысле, как уже было упомянуто, библиотеку ignite можно сравнить со всем известным Keras и его API для обучения и тестирования сетей. Также, библиотека ignite с первого взгляда очень похожа на библиотеку tnt, поскольку изначально обе библиотеки преследовали единые цели и имеют схожие идеи по их реализации.

Итак, зажигаем:

pip install pytorch-ignite

или

conda install ignite -c pytorch

Далее на конкретном примере мы ознакомимся с API библиотеки ignite.


Задача классификации с ignite

В этой части статьи рассмотрим школьный пример обучения нейронной сети для задачи классификации, используя библиотеку ignite.

Итак, возьмём простой датасет с картинками фруктов с kaggle. Задача заключается в том, чтобы каждой картинке с фруктом сопоставить соответствующий класс.

Прежде чем использовать ignite, давайте определим основные компоненты:

Поток данных (dataflow):


  • загрузчик батчей обучающей выборки, train_loader
  • загрузчик батчей проверочной выборки, val_loader

Модель:


  • возьмем маленькую сетку SqueezeNet из torchvision

Алгоритм оптимизации:


  • возьмем SGD

Функция потерь:


  • Cross-Entropy


Код
from pathlib import Path

import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import Subset
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, RandomResizedCrop, RandomVerticalFlip, RandomHorizontalFlip
from torchvision.transforms import ColorJitter, ToTensor, Normalize

FRUIT360_PATH = Path(".").resolve().parent / "input" / "fruits-360_dataset" / "fruits-360"

device = "cuda"

train_transform = Compose([
    RandomHorizontalFlip(),    
    RandomResizedCrop(size=32),
    ColorJitter(brightness=0.12),
    ToTensor(),
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

val_transform = Compose([
    RandomResizedCrop(size=32),
    ToTensor(),
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

batch_size = 128
num_workers = 8

train_dataset = ImageFolder((FRUIT360_PATH /"Training").as_posix(), transform=train_transform, target_transform=None)
val_dataset = ImageFolder((FRUIT360_PATH /"Test").as_posix(), transform=val_transform, target_transform=None)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                          num_workers=num_workers,
                          drop_last=True, pin_memory="cuda" in device)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True,
                        num_workers=num_workers,
                        drop_last=False, pin_memory="cuda" in device)

import torch.nn as nn
from torchvision.models.squeezenet import squeezenet1_1

model = squeezenet1_1(pretrained=False, num_classes=81)
model.classifier[-1] = nn.AdaptiveAvgPool2d(1)
model = model.to(device)
import torch.nn as nn
from torch.optim import SGD

optimizer = SGD(model.parameters(), lr=0.01, momentum=0.5)
criterion = nn.CrossEntropyLoss()

Итак, теперь пришло время запускать ignite:

from ignite.engine import Engine, _prepare_batch

def process_function(engine, batch):
    model.train()
    optimizer.zero_grad()
    x, y = _prepare_batch(batch, device=device)
    y_pred = model(x)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(process_function)

Давайте разберемся, что означает этот код.


Движок Engine

Класс ignite.engine.Engine — каркас библиотеки, а объект этого класса trainer:

trainer = Engine(process_function)

определен со входной функцией process_function для обработки одного батча и служит для реализации проходов по обучающей выборке. Внутри класса ignite.engine.Engine происходит следующее:

while epoch < max_epochs:
    # run once on data
    for batch in data:
        output = process_function(batch)

Вернемся к функции process_function:

def process_function(engine, batch):
    model.train()
    optimizer.zero_grad()
    x, y = _prepare_batch(batch, device=device)
    y_pred = model(x)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.item()

Мы видим, что внутри функции мы, как обычно в случае обучения модели, вычисляем предсказания y_pred, рассчитываем функцию потерь loss и градиенты. Последние позволяют обновить веса модели: optimizer.step().

В общем случае, нет никаких ограничений на код функции process_function. Отметим только, что она принимает на вход два аргумента: объект Engine (в нашем случае trainer) и батч от загрузчика данных. Поэтому, например, для тестирования нейронной сети мы можем определить другой объект класса ignite.engine.Engine, в котором входная функция просто вычисляет предсказания, и реализовать проход по проверочной выборке один единственный раз. Об этом читайте далее.

Итак, выше приведенный код лишь только определяет необходимые объекты без запуска обучения. В принципе, в минимальном примере, можно вызвать метод:

trainer.run(train_loader, max_epochs=10)

и данного кода достаточно, чтобы «тихо» (без какого-либо вывода промежуточных результатов) обучить модель.


Заметка

Отметим также, что для задач такого типа в библиотеке есть удобный метод создания объекта trainer:

from ignite.engine import create_supervised_trainer

trainer = create_supervised_trainer(model, optimizer, criterion, device)

Конечно, на практике вышеприведенный пример представляет мало интереса, поэтому давайте добавим следующие опции для «тренера»:


  • вывод на экран значения функции потерь через каждые 50 итераций
  • запуск расчета метрик на обучающей выборке при фиксированной модели
  • запуск расчета метрик на проверочной выборке после каждой эпохи
  • сохранение параметров модели после каждой эпохи
  • сохранение трёх лучших моделей
  • изменение скорости обучения в зависимости от эпохи (learning rate scheduling)
  • ранняя остановка обучения (early-stopping)


События и обработчики событий

Чтобы добавить вышеперечисленные опции для «тренера» в библиотеке ignite предусмотрена система событий и запуск пользовательских обработчиков событий. Таким образом, пользователь может управлять объектом класса Engine на каждом этапе:


  • движок запустился/завершил запуск
  • эпоха началась/завершилась
  • батч итерация началась/завершилась

и запускать свой код на каждом событии.


Вывод на экран значения функции потерь

Для этого нужно просто определить функцию, в которой будет происходит вывод на экран, и добавить ее к «тренеру»:

from ignite.engine import Events

log_interval = 50

@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
    iteration = (engine.state.iteration - 1) % len(train_loader) + 1
    if iteration % log_interval == 0:
        print("Epoch[{}] Iteration[{}/{}] Loss: {:.4f}"
              .format(engine.state.epoch, 
                         iteration, 
                         len(train_loader), 
                         engine.state.output))

На самом деле есть два способа добавить обработчик событий: через add_event_handler, либо через декоратор on. Тоже самое, что и выше, можно сделать так:

from ignite.engine import Events

log_interval = 50

def log_training_loss(engine):
    # ...

trainer.add_event_handler(Events.ITERATION_COMPLETED, log_training_loss)

Заметим, что в функцию обработки события можно передать любые аргументы. В общем случае, такая функция будет выглядеть так:

def custom_handler(engine, *args, **kwargs):
        pass

trainer.add_event_handler(Events.ITERATION_COMPLETED, custom_handler, *args, **kwargs)
# или
@trainer.on(Events.ITERATION_COMPLETED, *args, **kwargs)
def custom_handler(engine, *args, **kwargs):
    pass

Итак, давайте запустим обучение на одной эпохе и посмотрим, что будет:

output = trainer.run(train_loader, max_epochs=1)
Epoch[1] Iteration[50/322] Loss: 4.3459
Epoch[1] Iteration[100/322] Loss: 4.2801
Epoch[1] Iteration[150/322] Loss: 4.2294
Epoch[1] Iteration[200/322] Loss: 4.1467
Epoch[1] Iteration[250/322] Loss: 3.8607
Epoch[1] Iteration[300/322] Loss: 3.6688

Неплохо! Пойдем далее.


Запуск расчета метрик на обучающей и тестовой выборках

Давайте будем вычислять следующие метрики: средняя точность, средняя полнота после каждой эпохи на части обучающей и всей тестовой выборках. Заметим, что мы будем вычислять метрики на части обучающей выборки после каждой эпохи обучения, а не во время обучения. Таким образом замер эффективности будет более точным, поскольку модель не изменяется во время вычисления.

Итак, определим метрики:

from ignite.metrics import Loss, CategoricalAccuracy, Precision, Recall

metrics = {
    'avg_loss': Loss(criterion),
    'avg_accuracy': CategoricalAccuracy(),
    'avg_precision': Precision(average=True),
    'avg_recall': Recall(average=True)
}

Далее мы создадим два движка для оценки модели, используя ignite.engine.create_supervised_evaluator:

from ignite.engine import create_supervised_evaluator

# Напомним, что device = "cuda” был определен выше
train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
val_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

Мы создаем два движка для того, чтобы на один из них (val_evaluator) далее прицепить дополнительные обработчики событий для сохранения модели и ранней остановки обучения (обо всем этом далее).

Давайте также более детально рассмотрим, как определен движок для оценки модели, а именно, как определена входная функция process_function для обработки одного батча:

def create_supervised_evaluator(model, metrics={}, device=None):
    if device:
        model.to(device)

    def _inference(engine, batch):
        model.eval()
        with torch.no_grad():
            x, y = _prepare_batch(batch, device=device)
            y_pred = model(x)
            return y_pred, y

    engine = Engine(_inference)

    for name, metric in metrics.items():
        metric.attach(engine, name)

    return engine

Продолжаем далее. Выберем случайным образом часть обучающей выборки, на которой будем вычислять метрики:

import numpy as np
from torch.utils.data.dataset import Subset

indices = np.arange(len(train_dataset))
random_indices = np.random.permutation(indices)[:len(val_dataset)]
train_subset = Subset(train_dataset, indices=random_indices)

train_eval_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True,
                                num_workers=num_workers, 
                                drop_last=True, pin_memory="cuda" in device)

Далее, давайте определим в какой момент обучения мы будем запускать вычисление метрик и будем производить вывод на экран:

@trainer.on(Events.EPOCH_COMPLETED)
def compute_and_display_offline_train_metrics(engine):
    epoch = engine.state.epoch
    print("Compute train metrics...")
    metrics = train_evaluator.run(train_eval_loader).metrics
    print("Training Results - Epoch: {}  Average Loss: {:.4f} | Accuracy: {:.4f} | Precision: {:.4f} | Recall: {:.4f}"
          .format(engine.state.epoch, 
                      metrics['avg_loss'], 
                      metrics['avg_accuracy'], 
                      metrics['avg_precision'], 
                      metrics['avg_recall']))

@trainer.on(Events.EPOCH_COMPLETED)
def compute_and_display_val_metrics(engine):
    epoch = engine.state.epoch
    print("Compute validation metrics...")
    metrics = val_evaluator.run(val_loader).metrics
    print("Validation Results - Epoch: {}  Average Loss: {:.4f} | Accuracy: {:.4f} | Precision: {:.4f} | Recall: {:.4f}"
          .format(engine.state.epoch, 
                      metrics['avg_loss'], 
                      metrics['avg_accuracy'], 
                      metrics['avg_precision'], 
                      metrics['avg_recall']))

Можно запускать!

output = trainer.run(train_loader, max_epochs=1)

Получаем на экране

Epoch[1] Iteration[50/322] Loss: 3.5112
Epoch[1] Iteration[100/322] Loss: 2.9840
Epoch[1] Iteration[150/322] Loss: 2.8807
Epoch[1] Iteration[200/322] Loss: 2.9285
Epoch[1] Iteration[250/322] Loss: 2.5026
Epoch[1] Iteration[300/322] Loss: 2.1944
Compute train metrics...
Training Results - Epoch: 1  Average Loss: 2.1018 | Accuracy: 0.3699 | Precision: 0.3981 | Recall: 0.3686
Compute validation metrics...
Validation Results - Epoch: 1  Average Loss: 2.0519 | Accuracy: 0.3850 | Precision: 0.3578 | Recall: 0.3845

Уже лучше!

Немного деталей
Давайте немного разберемся в предыдущем коде. Читатель, возможно, обратил внимание на следующую строку кода:

metrics = train_evaluator.run(train_eval_loader).metrics

и, вероятно, был вопрос о типе объекта, полученного из функции train_evaluator.run(train_eval_loader), у которого есть атрибут metrics.

На самом деле, у класса Engine содержится структура под названием state (тип State) для того, чтобы была возможность передавать данные между обработчиками событий. Этот атрибут state содержит базовую информацию о текущей эпохе, итерации, о количестве эпох и т.д. Также его можно использовать для передачи любых пользовательских данных, в том числе и результатов расчета метрик.

state = train_evaluator.run(train_eval_loader)
metrics = state.metrics
# или просто
train_evaluator.run(train_eval_loader)
metrics = train_evaluator.state.metrics


Расчет метрик во время обучения

Если в задаче огромная обучающая выборка и расчет метрик после каждой эпохи обучения стоит дорого, а при этом все же хотелось бы видеть изменение некоторых метрик во время обучения, то можно использовать из коробки следующий обработчик событий RunningAverage. Например, мы хотим рассчитывать и выводить на экран точность классификатора:

acc_metric = RunningAverage(CategoryAccuracy(...), alpha=0.98)
acc_metric.attach(trainer, 'running_avg_accuracy')

@trainer.on(Events.ITERATION_COMPLETED)
def log_running_avg_metrics(engine):
    print("running avg accuracy:", engine.state.metrics['running_avg_accuracy'])

Чтобы использовать функционал RunningAverage, то нужно установить ignite из исходников:

pip install git+https://github.com/pytorch/ignite


Изменение скорости обучение (learning rate scheduling)

Есть несколько способов изменять скорость обучения с помощью ignite. Далее рассмотрим самый простой способ, вызывая функцию lr_scheduler.step() в начале каждой эпохи.

from torch.optim.lr_scheduler import ExponentialLR

lr_scheduler = ExponentialLR(optimizer, gamma=0.8)

@trainer.on(Events.EPOCH_STARTED)
def update_lr_scheduler(engine):
    lr_scheduler.step()
    # Вывод значений скорости обучения:
    if len(optimizer.param_groups) == 1:
        lr = float(optimizer.param_groups[0]['lr'])
        print("Learning rate: {}".format(lr))
    else:
        for i, param_group in enumerate(optimizer.param_groups):
            lr = float(param_group['lr'])
            print("Learning rate (group {}): {}".format(i, lr))    


Сохранение лучших моделей и других параметров во время обучения

Во время обучения было бы здорово записывать на диск веса лучшей модели, а также периодически сохранять веса модели, параметры оптимизатора и параметры изменения скорости обучения. Последнее может быть полезно для того, чтобы возобновить обучение из последнего сохраненного состояния.

В ignite для этого есть специальный класс ModelCheckpoint. Итак, давайте создадим обработчик событий ModelCheckpoint и будем сохранять лучшую модель по значению точности на проверочной выборке. В таком случае, определим score_function функцию, которая выдает значение точности в обработчик событий и он решает нужно ли сохранять модель или нет:

from ignite.handlers import ModelCheckpoint

def score_function(engine):
    val_avg_accuracy = engine.state.metrics['avg_accuracy']
    return val_avg_accuracy

best_model_saver = ModelCheckpoint("best_models",  
                                   filename_prefix="model",
                                   score_name="val_accuracy",  
                                   score_function=score_function,
                                   n_saved=3,
                                   save_as_state_dict=True,
                                   create_dir=True)
# "best_models" - Папка куда сохранять 1 или несколько лучших моделей
# Имя файла -> {filename_prefix}_{name}_{step_number}_{score_name}={abs(score_function_result)}.pth
# save_as_state_dict=True,  # Сохранять как `state_dict`

val_evaluator.add_event_handler(Events.COMPLETED, 
                                best_model_saver, 
                                {"best_model": model})

Теперь создадим еще один обработчик событий ModelCheckpoint для того, чтобы сохранять состояние обучения через каждые 1000 итераций:

training_saver = ModelCheckpoint("checkpoint",
                             filename_prefix="checkpoint",
                             save_interval=1000,
                             n_saved=1,
                             save_as_state_dict=True,
                             create_dir=True)

to_save = {"model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler} 
trainer.add_event_handler(Events.ITERATION_COMPLETED, training_saver, to_save)

Итак, уже почти все готово, добавим последний элемент:


Ранняя остановка обучения (early-stopping)

Давайте добавим еще один обработчик событий, который остановит обучение, если не будет происходить улучшение качества модели в течение 10 эпох. Качество модели будем снова оценивать с помощью фунцкии score_function.

from ignite.handlers import EarlyStopping

early_stopping = EarlyStopping(patience=10, 
                              score_function=score_function, 
                              trainer=trainer)

val_evaluator.add_event_handler(Events.EPOCH_COMPLETED, early_stopping)


Запуск обучения

Для того, чтобы запустить обучение нам достаточно вызвать метод run(). Будем обучать модель в течение 10 эпох:

max_epochs = 10
output = trainer.run(train_loader, max_epochs=max_epochs)


Вывод на экран
Learning rate: 0.01
Epoch[1] Iteration[50/322] Loss: 2.7984
Epoch[1] Iteration[100/322] Loss: 1.9736
Epoch[1] Iteration[150/322] Loss: 4.3419
Epoch[1] Iteration[200/322] Loss: 2.0261
Epoch[1] Iteration[250/322] Loss: 2.1724
Epoch[1] Iteration[300/322] Loss: 2.1599
Compute train metrics...
Training Results - Epoch: 1  Average Loss: 1.5363 | Accuracy: 0.5177 | Precision: 0.5477 | Recall: 0.5178
Compute validation metrics...
Validation Results - Epoch: 1  Average Loss: 1.5116 | Accuracy: 0.5139 | Precision: 0.5400 | Recall: 0.5140
Learning rate: 0.008
Epoch[2] Iteration[50/322] Loss: 1.4076
Epoch[2] Iteration[100/322] Loss: 1.4892
Epoch[2] Iteration[150/322] Loss: 1.2485
Epoch[2] Iteration[200/322] Loss: 1.6511
Epoch[2] Iteration[250/322] Loss: 3.3376
Epoch[2] Iteration[300/322] Loss: 1.3299
Compute train metrics...
Training Results - Epoch: 2  Average Loss: 3.2686 | Accuracy: 0.1977 | Precision: 0.1792 | Recall: 0.1942
Compute validation metrics...
Validation Results - Epoch: 2  Average Loss: 3.2772 | Accuracy: 0.1962 | Precision: 0.1628 | Recall: 0.1918
Learning rate: 0.006400000000000001
Epoch[3] Iteration[50/322] Loss: 0.9016
Epoch[3] Iteration[100/322] Loss: 1.2006
Epoch[3] Iteration[150/322] Loss: 0.8892
Epoch[3] Iteration[200/322] Loss: 0.8141
Epoch[3] Iteration[250/322] Loss: 1.4005
Epoch[3] Iteration[300/322] Loss: 0.8888
Compute train metrics...
Training Results - Epoch: 3  Average Loss: 0.7368 | Accuracy: 0.7554 | Precision: 0.7818 | Recall: 0.7554
Compute validation metrics...
Validation Results - Epoch: 3  Average Loss: 0.7177 | Accuracy: 0.7623 | Precision: 0.7863 | Recall: 0.7611
Learning rate: 0.005120000000000001
Epoch[4] Iteration[50/322] Loss: 0.8490
Epoch[4] Iteration[100/322] Loss: 0.8493
Epoch[4] Iteration[150/322] Loss: 0.8100
Epoch[4] Iteration[200/322] Loss: 0.9165
Epoch[4] Iteration[250/322] Loss: 0.9370
Epoch[4] Iteration[300/322] Loss: 0.6548
Compute train metrics...
Training Results - Epoch: 4  Average Loss: 0.7047 | Accuracy: 0.7713 | Precision: 0.8040 | Recall: 0.7728
Compute validation metrics...
Validation Results - Epoch: 4  Average Loss: 0.6737 | Accuracy: 0.7778 | Precision: 0.7955 | Recall: 0.7806
Learning rate: 0.004096000000000001
Epoch[5] Iteration[50/322] Loss: 0.6965
Epoch[5] Iteration[100/322] Loss: 0.6196
Epoch[5] Iteration[150/322] Loss: 0.6194
Epoch[5] Iteration[200/322] Loss: 0.3986
Epoch[5] Iteration[250/322] Loss: 0.6032
Epoch[5] Iteration[300/322] Loss: 0.7152
Compute train metrics...
Training Results - Epoch: 5  Average Loss: 0.5049 | Accuracy: 0.8282 | Precision: 0.8393 | Recall: 0.8314
Compute validation metrics...
Validation Results - Epoch: 5  Average Loss: 0.5084 | Accuracy: 0.8304 | Precision: 0.8386 | Recall: 0.8328
Learning rate: 0.0032768000000000007
Epoch[6] Iteration[50/322] Loss: 0.4433
Epoch[6] Iteration[100/322] Loss: 0.4764
Epoch[6] Iteration[150/322] Loss: 0.5578
Epoch[6] Iteration[200/322] Loss: 0.3684
Epoch[6] Iteration[250/322] Loss: 0.4847
Epoch[6] Iteration[300/322] Loss: 0.3811
Compute train metrics...
Training Results - Epoch: 6  Average Loss: 0.4383 | Accuracy: 0.8474 | Precision: 0.8618 | Recall: 0.8495
Compute validation metrics...
Validation Results - Epoch: 6  Average Loss: 0.4419 | Accuracy: 0.8446 | Precision: 0.8532 | Recall: 0.8442
Learning rate: 0.002621440000000001
Epoch[7] Iteration[50/322] Loss: 0.4447
Epoch[7] Iteration[100/322] Loss: 0.4602
Epoch[7] Iteration[150/322] Loss: 0.5345
Epoch[7] Iteration[200/322] Loss: 0.3973
Epoch[7] Iteration[250/322] Loss: 0.5023
Epoch[7] Iteration[300/322] Loss: 0.5303
Compute train metrics...
Training Results - Epoch: 7  Average Loss: 0.4305 | Accuracy: 0.8579 | Precision: 0.8691 | Recall: 0.8596
Compute validation metrics...
Validation Results - Epoch: 7  Average Loss: 0.4262 | Accuracy: 0.8590 | Precision: 0.8685 | Recall: 0.8606
Learning rate: 0.002097152000000001
Epoch[8] Iteration[50/322] Loss: 0.4867
Epoch[8] Iteration[100/322] Loss: 0.3090
Epoch[8] Iteration[150/322] Loss: 0.3721
Epoch[8] Iteration[200/322] Loss: 0.4559
Epoch[8] Iteration[250/322] Loss: 0.3958
Epoch[8] Iteration[300/322] Loss: 0.4222
Compute train metrics...
Training Results - Epoch: 8  Average Loss: 0.3432 | Accuracy: 0.8818 | Precision: 0.8895 | Recall: 0.8817
Compute validation metrics...
Validation Results - Epoch: 8  Average Loss: 0.3644 | Accuracy: 0.8713 | Precision: 0.8784 | Recall: 0.8707
Learning rate: 0.001677721600000001
Epoch[9] Iteration[50/322] Loss: 0.3557
Epoch[9] Iteration[100/322] Loss: 0.3692
Epoch[9] Iteration[150/322] Loss: 0.3510
Epoch[9] Iteration[200/322] Loss: 0.3446
Epoch[9] Iteration[250/322] Loss: 0.3966
Epoch[9] Iteration[300/322] Loss: 0.3451
Compute train metrics...
Training Results - Epoch: 9  Average Loss: 0.3315 | Accuracy: 0.8954 | Precision: 0.9001 | Recall: 0.8982
Compute validation metrics...
Validation Results - Epoch: 9  Average Loss: 0.3559 | Accuracy: 0.8818 | Precision: 0.8876 | Recall: 0.8847
Learning rate: 0.0013421772800000006
Epoch[10] Iteration[50/322] Loss: 0.3340
Epoch[10] Iteration[100/322] Loss: 0.3370
Epoch[10] Iteration[150/322] Loss: 0.3694
Epoch[10] Iteration[200/322] Loss: 0.3409
Epoch[10] Iteration[250/322] Loss: 0.4420
Epoch[10] Iteration[300/322] Loss: 0.2770
Compute train metrics...
Training Results - Epoch: 10  Average Loss: 0.3246 | Accuracy: 0.8921 | Precision: 0.8988 | Recall: 0.8925
Compute validation metrics...
Validation Results - Epoch: 10  Average Loss: 0.3536 | Accuracy: 0.8731 | Precision: 0.8785 | Recall: 0.8722

Теперь проверим модели и параметры, сохраненные на диск:

ls best_models/

model_best_model_10_val_accuracy=0.8730994.pth
model_best_model_8_val_accuracy=0.8712978.pth
model_best_model_9_val_accuracy=0.8818188.pth

и

ls checkpoint/

checkpoint_lr_scheduler_3000.pth  
checkpoint_optimizer_3000.pth
checkpoint_model_3000.pth


Предсказания обученной моделью

Для начала создадим загрузчик тестовых данных (для примера возьмем валидационную выборку) так, чтобы батч данных состоял из изображений и их индексов:

class TestDataset(Dataset):

    def __init__(self, ds):
        self.ds = ds

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, index):
        return self.ds[index][0], index

test_dataset = TestDataset(val_dataset)

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                         num_workers=num_workers, 
                         drop_last=False, pin_memory="cuda" in device)

С помощью ignite создадим новый движок для предсказания на тестовых данных. Для этого определим функцию inference_update, которая выдает результат предсказания и индекс изображения. Для повышения точности, мы также будем использовать всем известный трюк «test time augmentation» (TTA).

import torch.nn.functional as F
from ignite._utils import convert_tensor

def _prepare_batch(batch):
    x, index = batch
    x = convert_tensor(x, device=device)
    return x, index

def inference_update(engine, batch):
    x, indices = _prepare_batch(batch)
    y_pred = model(x)
    y_pred = F.softmax(y_pred, dim=1)
    return {"y_pred": convert_tensor(y_pred, device='cpu'), "indices": indices}

model.eval()
inferencer = Engine(inference_update)    

Далее создадим обработчики событий, которые будут оповещать об этапе предсказаний и сохранять предсказания в выделенный массив:

@inferencer.on(Events.EPOCH_COMPLETED)
def log_tta(engine):
    print("TTA {} / {}".format(engine.state.epoch, n_tta))

n_tta = 3
num_classes = 81
n_samples = len(val_dataset)

# Массив для хранения предсказаний
y_probas_tta = np.zeros((n_samples, num_classes, n_tta), dtype=np.float32)

@inferencer.on(Events.ITERATION_COMPLETED)
def save_results(engine):
    output = engine.state.output
    tta_index = engine.state.epoch - 1
    start_index = ((engine.state.iteration - 1) % len(test_loader)) * batch_size
    end_index = min(start_index + batch_size, n_samples)
    batch_y_probas = output['y_pred'].detach().numpy()
    y_probas_tta[start_index:end_index, :, tta_index] = batch_y_probas

Прежде чем запустить процесс, давайте загрузим лучшую модель:

model = squeezenet1_1(pretrained=False, num_classes=64)
model.classifier[-1] = nn.AdaptiveAvgPool2d(1)
model = model.to(device)

model_state_dict = torch.load("best_models/model_best_model_10_val_accuracy=0.8730994.pth")
model.load_state_dict(model_state_dict)

Запускаем:

inferencer.run(test_loader, max_epochs=n_tta)

> TTA 1 / 3
> TTA 2 / 3
> TTA 3 / 3

Далее, стандартным образом, возьмем среднее от предсказаний TTA и вычислим индекс класса с наибольшей вероятностью:

y_probas = np.mean(y_probas_tta, axis=-1)
y_preds = np.argmax(y_probas, axis=-1)

И теперь можем посчитать еще раз точность модели по полученным предсказаниям:

from sklearn.metrics import accuracy_score

y_test_true = [y for _, y in val_dataset]
accuracy_score(y_test_true, y_preds)

> 0.9310369676443035

Итак, в этой части мы показали, как посчитать предсказания с помощью обученной модели на валидационной выборке. На самом деле, пример очень простой, но из него должно быть понятно, каким образом изпользовать ignite для других и более сложных ситуаций.


Другие примеры с ignite

Полный код данного примера можно найти здесь.

В github репозитории библиотеки можно найти и другие примеры обучения сетей для таких задач как


  • fast neural transfer
  • reinforcement learning
  • dcgan


Заключение

В заключении хочу сказать, что библиотека ignite не является официальным продуктом от Facebook и в её разработке принимают участие программисты на добровольной основе (напр. автор этой статьи). На текущий момент она находится в версии 0.1.0, но основной API (Engine, State, Events, Metric, …) будет по мере возможного оставаться без изменений и в последующих версиях. Поскольку библиотека находится в стадии активной разработки, в том числе и дополнительных модулей, то разработчики будут рады отзывам, сообщениях об ошибках и pull request-ам в репозитории github.

Спасибо за внимание!

© Habrahabr.ru