B-tree构建轮排索引 python实现
检索系统设计
检索系统整体上利用btree构建轮排索引,利用项目二中已经构建好的倒排索引文件,将词项转化成多个词项,如对于tree,将tree\$, tre\$e, tr\$ee,t\$ree,\$tree插入btree,在$O(nlgn)$时间复杂度内构建轮排索引。
对于每个通配符查询,将对应的通配符查询转换成前缀查询后,在btree上进行区间查询,找到所有符合要求的轮排词项(如tr$ee),将轮排词项重新转化成原词项后进行排序,输出所有符合要求的原词项以及对应的文档列表。
构建、测试Btree
由于轮排索引是在btree上构建的,所以要先实现btree类并测试其正确性。需要实现的功能包括btree插入和区间查询,以及测试构建时间和分裂次数,以及树高、内部节点最大数和节点总数的搜索函数。
此处设置btree阶数_M为11,即内部节点个数最多为10个。
import random
import time
tot_split_time = 0
count_key = 0
mxheight = 0
mxnode = 0
class BNodeType(object):
def __init__(self, key, data) -> None:
self.key = key
self.data = data
def getKey(self) -> int:
return self.key
def getData(self) -> str:
return self.data
class BNode(object):
_M = 11
def __init__(self) -> None:
self.keyList = []
self.childList = []
self.parent = None
def getParent(self):
return self.parent
def setParent(self, parent):
self.parent = parent
def getIndex(self, key_data: BNodeType) -> int:
for index in range(len(self.keyList)):
if (self.keyList[index].getKey() > key_data.getKey()):
return index
return len(self.keyList)
def blindInsert(self, key_data):
index = self.getIndex(key_data)
self.keyList.insert(index, key_data)
def split(self):
parent, center, LNode, RNode = self.splitToPieces()
index = parent.getIndex(center)
parent.childList.insert(index, RNode)
parent.childList.insert(index, LNode)
LNode.setParent(parent)
RNode.setParent(parent)
if (self in parent.childList):
parent.childList.remove(self)
parent.add(center, modify = True)
def splitToPieces(self):
LNode = BNode()
RNode = BNode()
center = self.keyList[self._M // 2]
isLeaf = 0
if (len(self.childList) == 0):
isLeaf = 1
LNode.keyList = self.keyList[:self._M // 2]
RNode.keyList = self.keyList[self._M // 2 + 1: self._M + 1]
if (not isLeaf):
LNode.childList = self.childList[:self._M // 2 + 1]
for node in LNode.childList:
node.setParent(LNode)
RNode.childList = self.childList[self._M // 2 + 1: self._M + 1]
for node in RNode.childList:
node.setParent(RNode)
if (self.getParent() == None):
self.setParent(BNode())
return self.getParent(), center, LNode, RNode
def add(self, key_data: BNodeType, modify = False):
# 是叶子节点
if len(self.childList) == 0 or modify:
self.blindInsert(key_data)
# key数量达到阶数,分裂节点
if len(self.keyList) == self._M:
self.split()
global tot_split_time
tot_split_time += 1
else:
index = self.getIndex(key_data)
self.childList[index].add(key_data)
def DFSTree(self, height):
global count_key, mxheight, mxnode
mxnode = max(mxnode, len(self.keyList))
count_key += len(self.keyList)
mxheight = max(mxheight, height)
isLeaf = 0
if (len(self.childList) == 0):
isLeaf = 1
if (isLeaf):
for index in range(len(self.keyList)):
1# print(self.keyList[index].getData())
else:
for index in range(len(self.keyList)):
self.childList[index].DFSTree(height + 1)
# print(self.keyList[index].getData())
self.childList[index + 1].DFSTree(height + 1)
def dfs_by_interval(self, left, right, result):
isLeaf = len(self.childList) == 0
for index in range(len(self.keyList)):
if (self.keyList[index].getKey() < left):
continue
if index >= 1 and self.keyList[index - 1].getKey() > right:
return
if (not isLeaf):
self.childList[index].dfs_by_interval(left, right, result)
if self.keyList[index].getKey() < right:
result.append(self.keyList[index].getKey())
if ((not isLeaf) and self.keyList[index].getKey() < right):
self.childList[index + 1].dfs_by_interval(left, right, result)
def search_by_interval(self, left, right):
result = []
self.dfs_by_interval(left, right, result)
return result
if __name__ == '__main__':
root = BNode()
data_list = []
for i in range(10000):
data_list.append(BNodeType(i, 'hasd' + str(i)))
random.shuffle(data_list)
cnt = 0
print('Insertion started.')
t = time.time()
for data in data_list:
root.add(data)
if (root.getParent() != None):
root = root.getParent()
print('runtime:%.2fs' % (time.time() - t))
root.DFSTree(1)
print('tot_split_time:', tot_split_time)
print('count_key:', count_key)
print('mxheight:', mxheight)
print('mxnode:', mxnode)
测试:插入10000个点,并运行测试函数
可见构建10000个点的btree需要0.83s,可见由于是用python实现,运行效率是比较慢的。分裂次数、总节点个数、树高和最大内部节点个数符合预期。
轮排索引构建
在btree的基础上构建轮排查询,对于每一条词项,插入多个轮排词项(如对于tree,将tree\$, tre\$e, tr\$ee,t\$ree,\$tree插入btree),并用字典记录每个原词项对应的文档列表。
由于轮排词项较多(约160万个),构建轮排索引所需要的时间较长(约3分钟)。
def getDict():
f = open('dict.index2.txt', 'r', encoding='utf-8')
# wt = open('eindex.txt', 'w', encoding='utf-8')
lines = f.readlines()
wordDict = {}
docsList = {}
cnt = 0
for index in range(len(lines)):
word, df, docs = lines[index].strip().split('\t')
rule = re.compile('[^a-z]')
# if (len(word) > 10): continue
# if (rule.search(word) != None): continue
# if (word.startswith('m.')): continue
for i in range(len(word) + 1):
newWord = word[i:] + '$' + word[:i]
cnt += 1
wordDict[newWord] = cnt
# print(newWord)
docsList[word] = docs
print('total words:%d' % cnt)
return wordDict, docsList
if __name__ == '__main__':
wordDict, docsList = getDict()
t = time.time()
root = BNode()
data_list = []
for key in wordDict:
data_list.append(BNodeType(key, key))
# print(len(data_list))
tot = len(data_list)
random.shuffle(data_list)
cnt = 0
for data in data_list:
root.add(data)
if (root.getParent() != None):
root = root.getParent()
cnt += 1
if (cnt % 10000 == 0):
# break
print('loading dictionary: %.2f%%' % (cnt * 100 / tot))
print('loading completed. time : %.2fs' % (time.time() - t))
通配符查询处理
将通配符查询转换为前缀查询,根据通配符的个数分类讨论:
若没有通配符,则只需要查询单个词项,若原查询为A,则对应的查询区间为[$A,\$A\$]
若通配符个数为1,则只需要一次前缀查询即可找出所有符合要求的词项,若原查询为A*B,则对应的查询区间左边界为B\$A,右边界为B\$A将最后一个字符ASCII码+1后的结果
若通配符个数为2,则查询方法比较复杂,以A*B*C为例,首先转化成通配符个数为1的查询A*C,此时查询的结果是真实结果的超集,可能有一些不符合要求的结果也在查询结果里(如AC),因此需要利用正则化工具对查询结果做二次筛选,剔除那些中间不包含B的结果。
# 开始处理通配符查询
# 允许以下查询模式
# 1. A
# 2. A*
# 3. *B
# 4. A*B
# 5. A*B*C (先找出所有A*C,再对查找结果用正则表达式二次判断)
while(1):
input_str = input('Please input a string: ')
input_str = input_str.replace(' ','')
pos = find_last(input_str, '*')
first_pos = input_str.find('*')
if (pos == -1):
# 没有通配符
left = '$' + input_str
right = '$' + input_str + '$'
elif pos == first_pos:
# 只有一个通配符
left = input_str[pos + 1:] + '$' + input_str[:pos]
temp = chr(ord(left[-1]) + 1)
right = left[:-1] + temp
else:
# 有两个或更多通配符
left = input_str[pos + 1:] + '$' + input_str[:first_pos]
temp = chr(ord(left[-1]) + 1)
right = left[:-1] + temp
results = root.search_by_interval(left, right)
real_results = []
regex = input_str
regex = regex.replace('*', '\w*')
rule = re.compile(regex)
for result in results:
pos = result.find('$')
real = result[pos + 1:] + result[: pos]
# 正则表达式二次验证
if (rule.match(real)):
real_results.append(real)
real_results.sort()
if (len(real_results) == 0):
print('No result.')
else:
result1 = getDocList(docsList[real_results[0]])
if (len(real_results) == 1):
print(result1)
else:
result2 = getDocList(docsList[real_results[1]])
result = list(set(result1).union(set(result2)))
result.sort()
print(result)
print(real_results)
结果
部分结果如下:
总结
处理存在多个通配符的查询
如上所说,存在多个通配符的查询处理起来比较复杂,不仅需要用轮排索引找到对应的前后缀,还要利用正则化等工具进行二次检查,由于二次检查需要遍历每个查询结果,所以若查询结果数量较多,则运行效率会受到影响。但目前看来没有其他更好的方法解决该问题。
Btree效率测试和阶数选择
btree的阶数会很大程度上影响btree索引的构建和查询效率,阶数过大或国小都会导致查询效率低下。经过一些测试,认为阶数在7-15范围内效率较高,因此项目中使用11作为btree的阶数。
代码
mybtree.py
import random
import time
tot_split_time = 0
count_key = 0
mxheight = 0
mxnode = 0
class BNodeType(object):
def __init__(self, key, data) -> None:
self.key = key
self.data = data
def getKey(self) -> int:
return self.key
def getData(self) -> str:
return self.data
class BNode(object):
_M = 11
def __init__(self) -> None:
self.keyList = []
self.childList = []
self.parent = None
def getParent(self):
return self.parent
def setParent(self, parent):
self.parent = parent
def getIndex(self, key_data: BNodeType) -> int:
for index in range(len(self.keyList)):
if (self.keyList[index].getKey() > key_data.getKey()):
return index
return len(self.keyList)
def blindInsert(self, key_data):
index = self.getIndex(key_data)
self.keyList.insert(index, key_data)
def split(self):
parent, center, LNode, RNode = self.splitToPieces()
index = parent.getIndex(center)
parent.childList.insert(index, RNode)
parent.childList.insert(index, LNode)
LNode.setParent(parent)
RNode.setParent(parent)
if (self in parent.childList):
parent.childList.remove(self)
parent.add(center, modify = True)
def splitToPieces(self):
LNode = BNode()
RNode = BNode()
center = self.keyList[self._M // 2]
isLeaf = 0
if (len(self.childList) == 0):
isLeaf = 1
LNode.keyList = self.keyList[:self._M // 2]
RNode.keyList = self.keyList[self._M // 2 + 1: self._M + 1]
if (not isLeaf):
LNode.childList = self.childList[:self._M // 2 + 1]
for node in LNode.childList:
node.setParent(LNode)
RNode.childList = self.childList[self._M // 2 + 1: self._M + 1]
for node in RNode.childList:
node.setParent(RNode)
if (self.getParent() == None):
self.setParent(BNode())
return self.getParent(), center, LNode, RNode
def add(self, key_data: BNodeType, modify = False):
# 是叶子节点
if len(self.childList) == 0 or modify:
self.blindInsert(key_data)
# key数量达到阶数,分裂节点
if len(self.keyList) == self._M:
self.split()
global tot_split_time
tot_split_time += 1
else:
index = self.getIndex(key_data)
self.childList[index].add(key_data)
def DFSTree(self, height):
global count_key, mxheight, mxnode
mxnode = max(mxnode, len(self.keyList))
count_key += len(self.keyList)
mxheight = max(mxheight, height)
isLeaf = 0
if (len(self.childList) == 0):
isLeaf = 1
if (isLeaf):
for index in range(len(self.keyList)):
1# print(self.keyList[index].getData())
else:
for index in range(len(self.keyList)):
self.childList[index].DFSTree(height + 1)
# print(self.keyList[index].getData())
self.childList[index + 1].DFSTree(height + 1)
def dfs_by_interval(self, left, right, result):
isLeaf = len(self.childList) == 0
for index in range(len(self.keyList)):
if (self.keyList[index].getKey() < left):
continue
if index >= 1 and self.keyList[index - 1].getKey() > right:
return
if (not isLeaf):
self.childList[index].dfs_by_interval(left, right, result)
if self.keyList[index].getKey() < right:
result.append(self.keyList[index].getKey())
if ((not isLeaf) and self.keyList[index].getKey() < right):
self.childList[index + 1].dfs_by_interval(left, right, result)
def search_by_interval(self, left, right):
result = []
self.dfs_by_interval(left, right, result)
return result
if __name__ == '__main__':
root = BNode()
data_list = []
for i in range(10000):
data_list.append(BNodeType(i, 'hasd' + str(i)))
random.shuffle(data_list)
cnt = 0
print('Insertion started.')
t = time.time()
for data in data_list:
root.add(data)
if (root.getParent() != None):
root = root.getParent()
print('runtime:%.2fs' % (time.time() - t))
root.DFSTree(1)
print('tot_split_time:', tot_split_time)
print('count_key:', count_key)
print('mxheight:', mxheight)
print('mxnode:', mxnode)
main.py
from random import shuffle
import random
# from BTree import *
from mybtree import *
import re
import time
def getDict():
f = open('dict.index2.txt', 'r', encoding='utf-8')
# wt = open('eindex.txt', 'w', encoding='utf-8')
lines = f.readlines()
wordDict = {}
docsList = {}
cnt = 0
for index in range(len(lines)):
word, df, docs = lines[index].strip().split('\t')
rule = re.compile('[^a-z]')
# if (len(word) > 10): continue
# if (rule.search(word) != None): continue
# if (word.startswith('m.')): continue
for i in range(len(word) + 1):
newWord = word[i:] + '$' + word[:i]
cnt += 1
wordDict[newWord] = cnt
# print(newWord)
docsList[word] = docs
print('total words:%d' % cnt)
return wordDict, docsList
def find_last(string,str):
last_position=-1
while True:
position=string.find(str,last_position+1)
if position==-1:
return last_position
last_position=position
def getDocList(docs):
docList = docs.split(',')
sum = 0
res = []
for doc in docList:
sum += int(doc)
res.append(sum)
return res
if __name__ == '__main__':
wordDict, docsList = getDict()
t = time.time()
root = BNode()
data_list = []
for key in wordDict:
data_list.append(BNodeType(key, key))
# print(len(data_list))
tot = len(data_list)
random.shuffle(data_list)
cnt = 0
for data in data_list:
root.add(data)
if (root.getParent() != None):
root = root.getParent()
cnt += 1
if (cnt % 10000 == 0):
# break
print('loading dictionary: %.2f%%' % (cnt * 100 / tot))
print('loading completed. time : %.2fs' % (time.time() - t))
# root.DFSTree(1)
# 开始处理通配符查询
# 允许以下查询模式
# 1. A
# 2. A*
# 3. *B
# 4. A*B
# 5. A*B*C (先找出所有A*C,再对查找结果用正则表达式二次判断)
while(1):
input_str = input('Please input a string: ')
input_str = input_str.replace(' ','')
pos = find_last(input_str, '*')
first_pos = input_str.find('*')
if (pos == -1):
# 没有通配符
left = '$' + input_str
right = '$' + input_str + '$'
elif pos == first_pos:
# 只有一个通配符
left = input_str[pos + 1:] + '$' + input_str[:pos]
temp = chr(ord(left[-1]) + 1)
right = left[:-1] + temp
else:
# 有两个或更多通配符
left = input_str[pos + 1:] + '$' + input_str[:first_pos]
temp = chr(ord(left[-1]) + 1)
right = left[:-1] + temp
results = root.search_by_interval(left, right)
real_results = []
regex = input_str
regex = regex.replace('*', '\w*')
rule = re.compile(regex)
for result in results:
pos = result.find('$')
real = result[pos + 1:] + result[: pos]
# 正则表达式二次验证
if (rule.match(real)):
real_results.append(real)
real_results.sort()
if (len(real_results) == 0):
print('No result.')
else:
result1 = getDocList(docsList[real_results[0]])
if (len(real_results) == 1):
print(result1)
else:
result2 = getDocList(docsList[real_results[1]])
result = list(set(result1).union(set(result2)))
result.sort()
print(result)
print(real_results)