package lib import ( "context" "errors" "fmt" "go_dreamfactory/cmd/v2/lib/common" "go_dreamfactory/cmd/v2/service/observer" "go_dreamfactory/pb" "math" "sync/atomic" "time" "github.com/sirupsen/logrus" "google.golang.org/protobuf/proto" ) type assistant struct { timeout time.Duration //处理超时时间 lps uint32 //每秒请求量 duration time.Duration //持续时间 concurrency uint32 //并发量 ctx context.Context callCount int64 //调用次数,每次启动时重置 goPool GoPool //携程池 cancelFunc context.CancelFunc //取消 caller Handler //处理器 status uint32 //状态 resultCh chan *CallResult //调用结果 obs observer.Observer } func NewAssistant(obs observer.Observer, pm ParamMgr) (Aiassistant, error) { if err := pm.Check(); err != nil { return nil, err } a := &assistant{ timeout: pm.Timeout, lps: pm.Lps, duration: pm.Duration, caller: pm.Caller, status: STATUS_ORIGINAL, resultCh: pm.ResultCh, obs: obs, } if err := a.init(); err != nil { return nil, err } return a, nil } func (a *assistant) init() error { logrus.Debug("AI助手初始化") //并发量的计算 //并发量 ≈ 超时时间 / 发送的间隔时间 var total = int64(a.timeout)/int64(1e9/a.lps) + 1 if total > math.MaxInt32 { total = math.MaxInt32 } a.concurrency = uint32(total) gp, err := NewGoPool(a.concurrency) if err != nil { return err } a.goPool = gp logrus.WithField("并发量", a.concurrency).Debug("AI助手初始化完成 ") return nil } func (a *assistant) callOne(req *RawReq) *RawResp { atomic.AddInt64(&a.callCount, 1) logrus.WithField("count", &a.callCount).WithField("Len", len(req.Req)).Debug("调用协议") if req == nil { return &RawResp{ID: -1, Err: errors.New("无效的请求")} } var rawResp RawResp start := time.Now().UnixNano() resp, err := a.caller.Call(req.Req) end := time.Now().UnixNano() elapsedTime := time.Duration(end - start) if err != nil { errMsg := fmt.Sprintf("调用失败: %v", err) rawResp = RawResp{ ID: req.ID, Err: errors.New(errMsg), Elapse: elapsedTime, } } else { rawResp = RawResp{ ID: req.ID, Resp: resp, Elapse: elapsedTime, } } return &rawResp } // 异步调用接口 func (a *assistant) asyncCall() { a.goPool.Take() go func() { defer func() { if p := recover(); p != nil { err, ok := interface{}(p).(error) var errMsg string if ok { errMsg = fmt.Sprintf("Async Call Panic! (error: %s)", err) } else { errMsg = fmt.Sprintf("Async Call Panic! (clue: %#v)", p) } logrus.Errorln(errMsg) result := &CallResult{ Id: -1, Code: RES_CODE_FATAL_CALL, Message: errMsg, } a.sendResult(result) } a.goPool.Return() }() req := a.caller.BuildReq() if req.ID == 0 { return } //调用状态 0未调用 1调用结束 2调用超时 var callStatus uint32 // 超时处理 timer := time.AfterFunc(a.timeout, func() { if !atomic.CompareAndSwapUint32(&callStatus, 0, 2) { return } result := &CallResult{ Id: req.ID, Req: req, Code: RES_CODE_CALL_TIMEOUT, Message: fmt.Sprintf("超时,期望小于 %v", a.timeout), Elapse: a.timeout, } a.sendResult(result) }) resp := a.callOne(&req) logrus.WithField("耗时", resp.Elapse).Debug("实际耗时") if !atomic.CompareAndSwapUint32(&callStatus, 0, 1) { return } timer.Stop() var result *CallResult if resp.Err != nil { result = &CallResult{ Id: req.ID, Req: req, Code: RES_CODE_ERROR_CALL, Message: resp.Err.Error(), Elapse: resp.Elapse, } } else { result = a.caller.Check(req, *resp) result.Elapse = resp.Elapse } a.sendResult(result) }() } // 停止发送 func (a *assistant) prepareStop(ctxErr error) { logrus.WithField("cause", ctxErr).Info("准备停止") atomic.CompareAndSwapUint32(&a.status, STATUS_STARTED, STATUS_STOPPING) logrus.Debug("关闭结果通道") close(a.resultCh) atomic.StoreUint32(&a.status, STATUS_STOPPED) } // 发送请求即调用接口 func (a *assistant) handleReq(throttle <-chan time.Time) { for { select { case <-a.ctx.Done(): a.prepareStop(a.ctx.Err()) return default: } a.asyncCall() if a.lps > 0 { select { case <-throttle: case <-a.ctx.Done(): a.prepareStop(a.ctx.Err()) return } } } } func (a *assistant) sendResult(result *CallResult) bool { if atomic.LoadUint32(&a.status) != STATUS_STARTED { a.printResult(result, "已停止") return false } select { case a.resultCh <- result: return true default: a.printResult(result, "结果通道已满") return false } } func (a *assistant) printResult(result *CallResult, cause string) { resultMsg := fmt.Sprintf( "Id:%d,Code=%d,Msg=%s,Elapse=%v", result.Id, result.Code, result.Message, result.Elapse) logrus.Warnf("result:%s (cause:%s)", resultMsg, cause) } // 启动AI助手 func (a *assistant) Start() bool { logrus.Debug("AI助手启动") // 节流 周期性向目标发送 var throttle <-chan time.Time if a.lps > 0 { //间隔时间 interval := time.Duration(1e9 / a.lps) logrus.Debugf("启动节流控制 间隔: %v", interval) throttle = time.Tick(interval) } // 初始化上下文和设置取消函数 a.ctx, a.cancelFunc = context.WithTimeout(context.Background(), a.duration) // 重置调用次数 a.callCount = 0 // 设置状态为已启动 atomic.StoreUint32(&a.status, STATUS_STARTED) go func() { logrus.Debug("请求处理中...") a.handleReq(throttle) logrus.Infof("停止 调用次数:%d", a.callCount) }() return true } // 手动停止 func (a *assistant) Stop() error { return nil } func (a *assistant) ShowResult() { statistics := &Statistics{} max := statistics.MaxElapse min := statistics.MinElapse for r := range a.resultCh { if r.Code != RES_CODE_SUCCESS { logrus.WithField("result", r.Code).Debug("失败的结果") continue } msg := &pb.UserMessage{} // logrus.Debugf("结果字节长度 %d", len(r.Resp.Resp)) if err := proto.Unmarshal(r.Resp.Resp, msg); err != nil { logrus.Error("结果解析失败") continue } // 协议名 statistics.Route = fmt.Sprintf("%s.%s", msg.MainType, msg.SubType) // 总耗时 statistics.ElapseTotal += float64(r.Elapse.Nanoseconds()) if float64(r.Elapse.Nanoseconds()) > max { max = float64(r.Elapse.Nanoseconds()) } else { min = float64(r.Elapse.Nanoseconds()) } logrus.WithFields(logrus.Fields{"mainType": msg.MainType, "subType": msg.SubType, "耗时": r.Elapse}).Info("结果") } if a.callCount == 1 { min = max } // 调用次数 statistics.CallCount = a.callCount statistics.ElapseTotal = common.FormatFloatCommon(statistics.ElapseTotal / 1e6) statistics.MaxElapse = common.FormatFloatCommon(max / 1e6) statistics.MinElapse = common.FormatFloatCommon(min / 1e6) //平均耗时=总耗时/调用次数 statistics.AvgElapse = common.FormatFloatCommon(statistics.ElapseTotal / float64(statistics.CallCount)) a.obs.Notify(observer.EVENT_RESULT, statistics) }