B-tree构建轮排索引 python实现

检索系统设计

检索系统整体上利用btree构建轮排索引,利用项目二中已经构建好的倒排索引文件,将词项转化成多个词项,如对于tree,将tree\$, tre\$e, tr\$ee,t\$ree,\$tree插入btree,在$O(nlgn)$时间复杂度内构建轮排索引。

对于每个通配符查询,将对应的通配符查询转换成前缀查询后,在btree上进行区间查询,找到所有符合要求的轮排词项(如tr$ee),将轮排词项重新转化成原词项后进行排序,输出所有符合要求的原词项以及对应的文档列表。

构建、测试Btree

由于轮排索引是在btree上构建的,所以要先实现btree类并测试其正确性。需要实现的功能包括btree插入和区间查询,以及测试构建时间和分裂次数,以及树高、内部节点最大数和节点总数的搜索函数。

此处设置btree阶数_M为11,即内部节点个数最多为10个。

import random
import time

tot_split_time = 0
count_key = 0
mxheight = 0
mxnode = 0
class BNodeType(object):
    def __init__(self, key, data) -> None:
        self.key = key
        self.data = data
    def getKey(self) -> int:
        return self.key
    def getData(self) -> str:
        return self.data

class BNode(object):
    _M = 11
    def __init__(self) -> None:
        self.keyList = []
        self.childList = []
        self.parent = None
    def getParent(self):
        return self.parent
    def setParent(self, parent):
        self.parent = parent

    def getIndex(self, key_data: BNodeType) -> int:
        for index in range(len(self.keyList)):
            if (self.keyList[index].getKey() > key_data.getKey()):
                return index
        return len(self.keyList)

    def blindInsert(self, key_data):
        index = self.getIndex(key_data)
        self.keyList.insert(index, key_data)

    def split(self):
        parent, center, LNode, RNode = self.splitToPieces()
        index = parent.getIndex(center)

        parent.childList.insert(index, RNode)
        parent.childList.insert(index, LNode)
        LNode.setParent(parent)
        RNode.setParent(parent)
        if (self in parent.childList):
            parent.childList.remove(self)
        parent.add(center, modify = True)

    def splitToPieces(self):
        LNode = BNode()
        RNode = BNode()
        center = self.keyList[self._M // 2]
        isLeaf = 0
        if (len(self.childList) == 0):
            isLeaf = 1
        LNode.keyList = self.keyList[:self._M // 2]
        RNode.keyList = self.keyList[self._M // 2 + 1: self._M + 1]

        if (not isLeaf):
            LNode.childList = self.childList[:self._M // 2 + 1]
            for node in LNode.childList:
                node.setParent(LNode)
            RNode.childList = self.childList[self._M // 2 + 1: self._M + 1]
            for node in RNode.childList:
                node.setParent(RNode)

        if (self.getParent() == None):
            self.setParent(BNode())
        return self.getParent(), center, LNode, RNode

    def add(self, key_data: BNodeType, modify = False):
        # 是叶子节点
        if len(self.childList) == 0 or modify:
            self.blindInsert(key_data)
            # key数量达到阶数,分裂节点
            if len(self.keyList) == self._M:
                self.split()
                global tot_split_time
                tot_split_time += 1
        else:
            index = self.getIndex(key_data)    
            self.childList[index].add(key_data)
    def DFSTree(self, height):
        global count_key, mxheight, mxnode
        mxnode = max(mxnode, len(self.keyList))
        count_key += len(self.keyList)
        mxheight = max(mxheight, height)
        isLeaf = 0
        if (len(self.childList) == 0):
            isLeaf = 1
        if (isLeaf):
            for index in range(len(self.keyList)):
                1# print(self.keyList[index].getData())
        else:
            for index in range(len(self.keyList)):
                self.childList[index].DFSTree(height + 1)
                # print(self.keyList[index].getData())
            self.childList[index + 1].DFSTree(height + 1)
    def dfs_by_interval(self, left, right, result):
        isLeaf = len(self.childList) == 0
        for index in range(len(self.keyList)):
            if (self.keyList[index].getKey() < left):
                continue
            if index >= 1 and self.keyList[index - 1].getKey() > right:
                return
            if (not isLeaf):
                self.childList[index].dfs_by_interval(left, right, result)
            if self.keyList[index].getKey() < right:
                result.append(self.keyList[index].getKey())
        if ((not isLeaf) and self.keyList[index].getKey() < right):
            self.childList[index + 1].dfs_by_interval(left, right, result)

    def search_by_interval(self, left, right):
        result = []
        self.dfs_by_interval(left, right, result)
        return result

if __name__ == '__main__':
    root = BNode()
    data_list = []
    for i in range(10000):
        data_list.append(BNodeType(i, 'hasd' + str(i)))
    random.shuffle(data_list)  
    cnt = 0
    print('Insertion started.')
    t = time.time()
    for data in data_list:
        root.add(data)
        if (root.getParent() != None):
            root = root.getParent()
    print('runtime:%.2fs' % (time.time() - t))
    root.DFSTree(1)
    print('tot_split_time:', tot_split_time)
    print('count_key:', count_key)
    print('mxheight:', mxheight)
    print('mxnode:', mxnode)

测试:插入10000个点,并运行测试函数

image-20220319174553838

可见构建10000个点的btree需要0.83s,可见由于是用python实现,运行效率是比较慢的。分裂次数、总节点个数、树高和最大内部节点个数符合预期。

轮排索引构建

在btree的基础上构建轮排查询,对于每一条词项,插入多个轮排词项(如对于tree,将tree\$, tre\$e, tr\$ee,t\$ree,\$tree插入btree),并用字典记录每个原词项对应的文档列表。

由于轮排词项较多(约160万个),构建轮排索引所需要的时间较长(约3分钟)。

def getDict():
    f = open('dict.index2.txt', 'r', encoding='utf-8')
    # wt = open('eindex.txt', 'w', encoding='utf-8')
    lines = f.readlines()
    wordDict = {}
    docsList = {}
    cnt = 0
    for index in range(len(lines)):
        word, df, docs = lines[index].strip().split('\t')
        rule = re.compile('[^a-z]')
        # if (len(word) > 10): continue
        # if (rule.search(word) != None): continue
        # if (word.startswith('m.')): continue
        for i in range(len(word) + 1):
            newWord = word[i:] + '$' + word[:i]
            cnt += 1
            wordDict[newWord] = cnt
            # print(newWord)
        docsList[word] = docs
    print('total words:%d' % cnt)
    return wordDict, docsList

if __name__ == '__main__':   
    wordDict, docsList = getDict()
    t = time.time()
    root = BNode()
    data_list = []
    for key in wordDict:
        data_list.append(BNodeType(key, key))
    # print(len(data_list))
    tot = len(data_list)
    random.shuffle(data_list)
    cnt = 0
    for data in data_list:
        root.add(data)
        if (root.getParent() != None):
            root = root.getParent()
        cnt += 1
        if (cnt % 10000 == 0): 
            # break
            print('loading dictionary: %.2f%%' % (cnt * 100 / tot))
    print('loading completed. time : %.2fs' % (time.time() - t))

通配符查询处理

将通配符查询转换为前缀查询,根据通配符的个数分类讨论:

  • 若没有通配符,则只需要查询单个词项,若原查询为A,则对应的查询区间为[$A,\$A\$]

  • 若通配符个数为1,则只需要一次前缀查询即可找出所有符合要求的词项,若原查询为A*B,则对应的查询区间左边界为B\$A,右边界为B\$A将最后一个字符ASCII码+1后的结果

  • 若通配符个数为2,则查询方法比较复杂,以A*B*C为例,首先转化成通配符个数为1的查询A*C,此时查询的结果是真实结果的超集,可能有一些不符合要求的结果也在查询结果里(如AC),因此需要利用正则化工具对查询结果做二次筛选,剔除那些中间不包含B的结果。

# 开始处理通配符查询
    # 允许以下查询模式
    # 1. A
    # 2. A*
    # 3. *B
    # 4. A*B
    # 5. A*B*C (先找出所有A*C,再对查找结果用正则表达式二次判断)
    while(1):
        input_str = input('Please input a string: ')
        input_str = input_str.replace(' ','')
        pos = find_last(input_str, '*')
        first_pos = input_str.find('*')
        if (pos == -1):
            # 没有通配符
            left = '$' + input_str
            right = '$' + input_str + '$'
        elif pos == first_pos:
            # 只有一个通配符
            left = input_str[pos + 1:] + '$' + input_str[:pos]
            temp = chr(ord(left[-1]) + 1)
            right = left[:-1] + temp
        else:
            # 有两个或更多通配符
            left = input_str[pos + 1:] + '$' + input_str[:first_pos]
            temp = chr(ord(left[-1]) + 1)
            right = left[:-1] + temp
        results = root.search_by_interval(left, right)
        real_results = []
        regex = input_str
        regex = regex.replace('*', '\w*')
        rule = re.compile(regex)
        for result in results:
            pos = result.find('$')
            real = result[pos + 1:] + result[: pos]
            # 正则表达式二次验证
            if (rule.match(real)):
                real_results.append(real)
        real_results.sort()
        if (len(real_results) == 0):
            print('No result.')
        else:
            result1 = getDocList(docsList[real_results[0]])
            if (len(real_results) == 1):
                print(result1)
            else:
                result2 = getDocList(docsList[real_results[1]])
                result = list(set(result1).union(set(result2)))
                result.sort()
                print(result)
            print(real_results)

结果

部分结果如下:

image-20220319174834413

image-20220319174837558

image-20220319174843370

image-20220319174846846

image-20220319174854566

总结

  1. 处理存在多个通配符的查询

    如上所说,存在多个通配符的查询处理起来比较复杂,不仅需要用轮排索引找到对应的前后缀,还要利用正则化等工具进行二次检查,由于二次检查需要遍历每个查询结果,所以若查询结果数量较多,则运行效率会受到影响。但目前看来没有其他更好的方法解决该问题。

  2. Btree效率测试和阶数选择

    btree的阶数会很大程度上影响btree索引的构建和查询效率,阶数过大或国小都会导致查询效率低下。经过一些测试,认为阶数在7-15范围内效率较高,因此项目中使用11作为btree的阶数。

代码

mybtree.py

import random
import time

tot_split_time = 0
count_key = 0
mxheight = 0
mxnode = 0
class BNodeType(object):
    def __init__(self, key, data) -> None:
        self.key = key
        self.data = data
    def getKey(self) -> int:
        return self.key
    def getData(self) -> str:
        return self.data

class BNode(object):
    _M = 11
    def __init__(self) -> None:
        self.keyList = []
        self.childList = []
        self.parent = None
    def getParent(self):
        return self.parent
    def setParent(self, parent):
        self.parent = parent

    def getIndex(self, key_data: BNodeType) -> int:
        for index in range(len(self.keyList)):
            if (self.keyList[index].getKey() > key_data.getKey()):
                return index
        return len(self.keyList)

    def blindInsert(self, key_data):
        index = self.getIndex(key_data)
        self.keyList.insert(index, key_data)

    def split(self):
        parent, center, LNode, RNode = self.splitToPieces()
        index = parent.getIndex(center)

        parent.childList.insert(index, RNode)
        parent.childList.insert(index, LNode)
        LNode.setParent(parent)
        RNode.setParent(parent)
        if (self in parent.childList):
            parent.childList.remove(self)
        parent.add(center, modify = True)

    def splitToPieces(self):
        LNode = BNode()
        RNode = BNode()
        center = self.keyList[self._M // 2]
        isLeaf = 0
        if (len(self.childList) == 0):
            isLeaf = 1
        LNode.keyList = self.keyList[:self._M // 2]
        RNode.keyList = self.keyList[self._M // 2 + 1: self._M + 1]

        if (not isLeaf):
            LNode.childList = self.childList[:self._M // 2 + 1]
            for node in LNode.childList:
                node.setParent(LNode)
            RNode.childList = self.childList[self._M // 2 + 1: self._M + 1]
            for node in RNode.childList:
                node.setParent(RNode)

        if (self.getParent() == None):
            self.setParent(BNode())
        return self.getParent(), center, LNode, RNode

    def add(self, key_data: BNodeType, modify = False):
        # 是叶子节点
        if len(self.childList) == 0 or modify:
            self.blindInsert(key_data)
            # key数量达到阶数,分裂节点
            if len(self.keyList) == self._M:
                self.split()
                global tot_split_time
                tot_split_time += 1
        else:
            index = self.getIndex(key_data)    
            self.childList[index].add(key_data)
    def DFSTree(self, height):
        global count_key, mxheight, mxnode
        mxnode = max(mxnode, len(self.keyList))
        count_key += len(self.keyList)
        mxheight = max(mxheight, height)
        isLeaf = 0
        if (len(self.childList) == 0):
            isLeaf = 1
        if (isLeaf):
            for index in range(len(self.keyList)):
                1# print(self.keyList[index].getData())
        else:
            for index in range(len(self.keyList)):
                self.childList[index].DFSTree(height + 1)
                # print(self.keyList[index].getData())
            self.childList[index + 1].DFSTree(height + 1)
    def dfs_by_interval(self, left, right, result):
        isLeaf = len(self.childList) == 0
        for index in range(len(self.keyList)):
            if (self.keyList[index].getKey() < left):
                continue
            if index >= 1 and self.keyList[index - 1].getKey() > right:
                return
            if (not isLeaf):
                self.childList[index].dfs_by_interval(left, right, result)
            if self.keyList[index].getKey() < right:
                result.append(self.keyList[index].getKey())
        if ((not isLeaf) and self.keyList[index].getKey() < right):
            self.childList[index + 1].dfs_by_interval(left, right, result)

    def search_by_interval(self, left, right):
        result = []
        self.dfs_by_interval(left, right, result)
        return result

if __name__ == '__main__':
    root = BNode()
    data_list = []
    for i in range(10000):
        data_list.append(BNodeType(i, 'hasd' + str(i)))
    random.shuffle(data_list)  
    cnt = 0
    print('Insertion started.')
    t = time.time()
    for data in data_list:
        root.add(data)
        if (root.getParent() != None):
            root = root.getParent()
    print('runtime:%.2fs' % (time.time() - t))
    root.DFSTree(1)
    print('tot_split_time:', tot_split_time)
    print('count_key:', count_key)
    print('mxheight:', mxheight)
    print('mxnode:', mxnode)

main.py

from random import shuffle
import random
# from BTree import *
from mybtree import *
import re
import time
def getDict():
    f = open('dict.index2.txt', 'r', encoding='utf-8')
    # wt = open('eindex.txt', 'w', encoding='utf-8')
    lines = f.readlines()
    wordDict = {}
    docsList = {}
    cnt = 0
    for index in range(len(lines)):
        word, df, docs = lines[index].strip().split('\t')
        rule = re.compile('[^a-z]')
        # if (len(word) > 10): continue
        # if (rule.search(word) != None): continue
        # if (word.startswith('m.')): continue
        for i in range(len(word) + 1):
            newWord = word[i:] + '$' + word[:i]
            cnt += 1
            wordDict[newWord] = cnt
            # print(newWord)
        docsList[word] = docs
    print('total words:%d' % cnt)
    return wordDict, docsList

def find_last(string,str):
    last_position=-1
    while True:
        position=string.find(str,last_position+1)
        if position==-1:
            return last_position
        last_position=position

def getDocList(docs):
    docList = docs.split(',')
    sum = 0
    res = []
    for doc in docList:
        sum += int(doc)
        res.append(sum)
    return res
if __name__ == '__main__':   
    wordDict, docsList = getDict()
    t = time.time()
    root = BNode()
    data_list = []
    for key in wordDict:
        data_list.append(BNodeType(key, key))
    # print(len(data_list))
    tot = len(data_list)
    random.shuffle(data_list)
    cnt = 0
    for data in data_list:
        root.add(data)
        if (root.getParent() != None):
            root = root.getParent()
        cnt += 1
        if (cnt % 10000 == 0): 
            # break
            print('loading dictionary: %.2f%%' % (cnt * 100 / tot))
    print('loading completed. time : %.2fs' % (time.time() - t))
    # root.DFSTree(1)
    # 开始处理通配符查询
    # 允许以下查询模式
    # 1. A
    # 2. A*
    # 3. *B
    # 4. A*B
    # 5. A*B*C (先找出所有A*C,再对查找结果用正则表达式二次判断)
    while(1):
        input_str = input('Please input a string: ')
        input_str = input_str.replace(' ','')
        pos = find_last(input_str, '*')
        first_pos = input_str.find('*')
        if (pos == -1):
            # 没有通配符
            left = '$' + input_str
            right = '$' + input_str + '$'
        elif pos == first_pos:
            # 只有一个通配符
            left = input_str[pos + 1:] + '$' + input_str[:pos]
            temp = chr(ord(left[-1]) + 1)
            right = left[:-1] + temp
        else:
            # 有两个或更多通配符
            left = input_str[pos + 1:] + '$' + input_str[:first_pos]
            temp = chr(ord(left[-1]) + 1)
            right = left[:-1] + temp
        results = root.search_by_interval(left, right)
        real_results = []
        regex = input_str
        regex = regex.replace('*', '\w*')
        rule = re.compile(regex)
        for result in results:
            pos = result.find('$')
            real = result[pos + 1:] + result[: pos]
            # 正则表达式二次验证
            if (rule.match(real)):
                real_results.append(real)
        real_results.sort()
        if (len(real_results) == 0):
            print('No result.')
        else:
            result1 = getDocList(docsList[real_results[0]])
            if (len(real_results) == 1):
                print(result1)
            else:
                result2 = getDocList(docsList[real_results[1]])
                result = list(set(result1).union(set(result2)))
                result.sort()
                print(result)
            print(real_results)