305 lines
7.1 KiB
Go
305 lines
7.1 KiB
Go
package lib
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"go_dreamfactory/cmd/v2/lib/common"
|
|
"go_dreamfactory/cmd/v2/service/observer"
|
|
"go_dreamfactory/pb"
|
|
"math"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
"google.golang.org/protobuf/proto"
|
|
)
|
|
|
|
type assistant struct {
|
|
timeout time.Duration //处理超时时间
|
|
lps uint32 //每秒请求量
|
|
duration time.Duration //持续时间
|
|
concurrency uint32 //并发量
|
|
|
|
ctx context.Context
|
|
callCount int64 //调用次数,每次启动时重置
|
|
|
|
goPool GoPool //携程池
|
|
cancelFunc context.CancelFunc //取消
|
|
caller Handler //处理器
|
|
status uint32 //状态
|
|
resultCh chan *CallResult //调用结果
|
|
obs observer.Observer
|
|
}
|
|
|
|
func NewAssistant(obs observer.Observer, pm ParamMgr) (Aiassistant, error) {
|
|
|
|
if err := pm.Check(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
a := &assistant{
|
|
timeout: pm.Timeout,
|
|
lps: pm.Lps,
|
|
duration: pm.Duration,
|
|
caller: pm.Caller,
|
|
status: STATUS_ORIGINAL,
|
|
resultCh: pm.ResultCh,
|
|
obs: obs,
|
|
}
|
|
if err := a.init(); err != nil {
|
|
return nil, err
|
|
}
|
|
return a, nil
|
|
}
|
|
|
|
func (a *assistant) init() error {
|
|
logrus.Debug("AI助手初始化")
|
|
//并发量的计算
|
|
//并发量 ≈ 超时时间 / 发送的间隔时间
|
|
var total = int64(a.timeout)/int64(1e9/a.lps) + 1
|
|
if total > math.MaxInt32 {
|
|
total = math.MaxInt32
|
|
}
|
|
|
|
a.concurrency = uint32(total)
|
|
gp, err := NewGoPool(a.concurrency)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
a.goPool = gp
|
|
logrus.WithField("并发量", a.concurrency).Debug("AI助手初始化完成 ")
|
|
return nil
|
|
}
|
|
|
|
func (a *assistant) callOne(req *RawReq) *RawResp {
|
|
atomic.AddInt64(&a.callCount, 1)
|
|
logrus.WithField("count", &a.callCount).WithField("Len", len(req.Req)).Debug("调用协议")
|
|
if req == nil {
|
|
return &RawResp{ID: -1, Err: errors.New("无效的请求")}
|
|
}
|
|
var rawResp RawResp
|
|
start := time.Now().UnixNano()
|
|
resp, err := a.caller.Call(req.Req)
|
|
end := time.Now().UnixNano()
|
|
elapsedTime := time.Duration(end - start)
|
|
if err != nil {
|
|
errMsg := fmt.Sprintf("调用失败: %v", err)
|
|
rawResp = RawResp{
|
|
ID: req.ID,
|
|
Err: errors.New(errMsg),
|
|
Elapse: elapsedTime,
|
|
}
|
|
} else {
|
|
rawResp = RawResp{
|
|
ID: req.ID,
|
|
Resp: resp,
|
|
Elapse: elapsedTime,
|
|
}
|
|
}
|
|
return &rawResp
|
|
}
|
|
|
|
// 异步调用接口
|
|
func (a *assistant) asyncCall() {
|
|
a.goPool.Take()
|
|
go func() {
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
err, ok := interface{}(p).(error)
|
|
var errMsg string
|
|
if ok {
|
|
errMsg = fmt.Sprintf("Async Call Panic! (error: %s)", err)
|
|
} else {
|
|
errMsg = fmt.Sprintf("Async Call Panic! (clue: %#v)", p)
|
|
}
|
|
logrus.Errorln(errMsg)
|
|
result := &CallResult{
|
|
Id: -1,
|
|
Code: RES_CODE_FATAL_CALL,
|
|
Message: errMsg,
|
|
}
|
|
a.sendResult(result)
|
|
}
|
|
a.goPool.Return()
|
|
}()
|
|
|
|
req := a.caller.BuildReq()
|
|
if req.ID == 0 {
|
|
return
|
|
}
|
|
//调用状态 0未调用 1调用结束 2调用超时
|
|
var callStatus uint32
|
|
|
|
// 超时处理
|
|
timer := time.AfterFunc(a.timeout, func() {
|
|
if !atomic.CompareAndSwapUint32(&callStatus, 0, 2) {
|
|
return
|
|
}
|
|
|
|
result := &CallResult{
|
|
Id: req.ID,
|
|
Req: req,
|
|
Code: RES_CODE_CALL_TIMEOUT,
|
|
Message: fmt.Sprintf("超时,期望小于 %v", a.timeout),
|
|
Elapse: a.timeout,
|
|
}
|
|
a.sendResult(result)
|
|
})
|
|
resp := a.callOne(&req)
|
|
|
|
logrus.WithField("耗时", resp.Elapse).Debug("实际耗时")
|
|
if !atomic.CompareAndSwapUint32(&callStatus, 0, 1) {
|
|
return
|
|
}
|
|
|
|
timer.Stop()
|
|
|
|
var result *CallResult
|
|
if resp.Err != nil {
|
|
result = &CallResult{
|
|
Id: req.ID,
|
|
Req: req,
|
|
Code: RES_CODE_ERROR_CALL,
|
|
Message: resp.Err.Error(),
|
|
Elapse: resp.Elapse,
|
|
}
|
|
} else {
|
|
result = a.caller.Check(req, *resp)
|
|
result.Elapse = resp.Elapse
|
|
}
|
|
a.sendResult(result)
|
|
}()
|
|
}
|
|
|
|
// 停止发送
|
|
func (a *assistant) prepareStop(ctxErr error) {
|
|
logrus.WithField("cause", ctxErr).Info("准备停止")
|
|
atomic.CompareAndSwapUint32(&a.status, STATUS_STARTED, STATUS_STOPPING)
|
|
logrus.Debug("关闭结果通道")
|
|
close(a.resultCh)
|
|
atomic.StoreUint32(&a.status, STATUS_STOPPED)
|
|
}
|
|
|
|
// 发送请求即调用接口
|
|
func (a *assistant) handleReq(throttle <-chan time.Time) {
|
|
for {
|
|
select {
|
|
case <-a.ctx.Done():
|
|
a.prepareStop(a.ctx.Err())
|
|
return
|
|
default:
|
|
}
|
|
a.asyncCall()
|
|
if a.lps > 0 {
|
|
select {
|
|
case <-throttle:
|
|
case <-a.ctx.Done():
|
|
a.prepareStop(a.ctx.Err())
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *assistant) sendResult(result *CallResult) bool {
|
|
if atomic.LoadUint32(&a.status) != STATUS_STARTED {
|
|
a.printResult(result, "已停止")
|
|
return false
|
|
}
|
|
select {
|
|
case a.resultCh <- result:
|
|
return true
|
|
default:
|
|
a.printResult(result, "结果通道已满")
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (a *assistant) printResult(result *CallResult, cause string) {
|
|
resultMsg := fmt.Sprintf(
|
|
"Id:%d,Code=%d,Msg=%s,Elapse=%v",
|
|
result.Id, result.Code, result.Message, result.Elapse)
|
|
logrus.Warnf("result:%s (cause:%s)", resultMsg, cause)
|
|
}
|
|
|
|
// 启动AI助手
|
|
func (a *assistant) Start() bool {
|
|
logrus.Debug("AI助手启动")
|
|
|
|
// 节流 周期性向目标发送
|
|
var throttle <-chan time.Time
|
|
if a.lps > 0 {
|
|
//间隔时间
|
|
interval := time.Duration(1e9 / a.lps)
|
|
logrus.Debugf("启动节流控制 间隔: %v", interval)
|
|
throttle = time.Tick(interval)
|
|
}
|
|
|
|
// 初始化上下文和设置取消函数
|
|
a.ctx, a.cancelFunc = context.WithTimeout(context.Background(), a.duration)
|
|
|
|
// 重置调用次数
|
|
a.callCount = 0
|
|
|
|
// 设置状态为已启动
|
|
atomic.StoreUint32(&a.status, STATUS_STARTED)
|
|
|
|
go func() {
|
|
logrus.Debug("请求处理中...")
|
|
a.handleReq(throttle)
|
|
logrus.Infof("停止 调用次数:%d", a.callCount)
|
|
}()
|
|
|
|
return true
|
|
}
|
|
|
|
// 手动停止
|
|
func (a *assistant) Stop() error {
|
|
return nil
|
|
}
|
|
|
|
func (a *assistant) ShowResult() {
|
|
statistics := &Statistics{}
|
|
max := statistics.MaxElapse
|
|
min := statistics.MinElapse
|
|
|
|
for r := range a.resultCh {
|
|
if r.Code != RES_CODE_SUCCESS {
|
|
logrus.WithField("result", r.Code).Debug("失败的结果")
|
|
continue
|
|
}
|
|
msg := &pb.UserMessage{}
|
|
// logrus.Debugf("结果字节长度 %d", len(r.Resp.Resp))
|
|
if err := proto.Unmarshal(r.Resp.Resp, msg); err != nil {
|
|
logrus.Error("结果解析失败")
|
|
continue
|
|
}
|
|
// 协议名
|
|
statistics.Route = fmt.Sprintf("%s.%s", msg.MainType, msg.SubType)
|
|
// 总耗时
|
|
statistics.ElapseTotal += float64(r.Elapse.Nanoseconds())
|
|
|
|
if float64(r.Elapse.Nanoseconds()) > max {
|
|
max = float64(r.Elapse.Nanoseconds())
|
|
} else {
|
|
min = float64(r.Elapse.Nanoseconds())
|
|
}
|
|
logrus.WithFields(logrus.Fields{"mainType": msg.MainType, "subType": msg.SubType, "耗时": r.Elapse}).Info("结果")
|
|
}
|
|
if a.callCount == 1 {
|
|
min = max
|
|
}
|
|
// 调用次数
|
|
statistics.CallCount = a.callCount
|
|
statistics.ElapseTotal = common.FormatFloatCommon(statistics.ElapseTotal / 1e6)
|
|
statistics.MaxElapse = common.FormatFloatCommon(max / 1e6)
|
|
statistics.MinElapse = common.FormatFloatCommon(min / 1e6)
|
|
//平均耗时=总耗时/调用次数
|
|
statistics.AvgElapse = common.FormatFloatCommon(statistics.ElapseTotal / float64(statistics.CallCount))
|
|
|
|
a.obs.Notify(observer.EVENT_RESULT, statistics)
|
|
}
|