Умножение троичных матриц для нейросетей

cd04a27800ddf42767f433076dd2ce82

В статье «Как исследователи нарушают привычные подходы в ИИ, исключая матричное умножение» упоминалось, в частности, что перспективным кажется хранение в нейросетевых матрицах лишь троичных значений: (-1, 0, 1), иначе говоря — тритов. Такие матрицы умножать друг на друга проще. И в моей статье я расскажу, как собственно, матрицы из тритов хранить и умножать.

Как известно, при умножении матриц, мы строку левой матрицы умножаем на столбец правой, и результат записываем в соответствующую ячейку результирующей матрицы. Чтобы было быстрее, мы правую матрицу предварительно транспонируем: тогда строку левой будем умножать на строку правой. Иначе говоря, мы скалярно перемножаем два вектора из тритов, причём оба занимают непрерывную область памяти.

Предположим, что процессор ориентирован на 32-битную арифметику. Тогда разобьём строку матрицы на векторы по 32 трита. Каждый из этих векторов будем хранить в виде двух 32-битных целых чисел, назовём их «плюс-вектором» и «минус-вектором». Трит с номером N равен разности бита с номером N плюс-вектора, и бита с номером N минус-вектора. При этом нулевое значение трита кодируется двумя способами: когда оба бита равны 0, или когда они оба равны 1.

трит

-1

0

0

1

плюс-вектор

0

0

1

1

минус-вектор

1

0

1

0

Такой способ хранения позволяет быстро подсчитать сумму всех тритов вектора. Для этого мы из суммы битов плюс-вектора вычитаем сумму битов минус-вектора. Различные алгоритмы нахождения суммы битов описаны в статье «Обстоятельно о подсчёте единичных битов», к тому же процессор может поддерживать инструкцию POPCNT, которая эту сумму подсчитывает.
Например, можно использовать следующий алгоритм из вышеупомянутой статьи:

// Количество единичных битов
unsigned __int32 popcnt(unsigned __int32 value)
{
    // Суммируем чётные и нечётные биты
    //  00  01  10  11  value
    //   0   0   1   1  (value >> 1) & 0x55555555
    //  --  --  --  --  --
    //  00  01  01  10  =
    value -= (value >> 1) & 0x55555555;

    // Повторяем уже для пар битов
    value = ((value >> 2) & 0x33333333) + (value & 0x33333333);

    // Умножение на 0x01010101 эквивалентно сумированию значений 4 байт числа,
    // при условии, что в младших байтах не будет переполнения
    // (а его в нашем случае не будет, так как там содержатся суммы битов).
    // Результат сложения будет в старшем байте произведения.
    return ((((value >> 4) + value) & 0x0F0F0F0F) * 0x01010101) >> 24;
}

Остаётся разобраться с потритовым умножением векторов. Для наглядности, будем считать трит структурой из двух битовых полей с именами p и m, причём p содержит значение из плюс-вектора, а m — из минус-вектора. Назовём «r» результат умножения тритов a и b. Тогда:
r.p = (a.p | b.m) & (a.m | b.p)
r.m = (a.p | b.p) & (a.m | b.m)
Итак, мы сначала по этим формулам потритово умножаем два троичных вектора, затем у результата этого умножения находим суммы битов плюс-вектора и минус-вектора, затем находим разность этих сумм. Это и будет результатом скалярного произведения двух троичных векторов друг на друга.

Проиллюстрируем этот алгоритм программой на языке C++. В этой программе мы создадим две троичные матрицы размерами 32×32, заполним их случайными значениями, и перемножим друг на друга двумя способами: классическим алгоритмом, и оптимизированным для троичных вычислений.

#include 
#include 

// Неупакованные троичные матрицы,
// которые будем перемножать.
// Каждый элемент принимает значения (-1, 0, 1)
int A[32][32];
int B[32][32];

// Результат умножения неупакованных матриц
int C[32][32];

typedef unsigned __int32 u32;

// Вектор из 32 тритов
struct TritVector32
{
    u32 p; // плюс-вектор
    u32 m; // минус-вектор

    TritVector32() {p = 0; m = 0;}

    // Получить значение трита с указанным номером
    int getTrit(int index)
    {
        int mask = 1 << index;
        return ((p & mask) != 0) - ((m & mask) != 0);
    }
    void setTrit(int index, int trit)
    {
        int mask = 1 << index;
        p |= mask;
        p ^= mask & ((trit - 1) >> 1); // p ^= mask & -(trit <= 0);
        m |= mask;
        m ^= mask & ((~trit) >> 1);    // p ^= mask & -(trit >= 0);
    }
};

