Как поделить не деля или оптимизация деления компиляторам(и)

1f830a894c518089a199ffc3de9f1685.png

Если вы никогда не пробовали смотреть как код на C++ разворачивается компилятором в код Assembly — вас ждёт много сюрпризов, причём, не нужно смотреть какой-то замудренный исходный код полный templates или других сложных конструкций: рассмотрите следущий snippet:

uint8_t div10(uint8_t x)
{
    return x/10;
}

Конечно, я это уже сделал, и приведу результаты прямо здесь, хотя, советую и самим сходить на замечательный ресурс https://godbolt.org/ — выставить там, например, x86–64 gcc 14.1 и убедиться в том, что результаты крайне интересные:

div10(unsigned char):
        push    rbp
        mov     rbp, rsp
        mov     eax, edi
        mov     BYTE PTR [rbp-4], al
        movzx   eax, BYTE PTR [rbp-4]
        mov     edx, -51
        mul     dl
        shr     ax, 8
        shr     al, 3
        pop     rbp
        ret

Действительно, чего совсем не видно в этом куске кода так это инструкции div, которая несомненно существует и собственно осуществляет деление на x86 архитектуре. Зато — откуда та взялась магическая константа, ещё и отрицательная! Ну ладно, на самом деле она положительная и равна 205 (т.к. 205+51=256 = mod 256), так что этот вопрос закрыт, но как же всё таки это всё работает?

Базовый принцип

Работает это следущим образом после умножения на 205 мы делаем суммарный сдвиг вправо на 11 разрядов, что эквивалентно делению на 2^11 = 2048 с отбрасыванием остатка. Последнее — очень важно. Заметим, что 205/2048 = 0.10009765625, иначе говоря чуть-чуть больше чем 0.1, если вы умножите последнее число на калькуляторе (или в Python) на 255 вы получите 25.524, иначе говоря, после отбрасывания дробной части — это правильный ответ.

А можно было взять число чуть-чуть меньше чем 0.1? Нет — тривиально умножив на 10 мы бы получили «чуть-чуть меньше чем 1», и просто отбрасывая дробную часть получили бы 0 — немного ни тот результат которого ожидаешь деля 10 на 10. А насколько чуть-чуть больше должно быть число, что б трюк сработал? Для uint8_t максимальное число — 255, соответсвенно, число должно отличаться от 1/10 не больше чем на 1/256. А с любым ли числом (хотя бы из базового типа uint8_t это возможно) — да, с любым.

Тут, я напомню про такие математические функции как floor, ceil и trunc: мы работаем только с положительными числами, поэтому, без лишних сложностей floor=trunc и просто отбрасывает дробную часть, a ceil всегда округляет вверх. Пусть d — наш делитель, N — наша разрядность (у нас 8), мы хотим получить выражение вида: m / 2 ^ (N + k) такое что оно чуть больше 1/d (тут обо всём думаем в вещественных числах), а если точнее то:

m / (2 ^ (N + k)) — (1 / d) >= 1 / 2^N (просто обобщение предыдущего абзаца).

Как найти подходящие числа

Утверждение: m = ceil (2^(ceil (log (d)) + N) / d), k = ceil (log (d)) — как это всё понять, что тут такое написано? … Это можно понять, например, таким образом: число внизу по условию это степень 2ки, и это степень точно не меньше чем 2^N, далее преположим, я как-то нашёл k — как теперь по заданному как подобрать m? Я уже знаю знаменатель — 2 ^ (N + k), я хочу подобрать целое число, что б оно было чуть больше чем 1/d, что если я возьму просто trunc (2 ^ (N + k) / d) ?

Давайте в числах из примера: я хочу делить uint8_t на 10: т.е. N=8, d = 10, а k пока возьмем равным 2, например, тогда trunc (2 ^ (N + k) / d) = trunk (2^10 / 10) = trunk (1024 /10) = 102, а всё выражение m / (2 ^ (N + k)) = 102/ 1024 = 0,099609375, в общем, близко, но чуть меньше чем нам надо. А вот ceil — всегда будет больше потому что: ceil (x) >= x для положительных чисел. Я ещё не сделал эту оговорку, но сделаю, что d > 1 и d не является точной степенью двойки, то есть точных делений у нас тут не будет, вторая оговорка trunc (x/y) в C++ это просто обычное целочисленное деление.

