Переобучение моделей: гайд и советы для начинающих

551cffc4f1de1f6752109e5cee46321c.pngАлександр Рыжков

Руководитель команды LightAutoML

С развитием нейросетей появляются новые вызовы. Один из них — переобучение моделей. Совместно с Александром Рыжковым, ментором Skillfactory, руководителем команды LightAutoML и 4х Kaggle Grandmaster, разбираемся, что такое переобучение, хорошо ли это и как его избежать.

Что такое переобучение моделей

В мире машинного обучения и нейронных сетей термин «переобучение» встречается часто и является ключевой проблемой, с которой сталкиваются разработчики. Переобучение (или overfitting) происходит, когда модель обучается настолько хорошо на тренировочных данных, что запоминает неважные детали и «шумы» вместо того, чтобы обобщать закономерности. Это приводит к тому, что при работе с новыми данными модель демонстрирует плохую производительность.

Александр Рыжков, ментор Skillfactory, руководитель команды LightAutoML и 4х Kaggle Grandmaster:

Представьте, что вы готовитесь к экзамену, заучивая наизусть ответы на конкретные вопросы из учебника. На экзамене вам попадаются именно эти вопросы, и вы блестяще отвечаете. Но если вопросы немного изменятся, вы окажетесь в затруднительном положении, так как не понимаете сути предмета, а лишь запомнили ответы.Переобучение модели в машинном обучении — это аналогичная ситуация. Модель настолько хорошо «запоминает» тренировочные данные, что теряет способность делать точные предсказания на новых данных. 

Переобученная модель отлично работает на тренировочном наборе, но демонстрирует низкую производительность на тестовом или реальном наборе данных.

Переобучение модели — это хорошо или плохо

Переобучение — это однозначно плохо. Модель, которая хорошо работает только на тренировочных данных, бесполезна на практике и противоречит целям машинного обучения. Цель машинного обучения — делать точные прогнозы на новых данных.

Существует несколько причин, которые приводят к переобучению:

  • Слишком сложная модель: если у модели слишком много параметров по сравнению с размером тренировочного набора, она может запомнить все особенности тренировочных данных, вместо того чтобы выявить общие закономерности.

  • Недостаточный размер тренировочного набора: если данных для обучения мало, модель может «подстроиться» под конкретные примеры, не улавливая общую картину.

  • Наличие шума в данных: шум — это случайные или нерелевантные данные, которые могут ввести модель в заблуждение и привести к переобучению.

  • Длительное обучение: если модель обучается слишком долго, она может начать «запоминать» шум и особенности тренировочных данных.

Как понять, что произошло переобучение модели

Если модель демонстрирует почти идеальную точность на обучающих данных, но не может достичь аналогичных результатов на тестовых или валидационных наборах, это первый признак возможного переобучения.

Серьезный разрыв между ошибкой (loss) на тренировочных и тестовых данных — еще один показатель того, что модель запомнила данные, вместо того чтобы научиться их обобщать.

Александр Рыжков, ментор Skillfactory, руководитель команды LightAutoML и 4х Kaggle Grandmaster:

Расскажу о переобучении на трех практических примерах.

 Пример 1: Модель регрессии. Вы обучаете модель для предсказания цены дома на основе его характеристик. На тренировочных данных модель показывает очень низкую ошибку (например, среднюю абсолютную ошибку), но на тестовых данных ошибка значительно выше. Это говорит о том, что модель переобучилась на тренировочных данных и не может обобщать новые данные.

Пример 2: Модель классификации. Вы обучаете модель для классификации изображений кошек и собак. На тренировочном наборе точность модели составляет 99%, а на тестовом — всего 70%. Это явный признак переобучения. Модель «запомнила» конкретные изображения из тренировочного набора, но не научилась распознавать кошек и собак в целом.

Пример 3: Модель — анализ временных рядов. Вы обучаете модель для прогнозирования курса акций. На исторических данных модель показывает отличные результаты, но при прогнозировании дает неточные прогнозы. Это может быть связано с тем, что модель переобучилась на исторических данных и не учитывает изменения в рыночных условиях. 

Как предотвратить переобучение модели

  • Регуляризация: используются методы регуляризации (например, L1 и L2), чтобы уменьшить сложность модели. Регуляризация основана на штрафах. Штраф — дополнительные ограничения к условию задачи обучения.

Александр Рыжков, ментор Skillfactory, руководитель команды LightAutoML и 4х Kaggle Grandmaster:

L1-регуляризация (Lasso): стремится обнулить веса некоторых признаков, что приводит к отбору признаков. L2-регуляризация (Ridge): стремится уменьшить веса признаков, но не обнуляет их полностью. 

  • Увеличение объема данных: сбор большего количества качественных данных помогает улучшить обобщающую способность модели.

  • Кросс-валидация: деление данных на несколько частей и обучение модели на разных комбинациях этих частей. Это позволяет оценить производительность модели более объективно и избежать переобучения на конкретном разбиении данных.

  • Ранняя остановка: остановка обучения модели, когда ошибка на валидационном наборе данных начинает расти, даже если ошибка на тренировочном наборе продолжает уменьшаться.

  • Снижение сложности модели: уменьшение количества признаков или слоев в нейронной сети.

  • Дропаут (Dropout): случайное «выключение» части нейронов во время обучения. Это заставляет сеть обучаться более устойчивым признакам и предотвращает переобучение.

  • Аугментация данных: создание новых тренировочных данных путем изменения существующих (например, повороты, масштабирование, отражение изображений). Это помогает увеличить размер тренировочного набора и сделать модель более устойчивой к изменениям.

Когда необходимо организовывать переобучение

Александр Рыжков, ментор Skillfactory, руководитель команды LightAutoML и 4х Kaggle Grandmaster:

Может показаться странным, но иногда переобучение может быть полезным. Это касается ситуаций, когда необходимо создать модель, которая идеально работает на конкретном, ограниченном наборе данных. 

Например, при создании генеративно-состязательных моделей (GAN) переобучение может помочь модели генерировать данные, которые максимально похожи на тренировочные.

Первую генеративно-состязательную модель придумал Ян Гудфеллоу, чтобы обучить ее рисовать только котов. Переобучившись под конкретную задачу, она успешно с ней справилась. Однако важно понимать, что переобученная на котах модель будет бесполезна для работы с новыми данными, например для генерации изображения собак.

Организовать переобучение, если это требуется, можно следующими способами:

  • Использовать очень сложную модель с большим количеством параметров.

  • Обучать модель очень долго на одном и том же наборе данных.

  • Не использовать регуляризацию и другие методы борьбы с переобучением.

Изучить фундаментальную математику и Python с любого уровня до продвинутого, попробовать создать свою первую ML-модель можно на магистратуре «Прикладной анализ данных и машинное обучение» от Skillfactory и МИФИ.

© Habrahabr.ru