一、问题背景
在使用机器学习模型预测的各类场景中,对象间的序关系是否预测准确,是考量模型效果的指标之一。
正序数、逆序数可作为模型效果的衡量指标。当样本的之间存在序关系时,由样本间两两组成的,若模型预测结果的序关系与之间的序关系相同,称为正序;若模型预测结果的序关系与之间的序关系相反,称为逆序。当正序数量越多、逆序数量越少时,表明模型对序关系的刻画越准确,模型效果越好。正逆序即为正序数量与逆序数量的比值。
计算正序数量、逆序数量时,一种直观的方法是暴力构造所有的对并一一验证,在个样本下时间复杂度为,当的量级在5万时普通的服务器上已需要耗费近10分钟(Python),在更大量级时计算时间无法忍受。
进一步,如果在模型训练过程中,希望在训练的每一轮迭代时都查看正逆序的值(据此做终止条件);或是在多组参数间使用正逆序做验证调参的标准时,正逆序的计算速度问题则更为突出。我们需要复杂度更低的算法来快速计算出正逆序。
二、数学表述
给定个样本,现有人工打好的每个样本的,记为以及模型预测出的每个样本的分值,记为
以 表示集合的大小, 表示逻辑与、表示逻辑或
以下 均在 中取值、且均满足 ,不再特别写出
构造出的所有pair集合为(注意我们只对不同的样本构造)
正序集合为
逆序集合为
严格正序集合为
严格逆序集合为
由以上的定义容易看出,
三、算法思路
注:以下排序均指升序排列。
MergeSort
熟悉排序算法的同学应已看出:这个问题像极了经典的找逆序对问题。
给定数组,找出满足 的 对数量。使用可以在 时间计算出结果。
MergeSort计算逆序的思路较为直接,使用的是Divide-and-Conquer的思想:
- 将数组二分
- 左半边数组排好序并计算出其内部的逆序数 (可通过递归调用实现)
- 右半边数组排好序并计算出其内部的逆序数 (可通过递归调用实现)
- 左半边数组与右半边数组合并,得到整体排好序的数组、以及左半边与右半边形成的逆序对数量
- 返回整体排好序的数组、总逆序数(左半边内部 + 右半边内部 + 左右半边联合形成)
严格逆序 StrictWrong
我们从计算严格逆序集合入手。
根据上一点关于的讨论,一个自然的想法出现了:
先按排序得到数组,接着计算数组中的关于的逆序对数量。
但这与我们的原始需求仍有细微的差别:在按排序后,
我们对于下标只能得到
然而
因此直接计算逆序对的话,会把相等的情形也算进来,得不到正确答案。
为了消除这一影响,我们可以使用一个小trick:
在按排序时,排序的key不仅仅使用,而是按照二元组 排序。
即:先按排序,当值相等时,按排序
于是
因此在这种情形下我们不会统计到任何label相等时的逆序对。可在时间内计算得到
严格正序 StrictRight
思路与严格逆序完全一致,只是不等号方向变反。为了程序复用,可将正负号变反后、直接调用严格逆序计算的程序得到结果
正序Right, 逆序Wrong, Pair
因为
结合前面的结果, 只需计算出 , 即可获得 和
的计算较为简单,将排好序后,依次遍历处理即可,总复杂度为
结论及实验
根据以上讨论,各个集合的大小均可在 时间计算得出。
随机构造的样本在普通服务器上的实际运行时间统计如下(Python),可以看出优化后的算法执行时间大幅提升。
样本数量N | 基于 MergeSort 计算时间(秒) | 基于 暴力法 计算时间(秒) |
---|---|---|
500 | 0.017 | 0.051 |
5000 | 0.164 | 5.048 |
50000 | 2.017 | 512.371 |
四、总结与展望
总结
- 将原始问题转为经典的求解数组逆序对问题,并使用经典的MergeSort进行求解。在问题转化建模的过程中使用多字段排序等trick,使得逆序对问题与原始问题完全等价。
- 最终将计算正逆序的时间复杂度由 优化至
展望
- 在模型训练的迭代过程中,保持不变,且每一轮参数变化带来的预测变化较小。是否有可能依据的变化计算逆序对的变化、加速计算(不必每次都从头开始),使得多轮迭代时的逆序对计算整体时间缩短?
- 是否存在分布式的解决方案,能够应对海量样本数量的正逆序计算?
注
- 本文中使用的各类名词及其对应英文表达,均为随手自创,仅为上下文叙述方便(除了mergeSort等大众熟知的术语)
附、代码片段
(代码中的变量true即为上文中的,取"groundtruth"之意;pred即为上文中的)
from itertools import groupby
class InversionCounter(object):
@classmethod
def merge_sort_count_sub(cls, vals):
if len(vals) <= 1:
return vals, 0
n = len(vals)
left_vals, left_cnt = cls.merge_sort_count_sub(vals[:n/2])
right_vals, right_cnt = cls.merge_sort_count_sub(vals[n/2:])
left_i = 0
right_i = 0
mid_cnt = 0
new_vals = []
while True:
if left_vals[left_i][1] <= right_vals[right_i][1]:
new_vals.append(left_vals[left_i])
left_i += 1
elif left_vals[left_i][1] > right_vals[right_i][1]:
mid_cnt += (len(left_vals) - left_i)
new_vals.append(right_vals[right_i])
right_i += 1
if left_i == len(left_vals):
new_vals.extend(right_vals[right_i:])
break
if right_i == len(right_vals):
new_vals.extend(left_vals[left_i:])
break
return new_vals, left_cnt + mid_cnt + right_cnt
@classmethod
def merge_sort_count_strict_right(cls, trues, preds):
neg_preds = (-p for p in preds)
vals = zip(trues, neg_preds)
vals.sort()
return cls.merge_sort_count_sub(vals)[1]
@classmethod
def merge_sort_count_strict_wrong(cls, trues, preds):
vals = zip(trues, preds)
vals.sort()
return cls.merge_sort_count_sub(vals)[1]
@classmethod
def merge_sort_count_right(cls, trues, preds):
return cls.merge_sort_count_pair(trues) - cls.merge_sort_count_strict_wrong(trues, preds)
@classmethod
def merge_sort_count_wrong(cls, trues, preds):
return cls.merge_sort_count_pair(trues) - cls.merge_sort_count_strict_right(trues, preds)
@classmethod
def merge_sort_count_pair(cls, trues, preds=None):
'''
preds: dummpy variable, no need inside function
'''
trues = sorted(trues)
acc_num = 0
pair = 0
for k, ks in groupby(trues):
current_num = sum(1 for _ in ks)
acc_num += current_num
pair += (len(trues) - acc_num) * current_num
return pair