xbackend/models/behavior_analysis.py
2021-08-25 19:26:27 +08:00

649 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
from typing import Tuple
import arrow
import sqlalchemy as sa
import json
from fastapi import Depends
import pandas as pd
from sqlalchemy import func, or_, and_, not_
import crud
import schemas
from core.config import settings
from db import get_database
from db.redisdb import get_redis_pool, RedisDrive
class CustomEvent:
def __init__(self, tbl, string, format):
self.tbl = tbl
self.string = string
self.pattern = re.compile('[+\-*/]')
self.format = format
self.events_name = []
def _parse(self, s):
m = s.split('.')
if len(m) == 3:
event_name, attr, comp = m
self.events_name.append(event_name)
return getattr(func, comp)(getattr(func, 'if')(getattr(self.tbl.c, '#event_name') == event_name,
getattr(self.tbl.c, attr), 0))
elif len(m) == 2:
event_name, comp = m
self.events_name.append(event_name)
# 总次数
if comp == 'total_count':
return func.sum(getattr(func, 'if')(getattr(self.tbl.c, '#event_name') == event_name, 1, 0))
elif comp == 'touch_user_count':
return func.uniqCombined(getattr(func, 'if')(getattr(self.tbl.c, '#event_name') == event_name,
getattr(self.tbl.c, '#account_id'), None))
elif comp == 'touch_user_avg':
return func.divide(
func.sum(getattr(func, 'if')(getattr(self.tbl.c, '#event_name') == event_name, 1, 0)),
func.uniqCombined(getattr(func, 'if')(getattr(self.tbl.c, '#event_name') == event_name,
getattr(self.tbl.c, '#account_id'), None)))
elif len(m) == 1:
n = int(m[0])
return n
def str2obj(self, factors, opts):
sel = None
for i, factor in enumerate(factors):
if i == 0:
sel = self._parse(factor)
else:
tmp = self._parse(factor)
sel = settings.ARITHMETIC[opts[i - 1]](sel, tmp)
return sel
def parse(self):
factors = self.pattern.split(self.string)
opts = self.pattern.findall(self.string)
sel = self.str2obj(factors, opts)
decimal = 2
if self.format == 'percent':
sel = sel * 100
elif format == 'integer':
decimal = 0
elif format == 'float':
decimal = 2
sel = func.round(sel, decimal).label('values')
res = {
'event_name': self.events_name,
'select': sel
}
return res
class BehaviorAnalysis:
def __init__(self, game: str, data_in: schemas.CkQuery, rdb: RedisDrive = Depends(get_redis_pool)):
self.game = game
self.rdb = rdb
self.user_tbl = None
self.event_tbl = None
self.data_in = data_in
self.event_view = dict()
self.events = [dict()]
self.zone_time: int = 0
self.start_date = None
self.end_date = None
self.global_filters = None
self.groupby = None
self.time_particle = None
self.date_range = None
self.unit_num = None
async def init(self):
if self.data_in.report_id:
db = get_database()
report = await crud.report.get(db, id=self.data_in.report_id)
self.event_view = report['query']['eventView']
self.events = report['query']['events']
try:
e_days = self.event_view['e_days']
s_days = self.event_view['s_days']
except:
# 兼容以前的
e_days, s_days = self.event_view['recentDay'].split('-')
self.event_view['endTime'] = arrow.get().shift(days=-int(e_days)).strftime('%Y-%m-%d 23:59:59')
self.event_view['startTime'] = arrow.get().shift(days=-int(s_days)).strftime('%Y-%m-%d 00:00:00')
else:
self.event_view = self.data_in.eventView
self.events = self.data_in.events
await self._init_table()
self.zone_time = self._get_zone_time()
self.time_particle = self._get_time_particle_size()
self.start_date, self.end_date, self.date_range = self._get_date_range()
self.global_filters = self._get_global_filters()
self.groupby = self._get_group_by()
self.unit_num = self._get_unit_num()
def _get_time_particle_size(self):
return self.event_view.get('timeParticleSize') or 'P1D'
def _get_unit_num(self):
return self.event_view.get('unitNum')
def _get_group_by(self):
return [getattr(self.event_tbl.c, item['columnName']) for item in self.event_view.get('groupBy', [])]
def _get_zone_time(self):
return int(self.event_view.get('zone_time', 8))
def _get_date_range(self) -> Tuple[str, str, list]:
start_date: str = self.event_view.get('startTime')
end_date: str = self.event_view.get('endTime')
date_range = pd.date_range(start_date, end_date, freq=settings.PROPHET_TIME_GRAIN_MAP[self.time_particle],
tz='UTC').tolist()
if self.time_particle in ('P1D', 'P1W', 'P1M'):
date_range = [item.date() for item in date_range]
# start_date = date_range[0].strftime('%Y-%m-%d')
# end_date = date_range[-1].strftime('%Y-%m-%d')
return start_date, end_date, date_range
def _get_global_filters(self):
return self.event_view.get('filts') or []
async def _init_table(self):
"""
从redis中取出表字段构建表结构
:return:
"""
res_json = await self.rdb.get(f'{self.game}_user')
columns = json.loads(res_json).keys()
metadata = sa.MetaData(schema=self.game)
self.user_tbl = sa.Table('user_view', metadata, *[sa.Column(column) for column in columns])
res_json = await self.rdb.get(f'{self.game}_event')
columns = json.loads(res_json).keys()
metadata = sa.MetaData(schema=self.game)
# self.event_tbl = sa.Table('event_view', metadata, *[sa.Column(column) for column in columns])
self.event_tbl = sa.Table('event', metadata, *[sa.Column(column) for column in columns])
def handler_filts(self, *ext_filters, g_f=True):
user_filter = []
event_filter = []
# filters = (*self.global_filters, *ext_filters) if g_f else (*ext_filters,)
filters = []
filters.extend(ext_filters)
if g_f:
filters.extend(self.global_filters)
# filters = [] if filters == ([],) else filters
for item in filters:
if item['tableType'] == 'user':
where = user_filter
elif item['tableType'] == 'event':
where = event_filter
else:
continue
tbl = getattr(self, f'{item["tableType"]}_tbl')
col = getattr(tbl.c, item['columnName'])
comparator = item['comparator']
ftv = item['ftv']
if comparator == '==':
if len(ftv) > 1:
where.append(or_(*[col == v for v in ftv]))
else:
where.append(col == ftv[0])
elif comparator == '>=':
where.append(col >= ftv[0])
elif comparator == '<=':
where.append(col <= ftv[0])
elif comparator == '>':
where.append(col > ftv[0])
elif comparator == '<':
where.append(col < ftv[0])
elif comparator == 'is not null':
where.append(col.isnot(None))
elif comparator == 'is null':
where.append(col.is_(None))
elif comparator == '!=':
where.append(col != ftv[0])
return event_filter, user_filter
def retention_model_sql(self):
event_name_a = self.events[0]['eventName']
event_name_b = self.events[1]['eventName']
visit_name = self.events[0].get('event_attr_id')
event_time_col = getattr(self.event_tbl.c, '#event_time')
event_name_col = getattr(self.event_tbl.c, '#event_name')
e_account_id_col = getattr(self.event_tbl.c, '#account_id')
u_account_id_col = getattr(self.user_tbl.c, '#account_id')
date_col = sa.Column('date')
who_visit = e_account_id_col
if visit_name:
who_visit = getattr(self.event_tbl.c, visit_name)
filters, _ = self.handler_filts(*self.events[0].get('filts'), g_f=False)
filters = filters or [1]
selectd = [func.toStartOfDay(func.addHours(event_time_col, self.zone_time)).label('date'),
*self.groupby,
func.arrayDistinct(
(func.groupArray(
func.if_(func.and_(event_name_col == event_name_a, *filters), who_visit, None)))).label(
'val_a'),
func.length(sa.Column('val_a')).label('amount_a'),
func.length(sa.Column('val_b')).label('amount_b'),
]
if event_name_b == '*':
val_b = func.arrayDistinct(
(func.groupArray(func.if_(1, who_visit, None)))).label('val_b'),
selectd.insert(-2, *val_b)
else:
val_b = func.arrayDistinct(
(func.groupArray(func.if_(event_name_col == event_name_b, who_visit, None)))).label('val_b'),
selectd.insert(-2, *val_b)
base_where = [
func.addHours(event_time_col, self.zone_time) >= self.start_date,
func.addHours(event_time_col, self.zone_time) <= self.end_date,
]
event_filter, user_filter = self.handler_filts()
groupby = [date_col] + self.groupby
oredrby = [date_col]
if user_filter:
qry = sa.select(selectd).select_from(
self.event_tbl.join(self.user_tbl, u_account_id_col == e_account_id_col)).where(
and_(*user_filter, *event_filter, *base_where)).group_by(*groupby).order_by(
*oredrby).limit(10000)
else:
qry = sa.select(selectd).where(and_(*base_where, *event_filter)).group_by(*groupby).order_by(
*oredrby).limit(10000)
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
print(sql)
return {'sql': sql,
'groupby': ['date'] + [i.key for i in self.groupby],
'date_range': self.date_range,
'event_name': [event_name_a, event_name_b],
'unit_num': self.unit_num
}
def event_model_sql(self):
sqls = []
event_time_col = getattr(self.event_tbl.c, '#event_time')
for event in self.events:
event_name_display = event.get('eventNameDisplay')
select_exprs = []
if self.time_particle != 'total':
select_exprs.append(
settings.TIME_GRAIN_EXPRESSIONS[self.time_particle](event_time_col, self.zone_time))
base_where = [
func.addHours(event_time_col, self.zone_time) >= self.start_date,
func.addHours(event_time_col, self.zone_time) <= self.end_date,
]
event_name_col = getattr(self.event_tbl.c, '#event_name')
format = event.get('format') or 'float'
if event.get('customEvent'):
formula = event.get('customEvent')
custom = CustomEvent(self.event_tbl, formula, format).parse()
event_name = custom['event_name']
where = [event_name_col.in_(event_name)]
event_filter, _ = self.handler_filts(*event['filts'])
select_exprs.extend(self.groupby)
qry = sa.select(
*select_exprs,
custom['select']
).where(*base_where, *where, *event_filter)
else:
event_name = event['event_name']
select_exprs += self.groupby
if event_name != '*':
base_where.append(event_name_col == event_name)
analysis = event['analysis']
event_filter, user_filter = self.handler_filts(*event['filts'])
u_account_id_col = getattr(self.user_tbl.c, '#account_id')
# 按账号聚合
e_account_id_col = getattr(self.event_tbl.c, '#account_id')
# 聚合方式
if analysis == 'total_count':
selectd = select_exprs + [func.count().label('values')]
elif analysis == 'touch_user_count':
selectd = select_exprs + [func.count(sa.distinct(e_account_id_col)).label('values')]
elif analysis == 'touch_user_avg':
selectd = select_exprs + [
func.round((func.count() / func.count(sa.distinct(e_account_id_col))), 2).label(
'values')]
else:
selectd = select_exprs + [
func.round(getattr(func, analysis)(getattr(self.event_tbl.c, event['event_attr_id'])), 2).label(
'values')]
if user_filter:
qry = sa.select(selectd).select_from(
self.event_tbl.join(self.user_tbl, u_account_id_col == e_account_id_col)).where(
and_(*user_filter, *event_filter, *base_where))
else:
qry = sa.select(selectd).where(and_(*event_filter, *base_where))
qry = qry.group_by(*select_exprs)
if self.time_particle != 'total':
qry = qry.order_by(sa.Column('date'))
else:
qry = qry.order_by(sa.Column('values').desc())
qry = qry.limit(1000)
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
print(sql)
sqls.append({'sql': sql,
'groupby': [i.key for i in self.groupby],
'date_range': self.date_range,
'event_name': event_name_display or event_name,
'format': format,
'time_particle': self.time_particle
})
return sqls
def funnel_model_sql(self):
"""
SELECT level, count(*) AS values
FROM (SELECT windowFunnel(86400)(shjy.event."#event_time", shjy.event."#event_name" = 'create_role',
shjy.event."#event_name" = 'login') AS level
FROM shjy.event
WHERE addHours(shjy.event."#event_time", 8) >= '2021-05-16 00:00:00'
AND addHours(shjy.event."#event_time", 8) <= '2021-06-14 23:59:59'
GROUP BY shjy.event."#account_id") AS anon_1
GROUP BY level
ORDER BY level
:return:
"""
windows_gap = self.event_view['windows_gap'] * 86400
event_time_col = getattr(self.event_tbl.c, '#event_time')
event_name_col = getattr(self.event_tbl.c, '#event_name')
date_col = func.toStartOfDay(func.addHours(event_time_col, self.zone_time)).label('date')
e_account_id_col = getattr(self.event_tbl.c, '#account_id')
sub_group = [date_col, *self.groupby, e_account_id_col]
conds = []
cond_level = []
for item in self.events:
event_filter, _ = self.handler_filts(*item['filts'], g_f=False)
conds.append(
and_(event_name_col == item['eventName'], *event_filter)
)
cond_level.append(item['eventName'])
# todo 替换 _windows_gap_
subq = sa.select(*[sa.Column(i.key) for i in self.groupby], date_col,
func.windowFunnel_windows_gap__(event_time_col, *conds).label('level')).select_from(
self.event_tbl)
g_event_filter, _ = self.handler_filts()
where = [
func.addHours(event_time_col, self.zone_time) >= self.start_date,
func.addHours(event_time_col, self.zone_time) <= self.end_date,
*g_event_filter
]
subq = subq.where(and_(*where)).group_by(*sub_group)
subq = subq.subquery()
qry = sa.select(sa.Column('date'), *[sa.Column(i.key) for i in self.groupby], sa.Column('level'),
func.count().label('values')).select_from(subq) \
.where(sa.Column('level') > 0) \
.group_by(sa.Column('date'), *[sa.Column(i.key) for i in self.groupby], sa.Column('level')) \
.order_by(sa.Column('date'), *[sa.Column(i.key) for i in self.groupby], sa.Column('level'))
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
# sql = sql.replace('_windows_gap_', f"({windows_gap},'strict_increase')")
sql = sql.replace('_windows_gap_', f"({windows_gap})")
print(sql)
return {'sql': sql,
'groupby': [i.key for i in self.groupby],
'date_range': self.date_range,
'cond_level': cond_level
}
def scatter_model_sql(self):
event = self.events[0]
event_name = event['eventName']
analysis = event['analysis']
e_account_id_col = getattr(self.event_tbl.c, '#account_id')
event_name_col = getattr(self.event_tbl.c, '#event_name')
event_time_col = getattr(self.event_tbl.c, '#event_time').label('date')
event_date_col = settings.TIME_GRAIN_EXPRESSIONS[self.time_particle](event_time_col, self.zone_time)
quota_interval_arr = event.get('quotaIntervalArr')
where = [
event_date_col >= self.start_date,
event_date_col <= self.end_date,
event_name_col == event_name
]
event_filter, _ = self.handler_filts(*self.events[0]['filts'])
where.extend(event_filter)
values_col = func.count().label('values')
if analysis in ['number_of_days', 'number_of_hours']:
values_col = func.count(func.distinct(e_account_id_col)).label('values')
if analysis in ['times', 'number_of_days', 'number_of_hours']:
qry = sa.select(event_date_col, *self.groupby, values_col) \
.where(and_(*where)) \
.group_by(event_date_col, *self.groupby)
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
print(sql)
return {
'sql': sql,
'interval_type': event['intervalType'],
'analysis': analysis,
'quota_interval_arr': quota_interval_arr,
'groupby': [i.key for i in self.groupby]
}
elif event.get('quota'):
event_attr_col = getattr(self.event_tbl.c, event['quota'])
qry = sa.select(event_date_col, settings.CK_FUNC[analysis](event_attr_col).label('values')) \
.where(and_(*where)) \
.group_by(event_date_col, *self.groupby, e_account_id_col)
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
print(sql)
return {
'sql': sql,
'interval_type': event['intervalType'],
'analysis': analysis,
'quota_interval_arr': quota_interval_arr,
'groupby': [i.key for i in self.groupby]
}
def trace_model_sql(self):
session_interval = self.event_view.get('session_interval')
session_type = self.event_view.get('session_type')
session_type_map = {
'minute': 60,
'second': 1,
'hour': 3600
}
interval_ts = session_interval * session_type_map.get(session_type, 60)
event_names = self.events.get('event_names')
source_event = self.events.get('source_event', {}).get('eventName')
source_type = self.events.get('source_event', {}).get('source_type')
sql_a = f"""with
'{source_event}' as start_event,
{tuple(event_names)} as evnet_all,
'{self.start_date}' as start_data,
'{self.end_date}' as end_data
select event_chain,
count() as values
from (with
toUInt32(minIf(`#event_time`, `#event_name` = start_event)) AS start_event_ts,
arraySort(
x ->
x.1,
arrayFilter(
x -> x.1 >= start_event_ts,
groupArray((toUInt32(`#event_time`), `#event_name`))
)
) AS sorted_events,
arrayEnumerate(sorted_events) AS event_idxs,
arrayFilter(
(x, y, z) -> z.1 >= start_event_ts and ((z.2 = start_event and y > {interval_ts}) or y > {interval_ts}) ,
event_idxs,
arrayDifference(sorted_events.1),
sorted_events
) AS gap_idxs,
arrayMap(x -> x, gap_idxs) AS gap_idxs_,
arrayMap(x -> if(has(gap_idxs_, x), 1, 0), event_idxs) AS gap_masks,
arraySplit((x, y) -> y, sorted_events, gap_masks) AS split_events
select `#account_id`,
arrayJoin(split_events) AS event_chain_,
arrayMap(x ->
x.2, event_chain_) AS event_chain,
has(event_chain, start_event) AS has_midway_hit
from (select `#event_time`, `#event_name`, `#account_id`
from {self.game}.event
where addHours(`#event_time`, {self.zone_time}) >= start_data
and addHours(`#event_time`, {self.zone_time}) <= end_data
and `#event_name` in evnet_all)
group by `#account_id`
HAVING has_midway_hit = 1
)
where arrayElement(event_chain, 1) = start_event
GROUP BY event_chain
ORDER BY values desc
"""
sql_b = f"""with
'{source_event}' as end_event,
{tuple(event_names)} as evnet_all,
'{self.start_date}' as start_data,
'{self.end_date}' as end_data
select event_chain,
count() as values
from (with
toUInt32(maxIf(`#event_time`, `#event_name` = end_event)) AS end_event_ts,
arraySort(
x ->
x.1,
arrayFilter(
x -> x.1 <= end_event_ts,
groupArray((toUInt32(`#event_time`), `#event_name`))
)
) AS sorted_events,
arrayEnumerate(sorted_events) AS event_idxs,
arrayFilter(
(x, y, z) -> z.1 <= end_event_ts and (z.2 = end_event and y>{interval_ts}) OR y > {interval_ts},
event_idxs,
arrayDifference(sorted_events.1),
sorted_events
) AS gap_idxs,
arrayMap(x -> x+1, gap_idxs) AS gap_idxs_,
arrayMap(x -> if(has(gap_idxs_, x), 1,0), event_idxs) AS gap_masks,
arraySplit((x, y) -> y, sorted_events, gap_masks) AS split_events
select `#account_id`,
arrayJoin(split_events) AS event_chain_,
arrayMap(x ->
x.2, event_chain_) AS event_chain,
has(event_chain, end_event) AS has_midway_hit
from (select `#event_time`, `#event_name`, `#account_id`
from {self.game}.event
where addHours(`#event_time`, {self.zone_time}) >= start_data
and addHours(`#event_time`, {self.zone_time}) <= end_data
and `#event_name` in evnet_all)
group by `#account_id`
HAVING has_midway_hit = 1
)
where arrayElement(event_chain, -1) = end_event
GROUP BY event_chain
ORDER BY values desc"""
sql = sql_a if source_type == 'initial_event' else sql_b
print(sql)
return {
'sql': sql,
}
def retention_model_sql2(self):
event_name_a = self.events[0]['eventName']
event_name_b = self.events[1]['eventName']
visit_name = self.events[0].get('event_attr_id')
where, _ = self.handler_filts(*self.events[0].get('filts', []))
where_a = '1'
if where:
qry = sa.select().where(*where)
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
where_a = sql.split('WHERE ')[1]
where, _ = self.handler_filts(*self.events[1].get('filts', []))
where_b = '1'
if where:
qry = sa.select().where(*where)
sql = str(qry.compile(compile_kwargs={"literal_binds": True}))
where_b = sql.split('WHERE ')[1]
# 任意事件
event_name_b = 1 if event_name_b == '*' else f"`#event_name` = '{event_name_b}'"
days = (arrow.get(self.end_date).date() - arrow.get(self.start_date).date()).days
keep = []
cnt = []
for i in range(days+1):
keep.append(
f"""cnt{i + 1},round(cnt{i + 1} * 100 / cnt0, 2) as `p{i + 1}`""")
cnt.append(f"""sum(if(dateDiff('day',a.reg_date,b.visit_date)={i},1,0)) as cnt{i + 1}""")
keep_str = ','.join(keep)
cnt_str = ','.join(cnt)
sql = f"""
with '{event_name_a}' as start_event,
{event_name_b} as retuen_visit,
`{visit_name}` as visit,
'{self.start_date}' as start_data,
'{self.end_date}' as end_data,
toDate(addHours(`#event_time`, {self.zone_time})) as date
select reg_date,
cnt0 ,
{keep_str}
from(select date, uniqExact(visit) as cnt0 from {self.game}.event
where `#event_name` = start_event and addHours(`#event_time`, {self.zone_time}) >= start_data and addHours(`#event_time`, {self.zone_time}) <= end_data and {where_a}
group by date) reg left join
(select a.reg_date,
{cnt_str}
from (select date as reg_date, visit from {self.game}.event where `#event_name` = start_event and addHours(`#event_time`, {self.zone_time}) >= start_data and addHours(`#event_time`, {self.zone_time}) <= end_data and {where_a} group by reg_date, visit) a
left join (select date as visit_date, visit from {self.game}.event where retuen_visit and addHours(`#event_time`, {self.zone_time}) >= start_data and addHours(`#event_time`, {self.zone_time}) <= end_data and {where_b} group by visit_date, visit) b on
a.visit = b.visit
group by a.reg_date) log on reg.date=log.reg_date
"""
print(sql)
return {
'sql': sql,
'date_range':self.date_range,
'unit_num': self.unit_num
}