pgvector или как хранить и обрабатывать фичи в базе данных
На Хабре было много упоминаний pgvector в обзорах Postgresso. И каждый раз новость была про место которое где-то за границей и далеко. Многие коммерческие решения для хранения и поиска векторов в базе данных нынче не доступны, а pgvector доступен любому, тем более в самой популярной базе в России.
В этой статье покажу на практическом примере как хранить, можно кластеризрвать вектора.
Прежде всего надо установить pgvector в PostgreSQL, он доступен в виде расширения. Поскольку я работаю с базой данных из Docker, то могу просто добавить в Dockerfile строчки и пересобрать образ:
RUN git clone --branch v0.5.1 https://github.com/pgvector/pgvector.git
RUN cd pgvector && make && make install
А в самой базе данных, нужно загрузить расширение:
osmworld=# CREATE EXTENSION vector;
CREATE EXTENSION
Time: 32,606 ms
Данные для векторов можно получить, например, из модели машинного обучения в python скрипте или ML модели в spark и вставить в таблицу с колонкой типа vector. А можно создать в SQL как гистограмму определенных категорий. В этом случае можно значения в массивах float[], integer[] или double precision[], numeric[] привести к типу :: vector
Данными для примера послужат гистограммы числа объектов детской инфраструктуры в окрестностях жилых домов в Москве. Про то как рассчитать эти данные я рассказывал здесь раньше, но в этой публикации я просто возьму готовые данные и создам из них таблицу с колонкой типа одинадцатимерный vector:
create table infrastructure_for_children_features2 as
select (row_number() over ())::integer id, null::integer cluster,
district, street, housenumber,
ARRAY[kindergarten::integer, school::integer,college::integer, university::integer, language_school::integer, music_school::integer,training::integer,sports_centre::integer,community_centre::integer,playground::integer,clinic::integer]
::vector(11) feature
from infrastructure_for_children;
Так в базе создал таблицу на 30237 записей со структурой:
osmworld=# \d infrastructure_for_children_features2
Table "public.infrastructure_for_children_features2"
Column | Type |
-------------+------------|
id | integer |
cluster | integer |
district | text |
street | text |
housenumber | text |
feature | vector(11) |
Теперь хотелось бы объединить их в группы по близости векторов. Опять же можно использовать нейросети, а можно использовать классические алгоритмы кластеризации — метод k-средних (k-means) или основанную на плотности пространственную кластеризацию для приложений с шумами (DBSCAN). Для метрики близости использую Евклидово расстояние. Поскольку число кластеров мне не известно, то я выберу DBSCAN и прогоню этот крошечный набор данных через него чтобы посмотреть зависимость от epsilon числа групп и число элементов не попавших в группы:
eps|clusters|not_in_cluster
0.0 75 29667
0.5 75 29667
1.0 202 28648
1.5 475 26904
2.0 928 22630
2.5 1227 17620
3.0 1173 11778
3.5 856 7605
4.0 601 4562
4.5 364 2760
5.0 232 1574
5.5 138 972
6.0 77 604
6.5 51 377
7.0 29 265
7.5 14 168
8.0 10 96
8.5 4 54
9.0 4 37
9.5 2 30
На свой субъективный взгляд выберу eps=5.5 и запущу Java программу, которая заполнит колонку cluster значениями алгоритма DBSCAN для minPoints=3 и eps=5.5:
package com.github.isuhorukov;
import com.pgvector.PGvector;
import org.apache.commons.math3.ml.clustering.Cluster;
import org.apache.commons.math3.ml.clustering.Clusterable;
import org.apache.commons.math3.ml.clustering.DBSCANClusterer;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
public class Main {
public static void main(String[] args) throws Exception {
try (Connection connection = DriverManager.getConnection(
System.getenv("jdbc_url"), System.getenv("user"), System.getenv("password"))) {
connection.setAutoCommit(false);
PGvector.addVectorType(connection);
float eps = Float.parseFloat(System.getenv("eps"));
int minPoints = Integer.parseInt(System.getenv("minPoints"));
DBSCANClusterer dbscanClusterer = new DBSCANClusterer<>(eps,minPoints,new EuclideanDistance());
List features = fetchFeatures(connection,
"select id,feature from infrastructure_for_children_features");
List> cluster = dbscanClusterer.cluster(features);
saveClusters(connection, cluster);
}
}
private static void saveClusters(Connection connection, List> cluster) throws SQLException {
try (PreparedStatement clusterPs = connection.prepareStatement(
"update infrastructure_for_children_features set cluster = ? where id = ?")){
for (int idx = 0; idx < cluster.size(); idx++) {
List featureCluster = cluster.get(idx).getPoints();
for (Feature feature : featureCluster) {
clusterPs.setInt(1, idx);
clusterPs.setInt(2, feature.id);
clusterPs.addBatch();
}
clusterPs.executeBatch();
}
connection.commit();
} catch (Exception e) {
connection.rollback();
throw new RuntimeException(e);
}
}
private static List fetchFeatures(Connection connection, String query) {
List features = new ArrayList<>();
try (Statement statement = connection.createStatement();
ResultSet resultSet = statement.executeQuery(query))
{
while (resultSet.next()) {
int id = resultSet.getInt(1);
float[] feature = ((PGvector) resultSet.getObject(2)).toArray();
features.add(new Feature(id, feature));
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return features;
}
static class Feature implements Clusterable {
public int id;
public double[] feature;
public Feature(int id, float[] feature) {
this.id = id;
this.feature = new double[feature.length];
for (int i = 0; i < feature.length; i++) {
this.feature[i] = feature[i];
}
}
@Override
public double[] getPoint() {
return feature;
}
}
}
Для компиляции которого нужен pom.xml для maven:
4.0.0
com.github.igor-suhorukov
vectors
1.0-SNAPSHOT
11
11
UTF-8
org.apache.commons
commons-math3
3.6.1
com.pgvector
pgvector
0.1.3
org.postgresql
postgresql
42.6.0
Чем же похожи эти районы, еще стоит выяснить или попробовать другие epsilon
Вывод
C расширение pgvector PostgreSQL оказалось простым в использовании и с ним можно работать не только алгоритмами машинного обучения, но и классическими алгоритмами кластеризации из Java программы.