[Из песочницы] Divide&Conquer над алгоритмом Штрассена

image

Привет друзья! Будучи студентами одного небезызвестного образовательного проекта, мы с bo_0m, после вводной лекции по курсу Углубленное программирование на Java, получили свое первое домашнее задание. Необходимо было реализовать программу, которая бы перемножала матрицы. И всё бы ничего, да так совпало, что на следующей неделе должна была состояться конференция Joker, и наш преподаватель решил отменить по такому случаю занятие, подарив нам несколько часов свободного пятничного вечера. Не пропадать же времени зря! Раз никто не торопит, то можно подойти к делу творчески.

Welcome, under the hood ↓

Первое, что приходит в голову
Наверно каждому студенту технического вуза приходилось перемножать матрицы. Алгоритм был всегда один, а именно, простенький кубический метод перемножения. Да и как бы это ни звучало, но данный способ не так-то уж и плох (для размерностей матриц меньше 100).

Все мы с этого начинали:

for (int i = 0; i < A.rows(); i++) {
    for (int j = 0; j < A.columns(); j++) {
	for (int k = 0; k < B.columns(); k++) {
	    C[i][j] += A[i][k] * B[k][j];
	}
    }
}

Забегая вперед, скажу, что мы будем использовать модифицированный вариант с применением транспонирования. Про такую модификацию хорошо написано здесь, да и не только про неё.

Окей, поехали дальше!

Алгоритм Штрассена
Возможно, не все знают, но автор алгоритма Фолькер Штрассен не только жив, но и активно преподает, так же являясь почетным профессором кафедры математики и статистики Констанцского университета. Обязательно почитайте про этого человека хотя бы на вики.
Немножко теории из Википедии:
Пусть A и B — две (n*n)-матрицы, причём n — степень числа 2. Тогда можно разбить каждую матрицу A и B на четыре ((n/2)*(n/2))-матрицы и через них выразить произведение матриц A и B:

image

Определим новые элементы:

image

Таким образом, нам нужно всего 7 умножений на каждом этапе рекурсии. Элементы матрицы C выражаются из Pk по формулам:

image

Рекурсивный процесс продолжается n раз, до тех пор пока размер матриц Ci, j не станет достаточно малым, далее используют обычный метод умножения матриц. Это делают из-за того, что алгоритм Штрассена теряет эффективность по сравнению с обычным на малых матрицах в силу большего числа сложений.

let’s go to practice!

Для реализации алгоритма Штрассена нам понадобятся дополнительные функции. Как было сказано выше, алгоритм работает только с квадратными матрицами, размерность которых равна степени 2, поэтому приведем исходные матрицы к такому виду.

Для этого была реализована функция, которая определяет новую размерность:

private static int log2(int x) {
    int result = 1;
    while ((x >>= 1) != 0) result++;
    return result;
}

//******************************************************************************************
   
private static int getNewDimension(int[][] a, int[][] b) {
    return 1 << log2(Collections.max(Arrays.asList(a.length, a[0].length, b[0].length)));
    // Л - Лаконично
}

И функция, которая расширяет матрицу до нужного размера:
private static int[][] addition2SquareMatrix(int[][] a, int n) {
    int[][] result = new int[n][n];

    for (int i = 0; i < a.length; i++) {
        for (int j = 0; j < a[i].length; j++) {
            result[i][j] = a[i][j];
        }
    }
    return result;
}

Теперь исходные матрицы удовлетворяют требованиям для реализации алгоритма Штрассена. Также нам понадобится функция, которая позволит разбить матрицу размером n*n на четыре матрицы (n/2)*(n/2) и обратная для восстановления матрицы:
private static void splitMatrix(int[][] a, int[][] a11, int[][] a12, int[][] a21, int[][] a22) {
    int n = a.length >> 1;

    for (int i = 0; i < n; i++) {
        System.arraycopy(a[i], 0, a11[i], 0, n);
        System.arraycopy(a[i], n, a12[i], 0, n);
        System.arraycopy(a[i + n], 0, a21[i], 0, n);
        System.arraycopy(a[i + n], n, a22[i], 0, n);
    }
}

 //******************************************************************************************

private static int[][] collectMatrix(int[][] a11, int[][] a12, int[][] a21, int[][] a22) {
    int n = a11.length;
    int[][] a = new int[n << 1][n << 1];

    for (int i = 0; i < n; i++) {
        System.arraycopy(a11[i], 0, a[i], 0, n);
        System.arraycopy(a12[i], 0, a[i], n, n);
        System.arraycopy(a22[i], 0, a[i + n], n, n);
    }
    return a;
}

