Умножение троичных матриц для нейросетей
В статье «Как исследователи нарушают привычные подходы в ИИ, исключая матричное умножение» упоминалось, в частности, что перспективным кажется хранение в нейросетевых матрицах лишь троичных значений: (-1, 0, 1), иначе говоря — тритов. Такие матрицы умножать друг на друга проще. И в моей статье я расскажу, как собственно, матрицы из тритов хранить и умножать.
Как известно, при умножении матриц, мы строку левой матрицы умножаем на столбец правой, и результат записываем в соответствующую ячейку результирующей матрицы. Чтобы было быстрее, мы правую матрицу предварительно транспонируем: тогда строку левой будем умножать на строку правой. Иначе говоря, мы скалярно перемножаем два вектора из тритов, причём оба занимают непрерывную область памяти.
Предположим, что процессор ориентирован на 32-битную арифметику. Тогда разобьём строку матрицы на векторы по 32 трита. Каждый из этих векторов будем хранить в виде двух 32-битных целых чисел, назовём их «плюс-вектором» и «минус-вектором». Трит с номером N равен разности бита с номером N плюс-вектора, и бита с номером N минус-вектора. При этом нулевое значение трита кодируется двумя способами: когда оба бита равны 0, или когда они оба равны 1.
трит | -1 | 0 | 0 | 1 |
плюс-вектор | 0 | 0 | 1 | 1 |
минус-вектор | 1 | 0 | 1 | 0 |
Такой способ хранения позволяет быстро подсчитать сумму всех тритов вектора. Для этого мы из суммы битов плюс-вектора вычитаем сумму битов минус-вектора. Различные алгоритмы нахождения суммы битов описаны в статье «Обстоятельно о подсчёте единичных битов», к тому же процессор может поддерживать инструкцию POPCNT, которая эту сумму подсчитывает.
Например, можно использовать следующий алгоритм из вышеупомянутой статьи:
// Количество единичных битов
unsigned __int32 popcnt(unsigned __int32 value)
{
// Суммируем чётные и нечётные биты
// 00 01 10 11 value
// 0 0 1 1 (value >> 1) & 0x55555555
// -- -- -- -- --
// 00 01 01 10 =
value -= (value >> 1) & 0x55555555;
// Повторяем уже для пар битов
value = ((value >> 2) & 0x33333333) + (value & 0x33333333);
// Умножение на 0x01010101 эквивалентно сумированию значений 4 байт числа,
// при условии, что в младших байтах не будет переполнения
// (а его в нашем случае не будет, так как там содержатся суммы битов).
// Результат сложения будет в старшем байте произведения.
return ((((value >> 4) + value) & 0x0F0F0F0F) * 0x01010101) >> 24;
}
Остаётся разобраться с потритовым умножением векторов. Для наглядности, будем считать трит структурой из двух битовых полей с именами p и m, причём p содержит значение из плюс-вектора, а m — из минус-вектора. Назовём «r» результат умножения тритов a и b. Тогда:
r.p = (a.p | b.m) & (a.m | b.p)
r.m = (a.p | b.p) & (a.m | b.m)
Итак, мы сначала по этим формулам потритово умножаем два троичных вектора, затем у результата этого умножения находим суммы битов плюс-вектора и минус-вектора, затем находим разность этих сумм. Это и будет результатом скалярного произведения двух троичных векторов друг на друга.
Проиллюстрируем этот алгоритм программой на языке C++. В этой программе мы создадим две троичные матрицы размерами 32×32, заполним их случайными значениями, и перемножим друг на друга двумя способами: классическим алгоритмом, и оптимизированным для троичных вычислений.
#include
#include
// Неупакованные троичные матрицы,
// которые будем перемножать.
// Каждый элемент принимает значения (-1, 0, 1)
int A[32][32];
int B[32][32];
// Результат умножения неупакованных матриц
int C[32][32];
typedef unsigned __int32 u32;
// Вектор из 32 тритов
struct TritVector32
{
u32 p; // плюс-вектор
u32 m; // минус-вектор
TritVector32() {p = 0; m = 0;}
// Получить значение трита с указанным номером
int getTrit(int index)
{
int mask = 1 << index;
return ((p & mask) != 0) - ((m & mask) != 0);
}
void setTrit(int index, int trit)
{
int mask = 1 << index;
p |= mask;
p ^= mask & ((trit - 1) >> 1); // p ^= mask & -(trit <= 0);
m |= mask;
m ^= mask & ((~trit) >> 1); // p ^= mask & -(trit >= 0);
}
};
// Упакованные троичные матрицы,
// содержащие те же значения, что и неупакованные A и B
TritVector32 A3[32];
TritVector32 B3t[32]; // эта матрица транспонирована
// Результат умножения упакованных матриц
int C3[32][32];
// Возвращает случайное значение (-1, 0, 1)
int rand3()
{
return (int)((u32) rand() * 3 / (RAND_MAX + 1)) - 1;
}
// Количество единичных битов
int popcnt(u32 value)
{
// Для наглядности используется простой, но медленный алгоритм.
// Существуют гораздо более быстрые.
int result = 0;
while(value)
{
result += value & 1;
value >>= 1;
}
return result;
}
int main()
{
// Заполняем матрицы сомножителей случайными значениями
for(int i = 0; i < 32; i++)
for(int j = 0; j < 32; j++)
{
A[i][j] = rand3();
B[i][j] = rand3();
}
// Умножаем матрицы классическим способом
for(int i = 0; i < 32; i++)
for(int j = 0; j < 32; j++)
{
int c = 0;
for(int k = 0; k < 32; k++)
c += A[i][k] * B[k][j];
C[i][j] = c;
}
// Заполняем упакованные матрицы
for(int i = 0; i < 32; i++)
for(int j = 0; j < 32; j++)
{
A3[i].setTrit(j, A[i][j]);
B3t[j].setTrit(i, B[i][j]);
}
// Умножаем оптимизированным способом
for(int i = 0; i < 32; i++)
{
TritVector32 *a = A3 + i;
for(int j = 0; j < 32; j++)
{
TritVector32 *b = B3t + j;
TritVector32 r;
// Потритовое умножение
r.p = (a->p | b->m) & (a->m | b->p);
r.m = (a->p | b->p) & (a->m | b->m);
C3[i][j] = popcnt(r.p) - popcnt(r.m);
}
}
// Выводим результаты
FILE *fp = fopen("classic.txt", "wt");
fputs("A =\n", fp);
for(int i = 0; i < 32; i++)
{
for(int j = 0; j < 32; j++)
fprintf(fp, "%2d ", A[i][j]);
fputs("\n", fp);
}
fputs("\nB =\n", fp);
for(int i = 0; i < 32; i++)
{
for(int j = 0; j < 32; j++)
fprintf(fp, "%2d ", B[i][j]);
fputs("\n", fp);
}
fputs("\nC =\n", fp);
for(int i = 0; i < 32; i++)
{
for(int j = 0; j < 32; j++)
fprintf(fp, "%3d ", C[i][j]);
fputs("\n", fp);
}
fclose(fp);
fp = fopen("trit.txt", "wt");
fputs("A =\n", fp);
for(int i = 0; i < 32; i++)
{
for(int j = 0; j < 32; j++)
fprintf(fp, "%2d ", A3[i].getTrit(j));
fputs("\n", fp);
}
fputs("\nB =\n", fp);
for(int i = 0; i < 32; i++)
{
for(int j = 0; j < 32; j++)
fprintf(fp, "%2d ", B3t[j].getTrit(i));
fputs("\n", fp);
}
fputs("\nC =\n", fp);
for(int i = 0; i < 32; i++)
{
for(int j = 0; j < 32; j++)
fprintf(fp, "%3d ", C3[i][j]);
fputs("\n", fp);
}
fclose(fp);
return 0;
}
Сравнение файлов результатов показывает, что они одинаковы. Всё работает.
Теперь возникает вопрос:, а что, если у нас больше двух матриц? Ведь матрица произведения двух троичных матриц содержит числа, вообще говоря, не троичные. Как её умножить на троичную матрицу?
Но и здесь можно обойтись без умножения.
Пусть a — обычное целое число, которое может быть отрицательным. Пусть b — трит, имеющий два поля p и m, соответствующие плюс- и минус-векторам. Тогда произведение a на b можно записать так:
a•b = a•(b.p — b.m) = (a & -b.p) — (a & -b.m) =
= ((a — b.m) ^ -b.m) & -(b.m ^ b.p) =
= ((a — b.m) ^ -b.m) & ((-b.m) ^ -b.p) =
= ((a ^ -b.m) + b.m) & ((-b.m) ^ -b.p)
Можно выбрать любую из этих формул, или им аналогичных.