go_dreamfactory/cmd/v2/lib/assistant.go
2022-12-06 09:28:37 +08:00

256 lines
5.3 KiB
Go

package lib
import (
"context"
"errors"
"fmt"
"go_dreamfactory/pb"
"math"
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
)
type assistant struct {
timeout time.Duration //处理超时时间
lps uint32 //每秒请求量
duration time.Duration //持续时间
concurrency uint32 //并发量
ctx context.Context
callCount int64 //调用次数,每次启动时重置
goPool GoPool //携程池
cancelFunc context.CancelFunc //取消
caller Handler //处理器
status uint32 //状态
resultCh chan *CallResult //调用结果
}
func NewAssistant(pm ParamMgr) (Aiassistant, error) {
if err := pm.Check(); err != nil {
return nil, err
}
a := &assistant{
timeout: pm.Timeout,
lps: pm.Lps,
duration: pm.Duration,
caller: pm.Caller,
status: STATUS_ORIGINAL,
resultCh: pm.ResultCh,
}
if err := a.init(); err != nil {
return nil, err
}
return a, nil
}
func (a *assistant) init() error {
logrus.Info("AI助手初始化")
//并发量的计算
//并发量 ≈ 超时时间 / 发送的间隔时间
var total = int64(a.timeout)/int64(1e9/a.lps) + 1
if total > math.MaxInt32 {
total = math.MaxInt32
}
a.concurrency = uint32(total)
gp, err := NewGoPool(a.concurrency)
if err != nil {
return err
}
a.goPool = gp
logrus.WithField("并发量", a.concurrency).Info("AI助手初始化完成 并发量 ")
return nil
}
func (a *assistant) callOne(req *RawReq) *RawResp {
atomic.AddInt64(&a.callCount, 1)
if req == nil {
return &RawResp{ID: -1, Err: errors.New("无效的请求")}
}
var rawResp RawResp
start := time.Now().UnixNano()
resp, err := a.caller.Call(req.Req, a.timeout)
end := time.Now().UnixNano()
elapsedTime := time.Duration(end - start)
if err != nil {
errMsg := fmt.Sprintf("调用失败: %v", err)
rawResp = RawResp{
ID: req.ID,
Err: errors.New(errMsg),
Elapse: elapsedTime,
}
} else {
rawResp = RawResp{
ID: req.ID,
Resp: resp,
Elapse: elapsedTime,
}
}
return &rawResp
}
// 异步调用接口
func (a *assistant) asyncCall() {
a.goPool.Take()
go func() {
defer func() {
a.goPool.Return()
}()
req := a.caller.BuildReq()
//调用状态 0未调用 1调用结束 2调用超时
var callStatus uint32
// 超时处理
timer := time.AfterFunc(a.timeout, func() {
if !atomic.CompareAndSwapUint32(&callStatus, 0, 2) {
return
}
result := &CallResult{
Id: req.ID,
Req: req,
Code: RES_CODE_CALL_TIMEOUT,
Message: fmt.Sprintf("超时,期望< %v", a.timeout),
Elapse: a.timeout,
}
a.sendResult(result)
})
resp := a.callOne(&req)
logrus.WithField("耗时", resp.Elapse).Debug("实际耗时")
if !atomic.CompareAndSwapUint32(&callStatus, 0, 1) {
return
}
timer.Stop()
var result *CallResult
if resp.Err != nil {
result = &CallResult{
Id: req.ID,
Req: req,
Code: RES_CODE_ERROR_CALL,
Message: resp.Err.Error(),
Elapse: resp.Elapse,
}
} else {
result = a.caller.Check(req, *resp)
result.Elapse = resp.Elapse
}
a.sendResult(result)
}()
}
// 停止发送
func (a *assistant) prepareStop(ctxErr error) {
logrus.WithField("cause", ctxErr).Info("准备停止")
atomic.CompareAndSwapUint32(&a.status, STATUS_STARTED, STATUS_STOPPING)
logrus.Info("关闭结果通道")
close(a.resultCh)
atomic.StoreUint32(&a.status, STATUS_STOPPED)
}
// 发送请求即调用接口
func (a *assistant) handleReq(tick <-chan time.Time) {
for {
select {
case <-a.ctx.Done():
a.prepareStop(a.ctx.Err())
return
default:
}
a.asyncCall()
if a.lps > 0 {
select {
case <-tick:
case <-a.ctx.Done():
a.prepareStop(a.ctx.Err())
return
}
}
}
}
func (a *assistant) sendResult(result *CallResult) bool {
if atomic.LoadUint32(&a.status) != STATUS_STARTED {
return false
}
select {
case a.resultCh <- result:
return true
default:
return false
}
}
//注册账号
func (a *assistant) registUser() {
}
//登录账号
func (a *assistant) login() {
}
// 启动AI助手
func (a *assistant) Start() bool {
logrus.Infoln("AI助手启动")
// 节流 周期性向目标发送
var ticker <-chan time.Time
if a.lps > 0 {
//间隔时间
interval := time.Duration(1e9 / a.lps)
logrus.Infof("启动节流控制 间隔: %v", interval)
ticker = time.Tick(interval)
}
// 初始化上下文和设置取消函数
a.ctx, a.cancelFunc = context.WithTimeout(context.Background(), a.duration)
// 重置调用次数
a.callCount = 0
// 设置状态为已启动
atomic.StoreUint32(&a.status, STATUS_STARTED)
go func() {
logrus.Infoln("请求处理...")
a.handleReq(ticker)
logrus.Infof("停止 调用次数:%d", a.callCount)
}()
return true
}
// 手动停止
func (a *assistant) Stop() error {
return nil
}
func (a *assistant) ShowResult() {
for r := range a.resultCh {
if r.Code != RES_CODE_SUCCESS {
logrus.WithField("result", r.Code).Debug("失败的结果")
continue
}
msg := &pb.UserMessage{}
// logrus.Debugf("结果字节长度 %d", len(r.Resp.Resp))
if err := proto.Unmarshal(r.Resp.Resp, msg); err != nil {
logrus.Error("结果解析失败")
continue
}
logrus.WithFields(logrus.Fields{"mainType": msg.MainType, "subType": msg.SubType}).Debug("读取结果")
}
}