Итак, я надеюсь, к этому моменту стало понятно, что m в том виде как я ищу действительно апроксимирует 1/d сверху. Теперь посмотрим почему я выбрал такое k: ceil (2^(ceil (log (d)) + N) / d, вот здесь делая внешний ceil я прибаляю к числителю число не более чем d — потому что остаток не может быть ни больше ни равен d, и понятно что 2^(ceil (log (d)) > d.

Думаю, время показать код:

template
std::pair getDisionMultiplier(InputInteger divisor)
{
    if (!divisor)
    {
        throw std::invalid_argument("Division by zero is impossible");
    }

    if (divisor == 1)
    {
        return {1,0};
    }

    constexpr uint8_t n = sizeof(InputInteger) * CHAR_BIT;

    const double log_d_temp = std::log2(static_cast(divisor));
    const uint8_t log_d = std::ceil(log_d_temp);

    if (log_d == std::floor(log_d_temp))
    {
        return {1, log_d};
    }

    OutputInteger res = std::ceil(static_cast(static_cast(1) << (log_d + n)) / double(divisor));

    return {res, n + log_d};
}

// somewhere in the main function

    for(uint8_t divisor = 1; divisor > 0; divisor++)
    {
        auto [multiplier, shift] = getDisionMultiplier(divisor);

        for(uint8_t numenator = 1; numenator > 0; numenator++)
        {
            uint32_t res = static_cast(numenator * multiplier) >> shift;

            if (res != numenator / divisor)
            {
                std::cout << "panic: did something went wrong?" << std::endl;
            }
        }
    }

Наверное, очевидно, что фразу про панику мы никогда не увидим — значит всё? Работает и статью пора заканчивать? — Нет.

Что делает компилятор

Коспилятор, точно делает по-другому, действительно, выведем, что код, приведенный выше дает для d=10:

    auto p = getDisionMultiplier(static_cast(10));
    std::cout << p.first << " " << (uint16_t)(p.second) << std::endl;

410 12

А из куска Assemly из начала статьих понятно, что должно было быть 205 и 11… Дело в том, что gcc хочет получить константу m того же размера, что и d и использует для этого более хитрый алгоритм (если вы приглядитесь к моему коду выше — я предусмотрительно использовал тип uint16_t).

Алгоритм gcc основан на подсчёте пары констант m_low = trunc (2^(ceil (log (d)) + N) / d) и m_high = trunc (2^(ceil (log (d)) + N) + 2^(ceil (log (d)) / d), тут важно понимать что m_low — не может быть настоящим m (показывал это выше), а m_high — может (извините, я не буду это тоже пытаться тут расписать как и все оставшиеся выкладки), но m_high, вообще говоря, даже больше m из моего наивного алгоритма. Да, но, и m_low и m_high — точно меньше чем 2^(N + 1), то есть для нашего случая это было бы 9 бит, и ещё точно в целых числах m_high > m_low (без равенства). Что же делает этот алгоритм дальше, чтобы сделать из 9-ти битного числа 8-ми битное число? Правильно, он просто сдвинет m_high на разряд вправо: тут важно понимать, что таким образом он делает trunc (m_high/2), и конечно паралельно, он уменьшит k (число сдвигов направо в конечном коде), но …если число нечетное, это же не тоже самое? Да, в этом случае есть риск что мы начнём апроксимировать снизу…поэтому, компилятор так делает и trunc (m_low / 2) и сравнивает эти два числа, потому что m_low уже изначально слишком мала — если m_high деградировала до неё — то нельзя дальше уменьшать m_high и соответсвуюющий сдвиг. А вот и код делающий это:

template
std::tuple getDisionMultiplier(InputInteger divisor)
{
    if (!divisor)
    {
        throw std::invalid_argument("Division by zero is impossible");
    }

    if (divisor == 1)
    {
        return {1, 0, false};
    }

    constexpr uint8_t n = sizeof(InputInteger) * CHAR_BIT;

    const double log_d_temp = std::log2(static_cast(divisor));
    const uint8_t log_d = std::ceil(log_d_temp);

    if (log_d == std::floor(log_d_temp))
    {
        return {1, log_d, false};
    }

    uint64_t temp_low = (1UL << (log_d + n));
    uint64_t temp_hight = (1UL << log_d) | (1UL << (log_d + n));

    temp_hight /= divisor;
    temp_low /= divisor;

    uint8_t additionla_shift = log_d;

    while (additionla_shift)
    {

        if (temp_low /2 >= temp_hight/2)
        {
            break;
        }

        temp_low /= 2;
        temp_hight /= 2;

        --additionla_shift;
    }

    return {temp_hight, n + additionla_shift, temp_hight > std::numeric_limits::max()};
}

// somewhere in the main function
    auto [coeff, shift, _] = getDisionMultiplier(static_cast(10));
    std::cout << (uint16_t)coeff << " " << (uint16_t)(shift) << std::endl;

205 11

Успех! Коеффициенты совпали! Всё ли на этом? Увы: нет гарантии, что цикл, который редуцирует temp_hight, не выйдет сразу после первой итерации. То есть нет гарантии получить 8-ми битное число, но тогда у нас возможен срез по модулю в return. Gcc умеет успешно использовать это срезанное значение –, но это явно отдельная тема.

А если очень интересно или просто не терпится, то я оставлю ссылки, на которые я опирался

© Habrahabr.ru