- numpy 没有直接取top-k的函数。需要手动实现。
argpartition
这个操作指:根据一个数值x,把数组中的元素划分成两半,使得index前面的元素都不大于x,index后面的元素都不小于x。(快排)
numpy中的argpartition()函数就是起的这个作用。对于传入的数组a,先用O(n)复杂度求出第k大的数字,然后利用这个第k大的数字将数组a划分成两半。
此函数不对原数组进行操作,它只返回分区之后的下标
>>> x = np.array([3, 5, 6, 4, 2, 7, 1])
>>> x[np.argpartition(x, 3)]
array([2, 1, 3, 4, 5, 7, 6]) #3
def partition_arg_topK(matrix, K, axis=0):
"""
perform topK based on np.argpartition
:param matrix: to be sorted
:param K: select and sort the top K items
:param axis: 0 or 1. dimension to be sorted.
:return:
"""
a_part = np.argpartition(matrix, K, axis=axis)
if axis == 0:
row_index = np.arange(matrix.shape[1 - axis])
a_sec_argsort_K = np.argsort(matrix[a_part[0:K, :], row_index], axis=axis)
return a_part[0:K, :][a_sec_argsort_K, row_index]
else:
column_index = np.arange(matrix.shape[1 - axis])[:, None]
a_sec_argsort_K = np.argsort(matrix[column_index, a_part[:, 0:K]], axis=axis)
return a_part[:, 0:K][column_index, a_sec_argsort_K]