// Упакованные троичные матрицы,
// содержащие те же значения, что и неупакованные A и B
TritVector32 A3[32];
TritVector32 B3t[32]; // эта матрица транспонирована

// Результат умножения упакованных матриц
int C3[32][32];

// Возвращает случайное значение (-1, 0, 1)
int rand3()
{
    return (int)((u32) rand() * 3 / (RAND_MAX + 1)) - 1;
}

// Количество единичных битов
int popcnt(u32 value)
{
    // Для наглядности используется простой, но медленный алгоритм.
    // Существуют гораздо более быстрые.
    int result = 0;
    while(value)
    {
        result += value & 1;
        value >>= 1;
    }
    return result;
}

int main()
{
    // Заполняем матрицы сомножителей случайными значениями
    for(int i = 0; i < 32; i++)
        for(int j = 0; j < 32; j++)
        {
            A[i][j] = rand3();
            B[i][j] = rand3();
        }
    // Умножаем матрицы классическим способом
    for(int i = 0; i < 32; i++)
        for(int j = 0; j < 32; j++)
        {
            int c = 0;
            for(int k = 0; k < 32; k++)
                c += A[i][k] * B[k][j];
            C[i][j] = c;
        }

    // Заполняем упакованные матрицы
    for(int i = 0; i < 32; i++)
        for(int j = 0; j < 32; j++)
        {
            A3[i].setTrit(j, A[i][j]);
            B3t[j].setTrit(i, B[i][j]);
        }
    // Умножаем оптимизированным способом
    for(int i = 0; i < 32; i++)
    {
        TritVector32 *a = A3 + i;
        for(int j = 0; j < 32; j++)
        {
            TritVector32 *b = B3t + j;
            TritVector32 r;

            // Потритовое умножение
            r.p = (a->p | b->m) & (a->m | b->p);
            r.m = (a->p | b->p) & (a->m | b->m);

            C3[i][j] = popcnt(r.p) - popcnt(r.m);
        }
    }

    // Выводим результаты
    FILE *fp = fopen("classic.txt", "wt");
    fputs("A =\n", fp);
    for(int i = 0; i < 32; i++)
    {
        for(int j = 0; j < 32; j++)
            fprintf(fp, "%2d ", A[i][j]);
        fputs("\n", fp);
    }
    fputs("\nB =\n", fp);
    for(int i = 0; i < 32; i++)
    {
        for(int j = 0; j < 32; j++)
            fprintf(fp, "%2d ", B[i][j]);
        fputs("\n", fp);
    }
    fputs("\nC =\n", fp);
    for(int i = 0; i < 32; i++)
    {
        for(int j = 0; j < 32; j++)
            fprintf(fp, "%3d ", C[i][j]);
        fputs("\n", fp);
    }
    fclose(fp);

    fp = fopen("trit.txt", "wt");
    fputs("A =\n", fp);
    for(int i = 0; i < 32; i++)
    {
        for(int j = 0; j < 32; j++)
            fprintf(fp, "%2d ", A3[i].getTrit(j));
        fputs("\n", fp);
    }
    fputs("\nB =\n", fp);
    for(int i = 0; i < 32; i++)
    {
        for(int j = 0; j < 32; j++)
            fprintf(fp, "%2d ", B3t[j].getTrit(i));
        fputs("\n", fp);
    }
    fputs("\nC =\n", fp);
    for(int i = 0; i < 32; i++)
    {
        for(int j = 0; j < 32; j++)
            fprintf(fp, "%3d ", C3[i][j]);
        fputs("\n", fp);
    }
    fclose(fp);
    return 0;
}

Сравнение файлов результатов показывает, что они одинаковы. Всё работает.

Теперь возникает вопрос:, а что, если у нас больше двух матриц? Ведь матрица произведения двух троичных матриц содержит числа, вообще говоря, не троичные. Как её умножить на троичную матрицу?
Но и здесь можно обойтись без умножения.
Пусть a — обычное целое число, которое может быть отрицательным. Пусть b — трит, имеющий два поля p и m, соответствующие плюс- и минус-векторам. Тогда произведение a на b можно записать так:
a•b = a•(b.p — b.m) = (a & -b.p) — (a & -b.m) =
= ((a — b.m) ^ -b.m) & -(b.m ^ b.p) =
= ((a — b.m) ^ -b.m) & ((-b.m) ^ -b.p) =
= ((a ^ -b.m) + b.m) & ((-b.m) ^ -b.p)
Можно выбрать любую из этих формул, или им аналогичных.

© Habrahabr.ru