xbackend/models/behavior_analysis.py
2021-06-11 17:02:42 +08:00

271 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Tuple
import sqlalchemy as sa
import json
from fastapi import Depends
import pandas as pd
from sqlalchemy import func, or_, and_, not_
import schemas
from core.config import settings
from db.redisdb import get_redis_pool, RedisDrive
class BehaviorAnalysis:
def __init__(self, game: str, data_in: schemas.CkQuery, rdb: RedisDrive = Depends(get_redis_pool)):
self.game = game
self.rdb = rdb
self.user_tbl = None
self.event_tbl = None
self.event_view = data_in.eventView
self.events = data_in.events
self.zone_time: int = 0
self.start_date = None
self.end_date = None
self.global_filters = None
self.groupby = None
self.time_particle = None
self.date_range = None
self.unit_num = None
async def init(self):
await self._init_table()
self.zone_time = self._get_zone_time()
self.time_particle = self._get_time_particle_size()
self.start_date, self.end_date, self.date_range = self._get_date_range()
self.global_filters = self._get_global_filters()
self.groupby = self._get_group_by()
self.unit_num = self._get_unit_num()
def _get_time_particle_size(self):
return self.event_view.get('timeParticleSize') or 'P1D'
def _get_unit_num(self):
return self.event_view.get('unitNum')
def _get_group_by(self):
return [getattr(self.event_tbl.c, item['columnName']) for item in self.event_view.get('groupBy')]
def _get_zone_time(self):
return int(self.event_view.get('zone_time', 8))
def _get_date_range(self) -> Tuple[str, str, list]:
start_date: str = self.event_view.get('startTime')
end_date: str = self.event_view.get('endTime')
date_range = pd.date_range(start_date, end_date, freq=settings.PROPHET_TIME_GRAIN_MAP[self.time_particle],
tz='UTC').tolist()
return start_date, end_date, date_range
def _get_global_filters(self):
return self.event_view.get('filts') or []
async def _init_table(self):
"""
从redis中取出表字段构建表结构
:return:
"""
res_json = await self.rdb.get(f'{self.game}_user')
columns = json.loads(res_json).keys()
metadata = sa.MetaData(schema=self.game)
self.user_tbl = sa.Table('user_view', metadata, *[sa.Column(column) for column in columns])
res_json = await self.rdb.get(f'{self.game}_event')
columns = json.loads(res_json).keys()
metadata = sa.MetaData(schema=self.game)
self.event_tbl = sa.Table('event', metadata, *[sa.Column(column) for column in columns])
def handler_filts(self, *ext_filters, g_f=True):
user_filter = []
event_filter = []
filters = (*self.global_filters, *ext_filters) if g_f else (*ext_filters,)
for item in filters:
if item['tableType'] == 'user':
where = user_filter
elif item['tableType'] == 'event':
where = event_filter
else:
continue
tbl = getattr(self, f'{item["tableType"]}_tbl')
col = getattr(tbl.c, item['columnName'])
comparator = item['comparator']
ftv = item['ftv']
if comparator == '==':
if len(ftv) > 1:
where.append(or_(*[col == v for v in ftv]))
else:
where.append(col == ftv[0])
elif comparator == '>=':
where.append(col >= ftv[0])
elif comparator == '<=':
where.append(col <= ftv[0])
elif comparator == '>':
where.append(col > ftv[0])
elif comparator == '<':
where.append(col < ftv[0])
elif comparator == 'is not null':
where.append(col.isnot(None))
elif comparator == 'is null':
where.append(col.is_(None))
elif comparator == '!=':
where.append(col != ftv[0])
return event_filter, user_filter
def retention_model_sql(self):
event_name_a = self.events[0]['eventName']
event_name_b = self.events[1]['eventName']
event_time_col = getattr(self.event_tbl.c, '#event_time')
event_name_col = getattr(self.event_tbl.c, '#event_name')
e_account_id_col = getattr(self.event_tbl.c, '#account_id')
u_account_id_col = getattr(self.user_tbl.c, '#account_id')
date_col = sa.Column('date')
selectd = [func.toStartOfDay(func.addHours(event_time_col, self.zone_time)).label('date'),
event_name_col.label('event_name'),
*self.groupby,
func.arrayDistinct(func.groupArray(e_account_id_col)).label('values'),
func.length(sa.Column('values')).label('amount')
]
base_where = [
func.addHours(event_time_col, self.zone_time) >= self.start_date,
func.addHours(event_time_col, self.zone_time) <= self.end_date,
event_name_col.in_([event_name_a, event_name_b]),
]
event_filter, user_filter = self.handler_filts()
groupby = [date_col, event_name_col] + self.groupby
oredrby = [date_col]
if user_filter:
qry = sa.select(selectd).select_from(
self.event_tbl.join(self.user_tbl, u_account_id_col == e_account_id_col)).where(
and_(*user_filter, *event_filter, *base_where)).group_by(*groupby).order_by(
*oredrby).limit(10000)
else:
qry = sa.select(selectd).where(and_(*base_where, *event_filter)).group_by(*groupby).order_by(
*oredrby).limit(10000)
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
print(sql)
return {'sql': sql,
'groupby': ['date', 'event_name'] + [i.key for i in self.groupby],
'date_range': self.date_range,
'event_name': [event_name_a, event_name_b],
'unit_num': self.unit_num
}
def event_model_sql(self):
sqls = []
event_time_col = getattr(self.event_tbl.c, '#event_time')
select_exprs = [
settings.TIME_GRAIN_EXPRESSIONS[self.time_particle](event_time_col, self.zone_time)]
select_exprs += self.groupby
for event in self.events:
event_name = event['event_name']
event_name_col = getattr(self.event_tbl.c, '#event_name')
base_where = [
func.addHours(event_time_col, self.zone_time) >= self.start_date,
func.addHours(event_time_col, self.zone_time) <= self.end_date,
event_name_col == event_name
]
analysis = event['analysis']
event_filter, user_filter = self.handler_filts(*event['filts'])
u_account_id_col = getattr(self.user_tbl.c, '#account_id')
# 按账号聚合
e_account_id_col = getattr(self.event_tbl.c, '#account_id')
# 聚合方式
if analysis == 'total_count':
selectd = select_exprs + [func.count().label('values')]
elif analysis == 'touch_user_count':
selectd = select_exprs + [func.count(sa.distinct(e_account_id_col)).label('values')]
elif analysis == 'touch_user_avg':
selectd = select_exprs + [
func.round((func.count() / func.count(sa.distinct(e_account_id_col))), 2).label(
'values')]
elif analysis == 'distinct_count':
selectd = select_exprs + [
func.count(sa.distinct(getattr(self.event_tbl.c, event['event_attr_id']))).label('values')]
else:
selectd = select_exprs + [
func.round(getattr(func, analysis)(getattr(self.event_tbl.c, event['event_attr_id'])), 2).label(
'values')]
if user_filter:
qry = sa.select(selectd).select_from(
self.event_tbl.join(self.user_tbl, u_account_id_col == e_account_id_col)).where(
and_(*user_filter, *event_filter, *base_where))
else:
qry = sa.select(selectd).where(and_(*event_filter, *base_where))
qry = qry.group_by(*select_exprs)
qry = qry.order_by(sa.Column('date'))
qry = qry.limit(1000)
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
print(sql)
sqls.append({'sql': sql,
'groupby': [i.key for i in self.groupby],
'date_range': self.date_range,
'event_name': event_name
})
return sqls
def funnel_model_sql(self):
"""
SELECT
level,
count() AS values
FROM
(SELECT `#account_id`,
windowFunnel(864000)(`#event_time`, `#event_name` = 'create_role',`#event_name` = 'login') AS level
FROM event
WHERE (`#event_time` >= '2021-06-01 00:00:00')
AND (`#event_time` <= '2021-06-05 00:00:00')
GROUP BY `#account_id`)
GROUP BY level
ORDER BY level
:return:
"""
windows_gap = self.event_view['windows_gap'] * 86400
event_time_col = getattr(self.event_tbl.c, '#event_time')
event_name_col = getattr(self.event_tbl.c, '#event_name')
e_account_id_col = getattr(self.event_tbl.c, '#account_id')
conds = []
for item in self.events:
event_filter, _ = self.handler_filts(*item['filts'], g_f=False)
conds.append(
and_(event_name_col == item['eventName'], *event_filter)
)
# todo 替换 _windows_gap_
subq = sa.select(func.windowFunnel_windows_gap__(event_time_col, *conds).alias('level'))
g_event_filter, _ = self.handler_filts()
where = [
func.addHours(event_time_col, self.zone_time) >= self.start_date,
func.addHours(event_time_col, self.zone_time) <= self.end_date,
*g_event_filter
]
subq = subq.where(and_(*where)).group_by(e_account_id_col)
subq = subq.subquery()
qry = sa.select(sa.Column('level'), func.count()).select_from(subq)
sql = str(subq.compile(compile_kwargs={"literal_binds": True}))
sql = sql.replace('_windows_gap_', f'({windows_gap})')
print(sql)
return sql