[Перевод] Начинаем работу с PyTorch 2.0 и Hugging Face Transformers

753e0fe88c935759153b582e9fd85bb5.png

В этом посте разберем работу с PyTorch 2.0 и Hugging Face Transformers на примере fine-tune модели BERT для классификации текста.

PyTorch 2.0 лучше по производительности, скорости работы, более удобный для Python, но при этом остается таким же динамическим, как и ранее.

Разберем следующие шаги:

  1. Настройка окружения и установка PyTorch 2.0.

  2. Загрузка и подготовка датасета.

  3. Fine-tune и оценка модели BERT с помощью Hugging Face Trainer.

  4. Запуск инференса и тестирование модели.

Краткое введение: PyTorch 2.0

PyTorch 2.0, или, точнее, 1.14, полностью обратно совместим с предыдущими версиями. Он не потребует каких‑либо изменений в существующем коде PyTorch, но может оптимизировать код, если добавить model = torch.compile(model). Команда PyTorch так объясняет появление новой версии в своем FAQ: «Мы выпустили значительные новые функции, которые, на наш взгляд, меняют то, как вы используете PyTorch, поэтому мы назвали это 2.0 вместо 1.14.»

Среди этих новых функций: полная поддержка TorchDynamo, AOTAutograd, PrimTorch и TorchInductor. Это позволяет PyTorch 2.0 достигнуть ускорения времени обучения в 1,3–2 раза на более 40 архитектурах моделей от HuggingFace Transformers. Подробнее о PyTorch 2.0 можно узнать на официальном «GET STARTED».

Примечание: Этот туториал был создан и запущен на инстансе AWS EC2 g5.xlarge, включая GPU NVIDIA A10G.

1. Настройка окружения и установка PyTorch 2.0

Первый шаг — установить PyTorch 2.0 и библиотеки от Hugging Face, transformers и datasets.

# Установка PyTorch 2.0 с cuda 11.7
!pip install "torch>=2.0" --extra-index-url https://download.pytorch.org/whl/cu117 --upgrade --quiet 

Также ставим последнюю версию transformers, которая включает нативную интеграцию PyTorch 2.0 в Trainer.

# Установка transformers и dataset
!pip install "transformers==4.27.1" "datasets==2.9.0" "accelerate==0.17.1" "evaluate==0.4.0" tensorboard scikit-learn
# Установка git-lfs для загрузки модели и логов в hugging face hub
!sudo apt-get install git-lfs --yes

В этом примере для версионирования моделей мы будем использовать Hugging Face Hub. Чтобы загрузить модель на Hub, вначале необходимо зарегистрироваться на Hugging Face. Для входа в свою учетную запись и сохранения токена (ключа доступа) на диске используем login из пакета huggingface_hub.

from huggingface_hub import login

login(
  token="", # ADD YOUR TOKEN HERE
  add_to_git_credential=True
)

2. Загрузка и подготовка датасета

Будем обучать модель классификации текста на датасете BANKING77. Датасет BANKING77 содержит текстовые обращения от клиентов из области банковского/финансового сектора. Он состоит из 13 083 обращений, размеченных на 77 интентов (классов).

Для загрузки BANKING77 мы будем использовать метод load_dataset() из библиотеки

© Habrahabr.ru