338 lines
14 KiB
Python
338 lines
14 KiB
Python
from typing import Tuple
|
||
|
||
import sqlalchemy as sa
|
||
import json
|
||
|
||
from fastapi import Depends
|
||
|
||
import pandas as pd
|
||
|
||
from sqlalchemy import func, or_, and_, not_
|
||
|
||
import schemas
|
||
from core.config import settings
|
||
from db.redisdb import get_redis_pool, RedisDrive
|
||
|
||
|
||
class BehaviorAnalysis:
|
||
def __init__(self, game: str, data_in: schemas.CkQuery, rdb: RedisDrive = Depends(get_redis_pool)):
|
||
self.game = game
|
||
self.rdb = rdb
|
||
self.user_tbl = None
|
||
self.event_tbl = None
|
||
self.event_view = data_in.eventView
|
||
self.events = data_in.events
|
||
|
||
self.zone_time: int = 0
|
||
self.start_date = None
|
||
self.end_date = None
|
||
self.global_filters = None
|
||
self.groupby = None
|
||
self.time_particle = None
|
||
self.date_range = None
|
||
self.unit_num = None
|
||
|
||
async def init(self):
|
||
await self._init_table()
|
||
self.zone_time = self._get_zone_time()
|
||
self.time_particle = self._get_time_particle_size()
|
||
self.start_date, self.end_date, self.date_range = self._get_date_range()
|
||
self.global_filters = self._get_global_filters()
|
||
self.groupby = self._get_group_by()
|
||
self.unit_num = self._get_unit_num()
|
||
|
||
def _get_time_particle_size(self):
|
||
return self.event_view.get('timeParticleSize') or 'P1D'
|
||
|
||
def _get_unit_num(self):
|
||
return self.event_view.get('unitNum')
|
||
|
||
def _get_group_by(self):
|
||
return [getattr(self.event_tbl.c, item['columnName']) for item in self.event_view.get('groupBy')]
|
||
|
||
def _get_zone_time(self):
|
||
return int(self.event_view.get('zone_time', 8))
|
||
|
||
def _get_date_range(self) -> Tuple[str, str, list]:
|
||
start_date: str = self.event_view.get('startTime')
|
||
end_date: str = self.event_view.get('endTime')
|
||
date_range = pd.date_range(start_date, end_date, freq=settings.PROPHET_TIME_GRAIN_MAP[self.time_particle],
|
||
tz='UTC').tolist()
|
||
|
||
return start_date, end_date, date_range
|
||
|
||
def _get_global_filters(self):
|
||
return self.event_view.get('filts') or []
|
||
|
||
async def _init_table(self):
|
||
"""
|
||
从redis中取出表字段,构建表结构
|
||
:return:
|
||
"""
|
||
res_json = await self.rdb.get(f'{self.game}_user')
|
||
columns = json.loads(res_json).keys()
|
||
metadata = sa.MetaData(schema=self.game)
|
||
self.user_tbl = sa.Table('user_view', metadata, *[sa.Column(column) for column in columns])
|
||
|
||
res_json = await self.rdb.get(f'{self.game}_event')
|
||
columns = json.loads(res_json).keys()
|
||
metadata = sa.MetaData(schema=self.game)
|
||
self.event_tbl = sa.Table('event', metadata, *[sa.Column(column) for column in columns])
|
||
|
||
def handler_filts(self, *ext_filters, g_f=True):
|
||
user_filter = []
|
||
event_filter = []
|
||
filters = (*self.global_filters, *ext_filters) if g_f else (*ext_filters,)
|
||
filters = [] if filters == ([],) else filters
|
||
for item in filters:
|
||
if item['tableType'] == 'user':
|
||
where = user_filter
|
||
elif item['tableType'] == 'event':
|
||
where = event_filter
|
||
else:
|
||
continue
|
||
|
||
tbl = getattr(self, f'{item["tableType"]}_tbl')
|
||
col = getattr(tbl.c, item['columnName'])
|
||
|
||
comparator = item['comparator']
|
||
ftv = item['ftv']
|
||
if comparator == '==':
|
||
if len(ftv) > 1:
|
||
where.append(or_(*[col == v for v in ftv]))
|
||
else:
|
||
where.append(col == ftv[0])
|
||
elif comparator == '>=':
|
||
where.append(col >= ftv[0])
|
||
elif comparator == '<=':
|
||
where.append(col <= ftv[0])
|
||
elif comparator == '>':
|
||
where.append(col > ftv[0])
|
||
elif comparator == '<':
|
||
where.append(col < ftv[0])
|
||
|
||
elif comparator == 'is not null':
|
||
where.append(col.isnot(None))
|
||
elif comparator == 'is null':
|
||
where.append(col.is_(None))
|
||
|
||
elif comparator == '!=':
|
||
where.append(col != ftv[0])
|
||
|
||
return event_filter, user_filter
|
||
|
||
def retention_model_sql(self):
|
||
event_name_a = self.events[0]['eventName']
|
||
event_name_b = self.events[1]['eventName']
|
||
event_time_col = getattr(self.event_tbl.c, '#event_time')
|
||
event_name_col = getattr(self.event_tbl.c, '#event_name')
|
||
e_account_id_col = getattr(self.event_tbl.c, '#account_id')
|
||
u_account_id_col = getattr(self.user_tbl.c, '#account_id')
|
||
date_col = sa.Column('date')
|
||
|
||
selectd = [func.toStartOfDay(func.addHours(event_time_col, self.zone_time)).label('date'),
|
||
event_name_col.label('event_name'),
|
||
*self.groupby,
|
||
func.arrayDistinct(func.groupArray(e_account_id_col)).label('values'),
|
||
func.length(sa.Column('values')).label('amount')
|
||
]
|
||
base_where = [
|
||
func.addHours(event_time_col, self.zone_time) >= self.start_date,
|
||
func.addHours(event_time_col, self.zone_time) <= self.end_date,
|
||
event_name_col.in_([event_name_a, event_name_b]),
|
||
]
|
||
|
||
event_filter, user_filter = self.handler_filts()
|
||
|
||
groupby = [date_col, event_name_col] + self.groupby
|
||
oredrby = [date_col]
|
||
if user_filter:
|
||
qry = sa.select(selectd).select_from(
|
||
self.event_tbl.join(self.user_tbl, u_account_id_col == e_account_id_col)).where(
|
||
and_(*user_filter, *event_filter, *base_where)).group_by(*groupby).order_by(
|
||
*oredrby).limit(10000)
|
||
else:
|
||
qry = sa.select(selectd).where(and_(*base_where, *event_filter)).group_by(*groupby).order_by(
|
||
*oredrby).limit(10000)
|
||
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
|
||
print(sql)
|
||
return {'sql': sql,
|
||
'groupby': ['date', 'event_name'] + [i.key for i in self.groupby],
|
||
'date_range': self.date_range,
|
||
'event_name': [event_name_a, event_name_b],
|
||
'unit_num': self.unit_num
|
||
}
|
||
|
||
def event_model_sql(self):
|
||
sqls = []
|
||
event_time_col = getattr(self.event_tbl.c, '#event_time')
|
||
select_exprs = [
|
||
settings.TIME_GRAIN_EXPRESSIONS[self.time_particle](event_time_col, self.zone_time)]
|
||
select_exprs += self.groupby
|
||
|
||
for event in self.events:
|
||
event_name = event['event_name']
|
||
event_name_col = getattr(self.event_tbl.c, '#event_name')
|
||
base_where = [
|
||
func.addHours(event_time_col, self.zone_time) >= self.start_date,
|
||
func.addHours(event_time_col, self.zone_time) <= self.end_date,
|
||
event_name_col == event_name
|
||
]
|
||
analysis = event['analysis']
|
||
event_filter, user_filter = self.handler_filts(*event['filts'])
|
||
|
||
u_account_id_col = getattr(self.user_tbl.c, '#account_id')
|
||
# 按账号聚合
|
||
e_account_id_col = getattr(self.event_tbl.c, '#account_id')
|
||
|
||
# 聚合方式
|
||
if analysis == 'total_count':
|
||
selectd = select_exprs + [func.count().label('values')]
|
||
elif analysis == 'touch_user_count':
|
||
selectd = select_exprs + [func.count(sa.distinct(e_account_id_col)).label('values')]
|
||
elif analysis == 'touch_user_avg':
|
||
selectd = select_exprs + [
|
||
func.round((func.count() / func.count(sa.distinct(e_account_id_col))), 2).label(
|
||
'values')]
|
||
|
||
elif analysis == 'distinct_count':
|
||
selectd = select_exprs + [
|
||
func.count(sa.distinct(getattr(self.event_tbl.c, event['event_attr_id']))).label('values')]
|
||
else:
|
||
selectd = select_exprs + [
|
||
func.round(getattr(func, analysis)(getattr(self.event_tbl.c, event['event_attr_id'])), 2).label(
|
||
'values')]
|
||
|
||
if user_filter:
|
||
qry = sa.select(selectd).select_from(
|
||
self.event_tbl.join(self.user_tbl, u_account_id_col == e_account_id_col)).where(
|
||
and_(*user_filter, *event_filter, *base_where))
|
||
|
||
else:
|
||
qry = sa.select(selectd).where(and_(*event_filter, *base_where))
|
||
|
||
qry = qry.group_by(*select_exprs)
|
||
qry = qry.order_by(sa.Column('date'))
|
||
qry = qry.limit(1000)
|
||
|
||
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
|
||
print(sql)
|
||
sqls.append({'sql': sql,
|
||
'groupby': [i.key for i in self.groupby],
|
||
'date_range': self.date_range,
|
||
'event_name': event_name
|
||
})
|
||
|
||
return sqls
|
||
|
||
def funnel_model_sql(self):
|
||
"""
|
||
SELECT level, count(*) AS values
|
||
FROM (SELECT windowFunnel(86400)(shjy.event."#event_time", shjy.event."#event_name" = 'create_role',
|
||
shjy.event."#event_name" = 'login') AS level
|
||
FROM shjy.event
|
||
WHERE addHours(shjy.event."#event_time", 8) >= '2021-05-16 00:00:00'
|
||
AND addHours(shjy.event."#event_time", 8) <= '2021-06-14 23:59:59'
|
||
GROUP BY shjy.event."#account_id") AS anon_1
|
||
GROUP BY level
|
||
ORDER BY level
|
||
:return:
|
||
"""
|
||
|
||
windows_gap = self.event_view['windows_gap'] * 86400
|
||
event_time_col = getattr(self.event_tbl.c, '#event_time')
|
||
event_name_col = getattr(self.event_tbl.c, '#event_name')
|
||
date_col = func.toStartOfDay(func.addHours(event_time_col, self.zone_time)).label('date')
|
||
e_account_id_col = getattr(self.event_tbl.c, '#account_id')
|
||
|
||
sub_group = [date_col, *self.groupby, e_account_id_col]
|
||
conds = []
|
||
cond_level = []
|
||
for item in self.events:
|
||
event_filter, _ = self.handler_filts(*item['filts'], g_f=False)
|
||
conds.append(
|
||
and_(event_name_col == item['eventName'], *event_filter)
|
||
)
|
||
cond_level.append(item['eventName'])
|
||
# todo 替换 _windows_gap_
|
||
subq = sa.select(*[sa.Column(i.key) for i in self.groupby], date_col,
|
||
func.windowFunnel_windows_gap__(event_time_col, *conds).label('level')).select_from(
|
||
self.event_tbl)
|
||
|
||
g_event_filter, _ = self.handler_filts()
|
||
where = [
|
||
func.addHours(event_time_col, self.zone_time) >= self.start_date,
|
||
func.addHours(event_time_col, self.zone_time) <= self.end_date,
|
||
*g_event_filter
|
||
]
|
||
subq = subq.where(and_(*where)).group_by(*sub_group)
|
||
subq = subq.subquery()
|
||
|
||
qry = sa.select(sa.Column('date'), *[sa.Column(i.key) for i in self.groupby], sa.Column('level'),
|
||
func.count().label('values')).select_from(subq) \
|
||
.where(sa.Column('level') > 0) \
|
||
.group_by(sa.Column('date'), *[sa.Column(i.key) for i in self.groupby], sa.Column('level')) \
|
||
.order_by(sa.Column('date'), *[sa.Column(i.key) for i in self.groupby], sa.Column('level'))
|
||
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
|
||
# sql = sql.replace('_windows_gap_', f"({windows_gap},'strict_increase')")
|
||
sql = sql.replace('_windows_gap_', f"({windows_gap})")
|
||
print(sql)
|
||
return {'sql': sql,
|
||
'groupby': [i.key for i in self.groupby],
|
||
'date_range': self.date_range,
|
||
'cond_level': cond_level
|
||
}
|
||
|
||
def scatter_model_sql(self):
|
||
event = self.events[0]
|
||
event_name = event['eventName']
|
||
analysis = event['analysis']
|
||
e_account_id_col = getattr(self.event_tbl.c, '#account_id')
|
||
event_name_col = getattr(self.event_tbl.c, '#event_name')
|
||
event_time_col = getattr(self.event_tbl.c, '#event_time').label('date')
|
||
event_date_col = settings.TIME_GRAIN_EXPRESSIONS[self.time_particle](event_time_col, self.zone_time)
|
||
|
||
quota_interval_arr = event.get('quotaIntervalArr')
|
||
|
||
where = [
|
||
event_date_col >= self.start_date,
|
||
event_date_col <= self.end_date,
|
||
event_name_col == event_name
|
||
]
|
||
event_filter, _ = self.handler_filts(self.events[0]['filts'])
|
||
where.extend(event_filter)
|
||
values_col = func.count().label('values')
|
||
if analysis in ['number_of_days', 'number_of_hours']:
|
||
values_col = func.count(func.distinct(e_account_id_col)).label('values')
|
||
|
||
if analysis in ['times', 'number_of_days', 'number_of_hours']:
|
||
|
||
qry = sa.select(event_date_col, *self.groupby, values_col) \
|
||
.where(and_(*where)) \
|
||
.group_by(event_date_col, *self.groupby, e_account_id_col)
|
||
|
||
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
|
||
print(sql)
|
||
return {
|
||
'sql': sql,
|
||
'interval_type': event['intervalType'],
|
||
'analysis': analysis,
|
||
'quota_interval_arr': quota_interval_arr,
|
||
'groupby': [i.key for i in self.groupby]
|
||
}
|
||
elif event.get('quota'):
|
||
event_attr_col = getattr(self.event_tbl.c, event['quota'])
|
||
|
||
qry = sa.select(event_date_col, settings.CK_FUNC[analysis](event_attr_col).label('values')) \
|
||
.where(and_(*where)) \
|
||
.group_by(event_date_col, *self.groupby, e_account_id_col)
|
||
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
|
||
print(sql)
|
||
return {
|
||
'sql': sql,
|
||
'interval_type': event['intervalType'],
|
||
'analysis': analysis,
|
||
'quota_interval_arr': quota_interval_arr,
|
||
'groupby': [i.key for i in self.groupby]
|
||
}
|