xbackend/models/behavior_analysis.py
2021-07-28 21:20:08 +08:00

507 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Tuple
import arrow
import sqlalchemy as sa
import json
from fastapi import Depends
import pandas as pd
from sqlalchemy import func, or_, and_, not_
import crud
import schemas
from core.config import settings
from db import get_database
from db.redisdb import get_redis_pool, RedisDrive
class BehaviorAnalysis:
def __init__(self, game: str, data_in: schemas.CkQuery, rdb: RedisDrive = Depends(get_redis_pool)):
self.game = game
self.rdb = rdb
self.user_tbl = None
self.event_tbl = None
self.data_in = data_in
self.event_view = dict()
self.events = [dict()]
self.zone_time: int = 0
self.start_date = None
self.end_date = None
self.global_filters = None
self.groupby = None
self.time_particle = None
self.date_range = None
self.unit_num = None
async def init(self):
if self.data_in.report_id:
db = get_database()
report = await crud.report.get(db, id=self.data_in.report_id)
self.event_view = report['query']['eventView']
self.events = report['query']['events']
e_days, s_days = self.event_view['recentDay'].split('-')
self.event_view['endTime'] = arrow.get().shift(days=-int(e_days)+1).strftime('%Y-%m-%d 23:59:59')
self.event_view['startTime'] = arrow.get().shift(days=-int(s_days)).strftime('%Y-%m-%d 00:00:00')
else:
self.event_view = self.data_in.eventView
self.events = self.data_in.events
await self._init_table()
self.zone_time = self._get_zone_time()
self.time_particle = self._get_time_particle_size()
self.start_date, self.end_date, self.date_range = self._get_date_range()
self.global_filters = self._get_global_filters()
self.groupby = self._get_group_by()
self.unit_num = self._get_unit_num()
def _get_time_particle_size(self):
return self.event_view.get('timeParticleSize') or 'P1D'
def _get_unit_num(self):
return self.event_view.get('unitNum')
def _get_group_by(self):
return [getattr(self.event_tbl.c, item['columnName']) for item in self.event_view.get('groupBy', [])]
def _get_zone_time(self):
return int(self.event_view.get('zone_time', 8))
def _get_date_range(self) -> Tuple[str, str, list]:
start_date: str = self.event_view.get('startTime')
end_date: str = self.event_view.get('endTime')
date_range = pd.date_range(start_date, end_date, freq=settings.PROPHET_TIME_GRAIN_MAP[self.time_particle],
tz='UTC').tolist()
return start_date, end_date, date_range
def _get_global_filters(self):
return self.event_view.get('filts') or []
async def _init_table(self):
"""
从redis中取出表字段构建表结构
:return:
"""
res_json = await self.rdb.get(f'{self.game}_user')
columns = json.loads(res_json).keys()
metadata = sa.MetaData(schema=self.game)
self.user_tbl = sa.Table('user_view', metadata, *[sa.Column(column) for column in columns])
res_json = await self.rdb.get(f'{self.game}_event')
columns = json.loads(res_json).keys()
metadata = sa.MetaData(schema=self.game)
self.event_tbl = sa.Table('event', metadata, *[sa.Column(column) for column in columns])
def handler_filts(self, *ext_filters, g_f=True):
user_filter = []
event_filter = []
filters = (*self.global_filters, *ext_filters) if g_f else (*ext_filters,)
filters = [] if filters == ([],) else filters
for item in filters:
if item['tableType'] == 'user':
where = user_filter
elif item['tableType'] == 'event':
where = event_filter
else:
continue
tbl = getattr(self, f'{item["tableType"]}_tbl')
col = getattr(tbl.c, item['columnName'])
comparator = item['comparator']
ftv = item['ftv']
if comparator == '==':
if len(ftv) > 1:
where.append(or_(*[col == v for v in ftv]))
else:
where.append(col == ftv[0])
elif comparator == '>=':
where.append(col >= ftv[0])
elif comparator == '<=':
where.append(col <= ftv[0])
elif comparator == '>':
where.append(col > ftv[0])
elif comparator == '<':
where.append(col < ftv[0])
elif comparator == 'is not null':
where.append(col.isnot(None))
elif comparator == 'is null':
where.append(col.is_(None))
elif comparator == '!=':
where.append(col != ftv[0])
return event_filter, user_filter
def retention_model_sql(self):
event_name_a = self.events[0]['eventName']
event_name_b = self.events[1]['eventName']
event_time_col = getattr(self.event_tbl.c, '#event_time')
event_name_col = getattr(self.event_tbl.c, '#event_name')
e_account_id_col = getattr(self.event_tbl.c, '#account_id')
u_account_id_col = getattr(self.user_tbl.c, '#account_id')
date_col = sa.Column('date')
selectd = [func.toStartOfDay(func.addHours(event_time_col, self.zone_time)).label('date'),
event_name_col.label('event_name'),
*self.groupby,
func.arrayDistinct(func.groupArray(e_account_id_col)).label('values'),
func.length(sa.Column('values')).label('amount')
]
base_where = [
func.addHours(event_time_col, self.zone_time) >= self.start_date,
func.addHours(event_time_col, self.zone_time) <= self.end_date,
event_name_col.in_([event_name_a, event_name_b]),
]
event_filter, user_filter = self.handler_filts()
groupby = [date_col, event_name_col] + self.groupby
oredrby = [date_col]
if user_filter:
qry = sa.select(selectd).select_from(
self.event_tbl.join(self.user_tbl, u_account_id_col == e_account_id_col)).where(
and_(*user_filter, *event_filter, *base_where)).group_by(*groupby).order_by(
*oredrby).limit(10000)
else:
qry = sa.select(selectd).where(and_(*base_where, *event_filter)).group_by(*groupby).order_by(
*oredrby).limit(10000)
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
print(sql)
return {'sql': sql,
'groupby': ['date', 'event_name'] + [i.key for i in self.groupby],
'date_range': self.date_range,
'event_name': [event_name_a, event_name_b],
'unit_num': self.unit_num
}
def custom_event(self, s):
def f(m):
if len(m) == 3:
event_name, attr, comp = m
return getattr(func, comp)(getattr(func, 'if')(getattr(self.event_tbl.c, '#event_name') == event_name,
getattr(self.event_tbl.c, attr), 0))
elif len(m) == 2:
event_name, comp = m
# 总次数
if comp == 'total_count':
return func.sum(getattr(func, 'if')(getattr(self.event_tbl.c, '#event_name') == event_name, 1, 0))
elif comp == 'touch_user_count':
return func.uniqCombined(getattr(func, 'if')(getattr(self.event_tbl.c, '#event_name') == event_name,
getattr(self.event_tbl.c, 'binduid'), ''))
elif comp == 'touch_user_avg':
return func.divide(
func.sum(getattr(func, 'if')(getattr(self.event_tbl.c, '#event_name') == event_name, 1, 0)),
func.uniqCombined(getattr(func, 'if')(getattr(self.event_tbl.c, '#event_name') == event_name,
getattr(self.event_tbl.c, 'binduid'), '')))
opt = ({'+', '-', '*', '/'} & set(s)).pop()
a, b = s.split(opt)
r1 = a.split('.')
r2 = b.split('.')
return {'event_name': [r1[0], r2[0]],
'select': func.round(settings.ARITHMETIC[opt](f(r1), f(r2)), 2).label('values')
}
def event_model_sql(self):
sqls = []
event_time_col = getattr(self.event_tbl.c, '#event_time')
for event in self.events:
event_name_display = event.get('eventNameDisplay')
select_exprs = [
settings.TIME_GRAIN_EXPRESSIONS[self.time_particle](event_time_col, self.zone_time)]
base_where = [
func.addHours(event_time_col, self.zone_time) >= self.start_date,
func.addHours(event_time_col, self.zone_time) <= self.end_date,
]
event_name_col = getattr(self.event_tbl.c, '#event_name')
if event.get('customEvent'):
formula = event.get('customEvent')
custom = self.custom_event(formula)
event_name = custom['event_name']
where = [event_name_col.in_(event_name)]
event_filter, _ = self.handler_filts(*event['filts'])
qry = sa.select(
*select_exprs,
custom['select']
).where(*base_where, *where, *event_filter)
else:
event_name = event['event_name']
select_exprs += self.groupby
base_where.append(event_name_col == event_name)
analysis = event['analysis']
event_filter, user_filter = self.handler_filts(*event['filts'])
u_account_id_col = getattr(self.user_tbl.c, '#account_id')
# 按账号聚合
e_account_id_col = getattr(self.event_tbl.c, '#account_id')
# 聚合方式
if analysis == 'total_count':
selectd = select_exprs + [func.count().label('values')]
elif analysis == 'touch_user_count':
selectd = select_exprs + [func.count(sa.distinct(e_account_id_col)).label('values')]
elif analysis == 'touch_user_avg':
selectd = select_exprs + [
func.round((func.count() / func.count(sa.distinct(e_account_id_col))), 2).label(
'values')]
else:
selectd = select_exprs + [
func.round(getattr(func, analysis)(getattr(self.event_tbl.c, event['event_attr_id'])), 2).label(
'values')]
if user_filter:
qry = sa.select(selectd).select_from(
self.event_tbl.join(self.user_tbl, u_account_id_col == e_account_id_col)).where(
and_(*user_filter, *event_filter, *base_where))
else:
qry = sa.select(selectd).where(and_(*event_filter, *base_where))
qry = qry.group_by(*select_exprs)
qry = qry.order_by(sa.Column('date'))
qry = qry.limit(1000)
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
print(sql)
sqls.append({'sql': sql,
'groupby': [i.key for i in self.groupby],
'date_range': self.date_range,
'event_name': event_name_display or event_name,
})
return sqls
def funnel_model_sql(self):
"""
SELECT level, count(*) AS values
FROM (SELECT windowFunnel(86400)(shjy.event."#event_time", shjy.event."#event_name" = 'create_role',
shjy.event."#event_name" = 'login') AS level
FROM shjy.event
WHERE addHours(shjy.event."#event_time", 8) >= '2021-05-16 00:00:00'
AND addHours(shjy.event."#event_time", 8) <= '2021-06-14 23:59:59'
GROUP BY shjy.event."#account_id") AS anon_1
GROUP BY level
ORDER BY level
:return:
"""
windows_gap = self.event_view['windows_gap'] * 86400
event_time_col = getattr(self.event_tbl.c, '#event_time')
event_name_col = getattr(self.event_tbl.c, '#event_name')
date_col = func.toStartOfDay(func.addHours(event_time_col, self.zone_time)).label('date')
e_account_id_col = getattr(self.event_tbl.c, '#account_id')
sub_group = [date_col, *self.groupby, e_account_id_col]
conds = []
cond_level = []
for item in self.events:
event_filter, _ = self.handler_filts(*item['filts'], g_f=False)
conds.append(
and_(event_name_col == item['eventName'], *event_filter)
)
cond_level.append(item['eventName'])
# todo 替换 _windows_gap_
subq = sa.select(*[sa.Column(i.key) for i in self.groupby], date_col,
func.windowFunnel_windows_gap__(event_time_col, *conds).label('level')).select_from(
self.event_tbl)
g_event_filter, _ = self.handler_filts()
where = [
func.addHours(event_time_col, self.zone_time) >= self.start_date,
func.addHours(event_time_col, self.zone_time) <= self.end_date,
*g_event_filter
]
subq = subq.where(and_(*where)).group_by(*sub_group)
subq = subq.subquery()
qry = sa.select(sa.Column('date'), *[sa.Column(i.key) for i in self.groupby], sa.Column('level'),
func.count().label('values')).select_from(subq) \
.where(sa.Column('level') > 0) \
.group_by(sa.Column('date'), *[sa.Column(i.key) for i in self.groupby], sa.Column('level')) \
.order_by(sa.Column('date'), *[sa.Column(i.key) for i in self.groupby], sa.Column('level'))
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
# sql = sql.replace('_windows_gap_', f"({windows_gap},'strict_increase')")
sql = sql.replace('_windows_gap_', f"({windows_gap})")
print(sql)
return {'sql': sql,
'groupby': [i.key for i in self.groupby],
'date_range': self.date_range,
'cond_level': cond_level
}
def scatter_model_sql(self):
event = self.events[0]
event_name = event['eventName']
analysis = event['analysis']
e_account_id_col = getattr(self.event_tbl.c, '#account_id')
event_name_col = getattr(self.event_tbl.c, '#event_name')
event_time_col = getattr(self.event_tbl.c, '#event_time').label('date')
event_date_col = settings.TIME_GRAIN_EXPRESSIONS[self.time_particle](event_time_col, self.zone_time)
quota_interval_arr = event.get('quotaIntervalArr')
where = [
event_date_col >= self.start_date,
event_date_col <= self.end_date,
event_name_col == event_name
]
event_filter, _ = self.handler_filts(self.events[0]['filts'])
where.extend(event_filter)
values_col = func.count().label('values')
if analysis in ['number_of_days', 'number_of_hours']:
values_col = func.count(func.distinct(e_account_id_col)).label('values')
if analysis in ['times', 'number_of_days', 'number_of_hours']:
qry = sa.select(event_date_col, *self.groupby, values_col) \
.where(and_(*where)) \
.group_by(event_date_col, *self.groupby, e_account_id_col)
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
print(sql)
return {
'sql': sql,
'interval_type': event['intervalType'],
'analysis': analysis,
'quota_interval_arr': quota_interval_arr,
'groupby': [i.key for i in self.groupby]
}
elif event.get('quota'):
event_attr_col = getattr(self.event_tbl.c, event['quota'])
qry = sa.select(event_date_col, settings.CK_FUNC[analysis](event_attr_col).label('values')) \
.where(and_(*where)) \
.group_by(event_date_col, *self.groupby, e_account_id_col)
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
print(sql)
return {
'sql': sql,
'interval_type': event['intervalType'],
'analysis': analysis,
'quota_interval_arr': quota_interval_arr,
'groupby': [i.key for i in self.groupby]
}
def trace_model_sql(self):
session_interval = self.event_view.get('session_interval')
session_type = self.event_view.get('session_type')
session_type_map = {
'minute': 60,
'second': 1,
'hour': 3600
}
interval_ts = session_interval * session_type_map.get(session_type, 60)
event_names = self.events.get('event_names')
source_event = self.events.get('source_event', {}).get('eventName')
source_type = self.events.get('source_event', {}).get('source_type')
sql_a = f"""with
'{source_event}' as start_event,
{tuple(event_names)} as evnet_all,
'{self.start_date}' as start_data,
'{self.end_date}' as end_data
select event_chain,
count() as values
from (with
toUInt32(minIf(`#event_time`, `#event_name` = start_event)) AS start_event_ts,
arraySort(
x ->
x.1,
arrayFilter(
x -> x.1 >= start_event_ts,
groupArray((toUInt32(`#event_time`), `#event_name`))
)
) AS sorted_events,
arrayEnumerate(sorted_events) AS event_idxs,
arrayFilter(
(x, y, z) -> z.1 >= start_event_ts and ((z.2 = start_event and y > {interval_ts}) or y > {interval_ts}) ,
event_idxs,
arrayDifference(sorted_events.1),
sorted_events
) AS gap_idxs,
arrayMap(x -> x, gap_idxs) AS gap_idxs_,
arrayMap(x -> if(has(gap_idxs_, x), 1, 0), event_idxs) AS gap_masks,
arraySplit((x, y) -> y, sorted_events, gap_masks) AS split_events
select `#account_id`,
arrayJoin(split_events) AS event_chain_,
arrayMap(x ->
x.2, event_chain_) AS event_chain,
has(event_chain, start_event) AS has_midway_hit
from (select `#event_time`, `#event_name`, `#account_id`
from {self.game}.event
where addHours(`#event_time`, {self.zone_time}) >= start_data
and addHours(`#event_time`, {self.zone_time}) <= end_data
and `#event_name` in evnet_all)
group by `#account_id`
HAVING has_midway_hit = 1
)
where arrayElement(event_chain, 1) = start_event
GROUP BY event_chain
ORDER BY values desc
"""
sql_b = f"""with
'{source_event}' as end_event,
{tuple(event_names)} as evnet_all,
'{self.start_date}' as start_data,
'{self.end_date}' as end_data
select event_chain,
count() as values
from (with
toUInt32(maxIf(`#event_time`, `#event_name` = end_event)) AS end_event_ts,
arraySort(
x ->
x.1,
arrayFilter(
x -> x.1 <= end_event_ts,
groupArray((toUInt32(`#event_time`), `#event_name`))
)
) AS sorted_events,
arrayEnumerate(sorted_events) AS event_idxs,
arrayFilter(
(x, y, z) -> z.1 <= end_event_ts and (z.2 = end_event and y>{interval_ts}) OR y > {interval_ts},
event_idxs,
arrayDifference(sorted_events.1),
sorted_events
) AS gap_idxs,
arrayMap(x -> x+1, gap_idxs) AS gap_idxs_,
arrayMap(x -> if(has(gap_idxs_, x), 1,0), event_idxs) AS gap_masks,
arraySplit((x, y) -> y, sorted_events, gap_masks) AS split_events
select `#account_id`,
arrayJoin(split_events) AS event_chain_,
arrayMap(x ->
x.2, event_chain_) AS event_chain,
has(event_chain, end_event) AS has_midway_hit
from (select `#event_time`, `#event_name`, `#account_id`
from {self.game}.event
where addHours(`#event_time`, {self.zone_time}) >= start_data
and addHours(`#event_time`, {self.zone_time}) <= end_data
and `#event_name` in evnet_all)
group by `#account_id`
HAVING has_midway_hit = 1
)
where arrayElement(event_chain, -1) = end_event
GROUP BY event_chain
ORDER BY values desc"""
sql = sql_a if source_type == 'initial_event' else sql_b
print(sql)
return {
'sql': sql,
}