from typing import Tuple import sqlalchemy as sa import json from fastapi import Depends import pandas as pd from sqlalchemy import func, or_, and_, not_ import schemas from core.config import settings from db.redisdb import get_redis_pool, RedisDrive class BehaviorAnalysis: def __init__(self, game: str, data_in: schemas.CkQuery, rdb: RedisDrive = Depends(get_redis_pool)): self.game = game self.rdb = rdb self.user_tbl = None self.event_tbl = None self.event_view = data_in.eventView self.events = data_in.events self.zone_time: int = 0 self.start_date = None self.end_date = None self.global_filters = None self.groupby = None self.time_particle = None self.date_range = None self.unit_num = None async def init(self): await self._init_table() self.zone_time = self._get_zone_time() self.time_particle = self._get_time_particle_size() self.start_date, self.end_date, self.date_range = self._get_date_range() self.global_filters = self._get_global_filters() self.groupby = self._get_group_by() self.unit_num = self._get_unit_num() def _get_time_particle_size(self): return self.event_view.get('timeParticleSize') or 'P1D' def _get_unit_num(self): return self.event_view.get('unitNum') def _get_group_by(self): return [getattr(self.event_tbl.c, item['columnName']) for item in self.event_view.get('groupBy')] def _get_zone_time(self): return int(self.event_view.get('zone_time', 8)) def _get_date_range(self) -> Tuple[str, str, list]: start_date: str = self.event_view.get('startTime') end_date: str = self.event_view.get('endTime') date_range = pd.date_range(start_date, end_date, freq=settings.PROPHET_TIME_GRAIN_MAP[self.time_particle], tz='UTC').tolist() return start_date, end_date, date_range def _get_global_filters(self): return self.event_view.get('filts') or [] async def _init_table(self): """ 从redis中取出表字段,构建表结构 :return: """ res_json = await self.rdb.get(f'{self.game}_user') columns = json.loads(res_json).keys() metadata = sa.MetaData(schema=self.game) self.user_tbl = sa.Table('user_view', metadata, *[sa.Column(column) for column in columns]) res_json = await self.rdb.get(f'{self.game}_event') columns = json.loads(res_json).keys() metadata = sa.MetaData(schema=self.game) self.event_tbl = sa.Table('event', metadata, *[sa.Column(column) for column in columns]) def handler_filts(self, *ext_filters, g_f=True): user_filter = [] event_filter = [] filters = (*self.global_filters, *ext_filters) if g_f else (*ext_filters,) for item in filters: if item['tableType'] == 'user': where = user_filter elif item['tableType'] == 'event': where = event_filter else: continue tbl = getattr(self, f'{item["tableType"]}_tbl') col = getattr(tbl.c, item['columnName']) comparator = item['comparator'] ftv = item['ftv'] if comparator == '==': if len(ftv) > 1: where.append(or_(*[col == v for v in ftv])) else: where.append(col == ftv[0]) elif comparator == '>=': where.append(col >= ftv[0]) elif comparator == '<=': where.append(col <= ftv[0]) elif comparator == '>': where.append(col > ftv[0]) elif comparator == '<': where.append(col < ftv[0]) elif comparator == 'is not null': where.append(col.isnot(None)) elif comparator == 'is null': where.append(col.is_(None)) elif comparator == '!=': where.append(col != ftv[0]) return event_filter, user_filter def retention_model_sql(self): event_name_a = self.events[0]['eventName'] event_name_b = self.events[1]['eventName'] event_time_col = getattr(self.event_tbl.c, '#event_time') event_name_col = getattr(self.event_tbl.c, '#event_name') e_account_id_col = getattr(self.event_tbl.c, '#account_id') u_account_id_col = getattr(self.user_tbl.c, '#account_id') date_col = sa.Column('date') selectd = [func.toStartOfDay(func.addHours(event_time_col, self.zone_time)).label('date'), event_name_col.label('event_name'), *self.groupby, func.arrayDistinct(func.groupArray(e_account_id_col)).label('values'), func.length(sa.Column('values')).label('amount') ] base_where = [ func.addHours(event_time_col, self.zone_time) >= self.start_date, func.addHours(event_time_col, self.zone_time) <= self.end_date, event_name_col.in_([event_name_a, event_name_b]), ] event_filter, user_filter = self.handler_filts() groupby = [date_col, event_name_col] + self.groupby oredrby = [date_col] if user_filter: qry = sa.select(selectd).select_from( self.event_tbl.join(self.user_tbl, u_account_id_col == e_account_id_col)).where( and_(*user_filter, *event_filter, *base_where)).group_by(*groupby).order_by( *oredrby).limit(10000) else: qry = sa.select(selectd).where(and_(*base_where, *event_filter)).group_by(*groupby).order_by( *oredrby).limit(10000) sql = str(qry.compile(compile_kwargs={"literal_binds": True})) print(sql) return {'sql': sql, 'groupby': ['date', 'event_name'] + [i.key for i in self.groupby], 'date_range': self.date_range, 'event_name': [event_name_a, event_name_b], 'unit_num': self.unit_num } def event_model_sql(self): sqls = [] event_time_col = getattr(self.event_tbl.c, '#event_time') select_exprs = [ settings.TIME_GRAIN_EXPRESSIONS[self.time_particle](event_time_col, self.zone_time)] select_exprs += self.groupby for event in self.events: event_name = event['event_name'] event_name_col = getattr(self.event_tbl.c, '#event_name') base_where = [ func.addHours(event_time_col, self.zone_time) >= self.start_date, func.addHours(event_time_col, self.zone_time) <= self.end_date, event_name_col == event_name ] analysis = event['analysis'] event_filter, user_filter = self.handler_filts(*event['filts']) u_account_id_col = getattr(self.user_tbl.c, '#account_id') # 按账号聚合 e_account_id_col = getattr(self.event_tbl.c, '#account_id') # 聚合方式 if analysis == 'total_count': selectd = select_exprs + [func.count().label('values')] elif analysis == 'touch_user_count': selectd = select_exprs + [func.count(sa.distinct(e_account_id_col)).label('values')] elif analysis == 'touch_user_avg': selectd = select_exprs + [ func.round((func.count() / func.count(sa.distinct(e_account_id_col))), 2).label( 'values')] elif analysis == 'distinct_count': selectd = select_exprs + [ func.count(sa.distinct(getattr(self.event_tbl.c, event['event_attr_id']))).label('values')] else: selectd = select_exprs + [ func.round(getattr(func, analysis)(getattr(self.event_tbl.c, event['event_attr_id'])), 2).label( 'values')] if user_filter: qry = sa.select(selectd).select_from( self.event_tbl.join(self.user_tbl, u_account_id_col == e_account_id_col)).where( and_(*user_filter, *event_filter, *base_where)) else: qry = sa.select(selectd).where(and_(*event_filter, *base_where)) qry = qry.group_by(*select_exprs) qry = qry.order_by(sa.Column('date')) qry = qry.limit(1000) sql = str(qry.compile(compile_kwargs={"literal_binds": True})) print(sql) sqls.append({'sql': sql, 'groupby': [i.key for i in self.groupby], 'date_range': self.date_range, 'event_name': event_name }) return sqls def funnel_model_sql(self): """ SELECT level, count(*) AS values FROM (SELECT windowFunnel(86400)(shjy.event."#event_time", shjy.event."#event_name" = 'create_role', shjy.event."#event_name" = 'login') AS level FROM shjy.event WHERE addHours(shjy.event."#event_time", 8) >= '2021-05-16 00:00:00' AND addHours(shjy.event."#event_time", 8) <= '2021-06-14 23:59:59' GROUP BY shjy.event."#account_id") AS anon_1 GROUP BY level ORDER BY level :return: """ windows_gap = self.event_view['windows_gap'] * 86400 event_time_col = getattr(self.event_tbl.c, '#event_time') event_name_col = getattr(self.event_tbl.c, '#event_name') e_account_id_col = getattr(self.event_tbl.c, '#account_id') sub_group = [*self.groupby, e_account_id_col] conds = [] for item in self.events: event_filter, _ = self.handler_filts(*item['filts'], g_f=False) conds.append( and_(event_name_col == item['eventName'], *event_filter) ) # todo 替换 _windows_gap_ subq = sa.select(*[sa.Column(i.key) for i in self.groupby], func.windowFunnel_windows_gap__(event_time_col, *conds).label('level')).select_from( self.event_tbl) g_event_filter, _ = self.handler_filts() where = [ func.addHours(event_time_col, self.zone_time) >= self.start_date, func.addHours(event_time_col, self.zone_time) <= self.end_date, *g_event_filter ] subq = subq.where(and_(*where)).group_by(*sub_group) subq = subq.subquery() qry = sa.select(*[sa.Column(i.key) for i in self.groupby], sa.Column('level'), func.count().label('values')).select_from(subq) \ .where(sa.Column('level') > 0) \ .group_by(*[sa.Column(i.key) for i in self.groupby], sa.Column('level')) \ .order_by(*[sa.Column(i.key) for i in self.groupby], sa.Column('level')) sql = str(qry.compile(compile_kwargs={"literal_binds": True})) sql = sql.replace('_windows_gap_', f'({windows_gap})') print(sql) return {'sql': sql, 'groupby': [i.key for i in self.groupby], 'date_range': self.date_range, }