Глубокое обучение на Kotlin: вышла альфа-версия KotlinDL

5zn88xlaujmrgteumx4i2fe9zno.png?v=1

Всем привет!

На днях мы выпустили первую альфа-версию KotlinDL, фреймворка для глубокого обучения нейросетей, API которого мы старались сделать максимально похожим на Keras (фреймворк на Python поверх TensorFlow).
В KotlinDL вы найдете простые API как для описания, так и для тренировки нейронных сетей. За счет высокоуровневого API и аккуратно подобранных значений по умолчанию для множества параметров мы надеемся снизить порог входа в глубокое обучение на JVM. Вот так, например, выглядит тренировка и сохранение простой нейросети, написанной при помощи KotlinDL:

private val model = Sequential.of(
   Input(28,28,1),
   Flatten(),
   Dense(300),
   Dense(100),
   Dense(10)
)

fun main() {
   val (train, test) = Dataset.createTrainAndTestDatasets(
       trainFeaturesPath = "datasets/mnist/train-images-idx3-ubyte.gz",
       trainLabelsPath = "datasets/mnist/train-labels-idx1-ubyte.gz",
       testFeaturesPath = "datasets/mnist/t10k-images-idx3-ubyte.gz",
       testLabelsPath = "datasets/mnist/t10k-labels-idx1-ubyte.gz",
       numClasses = 10,
       ::extractImages,
       ::extractLabels
   )
   val (newTrain, validation) = train.split(splitRatio = 0.95)

   model.use{
       it.compile(optimizer = Adam(),
               loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
               metric = Metrics.ACCURACY)

       it.summary()

       it.fit(dataset = newTrain,
               epochs = 10,
               batchSize = 100,
               verbose = false)

       val accuracy = it.evaluate(dataset = validation,
               batchSize = 100).metrics[Metrics.ACCURACY]

       println("Accuracy: $accuracy")
       it.save(File("src/model/my_model"))
   }
}


Поддержка GPU

Тренировка моделей на центральном процессоре может занимать значительное время. Распространенной практикой является запуск вычислений на GPU. Для этого вам понадобится установленная CUDA от NVIDIA. Для запуска тренировки модели на GPU достаточно добавить всего одну зависимость.


Что вошло в API

В этой ранней версии вы найдете все необходимые методы для описания многослойных перцептронов и сверточных сетей. Для большинства гиперпараметров проставлены разумные значения по умолчанию, но в то же время у вас есть широкий выбор оптимизаторов, инициализаторов, функций активации и прочих настроек. Полученную в процессе тренировки модель можно сохранить и использовать в backend-приложении, написанном на Kotlin или Java.


Загрузка моделей, тренированных на Keras

KotlinDL не только умеет загружать модели, тренированные этим же фреймворком, но и предоставляет возможность загрузить и использовать модель, натренированную с помощью Keras на языке Python (поддерживаются версии Keras 2.*).

При загрузке моделей вы также можете использовать технику Transfer Learning, которая позволяет вам не тренировать огромные нейронные сети с нуля, а воспользоваться уже имеющейся моделью и просто подогнать ее под вашу задачу.


Текущие ограничения

В этой, самой ранней версии доступно ограниченное количество слоев: Input (), Flatten (), Dense (), Dropout (), Conv2D (), MaxPool2D () и AvgPool2D ().
Это ограничение распространяется и на то, какие Keras модели можно загружать. Это означает, что архитектуры VGG-16 и VGG-19 поддерживаются уже сейчас, а, например, ResNet50 пока что не поддерживается. В ближайшие месяцы мы планируем выпустить следующую минорную версию, в которой увеличится количество поддерживаемых архитектур.

Второе временное ограничение заключается в том, что поддержка Android-устройств пока не вошла в данную версию. Но над этим мы тоже будем работать.


А что под капотом?

В качестве движка KotlinDL использует TensorFlow Java API. Все вычисления выполняются в TensorFlow, в нативной памяти, причем во время тренировки все данные остаются в нативе.


Попробуйте и поделитесь впечатлениями!

В документации к проекту вы найдете статьи (на английском), которые, надеемся, помогут вам попробовать наш фреймворк:

Будем рады вашим замечаниям, пожеланиям, баг-репортам и другим отзывам в GitHub Issues. Особенно приветствуются пул-реквесты. Присоединяйтесь к каналу #deeplearning в Kotlin Slack.

© Habrahabr.ru