diff --git a/app.py b/app.py index c3f39da..dda5e6b 100644 --- a/app.py +++ b/app.py @@ -7,16 +7,16 @@ from v2 import * class XProcess(Process): - def __init__(self, partition, lock): + def __init__(self, partition, lock, ipsearch): super(XProcess, self).__init__() self.partition = partition self.lock = lock - + self.ipsearch = ipsearch def run(self): db_client = CK(**settings.CK_CONFIG) sketch = Sketch(db_client) - handler_event = HandlerEvent(db_client, settings.GAME) + handler_event = HandlerEvent(db_client, settings.GAME,ipsearch) handler_user = HandlerUser(db_client, settings.GAME) transmitter = Transmitter(db_client, settings.GAME, sketch, self.lock) transmitter.add_source(handler_event, 10000, 60) diff --git a/ip2region.db b/ip2region.db new file mode 100644 index 0000000..3b6a296 Binary files /dev/null and b/ip2region.db differ diff --git a/main.py b/main.py index a6ffcbf..5109952 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,10 @@ from app import XProcess from multiprocessing import Lock +from v2.ipregion import IpSearch, Ip2Region + if __name__ == '__main__': lock = Lock() + ipsearch = IpSearch(Ip2Region, "ip2region.db") for i in range(0, 16): - XProcess(i, lock).start() + XProcess(i, lock,ipsearch).start() diff --git a/v2/handler_event.py b/v2/handler_event.py index 2e0ce6d..8204759 100644 --- a/v2/handler_event.py +++ b/v2/handler_event.py @@ -6,10 +6,16 @@ __all__ = 'HandlerEvent', class HandlerEvent: tb = 'event' - def __init__(self, db_client, db_name): + def __init__(self, db_client, db_name, ipsearch): self.event = dict() self.db_client = db_client self.db_name = db_name + self.ipsearch = ipsearch + + def set_region(self, data): + ip = data.get('#ip') + if ip: + data['#country'], data['#province'], data['#city'] = self.ipsearch(ip) def merge_update(self, a: dict, b: dict): """ @@ -24,11 +30,13 @@ class HandlerEvent: a[k] = v def track(self, data): + self.set_region(data) self.event[len(self.event)] = data def track_update(self, data): if '#event_id' not in data: return + self.set_region(data) old_event = self.db_client.get_one(self.db_name, self.tb, {'#event_id': data['#event_id']}) old_event['sign'] = -1 self.event[len(self.event)] = old_event @@ -41,6 +49,7 @@ class HandlerEvent: def track_overwrite(self, data): if '#event_id' not in data: return + self.set_region(data) old_event = self.db_client.get_one(self.db_name, self.tb, {'#event_id': data['#event_id']}) old_event['sign'] = -1 self.event[len(self.event)] = old_event diff --git a/v2/ipregion.py b/v2/ipregion.py new file mode 100644 index 0000000..af62c03 --- /dev/null +++ b/v2/ipregion.py @@ -0,0 +1,271 @@ +# -*- coding:utf-8 -*- +""" +" ip2region python seacher client module +" +" Author: koma +" Date : 2015-11-06 +""" +import os +import struct, io, socket, sys +import time + +__all__ = 'IpSearch','Ip2Region' + + +class Ip2Region(object): + + def __init__(self, dbfile): + self.__INDEX_BLOCK_LENGTH = 12 + self.__TOTAL_HEADER_LENGTH = 8192 + self.__f = None + self.__headerSip = [] + self.__headerPtr = [] + self.__headerLen = 0 + self.__indexSPtr = 0 + self.__indexLPtr = 0 + self.__indexCount = 0 + self.__dbBinStr = '' + self.initDatabase(dbfile) + + def memorySearch(self, ip): + """ + " memory search method + " param: ip + """ + if not ip.isdigit(): ip = self.ip2long(ip) + + if self.__dbBinStr == '': + self.__dbBinStr = self.__f.read() # read all the contents in file + self.__indexSPtr = self.getLong(self.__dbBinStr, 0) + self.__indexLPtr = self.getLong(self.__dbBinStr, 4) + self.__indexCount = int((self.__indexLPtr - self.__indexSPtr) / self.__INDEX_BLOCK_LENGTH) + 1 + + l, h, dataPtr = (0, self.__indexCount, 0) + while l <= h: + m = int((l + h) >> 1) + p = self.__indexSPtr + m * self.__INDEX_BLOCK_LENGTH + sip = self.getLong(self.__dbBinStr, p) + + if ip < sip: + h = m - 1 + else: + eip = self.getLong(self.__dbBinStr, p + 4) + if ip > eip: + l = m + 1; + else: + dataPtr = self.getLong(self.__dbBinStr, p + 8) + break + + if dataPtr == 0: raise Exception("Data pointer not found") + + return self.returnData(dataPtr) + + def binarySearch(self, ip): + """ + " binary search method + " param: ip + """ + if not ip.isdigit(): ip = self.ip2long(ip) + + if self.__indexCount == 0: + self.__f.seek(0) + superBlock = self.__f.read(8) + self.__indexSPtr = self.getLong(superBlock, 0) + self.__indexLPtr = self.getLong(superBlock, 4) + self.__indexCount = int((self.__indexLPtr - self.__indexSPtr) / self.__INDEX_BLOCK_LENGTH) + 1 + + l, h, dataPtr = (0, self.__indexCount, 0) + while l <= h: + m = int((l + h) >> 1) + p = m * self.__INDEX_BLOCK_LENGTH + + self.__f.seek(self.__indexSPtr + p) + buffer = self.__f.read(self.__INDEX_BLOCK_LENGTH) + sip = self.getLong(buffer, 0) + if ip < sip: + h = m - 1 + else: + eip = self.getLong(buffer, 4) + if ip > eip: + l = m + 1 + else: + dataPtr = self.getLong(buffer, 8) + break + + if dataPtr == 0: raise Exception("Data pointer not found") + + return self.returnData(dataPtr) + + def btreeSearch(self, ip): + """ + " b-tree search method + " param: ip + """ + if not ip.isdigit(): ip = self.ip2long(ip) + + if len(self.__headerSip) < 1: + headerLen = 0 + # pass the super block + self.__f.seek(8) + # read the header block + b = self.__f.read(self.__TOTAL_HEADER_LENGTH) + # parse the header block + for i in range(0, len(b), 8): + sip = self.getLong(b, i) + ptr = self.getLong(b, i + 4) + if ptr == 0: + break + self.__headerSip.append(sip) + self.__headerPtr.append(ptr) + headerLen += 1 + self.__headerLen = headerLen + + l, h, sptr, eptr = (0, self.__headerLen, 0, 0) + while l <= h: + m = int((l + h) >> 1) + + if ip == self.__headerSip[m]: + if m > 0: + sptr = self.__headerPtr[m - 1] + eptr = self.__headerPtr[m] + else: + sptr = self.__headerPtr[m] + eptr = self.__headerPtr[m + 1] + break + + if ip < self.__headerSip[m]: + if m == 0: + sptr = self.__headerPtr[m] + eptr = self.__headerPtr[m + 1] + break + elif ip > self.__headerSip[m - 1]: + sptr = self.__headerPtr[m - 1] + eptr = self.__headerPtr[m] + break + h = m - 1 + else: + if m == self.__headerLen - 1: + sptr = self.__headerPtr[m - 1] + eptr = self.__headerPtr[m] + break + elif ip <= self.__headerSip[m + 1]: + sptr = self.__headerPtr[m] + eptr = self.__headerPtr[m + 1] + break + l = m + 1 + + if sptr == 0: raise Exception("Index pointer not found") + + indexLen = eptr - sptr + self.__f.seek(sptr) + index = self.__f.read(indexLen + self.__INDEX_BLOCK_LENGTH) + + l, h, dataPrt = (0, int(indexLen / self.__INDEX_BLOCK_LENGTH), 0) + while l <= h: + m = int((l + h) >> 1) + offset = int(m * self.__INDEX_BLOCK_LENGTH) + sip = self.getLong(index, offset) + + if ip < sip: + h = m - 1 + else: + eip = self.getLong(index, offset + 4) + if ip > eip: + l = m + 1; + else: + dataPrt = self.getLong(index, offset + 8) + break + + if dataPrt == 0: raise Exception("Data pointer not found") + + return self.returnData(dataPrt) + + def initDatabase(self, dbfile): + """ + " initialize the database for search + " param: dbFile + """ + try: + self.__f = io.open(dbfile, "rb") + except IOError as e: + print("[Error]: %s" % e) + sys.exit() + + def returnData(self, dataPtr): + """ + " get ip data from db file by data start ptr + " param: dsptr + """ + dataLen = (dataPtr >> 24) & 0xFF + dataPtr = dataPtr & 0x00FFFFFF + + self.__f.seek(dataPtr) + data = self.__f.read(dataLen) + + return { + "city_id": self.getLong(data, 0), + "region": data[4:] + } + + def ip2long(self, ip): + _ip = socket.inet_aton(ip) + return struct.unpack("!L", _ip)[0] + + def isip(self, ip): + p = ip.split(".") + + if len(p) != 4: return False + for pp in p: + if not pp.isdigit(): return False + if len(pp) > 3: return False + if int(pp) > 255: return False + + return True + + def getLong(self, b, offset): + if len(b[offset:offset + 4]) == 4: + return struct.unpack('I', b[offset:offset + 4])[0] + return 0 + + def close(self): + if self.__f != None: + self.__f.close() + + self.__dbBinStr = None + self.__headerPtr = None + self.__headerSip = None + + +class IpSearch: + def __init__(self, ipregion, db_path): + self.ipregion = ipregion + self.db_path = db_path + self.last_ts = 0 + self.searcher = None + self.mtime = 0 + + def search(self, ip): + # import psutil + # print(u'当前进程的内存使用:%.4f KB' % (psutil.Process(os.getpid()).memory_info().rss / 1024)) + # #每10分钟检查 数据文件是否修改 如果修改 重新加载 + country, province, city = None, None, None + try: + ts = int(time.time()) + if self.last_ts + 600 < ts and self.mtime != int(os.stat(self.db_path).st_mtime): + self.mtime = int(os.stat(self.db_path).st_mtime) + self.last_ts = ts + self.searcher = self.ipregion(self.db_path) + + data = self.searcher.memorySearch(ip) + array = data["region"].decode('utf-8').split('|') + country = array[0] + province = array[2] + city = array[3] + # isp = array[4] + except Exception as e: + print(e) + + return country, province, city + + +