Darts: тестируем временные ряды с нуля
Привет, Хабр!
Если вы когда-либо имели дело с временными рядами, то, вероятно, слышали о Darts. А для тех, кто ещё в танке: Darts — это мощный инструмент, который поддерживает мультиварибельные временные ряды и легко интегрируется с PyTorch и TensorFlow.
Зачем же тестировать временные ряды, когда в классическом машинном обучении всё так просто с кросс-валидацией? Временные ряды обладают своей изюминкой: они подвержены временным зависимостям, сезонности, трендам и другим радостям жизни. Так что, если вы хотите, чтобы ваши модели не провалились на тестах, время разобраться с их особенностями!
Основы тестирования временных рядов
Перед тем как приступать к работе с данными, нужно убедиться, что они в правильном формате.
Формат данных
Первым делом нужно проверить временные метки на корректность. Важно, чтобы не было пропусков! С помощью pandas
это делается легко:
import pandas as pd
# Загружаем данные
data = pd.read_csv('your_data.csv', parse_dates=['timestamp'], index_col='timestamp')
# Проверка на пропуски
missing_values = data.isnull().sum()
print(f"Пропуски в данных:\n{missing_values}")
# Проверка на равномерность временных меток
time_diff = data.index.to_series().diff()
print(f"Минимальный интервал между временными метками: {time_diff.min()}")
print(f"Максимальный интервал между временными метками: {time_diff.max()}")
Распределение данных
Важно убедиться, что данные равномерно распределены. Если выйдет так, что обнаружатся большие интервалы без наблюдений, может возникнуть необходимость в ресэмплинге или интерполяции. Помните, что данные должны быть в форме, удобной для анализа.
Перед тем как запустить свою модель, выполним несколько предобработок, которые могут существенно повлиять на точность прогнозов.
Скейлинг: Приведение данных к общему масштабу поможет вашему алгоритму быстрее схватывать закономерности. Используйте StandardScaler
или MinMaxScaler
из scikit-learn
:
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(data.values.reshape(-1, 1))
Логарифмирование: Если ваши данные демонстрируют экспоненциальный рост, логарифмирование станет вашим лучшим другом:
import numpy as np
data_log = np.log(data + 1) # Добавляем 1, чтобы избежать log(0)
Иммутация пропусков: Обработка пропущенных значений — ещё один важный шаг:
data_filled = data.interpolate(method='linear')
Метрики для оценки моделей
Когда дело доходит до оценки моделей временных рядов, стандартные метрики могут вас разочаровать. Временные ряды требуют специфических подходов:
RMSE: Это показатель отклонения предсказаний от реальных значений.
from sklearn.metrics import mean_squared_error
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
print(f"RMSE: {rmse}")
MAE: Более устойчивый к выбросам, MAE дает ясное представление о точности модели.
mae = np.mean(np.abs(y_true - y_pred))
print(f"MAE: {mae}")
MASE: Эта метрика помогает сравнивать качество модели с наивным подходом.
mase = np.mean(np.abs(y_true - y_pred)) / np.mean(np.abs(y_true - np.roll(y_true, 1)[1:]))
print(f"MASE: {mase}")
Как протестировать данные и прогнозы с помощью Darts
Итак, начнем с загрузки данных. Допустим, есть данные о продажах, хранящиеся в CSV-файле.
import pandas as pd
from darts import TimeSeries
# Загружаем данные
data = pd.read_csv('sales_data.csv', parse_dates=['date'], index_col='date')
# Создаем временной ряд
series = TimeSeries.from_dataframe(data, 'date', 'sales')
print("Данные загружены и преобразованы в временной ряд:")
print(series)
Перед тем как двигаться дальше, убедимся, что с временным рядом всё в порядке. Проверим на наличие пропусков и аномалий:
# Проверка на пропуски
if series.isnull().any():
print("Обнаружены пропуски в данных!")
else:
print("Пропусков нет!")
# Визуализация данных
series.plot(title='График продаж', xlabel='Дата', ylabel='Количество продаж')
Теперь, когда мы уверены в целостности данных, проведём их предобработку.
Логарифмирование:
import numpy as np
# Логарифмирование
series_log = TimeSeries.from_dataframe(np.log1p(data.set_index('date')['sales']))
series_log.plot(title='Логарифмированные данные')
Скейлинг:
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
scaled_data = scaler.fit_transform(data['sales'].values.reshape(-1, 1))
scaled_series = TimeSeries.from_dataframe(pd.DataFrame(scaled_data, index=data['date']))
scaled_series.plot(title='Масштабированные данные')
Теперь всё готово к построению модели. В Darts есть множество моделей, и мы будем использовать N‑BEATS.
from darts.models import NBEATSModel
# Определяем модель
model = NBEATSModel(input_chunk_length=30, output_chunk_length=10, n_epochs=100)
# Обучаем модель
model.fit(series_log)
print("Модель N-BEATS успешно обучена.")
Теперь протестируем модель на различных горизонтах прогнозирования:
# Прогноз на 10 шагов вперед
forecast_horizon = 10
predictions = model.predict(forecast_horizon)
# Визуализация прогноза
predictions.plot(label='Прогноз N-BEATS', title='Прогнозирование на 10 шагов вперед')
series_log.plot(label='Исторические данные')
После получения прогнозов важно оценить их точность с помощью метрик:
from darts.utils.statistics import mean_absolute_error, mean_squared_error
# Оценка точности
mae = mean_absolute_error(series_log[-forecast_horizon:], predictions)
rmse = np.sqrt(mean_squared_error(series_log[-forecast_horizon:], predictions))
print(f"MAE: {mae:.4f}")
print(f"RMSE: {rmse:.4f}")
Также стоит визуализировать остатки:
# Остатки
residuals = series_log[-forecast_horizon:] - predictions
residuals.plot(title='Остатки прогноза')
Если остатки показывают закономерности, это сигнализирует о проблемах в модели.
Как улучшить прогноз
Если N-BEATS не даёт желаемых результатов, можно попробовать другие подходы:
Изменение модели: Попробуйте, например, модель Prophet.
from darts.models import Prophet
# Обучаем модель Prophet
prophet_model = Prophet()
prophet_model.fit(series)
# Прогнозируем
prophet_predictions = prophet_model.predict(forecast_horizon)
# Визуализация
prophet_predictions.plot(label='Прогноз Prophet')
series_log.plot(label='Исторические данные')
Тонкая настройка гиперпараметров: Используйте кросс-валидацию для подбора гиперпараметров.
from darts.models import NBEATSModel
from sklearn.model_selection import GridSearchCV
# Определяем параметры для поиска
param_grid = {
'input_chunk_length': [10, 30],
'output_chunk_length': [5, 10],
'n_epochs': [100, 200],
}
# Настраиваем GridSearchCV
grid_search = GridSearchCV(estimator=NBEATSModel(), param_grid=param_grid, scoring='neg_mean_absolute_error')
grid_search.fit(series_log)
print(f"Лучшие параметры: {grid_search.best_params_}")
Анализ различных горизонтов: Прогнозируйтена разных интервалах и проверяйте точность.
for horizon in [1, 5, 10]:
pred = model.predict(horizon)
error = mean_absolute_error(series_log[-horizon:], pred)
print(f"MAE для горизонта {horizon}: {error:.4f}")
Вот и всё! Darts предоставляет мощные инструменты для анализа временных рядов, так что не бойтесь экспериментировать с разными моделями и гиперпараметрами. Удачи вам и пусть ваши прогнозы всегда будут точными!
Подробнее с библиотекой можно ознакомиться здесь.
А в ближайшие дни пройдут открытые уроки по ML и CV, которые можно посетить бесплатно:
7 октября: «Word2Vec — классика векторных представлений слов для решения задач текстовой обработки». Узнать подробнее
10 октября: «OpenCV: Как Начать Работать с Компьютерным Зрением». Узнать подробнее