参考: 《Foundations of Statistical Natural Language Processing》第5章
实现了两种方法:互信息和卡方测试,直观上看卡方测试要好很多
#! /usr/bin/env python
#encoding: utf-8
#
# Extract collacation
#
#
import sys
import os
import math
from collections import defaultdict
g_word_map = {}
g_word_list = []
g_core_set = set()
g_exclusion_set = set()
def load_words(infile):
word_set = set()
if os.path.isfile(infile):
with open(infile) as f:
for line in f:
word = line.strip().split('\t')[0]
word_set.add(word)
return word_set
def get_word_index(word):
index = -1
if word not in g_word_map:
g_word_list.append(word)
index = len(g_word_list) - 1
g_word_map[word] = index
else:
index = g_word_map[word]
return index
def calc_mi(word_freq, bigram_freq, doc_num, least_score, least_heat):
'''
calculate mutual information of two words
'''
for k, freq in bigram_freq.iteritems():
if freq < least_heat:
continue
(i,j) = k
n1 = word_freq[i]
n2 = word_freq[j]
mi = math.log(float(freq*doc_num)/n1/n2, 2)
if mi < least_score:
continue
w1 = g_word_list[i]
w2 = g_word_list[j]
print "%s %s\t%.5f\t%d\t%d\t%d" % (w1, w2, mi, freq, n1, n2)
def calc_chi_square(word_freq, bigram_freq, doc_num, least_score, least_heat):
'''
calculate chi square score of two words
'''
for k, freq in bigram_freq.iteritems():
if freq < least_heat:
continue
(i,j) = k
w1 = g_word_list[i]
w2 = g_word_list[j]
n1 = word_freq[i]
n2 = word_freq[j]
e11 = freq
e12 = n1 - freq
e21 = n2 - freq
e22 = doc_num - n1 -n2 + freq
#print n1,n2,e11,e12,e21,e22
valid = 0
if e11 > 5 and e12 > 5 and e21 > 5 and e22 > 5:
valid = 1
denominator = n1 * n2 * (e22+e21)*(e22+e12) # when 28 24 0 0, denominator = 0
chi = 1000
if denominator > 0:
chi = doc_num*(e11*e22-e12*e21)**2 / float(denominator)
if chi < least_score:
continue
print "%s %s\t%.5f\t%d\t%d\t%d\t%d" % (w1, w2, chi, freq, n1, n2, valid)
def process_file(infile, seg_field, stop_file, method_type):
'''
extract collocation
'''
stop_set = load_words(stop_file)
bad_pos = set(["c", "t", "r", "ad", "d", "f", "w", "o", "y", "p", "u", "q", "m"])
doc_num = 0
word_freq = defaultdict(int)
bigram_freq = defaultdict(int)
for line in file(infile):
doc_num += 1
s = line.rstrip('\n').split('\t')
if len(s) <= seg_field:
continue
word_str = s[seg_field].lower()
words = word_str.split("||") #format: hello@nx||world@nx
word_set = set()
for x in words:
i = x.rfind('@')
if i == -1:
continue
word = x[:i]
pos = x[i+1:]
#filter word
if len(word) < 4 and pos != "a":
if word not in g_core_set:
continue
if pos in bad_pos:
continue
if word.find("{") != -1:
continue
if pos == "nx":
if not word.isalpha(): #word.isalnum()
continue
if word in stop_set:
continue
index = get_word_index(word)
if index in word_set: #uniq
continue
word_freq[index] += 1
word_set.add(index)
# get bigram
n = len(word_set)
if n < 2:
continue
word_list = sorted(word_set) #note this! python set is a hashtable
hit = 0
hit2 = 0
for i in range(n-1):
for j in range(i+1, n):
key = (word_list[i], word_list[j])
# filter bad collocation
w1 = g_word_list[key[0]]
w2 = g_word_list[key[1]]
if w1 in g_exclusion_set and w2 in g_exclusion_set:
continue
bigram_freq[key] += 1
#for k,v in word_freq.items():
# print "%s\t%d" % (g_word_list[k], v)
#return
if method_type == 'chi_square':
least_score = 3.841
least_heat = 5
calc_chi_square(word_freq, bigram_freq, doc_num, least_score, least_heat)
else:
least_score = 1
least_heat = 5
calc_mi(word_freq, bigram_freq, doc_num, least_score, least_heat)
if __name__ == "__main__":
if len(sys.argv) != 4:
print "usage: %s <infile> <seg_field> <stop_file>"
sys.exit(-1)
g_core_set = load_words("hero.1") #can replace by the other file
g_exclusion_set = load_words("hero.1")
method_type = 'chi_square' # otherwise use mutual information
process_file(sys.argv[1], int(sys.argv[2]), sys.argv[3], method_type)