#python #algorithm #performance #math #collatz
#python #алгоритм #Производительность #математика #collatz
Вопрос:
Мой вопрос очень прост.
Я написал эту программу для чистого развлечения. Он принимает числовой ввод и находит длину каждой последовательности Collatz вплоть до этого числа включительно.
Я хочу сделать это быстрее алгоритмически или математически (т. Е. Я знаю, что мог бы сделать это быстрее, запустив несколько версий параллельно или написав их на C , но где в этом удовольствие?).
Любая помощь приветствуется, спасибо!
РЕДАКТИРОВАТЬ: код дополнительно оптимизирован с помощью dankal444
from matplotlib import pyplot as plt
import numpy as np
import numba as nb
# Get Range to Check
top_range = int(input('Top Range: '))
@nb.njit('int64[:](int_)')
def collatz(top_range):
# Initialize mem
mem = np.zeros(top_range 1, dtype = np.int64)
for start in range(2, top_range 1):
# If mod4 == 1: (3x 1)/4
if start % 4 == 1:
mem[start] = mem[(start (start >> 1) 1) // 2] 3
# If 4mod == 3: 3(3x 1) 1 and continue
elif start % 4 == 3:
num = start (start >> 1) 1
num = (num >> 1) 1
count = 4
while num >= start:
if num % 2:
num = (num >> 1) 1
count = 2
else:
num //= 2
count = 1
mem[start] = mem[num] count
# If 4mod == 2 or 0: x/2
else:
mem[start] = mem[(start // 2)] 1
return mem
mem = collatz(top_range)
# Plot each starting number with the length of it's sequence
plt.scatter([*range(1, len(mem) 1)], mem, color = 'black', s = 1)
plt.show()
Комментарии:
1. Возможно, это было бы лучше подходит для Code Review (который является еще одним сообществом Stack Exchange).
Ответ №1:
Применение numba к вашему коду действительно очень помогает.
Я удалил tqdm, поскольку он не помогает с производительностью.
import time
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np
import numba as nb
@nb.njit('int64[:](int_)')
def collatz2(top_range):
mem = np.zeros(top_range 1, dtype=np.int64)
for start in range(2, top_range 1):
# If mod(4) == 1: Value 2 or 3 Cached
if start % 4 == 1:
mem[start] = mem[(start (start >> 1) 1) // 2] 3
# If mod(4) == 3: Use Algorithm
elif start % 4 == 3:
num = start
count = 0
while num >= start:
if num % 2:
num = (num >> 1) 1
count = 2
else:
num //= 2
count = 1
mem[start] = mem[num] count
# If mod(4) == 2 or 4: Value 1 Cached
else:
mem[start] = mem[(start // 2)] 1
return mem
def collatz(top_range):
mem = [0] * (top_range 1)
for start in range(2, top_range 1):
# If mod(4) == 1: Value 2 or 3 Cached
if start % 4 == 1:
mem[start] = mem[(start (start >> 1) 1) // 2] 3
# If mod(4) == 3: Use Algorithm
elif start % 4 == 3:
num = start
count = 0
while num >= start:
if num % 2:
num = (num >> 1) 1
count = 2
else:
num //= 2
count = 1
mem[start] = mem[num] count
# If mod(4) == 2 or 4: Value 1 Cached
else:
mem[start] = mem[(start // 2)] 1
return mem
# profiling here
def main():
top_range = 1_000_000
mem = collatz(top_range)
mem2 = collatz2(top_range)
assert np.allclose(np.array(mem), mem2)
Для top_range = 1_000 оптимизированная функция работает в ~ 100 раз быстрее. Для top_range = 1_000_000 оптимизированная функция примерно в 600 раз быстрее:
79 def main():
81 1 3.0 3.0 0.0 top_range = 1_000_000
83 1 24633045.0 24633045.0 98.7 mem = collatz(top_range)
85 1 39311.0 39311.0 0.2 mem2 = collatz2(top_range)
Комментарии:
1. Это действительно здорово. Я не знаком с Numba, что это делает? Преобразование из int32 в int64?
2. Numba — это jit-компилятор (just-in-time), который, короче говоря, компилирует заданную функцию в оптимизированный машинный код. Если ответ вас устраивает, пожалуйста, примите его.