Краткий гайд по квантованию нейросетей
Мы достаточно написали статей про оптимизацию ваших нейросетей, сегодня пора перейти к дроблению, уменьшению и прямому урезанию, иначе квантованию данных. Сам по себе процесс этот несложный с точки зрения всего, но подводные камни у операции есть.
Мы буквально уменьшаем битность данных, что позволяет сократить вычислительные ресурсы и уменьшить объем памяти, необходимой для хранения моделей.
Наша карточка от Nvidia пользуется дешевыми, например, 8-битными ядрами для вычисления операций свертки/умножения матриц — мы получаем дешевую модель. Конечно, такое подрезание чисел с плавающей запятой может приводить и к снижению точности. Катастрофической.
Придумали разные методы квантования, каждый из которых имеет свои особенности, подходы и применения.
Их делят по трем критериям: униформное и неуниформное квантование, симметричное и асимметричное квантование, а также статическое и динамическое квантование. Углубляться мы не будем. Главное, что квантование может приводиться не только к 8-битности, но и к 16…
Там, где данные с высокой массой распределения от -1 до 1 — там вероятно значение входит в диапазон. Самое главное — квантование — это всегда приближение, которое может обходиться вам дороговато. Если вы решили сократить объем память в несколько раз и буквально перевести float32 к int8, особенно… Т.е от плавающей запятой к целочисленным значениям.
Если все плохо, то приводится, например, статическое квантование. Иначе, модель сразу обучается на «квантованных» данных.
Два принципа квантования
Практическая реализация идет, например, через Post-training quantization (PTQ) и базируется на постобучающем преобразовании модели, уже завершившей тренировку на данных с высокой точностью чисел, обычно 32-битных с плавающей запятой. Тот случай, когда мы рассчитываем, что наша супермодель выживет от такого удешевления данных.
Поэтому основная цель PTQ — минимизировать потребление памяти и вычислительных ресурсов на этапе инференса без необходимости повторного обучения модели.
В PTQ веса модели и, в некоторых случаях, активации преобразуются в целочисленные значения меньшей разрядности, чаще всего это 8-битные целые числа (int8), что позволяет существенно сократить размер модели и ускорить вычисления за счёт использования инструкций SIMD (Single Instruction, Multiple Data) на уровне аппаратного обеспечения.
SIMD-операции обрабатывают несколько данных с помощью одной инструкции. В этом их отличие от традиционных/скалярных операций.
В PTQ не происходит изменений в архитектуре нейронной сети, и алгоритм квантования выполняется отдельно от процесса обучения.
Основные этапы включают квантование весов с плавающей запятой в int8 посредством вычисления масштабов и нулевых сдвигов, что позволяет сохранить диапазон значений. Это осуществляется с помощью статистической информации, собранной на небольшом объёме тренировочных данных.
Важно, что через PTQ активации и веса могут квантоваться по-разному: активации могут квантоваться динамически на этапе инференса в зависимости от входных данных, тогда как веса — статически на основе априорной (изначальной) статистики.
Если вы действительно работаете над глубокой-глубокой нейросетью, то могут возникнуть проблемы с точностью, то же относится к задачам с высокой чувствительностью в данных.
Quantization-Aware Training (QAT) сложнее — тут квантование учитывается уже на этапе обучения модели.
В отличие от PTQ, в QAT веса и активации модели представляются в формате низкой разрядности (int8 или int16) в ходе всего процесса тренировки, что позволяет модели адаптироваться к ограниченной точности чисел.
Архитектура QAT предполагает, что квантованные версии весов используются не напрямую, а через эмуляцию процесса квантования во время прямого и обратного прохода модели.
В прямом проходе веса и активации моделируются как квантованные целые числа, что позволяет эффективно эмулировать процесс инференса в квантованной среде. Обратный проход, однако, используется с весами с плавающей запятой, что сохраняет точность градиентного спуска и позволяет корректировать модель в условиях ограниченной точности.
В процессе обучения модель «учится» компенсировать ошибки, вызванные квантованием, что позволяет снизить потери точности, наблюдаемые при PTQ. QAT требует значительно большего объёма вычислительных ресурсов на этапе обучения, так как тренировка должна учитывать квантование всех промежуточных активаций и весов.
При этом необходимо выполнять симуляцию квантования не только для весов, но и для входных данных на каждом слое, что увеличивает вычислительную сложность модели на этапе тренировки.
Для реализации QAT требуется модификация стандартных слоёв нейронной сети таким образом, чтобы они поддерживали низкоразрядные вычисления, а также корректная настройка механизмов обратного распространения ошибки.
Применение QAT часто связано с такими задачами, как развёртывание на мобильных устройствах, где вычислительные ресурсы и память ограничены. Поэтому частенько такая архитектура квантования применяется для задач CV, когда мы ставим камеру с микропроцессором и ждем чуда детекции…
Как все это работает на практике?
В TensorFlow квантование реализовано через TensorFlow Lite — облегчённую версию TensorFlow, специально предназначенную для развертывания моделей на устройствах с ограниченными ресурсами, ну там raspberry pi)))
PTQ в TensorFlow Lite можно выполнить через использование метода post-training quantization, где модель, обученная с плавающей запятой, автоматически конвертируется в квантованную версию с помощью функции converter.optimizations = [tf.lite.Optimize.DEFAULT].
Этот процесс включает статическое квантование весов с помощью небольшого набора калибровочных данных для расчёта масштабов и сдвигов нулевого уровня. Пример кода в TensorFlow Lite для PTQ может выглядеть следующим образом:
import tensorflow as tf
# Загрузка модели
model = tf.keras.models.load_model('model.h5')
# Конвертация модели с использованием PTQ
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
# Сохранение квантованной модели
with open('model_quantized.tflite', 'wb') as f:
f.write(tflite_model)
Для сложного сценария квантования QAT в TensorFlow применяется функционал встроенного квантования на этапе тренировки, который позволяет учитывать квантованные представления весов и активаций в процессе обучения.
Все идет через tf.quantization.fake_quant_with_min_max_vars, где происходит симуляция квантования в прямом и обратном проходах.
Однако это требует более детальной настройки сети и специфических изменений в процессе тренировки.
В PyTorch квантование поддерживается через пакет torch.quantization, который позволяет как пост-тренировочное квантование, так и QAT.
В PyTorch модульны подход, позволяющий кодерам выбирать между симметричным, асимметричным квантованием, а также реализовывать как динамическое квантование активаций, так и статическое квантование весов.
Перед тем как провести квантование по PTQ в PyTorch включают подготовку модели через функции torch.quantization.prepare () и преобразование через torch.quantization.convert ().
Пример кода для PTQ в PyTorch:
import torch
import torch.quantization
# Загрузка обученной модели
model = MyPretrainedModel()
# Подготовка модели к квантованию
model.eval()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
# Применение PTQ
torch.quantization.convert(model, inplace=True)
# Тестирование квантованной модели
output = model(input_data)
Для QAT в PyTorch используется аналогичная структура, но с добавлением процедуры обучения с учётом квантования.
Модель сначала подготовлена через torch.quantization.prepare_qat (), и затем продолжается обучение, при котором происходит симуляция квантования.
Важно отметить, что во время обучения модель работает с весами с плавающей запятой, однако на этапе инференса она конвертируется в int8.
ONNX (Open Neural Network Exchange) — это открытый формат для представления моделей глубокого обучения, который обеспечивает переносимость моделей между разными фреймворками.
Для квантования в ONNX используется onnxruntime, поддерживающий как статическое, так и динамическое квантование.
Статическое квантование в ONNX работает через калибровочные данные — они используются для вычисления квантованных значений до начала инференса, в то время как динамическое квантование применяется только к весам, что упрощает процесс и снижает нагрузку на инференс.
Пример квантования модели в ONNX может выглядеть следующим образом:
import onnx from onnxruntime.quantization import quantize_dynamic, QuantType
model_fp32 = 'model.onnx' model_quant = 'model_quant.onnx'
Квантование с дистилляцией и прунингом
Часто процесс квантования проводится совместно с другими методами оптимизация — тем же прунингом или дистилляцией.
Pruning — это метод, при котором ненужные или малоактивные нейроны и связи удаляются из сети без существенного ущерба для её производительности.
Прореживание может быть выполнено на основе различных критериев, таких как величина весов (weights magnitude pruning), где удаляются веса, чьи значения минимальны, или на основе анализа чувствительности (sensitivity analysis pruning), где оценивается вклад каждого нейрона в общую ошибку модели.
В сочетании с квантованием pruning может значительно сократить количество вычислений, так как после прореживания остаются лишь активные нейроны, на которых квантование может быть применено.
Например, после выполнения pruning модель может быть преобразована в квантованную версию с меньшим числом параметров, что ещё больше снижает её вычислительную сложность.
Практическая реализация pruning с последующим квантованием в TensorFlow может выглядеть следующим образом:
import tensorflow_model_optimization as tfmot
model = tf.keras.models.load_model('model.h5')
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
pruned_model = prune_low_magnitude(model)
pruned_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
pruned_model.fit(train_data, train_labels, epochs=2)
converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quantized_model = converter.convert()
with open('model_pruned_quantized.tflite', 'wb') as f:
f.write(tflite_quantized_model)
По порядку.
Сначала мы загружаем модель с помощью tf.keras.models.load_model ('model.h5').
Эта модель может быть предварительно обученной нейронной сетью, например, для задачи классификации изображений или распознавания речи.
Для прореживания модели используется метод prune_low_magnitude, который является частью библиотеки tensorflow_model_optimization.
Метод удаляет связи (нейронные веса), значения которых близки к нулю, тем самым уменьшая размер модели и снижая объём вычислений, которые она требует.
В результате создаётся версия модели pruned_model, в которой часть параметров обнуляется. Это помогает сократить сложность модели и ускорить её выполнение без значительной потери точности.
После применения метода прореживания модель компилируется и дообучается на тренировочных данных с помощью метода fit (). Это необходимо для того, чтобы модель адаптировалась к изменённой структуре, где часть нейронных связей была удалена.
После завершения обучения модель подвергается пост-тренировочному квантованию с помощью TFLiteConverter.
Этот процесс заключается в преобразовании весов модели из 32-битного представления (FP32) в 8-битное целочисленное (int8), что значительно сокращает объём памяти, занимаемой моделью, и ускоряет инференс.
При этом используются оптимизации, заданные через converter.optimizations = [tf.lite.Optimize.DEFAULT]. После этого модель сохраняется в формате TFLite, что позволяет её легко развернуть на устройствах с ограниченными вычислительными ресурсами, таких как микроконтроллеры и мобильные устройства.
Принцип же дистилляции заключается в том, что «большая» модель (teacher model) обучает «меньшую» модель (student model), передавая ей свои знания в форме предсказаний.
В процессе дистилляции модель-учитель генерирует вероятностные распределения классов, которые затем используются для обучения модели-ученика.
Эти распределения, также называемые «мягкими метками» (soft labels), содержат более полную информацию, чем жёсткие метки (hard labels), так как отражают уверенность модели-учителя в каждом классе.
Сразу приводим пример дистилляции с квантованием.
Пример процесса дистилляции в PyTorch может выглядеть следующим образом:
import torch
import torch.nn.functional as F
def distillation_loss(student_output, teacher_output, labels, T, alpha):
soft_loss = F.kl_div(F.log_softmax(student_output / T, dim=1),
F.softmax(teacher_output / T, dim=1), reduction='batchmean') * (T * T)
hard_loss = F.cross_entropy(student_output, labels)
return soft_loss * alpha + hard_loss * (1. - alpha)
for data, labels in train_loader:
student_output = student_model(data)
with torch.no_grad():
teacher_output = teacher_model(data)
loss = distillation_loss(student_output, teacher_output, labels, T=4.0, alpha=0.7)
loss.backward()
optimizer.step()
Функция distillation_loss объединяет два компонента:
soft loss, который вычисляется с использованием распределений предсказаний модели-учителя и модели-ученика, нормализованных через температуру.
Это помогает передать модели-ученику более детализированную информацию о вероятностях классов, а не только о правильном классе (жёсткие метки), что делает процесс обучения более информативным.
hard loss — стандартная функция кросс-энтропии, которая измеряет расстояние между предсказаниями модели-ученика и реальными метками классов.
Комбинация этих двух составляющих (в зависимости от значения параметра α) позволяет модели-ученику лучше учиться на основе предсказаний модели-учителя.
Процесс обучения модели-ученика:
В этом коде цикл обучения модели-ученика выглядит следующим образом:
Для каждой порции данных (мини-батч) выполняется инференс как на модели-ученике, так и на модели-учителе.
Предсказания модели-учителя передаются в функцию потерь, где они используются для вычисления soft loss.
Функция потерь комбинирует это с традиционной кросс-энтропией (hard loss), обучая модель-ученика более эффективно.
Затем выполняется шаг оптимизации, и модель-ученик обновляет свои веса на основе полученной комбинированной функции потерь.
После обучения модели-ученика с использованием метода дистилляции, она может быть квантована с применением стандартных методов, таких как динамическое или статическое квантование, что ещё больше уменьшит её размеры и потребление ресурсов.
Интеграция всех трёх методов — квантования, pruning и distillation — представляет собой мощный подход к сжатию моделей.
В реальных сценариях, таких как мобильные устройства или встроенные системы, это позволяет достигать значительных улучшений в скорости выполнения модели и её энергоэффективности.
Например, если модель сначала подвергается прореживанию для удаления малозначимых связей, затем проходит процесс дистилляции для создания облегчённой версии, и, наконец, подвергается квантованию, можно достичь значительного уменьшения вычислительных затрат, не теряя при этом критически важной точности.
Это был краткий гайд по квантованию. Надеемся, для некоторых, особенно новичков, он был полезен.