Кратко про библиотеку Rumale для машинного обучения на Ruby
Привет, Хабр!
Библиотека Rumale создана для того, чтобы сделать машинное обучение доступным и удобным для разрабов на Ruby. Она имеет большой выбор алгоритмов и инструментов, аналогичных тем, что можно найти в Scikit-learn для Python.
Краткий формат статьи выбран из-за сходств с Sckit learn.
Установим
Открываем Gemfile и добавляем строку:
gem 'rumale'
После этого юзаем bundle install
для установки библиотеки:
$ bundle install
Если хочется установить Rumale без Bundler, можно сделать это напрямую через команду gem install
:
$ gem install rumale
После установки библиотеки, подключаем в проект:
require 'rumale'
Построение и обучение моделей в Rumale
Загружать данные будем с библиотеками Daru и RDatasets.
Линейная регрессия
Линейная регрессия — это база для предсказания числовых значений. В Rumale для этой цели используется класс Rumale::LinearModel::LinearRegression
:
require 'daru'
require 'rumale'
# создание набора данных
data = Daru::DataFrame.from_csv('housing_prices.csv')
x = data['size'].to_a
y = data['price'].to_a
# преобразование данных в формат, подходящий для Rumale
x = Numo::DFloat[x].reshape(x.size, 1)
y = Numo::DFloat[y]
# построение и обучение модели линейной регрессии
model = Rumale::LinearModel::LinearRegression.new
model.fit(x, y)
# предсказание на новых данных
predicted = model.predict(x)
puts "Предсказанные значения: #{predicted.to_a}"
Данные о размерах домов и их ценах загружаются из CSV-файла, преобразуются в массивы, а затем используются для обучения модели линейной регрессии.
Метод опорных векторов (SVM)
Метод опорных векторов — это алгоритм для задач классификации. В Rumale он представлен классом Rumale::LinearModel::SVC
:
require 'daru'
require 'rumale'
require 'rdatasets'
# загрузка набора данных Iris
iris = RDatasets.load(:datasets, :iris)
x = iris[0..3].to_matrix
y = iris['Species'].map { |species| species == 'setosa' ? 0 : 1 }
# преобразование данных в формат Numo::NArray
x = Numo::DFloat[*x.to_a]
y = Numo::Int32[*y]
# построение и обучение модели SVM
model = Rumale::LinearModel::SVC.new(kernel: 'linear', reg_param: 1.0)
model.fit(x, y)
# предсказание на новых данных
predicted = model.predict(x)
puts "Предсказанные значения: #{predicted.to_a}"
SVM моделька классифицирует цветы как setosa
или нет.
Кластеризация с использованием K-Means
K-Means — это алгоритм кластеризации, который группирует данные на основе их схожести. В Rumale используется класс Rumale::Clustering::KMeans
:
require 'daru'
require 'rumale'
require 'rdatasets'
# загрузка набора данных Iris
iris = RDatasets.load(:datasets, :iris)
x = iris[0..3].to_matrix
# преобразование данных в формат Numo::NArray
x = Numo::DFloat[*x.to_a]
# построение и обучение модели K-Means
model = Rumale::Clustering::KMeans.new(n_clusters: 3, max_iter: 300)
model.fit(x)
# предсказание кластеров
labels = model.predict(x)
puts "Кластеры: #{labels.to_a}"
Используем данные Iris для кластеризации их на три группы с помощью K-Means.
Прочие алгоритмы
Random Forest:
require 'daru'
require 'rumale'
require 'rdatasets'
# загрузка набора данных Iris
iris = RDatasets.load(:datasets, :iris)
x = iris[0..3].to_matrix
y = iris['Species'].map { |species| species == 'setosa' ? 0 : 1 }
# преобразование данных в формат Numo::NArray
x = Numo::DFloat[*x.to_a]
y = Numo::Int32[*y]
# построение и обучение модели Random Forest
model = Rumale::Ensemble::RandomForestClassifier.new(n_estimators: 10, max_depth: 3)
model.fit(x, y)
# предсказание на новых данных
predicted = model.predict(x)
puts "Предсказанные значения: #{predicted.to_a}"
Gradient Boosting:
require 'daru'
require 'rumale'
require 'rdatasets'
# загрузка набора данных Iris
iris = RDatasets.load(:datasets, :iris)
x = iris[0..3].to_matrix
y = iris['Species'].map { |species| species == 'setosa' ? 0 : 1 }
# преобразование данных в формат Numo::NArray
x = Numo::DFloat[*x.to_a]
y = Numo::Int32[*y]
# построение и обучение модели Gradient Boosting
model = Rumale::Ensemble::GradientBoostingClassifier.new(n_estimators: 100, learning_rate: 0.1, max_depth: 3)
model.fit(x, y)
# предсказание на новых данных
predicted = model.predict(x)
puts "Предсказанные значения: #{predicted.to_a}"
Оценка и валидация моделей
Метрики оценки качества моделей
Среднеквадратичная ошибка (MSE): измеряет среднее значение квадратов ошибок, т.е разницу между предсказанными и фактическими значениями:
require 'numo/narray'
require 'rumale'
# пример данных
y_true = Numo::DFloat[3.0, -0.5, 2.0, 7.0]
y_pred = Numo::DFloat[2.5, 0.0, 2.0, 8.0]
# расчет MSE
mse = Rumale::EvaluationMeasure::MeanSquaredError.new
mse_value = mse.score(y_true, y_pred)
puts "MSE: #{mse_value}"
Коэффициент детерминации (R²): измеряет долю дисперсии, объясненную моделью. Значение R² варьируется от 0 до 1, где 1 означает идеальное соответствие:
# расчет R²
r2 = Rumale::EvaluationMeasure::RSquared.new
r2_value = r2.score(y_true, y_pred)
puts "R²: #{r2_value}"
Кросс-валидации
Кросс-валидация позволяет оценить обобщающую способность модели. Одним из самых частых методов — K-Fold кросс-валидация.
K-Fold кросс-валидация:
require 'rumale'
require 'daru'
require 'rdatasets'
# загрузка данных Iris
iris = RDatasets.load(:datasets, :iris)
x = iris[0..3].to_matrix
y = iris['Species'].map { |species| species == 'setosa' ? 0 : 1 }
x = Numo::DFloat[*x.to_a]
y = Numo::Int32[*y]
# определение модели
model = Rumale::LinearModel::LogisticRegression.new
# определение метрики оценки
mse = Rumale::EvaluationMeasure::MeanSquaredError.new
# настройка K-Fold кросс-валидации
kf = Rumale::ModelSelection::KFold.new(n_splits: 5, shuffle: true, random_seed: 1)
# проведение кросс-валидации
cv = Rumale::ModelSelection::CrossValidation.new(estimator: model, splitter: kf, evaluator: mse)
report = cv.perform(x, y)
# вывод результатов
mean_score = report[:test_score].sum / kf.n_splits
puts "5-CV MSE: #{mean_score}"
После выполнения кросс-валидации или других методов оценки, очень важно не забывать о том, что нужно еще и правильно интерпретировать полученные результаты.
Среднее значение и стандартное отклонение: эти показатели дают представление о стабильности и надежности модели. Например, низкое ср. значение ошибки и низкое стандартное отклонение указывают на стабильную и точную модель:
mean_score = report[:test_score].mean
std_score = report[:test_score].std
puts "Mean MSE: #{mean_score}, Standard Deviation: #{std_score}"
Можно еще подключить gnuplot, чтобы визуализировать и помогает понять производительность модельки на различных наборах данных:
require 'gnuplot'
Gnuplot.open do |gp|
Gnuplot::Plot.new(gp) do |plot|
plot.title "K-Fold Cross Validation Scores"
plot.ylabel "MSE"
plot.xlabel "Fold"
plot.data << Gnuplot::DataSet.new(report[:test_score]) do |ds|
ds.with = "linespoints"
ds.title = "Fold MSE"
end
end
end
Подробнее с этой замечательной библиотекой можно ознакомиться здесь.
А с другими инструментами и библиотеками вы всегда можете познакомиться в рамках практических онлайн-курсов от моих коллег из OTUS.