Максимизация эффективности программы гипотезы Collatz на Python

#python #algorithm #performance #math #collatz

#python #алгоритм #Производительность #математика #collatz

Вопрос:

Мой вопрос очень прост.

Я написал эту программу для чистого развлечения. Он принимает числовой ввод и находит длину каждой последовательности Collatz вплоть до этого числа включительно.

Я хочу сделать это быстрее алгоритмически или математически (т. Е. Я знаю, что мог бы сделать это быстрее, запустив несколько версий параллельно или написав их на C , но где в этом удовольствие?).

Любая помощь приветствуется, спасибо!

РЕДАКТИРОВАТЬ: код дополнительно оптимизирован с помощью dankal444

 from matplotlib import pyplot as plt
import numpy as np
import numba as nb

# Get Range to Check
top_range = int(input('Top Range: '))

@nb.njit('int64[:](int_)')
def collatz(top_range):
    # Initialize mem
    mem = np.zeros(top_range   1, dtype = np.int64)
    for start in range(2, top_range   1):
        # If mod4 == 1: (3x   1)/4
        if start % 4 == 1:
            mem[start] = mem[(start   (start >> 1)   1) // 2]   3
        
        # If 4mod == 3: 3(3x   1)   1 and continue
        elif start % 4 == 3:
            num = start   (start >> 1)   1
            num  = (num >> 1)   1
            count = 4

            while num >= start:
                if num % 2:
                    num  = (num >> 1)   1
                    count  = 2
                else:
                    num //= 2
                    count  = 1
            mem[start] = mem[num]   count

        # If 4mod == 2 or 0: x/2
        else:
            mem[start] = mem[(start // 2)]   1

    return mem

mem = collatz(top_range)

# Plot each starting number with the length of it's sequence
plt.scatter([*range(1, len(mem)   1)], mem, color = 'black', s = 1)
plt.show()
 

Комментарии:

1. Возможно, это было бы лучше подходит для Code Review (который является еще одним сообществом Stack Exchange).

Ответ №1:

Применение numba к вашему коду действительно очень помогает.

Я удалил tqdm, поскольку он не помогает с производительностью.

 import time
from matplotlib import pyplot as plt
from tqdm import tqdm

import numpy as np
import numba as nb
@nb.njit('int64[:](int_)')
def collatz2(top_range):
    mem = np.zeros(top_range   1, dtype=np.int64)
    for start in range(2, top_range   1):
        # If mod(4) == 1: Value 2 or 3 Cached
        if start % 4 == 1:
            mem[start] = mem[(start   (start >> 1)   1) // 2]   3
        # If mod(4) == 3: Use Algorithm
        elif start % 4 == 3:
            num = start
            count = 0
            while num >= start:
                if num % 2:
                    num  = (num >> 1)   1
                    count  = 2
                else:
                    num //= 2
                    count  = 1
            mem[start] = mem[num]   count
        # If mod(4) == 2 or 4: Value 1 Cached
        else:
            mem[start] = mem[(start // 2)]   1
    return mem


def collatz(top_range):
    mem = [0] * (top_range   1)
    for start in range(2, top_range   1):
        # If mod(4) == 1: Value 2 or 3 Cached
        if start % 4 == 1:
            mem[start] = mem[(start   (start >> 1)   1) // 2]   3
        # If mod(4) == 3: Use Algorithm
        elif start % 4 == 3:
            num = start
            count = 0
            while num >= start:
                if num % 2:
                    num  = (num >> 1)   1
                    count  = 2
                else:
                    num //= 2
                    count  = 1
            mem[start] = mem[num]   count
        # If mod(4) == 2 or 4: Value 1 Cached
        else:
            mem[start] = mem[(start // 2)]   1
    return mem

# profiling here
def main():

    top_range = 1_000_000
    mem = collatz(top_range)
    mem2 = collatz2(top_range)
    assert np.allclose(np.array(mem), mem2)


 

Для top_range = 1_000 оптимизированная функция работает в ~ 100 раз быстрее. Для top_range = 1_000_000 оптимизированная функция примерно в 600 раз быстрее:

     79                                           def main():
    81         1          3.0      3.0      0.0      top_range = 1_000_000
    83         1   24633045.0 24633045.0     98.7      mem = collatz(top_range)
    85         1      39311.0  39311.0      0.2      mem2 = collatz2(top_range)

 

Комментарии:

1. Это действительно здорово. Я не знаком с Numba, что это делает? Преобразование из int32 в int64?

2. Numba — это jit-компилятор (just-in-time), который, короче говоря, компилирует заданную функцию в оптимизированный машинный код. Если ответ вас устраивает, пожалуйста, примите его.