#c #algorithm #bit-manipulation #modulo #integer-overflow
Вопрос:
Я наткнулся на эту функцию умножения по модулю в коде для теста на примитивность Миллера-Рабина. Предполагается, что это устранит переполнение целых чисел, возникающее при вычислении ( a * b ) % m
.
Мне нужна помощь в понимании того, что здесь происходит. Почему это работает? и каково значение этого буквального числа 0x8000000000000000ULL
?
unsigned long long mul_mod(unsigned long long a, unsigned long long b, unsigned long long m) {
unsigned long long d = 0, mp2 = m >> 1;
if (a >= m) a %= m;
if (b >= m) b %= m;
for (int i = 0; i < 64; i )
{
d = (d > mp2) ? (d << 1) - m : d << 1;
if (a amp; 0x8000000000000000ULL)
d = b;
if (d >= m) d -= m;
a <<= 1;
}
return d;
}
Комментарии:
1. Незначительное примечание:
LL
вход0x8000000000000000ULL
не требуется.2. Мы сдвигаем биты влево до тех пор, пока старший бит не станет единицей — тогда мы знаем, что больше не можем сдвигаться, иначе мы сместим то, что нам нужно.
3. Классический пример того, где несколько комментариев автора очень помогли бы.
4. Хотя это интересно и портативно, если ваша цепочка инструментов предлагает больший целочисленный тип , вы можете использовать его вместо этого.
Ответ №1:
Этот код, который в настоящее время отображается на странице Википедии «Модульная арифметика«, работает только для аргументов длиной до 63 бит-см. Внизу.
Обзор
Один из способов вычисления обычного умножения a * b
-добавить сдвинутые влево копии b
-по одной на каждый 1-бит a
. Это подобно тому, как большинство из нас сделали давно умножения в школе, но упрощается: так как мы только когда-либо нужно «умножить» каждый экземпляр b
на 1 или 0, все что нам нужно сделать, это добавить сдвинутой копией b
(если соответствующий бит a
равен 1) или ничего не делать (когда он равен 0).
Этот код делает нечто подобное. Однако, чтобы избежать переполнения (в основном; см. Ниже), вместо того, чтобы перемещать каждую копию b
, а затем добавлять ее в общую сумму, он добавляет нешифрованную копию b
в общую сумму и полагается на последующие сдвиги влево, выполненные для общей суммы, чтобы переместить ее в нужное место. Вы можете думать об этих сдвигах, «действующих» на все слагаемые, добавленные к общему числу до сих пор. Например, первая итерация цикла проверяет , равен ли старший бит a
, а именно бит 63, 1 (вот что a amp; 0x8000000000000000ULL
делает), и если это так, добавляет b
к общей сумме нешифрованную копию; к моменту завершения цикла предыдущая строка кода еще 63 раза сдвинет общую сумму d
влево на 1 бит.
Основное преимущество такого способа заключается в том, что мы всегда сложения двух чисел (а именно: b
А d
) что мы уже знаем меньше m
, поэтому регулировать по модулю запахом это недорого: мы знаем, что b d < 2 * m
, таким образом, чтобы гарантировать, что наше общее до сих пор остается меньше m
, достаточно проверить, является ли b d < m
, и если нет, то вычитаем m
. Если бы мы вместо этого использовали подход «сдвиг, а затем сложение», нам потребовалась %
бы операция по модулю на бит, которая стоит так же дорого, как деление, и обычно намного дороже, чем вычитание.
Одно из свойств арифметики по модулю состоит в том, что всякий раз , когда мы хотим выполнить последовательность арифметических операций по модулю некоторого числа m
, выполнение их всех в обычной арифметике и взятие остатка по модулю m
в конце всегда дает тот же результат, что и взятие остатков по модулю m
для каждого промежуточного результата (при условии, что не происходит переполнения).
Код
Перед первой строкой тела цикла у нас есть инварианты d < m
и b < m
.
Линия
d = (d > mp2) ? (d << 1) - m : d << 1;
это осторожный способ сдвинуть общее d
значение влево на 1 бит, сохраняя его в пределах диапазона 0 .. m
и избегая переполнения. Вместо того, чтобы сначала сдвинуть его, а затем проверить, является ли результат m
или больше , мы проверяем, находится ли он в настоящее время строго выше RoundDown(m/2)
-потому что если это так, то после удвоения он, безусловно, будет строго выше 2 * RoundDown(m/2) >= m - 1
, и поэтому потребуется вычитание m
, чтобы вернуться в диапазон. Обратите внимание , что даже если (d << 1)
in (d << 1) - m
может переполниться и потерять верхний бит d
, это не повредит, так как не влияет на самые низкие 64 бита результата вычитания, которые являются единственными, которые нас интересуют. (Также обратите внимание, что если d == m/2
точно, мы заканчиваем с d == m
«потом», что немного выходит за рамки допустимого , но изменение теста с d > mp2
d >= mp2
«на», чтобы исправить это, нарушило бы случай, когда m
это странно и d == RoundDown(m/2)
, поэтому мы должны жить с этим. Это не имеет значения, потому что это будет исправлено ниже.)
Почему бы просто не написать d <<= 1; if (d >= m) d -= m;
вместо этого? Предположим, что в арифметике с бесконечной точностью d << 1 >= m
, поэтому мы должны выполнить вычитание-но старший бит d
включен , а остальная часть d << 1
меньше m
: в этом случае начальный сдвиг потеряет старший бит и if
не сможет выполняться.
Ограничение на входы размером 63 бита или менее
Приведенный выше крайний случай может возникнуть только при d
включенном старшем бите, что может произойти только при m
включенном старшем бите (поскольку мы сохраняем инвариант d < m
). Таким образом, похоже, что код прилагает все усилия, чтобы работать правильно даже при очень высоких значениях m
. К сожалению, оказывается, что он все еще может переполняться в другом месте, что приводит к неправильным ответам на некоторые входные данные, которые устанавливают верхний бит. Например, когда a = 3
b = 0x7FFFFFFFFFFFFFFFULL
и m = 0xFFFFFFFFFFFFFFFFULL
, правильный ответ должен быть 0x7FFFFFFFFFFFFFFEULL
, но код вернется 0x7FFFFFFFFFFFFFFDULL
(простой способ увидеть правильный ответ-повторить запуск со значениями a
и b
поменять местами). В частности, такое поведение возникает всякий раз , когда строка d = b
переполняется и оставляет усеченное d
значение меньше m
, что приводит к ошибочному пропуску вычитания.
При условии, что такое поведение задокументировано (как на странице Википедии), это просто ограничение, а не ошибка.
Снятие ограничения
Если мы заменим строки
if (a amp; 0x8000000000000000ULL)
d = b;
if (d >= m) d -= m;
с
unsigned long long x = -(a >> 63) amp; b;
if (d >= m - x) d -= m;
d = x;
код будет работать для всех входных данных, включая те, у которых установлены верхние биты. Загадочная первая строка-это просто условно-свободный (и, следовательно, обычно более быстрый) способ написания
unsigned long long x = (a amp; 0x8000000000000000ULL) ? b : 0;
Тест d >= m - x
работает d
до того, как он был изменен-это похоже на старый d >= m
тест, но b
(когда верхний бит a
включен) или 0 (в противном случае) вычитается с обеих сторон. Это проверяет, будет ли он d
больше m
или больше после x
добавления к нему. Мы знаем, что RHS m - x
никогда не переполняется, потому что самый большой x
может быть b
, и мы установили это b < m
в верхней части функции.
Комментарии:
1. Википедия утверждает, что алгоритм работает «с целыми числами без знака, не превышающими 63 бита», и в этом случае старший бит не будет установлен, и ошибка не возникнет.
2. @interjay: Спасибо, что заметили это! Я обновил свой ответ.