我有以下问题,我很难为它编写优化的程序.
我得到了一个2D二进制数组(所有元素都是0或1,其形状是(n,m)),我需要从数组中删除最大数量的元素,同时保持每行之和大于min_x
,每列之和大于min_y
的属性.
我已经编写了以下非优化代码,它可以很好地用于小数组:
import numpy as np
from typing import Optional, Tuple
from itertools import product
from tqdm import tqdm
def min_dataset(data, min_x, min_y):
'''Data is a row-major numpy array of shape (n, m), consisting of only 0s and 1s (a mask).
min_x and min_y are the minimum number of 1s in each row and column, respectively.
This function returns a new data array, where each row and column has at least min_x and min_y 1s, respectively, and there are as few 1s as possible.
'''
best_solution = data
score = np.sum(best_solution)
# format: {matrix.tobytes(): (best_matrix, score)}
memo = {}
if np.any(np.sum(data, axis=0) < min_y) or np.any(np.sum(data, axis=1) < min_x):
return None
def get_best(data, main=False) -> Tuple[np.ndarray, int]:
'''Gets the matrix with the lowest score, given the constraints.
Uses memoization.
'''
nonlocal memo
if data.tobytes() in memo:
return memo[data.tobytes()]
# Find the rows with more than min_x 1s
can_remove_rows = np.sum(data, axis=1) > min_x
# Find the columns with more than min_y 1s
can_remove_columns = np.sum(data, axis=0) > min_y
if (not np.any(can_remove_rows)) or (not np.any(can_remove_columns)):
# If there's no row or column where we can remove a 1, we're done
ans = (data, np.sum(data))
memo[data.tobytes()] = (data, np.sum(data))
return ans
# Try removing each combination of rows and columns where we can remove a 1
best = data, np.sum(data)
# print(np.nonzero(can_remove_rows))
# print(np.nonzero(can_remove_columns))
iterator = product(np.nonzero(can_remove_rows)[0], np.nonzero(can_remove_columns)[0])
if main:
iterator = tqdm(list(iterator))
for row, col in iterator:
# print(row, col)
if data[row, col] == 0:
continue
# Remove the 1 at (row, col)
new_data = data.copy()
new_data[row, col] = 0
# Check if the new matrix is valid
if np.any(np.sum(new_data, axis=0) < min_y) or np.any(np.sum(new_data, axis=1) < min_x):
continue
# Recurse
new_best = get_best(new_data)
if new_best[1] < best[1]:
best = new_best
memo[data.tobytes()] = best
return best
get_best(data, main=True)
best = memo[data.tobytes()][0]
return best
if __name__ == "__main__":
# Create a random 5x5 binary matrix
data = np.random.randint(0, 2, (5, 5))
print(data)
print("computing")
ans = min_dataset(data, 1, 1)
print("Solution:")
print(ans)
print("Score (lower is better; None means no solution was found.):")
print(np.sum(ans))
示例输出:
[[0 1 0 0 0]
[1 1 1 1 1]
[1 1 0 1 1]
[0 1 1 0 1]
[0 1 1 0 1]]
computing
100%|█████████████████████████████████████████████████████████████████████████████████| 20/20 [00:02<00:00, 8.27it/s]
Solution:
[[0 1 0 0 0]
[0 0 0 1 0]
[1 0 0 0 0]
[0 0 0 0 1]
[0 0 1 0 0]]
Score (lower is better; None means no solution was found.):
5
然而,一旦我开始使用10x10矩阵,代码就运行得非常慢,我需要能够运行20000x500矩阵.
我如何才能加速我的代码?