Как классифицировать данные без разметки
Пользователи iFunny ежедневно загружают в приложение около 100 000 единиц контента, среди которого не только мемы, но и расизм, насилие, порнография и другие недопустимые вещи.
Раньше мы отсматривали это вручную, а сейчас разрабатываем автоматическую модерацию на основе свёрточных нейросетей. Систему уже обучили на разделение контента по трём классам: она распознает, что пропустить в ленты пользователей, что удалить, а что скрыть из общей ленты. Чтобы сделать алгоритмы точнее, решили добавить конкретизацию причины удаления контента, у которого до этого не было подобной разметки.
Как мы это в итоге сделали — расскажу под катом на наглядном примере. Статья рассчитана на тех, кто знаком с Python (при этом необязательно разбираться в Data Science и Machine Learning).
Классификация без разметки
Задача: сделать классификацию объектов.
Дано: множество данных без разметки и каких-либо подробностей.
Решение:
Для начала загрузим данные и проведём их первичный анализ:
from sklearn.datasets import load_digits
dataset = load_digits()
dataset['data'].shape
У нас в наличии датасет размера (1797, 64). Это сравнительно небольшой набор данных — он меньше 2000, но и этого может быть достаточно, если выборка репрезентативная (отражает особенности всего исследуемого множества). При этом у каждого объекта 64 признака — если они все бинарные (принимают значение 0 и 1), то нам потребуется 2^64 примеров, чтобы покрыть все возможные варианты. Для признаков, которые принимают 3 и больше значений, размер всеобъемлющей выборки будет ещё больше. На практике лишь небольшое число признаков несёт основную информацию об объекте и принимает гораздо меньше значений из допустимого множества.
Для начала выведем несколько строк из набора на экран:
dataset.data[10:15]
Полезно бывает смотреть на сырые данные без дополнительных агрегаций информации. Например, сейчас видно, что массив сохранен в формате float, но не видно ни одного элемента с числом после точки, будто бы все они целочисленные.
Перед работой с любыми данными стоит смотреть на статистику по разным признакам (столбцам). Взглянем на несколько случайных столбцов — возьмем с 30-го по 35-й и выведем статистику с помощью библиотеки pandas.
Метод describe позволяет посмотреть набор самых часто используемых статистик из таблицы ниже. Значения признаков группируются около нуля, на что указывают их средний показатель. Также есть признаки с нулевым значением у всех объектов выборки, значит они неинформативны и их можно не использовать при дальнейшем анализе.
dataset_df = pd.DataFrame(dataset.data[:, 30:35])
dataset_df.describe()
Есть большое количество методов для анализа данных, многие из которых связаны с графическим отображением. Один из любимых способов Data Science инженеров — график попарных корреляций. Он позволяет обнаружить зависимость между признаками, которая может вести к уменьшению признакового пространства. Также с его помощью можно обнаружить корреляцию между признаком и таргетом (искомой величиной), но у нас нет разметки, поэтому данный сценарий нереализуем.
import seaborn as sns
sns.pairplot(dataset_df);
В нашем случае видно лишь то, что все признаки принимают целочисленные значения. Отсутствие парных корреляций не исключает наличия зависимости между большим числом признаков одновременно. Но увидеть такие особенности данных невозможно — у нас 64-мерное признаковое пространство. Даже если в нём есть области, где объекты группируются, то обнаружить это каким-либо графическим методом будет крайне сложно (а может и совсем невозможно).
В такой ситуации нужно уменьшить размерность пространства признаков, отобразив его в двух- или трёхмерном, с которыми наше сознание в состоянии справиться.
Понижение размерности
Для начала избавимся от константных признаков. Выше мы отметили наличие признаков со значением 0 у всех объектов, поэтому спокойно их удаляем во всей выборке. Наша цель — разделить объекты, а значит, основная информация заключается в отличии их друг от друга.
Есть множество способов уменьшить размерность признакового пространства, сохранив его информативность. Для статьи возьмём алгоритм UMap, так как уже используем его в своих задачах. Одно из его преимуществ перед другими алгоритмами нелинейного снижения размерности — возможность обучать на одном наборе данных, а затем использовать его в дальнейшем на новых данных, применяя одно и то же преобразование.
Используем уже готовую библиотеку.Здесь самый важный параметр — количество компонент, которое нужно получить на выходе (до какой размерности сжать текущее пространство признаков). Выбираем два, потому что 2D-плоскость можно наглядно отобразить на рисунке:
import umap
reducer = umap.UMAP(n_components=2, random_state=47)
Делаем обучение командой fit. Данных не так много, поэтому обучаем на всём наборе, но, как было сказано ранее, он может быть меньше итогового:
reducer.fit(dataset.data)
Далее преобразуем все данные:
embeddings = reducer.transform(dataset.data)
И на выходе получаем уменьшенную размерность — количество образцов то же самое, но признаков всего два: (1797, 2).
Коротко о том, как это работает: UMap строит взвешенный граф, соединяя ребрами ближайших соседей в n-мерном пространстве, затем создает другой граф в низкоразмерном пространстве и приближает его к исходному так, чтобы сохранить относительное положение объектов. То есть близкие объекты оставляет ближе, дальние — дальше, но уже в уменьшенной размерности.
Построим график полученных 2D-векторов :
plt.scatter(embeddings[:, 0], embeddings[:, 1], s=5)
На графике видно 10 больших групп точек и ещё несколько поменьше. Проведём кластеризацию — разобьём на области, основываясь на каком-либо параметре или правиле.
Кластеризация
Воспользуемся алгоритмом k-средних (KMeans), который основывается на минимизации суммарного квадратичного отклонения точек кластеров от центров этих кластеров.
Задаем поиск 10 кластеров (на предыдущем графике видно 10), делаем обучение и предсказание итоговых классов:
clustering = KMeans(n_clusters=10)
classes = clustering.fit_predict(embeddings)
Раскрасим картинку с кластерами. Алгоритм очень хорошо их разделил:
plt.scatter(embeddings[:, 0], embeddings[:, 1], c=classes, cmap='Spectral', s=5)
plt.colorbar(boundaries=np.arange(11)-0.5).set_ticks(np.arange(10))
Полученные порядковые номера кластеров можно считать классами неразмеченной выборки. Для классификации новых данных нужно последовательно применить к ним уже обученные алгоритмы UMap и KMeans и получить номер кластера для этих объектов.
А теперь открою небольшую тайну — это были не просто данные.
Исходные данные
В тренировочном примере данные являются картинками 8×8 пикселей с рукописными числами. Если значения интенсивности всех пикселей слева-направо и сверху-вниз выложить в одну строку, то получится вектор длины 64 — именно тот, с которым работали до этого. Интенсивность в пикселе записана в формате uint8 и принимает только целочисленные значения от 0 до 255, а значит наши наблюдения в самом начале были верны.
Всего в датасете представлены цифры от 0 до 9, то есть как раз 10 классов (столько же кластеров нам удалось выделить):
Теперь у нас есть истинный класс и известно, какой строке соответствует какая метка. Если изобразить истинное распределение классов в пространстве меньшей размерности с помощью найденного преобразования, то получится следующее:
Точность классификации
На картинке выше видно, что в большинстве случаев отличаются только цвета, которые отвечают за номер кластера. Это связано с тем, что метод k-средних расставлял метки случайным образом, не вкладывая в 0-й класс смысл наличия нулей в его изображениях. Если поменять нумерацию, то станет видно, какое число примеров было выделено правильно.
Есть много метрик, которые одним числом указывают, насколько способ хорош. Самая известная — точность (accuracy), которая является отношением верных ответов ко всем примерам в тестовом наборе. У такого подхода есть большой недостаток — он не говорит, в чем именно ошибка. Использование этой и других интегральных метрик будет особенно неудобным в случае многоклассовой классификации, где по одному числу непонятно, какие классы путаются между собой.
Именно в такой ситуации мы сейчас находимся, поэтому стоит обратиться к матрице ошибок. Для ее построения используем библиотеку pycm:
from matplotlib.pyplot import cm
from pycm import ConfusionMatrix
y_true = dataset.target
conf_matrix = ConfusionMatrix(actual_vector=y_true, predict_vector=y_pred)
conf_matrix.plot(cmap=cm.Greens, number_label=True);
В этом коде y_pred — перенумерованные значения кластеров, найденных нами ранее. В качестве нового значения использовался наиболее встречающийся в нём истинный класс. Ниже показана полученная матрица ошибок:
По горизонтали — классы, предсказанные нашим методом.
По вертикали — истинные классы.
В клетках пересечения — количество объектов, удовлетворяющих двум условиям.
27 образцов из истинного класса единиц почему-то определились как шестерки. Разберёмся, почему так вышло и посмотрим на картинки из датасета.
Единицы, классифицированные как шестерки
На первый взгляд эти объекты не выглядят, как шестерки. Вернёмся к истинной разметке классов и увидим, что есть небольшая группа единиц, которая очень далеко от остальных и находится как раз ближе к шестеркам.
А реальные шестёрки и правда порой похожи на единицы, поэтому тут вопросы не к нашей модели, а к тому, кто так пишет:
Шестерки, классифицированные верно
Вместо заключения
Похожим образом мы работаем над тем, чтобы спорный контент вроде пейзажей, оружия и девушек в купальниках попадался не всем пользователям, а только тем, кто не против. Только вместо значений пикселей, как это было в нашем примере, берутся определённые паттерны.
Эти паттерны выделяет нейросеть, предобученная на большом наборе данных. Но для основной задачи удаления нежелательного контента в своём исходном состоянии она нам не подходит, потому что не знает три наших класса:
approved — картинки идут в раздел приложения collective;
not suitable — не попадают в общую ленту, но остаются в ленте пользователя (девушки в купальниках и мужчины в плавках, селфи и всё, что не является мемами);
risked — такой контент получает бан и перестает быть доступным для всех пользователей iFunny (расизм, порнография, расчленёнка и всё, что попадает под определение «противоправный контент»).
Нам предстояло дообучить сеть на эти классы. Но об этом подробно поговорим уже в следующей статье.