Вот мы и добрались до самого интересного, основная функция перемножения матриц алгоритмом Штрассена выглядит следующим образом:
Алгоритм Штрассена
private static int[][] multiStrassen(int[][] a, int[][] b, int n) {
    if (n <= 64) {
        return multiply(a, b);
    }

    n = n >> 1;

    int[][] a11 = new int[n][n];
    int[][] a12 = new int[n][n];
    int[][] a21 = new int[n][n];
    int[][] a22 = new int[n][n];

    int[][] b11 = new int[n][n];
    int[][] b12 = new int[n][n];
    int[][] b21 = new int[n][n];
    int[][] b22 = new int[n][n];

    splitMatrix(a, a11, a12, a21, a22);
    splitMatrix(b, b11, b12, b21, b22);

    int[][] p1 = multiStrassen(summation(a11, a22), summation(b11, b22), n);
    int[][] p2 = multiStrassen(summation(a21, a22), b11, n);
    int[][] p3 = multiStrassen(a11, subtraction(b12, b22), n);
    int[][] p4 = multiStrassen(a22, subtraction(b21, b11), n);
    int[][] p5 = multiStrassen(summation(a11, a12), b22, n);
    int[][] p6 = multiStrassen(subtraction(a21, a11), summation(b11, b12), n);
    int[][] p7 = multiStrassen(subtraction(a12, a22), summation(b21, b22), n);

    int[][] c11 = summation(summation(p1, p4), subtraction(p7, p5));
    int[][] c12 = summation(p3, p5);
    int[][] c21 = summation(p2, p4);
    int[][] c22 = summation(subtraction(p1, p2), summation(p3, p6));

    return collectMatrix(c11, c12, c21, c22);
}


На этом можно было бы и закончить. Реализованный алгоритм работает домашка выполнена, но пытливые умы жаждут взрослый perfomance. Да пребудет с нами Java 7. Пора распараллелить
Java 7 предоставляет прекрасный API для распараллеливания рекурсивных задач. С её выходом появилось одно из дополнений к пакетам java.util.concurrent — реализация парадигмы Divide and Conquer — Fork-Join. Идея заключается в следующем: рекурсивно разбиваем задачу на подзадачи, решаем, а потом объединяем результаты. Более подробно с данной технологией можно ознакомиться в документации.

Посмотрим как легко и эффективно можно применить эту парадигму к нашему алгоритму Штрассена.

Реализация алгоритма с Fork/Join
private static class myRecursiveTask extends RecursiveTask {
    private static final long serialVersionUID = -433764214304695286L;
    int n;
    int[][] a;
    int[][] b;

    public myRecursiveTask(int[][] a, int[][] b, int n) {
        this.a = a;
        this.b = b;
        this.n = n;
    }

    @Override
    protected int[][] compute() {
        if (n <= 64) {
            return multiply(a, b);
        }

        n = n >> 1;

        int[][] a11 = new int[n][n];
        int[][] a12 = new int[n][n];
        int[][] a21 = new int[n][n];
        int[][] a22 = new int[n][n];

        int[][] b11 = new int[n][n];
        int[][] b12 = new int[n][n];
        int[][] b21 = new int[n][n];
        int[][] b22 = new int[n][n];

        splitMatrix(a, a11, a12, a21, a22);
        splitMatrix(b, b11, b12, b21, b22);

        myRecursiveTask task_p1 = new myRecursiveTask(summation(a11,a22),summation(b11,b22),n);
        myRecursiveTask task_p2 = new myRecursiveTask(summation(a21,a22),b11,n);
        myRecursiveTask task_p3 = new myRecursiveTask(a11,subtraction(b12,b22),n);
        myRecursiveTask task_p4 = new myRecursiveTask(a22,subtraction(b21,b11),n);
        myRecursiveTask task_p5 = new myRecursiveTask(summation(a11,a12),b22,n);
        myRecursiveTask task_p6 = new myRecursiveTask(subtraction(a21,a11),summation(b11,b12),n);
        myRecursiveTask task_p7 = new myRecursiveTask(subtraction(a12,a22),summation(b21,b22),n);

        task_p1.fork();
        task_p2.fork();
        task_p3.fork();
        task_p4.fork();
        task_p5.fork();
        task_p6.fork();
        task_p7.fork();

        int[][] p1 = task_p1.join();
        int[][] p2 = task_p2.join();
        int[][] p3 = task_p3.join();
        int[][] p4 = task_p4.join();
        int[][] p5 = task_p5.join();
        int[][] p6 = task_p6.join();
        int[][] p7 = task_p7.join();

        int[][] c11 = summation(summation(p1, p4), subtraction(p7, p5));
        int[][] c12 = summation(p3, p5);
        int[][] c21 = summation(p2, p4);
        int[][] c22 = summation(subtraction(p1, p2), summation(p3, p6));

        return collectMatrix(c11, c12, c21, c22);
    }
}


Кульминация
Вам, наверно, уже не терпится посмотреть на сравнение производительности работы алгоритмов на реальном железе. Сразу оговорим, что тестирование будем проводить на квадратных матрицах. Итак, мы имеем:
  1. Традиционный (Кубический) метод умножения матриц
  2. Традиционный с применением транспонирования
  3. Алгоритм Штрассена
  4. Распараллеленный алгоритм Штрассена

Размерность матриц будем задавать в интервале [100…4000] и с шагом в 100.

image

Как и ожидалось, наш первый алгоритм сразу выпал из тройки лидеров. Но вот с его модернизированным братом (вариант с транспонированием) не все так просто. Даже на довольно больших размерностях данный алгоритм не только не уступает, но и зачастую превосходит однопоточный алгоритм Штрассена. Особенности чтения многомерных массивов в Java дают о себе знать! И всё же, имея в рукаве козырь в виде Fork-Join Framework’а, нам удалось получить весомый прирост производительности. Распараллеливание алгоритма Штрассена позволило сократить время перемножения почти в 3 раза, а также возглавить наш итоговый тотал.

» Исходный код размещен здесь.

Будем рады отзывам и замечаниям к нашей работе. Спасибо за внимание!

Комментарии (0)

© Habrahabr.ru