ускорьте простое решение с помощью динамического программирования

#python #algorithm #dynamic #memoization

Вопрос:

недавно я столкнулся со следующей проблемой: мы рассматриваем следующий массив:

 A = [2, 3, 6, 1, 6, 4, 12, 24]
 

нам нужно подсчитать, сколько раз эти два условия выполняются в массиве:

 A[i] * A[j] * A[k] = A[l] so that 0 <= i < j < k < l < len(A)
 

для этого примера результат должен быть равен 8.
пример удовлетворенных условий в массиве:

 2 * 3 * 1 = 6
2 * 6 * 1 = 12
6 * 1 * 4 = 24
3 * 1 * 4 = 12
 

простое решение, которое я создал с помощью python:

 A = [2, 3, 6, 1, 6, 4, 12, 24]
result = 0
for i in range(len(A)):
    for j in range(i   1, len(A)):
        for k in range(j   1, len(A)):
            for l in range(k   1, len(A)):
                if A[i] * A[j] * A[k] == A[l]:
                    result  = 1
print(result)
 

мне нужно найти способ ускорить программу с помощью динамического, возможно, запоминания или предварительных вычислений.

 result = 0
for i in range(A):
    for j in range(i   1, A):
        for k in range(j   1, A):
             #TODO
print(result)
 

я думал о создании словаря, который содержит набор словарей для каждого числа, чтобы указать число и его положение, пример:

 memo = {
    6: {
        2: True
        4: True
    }, 
    1: {
       3: True
    },
    ...
}
 

затем мы проверяем, как показано ниже:

 result = 0
for i in range(A):
    for j in range(i   1, A):
        for k in range(j   1, A):
            x = A[i] * A[j] * A[k]
            if x in memo:
                result  = len([z[0] for z in memo[x].items() if z[0] > k])
 

Таким образом, мы будем считать все вхождения одного и того же числа после индекса k сразу, нам не придется проходить через весь массив.

Пожалуйста, дайте мне знать, если моя оптимизация ошибочна, и если есть лучшая оптимизация с использованием методов динамического программирования, таких как запоминание или предварительные вычисления.

Ответ №1:

Я не являюсь экспертом в области оптимизации или динамического программирования. Но то, что вы могли бы сделать, чтобы ускорить свой код, — это использовать itertools.combinations для определения всех комбинаций индексов заранее. Таким образом, после этого вам нужно будет только выполнить цикл по всем комбинациям.

Приведенный ниже код должен быть быстрее:

 import itertools

# Determine combinations of indices
comb = itertools.combinations(range(len(A)),4)
result = 0
for idx in comb:
    if A[idx[0]]*A[idx[1]]*A[idx[2]] == A[idx[3]]:
        result  = 1
result
 

Очевидно, что это не будет быстрым решением, если A оно будет большого размера. Но для вашего массива это ~в 2 раза быстрее

Ответ №2:

Во-первых, кажется, что мы могли бы получить O(n^2 log n), хэшируя все кратные пары и все парные деления, где деление представляет собой целое число. Для кратных чисел сохраните список только с более высоким индексом каждой такой пары; а для деления сохраните список только с более низким индексом каждой такой пары. Затем повторите каждую пару делений и используйте двоичный поиск, чтобы определить, сколько пар кратных совпадают, у кого высокий индекс ниже.

 A[i] * A[j] = A[l] / A[k]