From c89ad0d2e4be62d498ee9508e22765202184e49e Mon Sep 17 00:00:00 2001 From: liwei1dao Date: Mon, 19 Dec 2022 10:30:30 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=B9=BF=E6=92=AD=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- modules/gateway/agent.go | 37 +++++++++++++++++++++----------- modules/gateway/agentmgr_comp.go | 24 ++++++++++++++++----- modules/gateway/core.go | 1 + 3 files changed, 45 insertions(+), 17 deletions(-) diff --git a/modules/gateway/agent.go b/modules/gateway/agent.go index 59181cb3c..5fc7b4bd0 100644 --- a/modules/gateway/agent.go +++ b/modules/gateway/agent.go @@ -33,7 +33,7 @@ func newAgent(gateway IGateway, conn *websocket.Conn) *Agent { wsConn: conn, sessionId: id.NewUUId(), uId: "", - writeChan: make(chan *pb.UserMessage, 2), + writeChan: make(chan []byte, 2), closeSignal: make(chan bool), state: 1, } @@ -50,7 +50,7 @@ type Agent struct { sessionId string uId string wId string - writeChan chan *pb.UserMessage + writeChan chan []byte closeSignal chan bool state int32 //状态 0 关闭 1 运行 2 关闭中 wg sync.WaitGroup @@ -102,8 +102,8 @@ locp: func (this *Agent) writeLoop() { defer this.wg.Done() var ( - data []byte - err error + // data []byte + err error ) locp: for { @@ -112,8 +112,8 @@ locp: break locp case msg, ok := <-this.writeChan: if ok { - data, err = proto.Marshal(msg) - if err = this.wsConn.WriteMessage(websocket.BinaryMessage, data); err != nil { + // data, err = proto.Marshal(msg) + if err = this.wsConn.WriteMessage(websocket.BinaryMessage, msg); err != nil { this.gateway.Errorf("agent:%s uId:%d WriteMessage err:%v", this.sessionId, this.uId, err) go this.Close() } @@ -128,7 +128,7 @@ locp: //安全认证 所有协议 func (this *Agent) secAuth(msg *pb.UserMessage) (code pb.ErrorCode, err error) { if !utils.ValidSecretKey(msg.Sec) { //验证失败 - log.Errorf("%v", msg.Sec) + this.gateway.Errorf("%v", msg.Sec) return pb.ErrorCode_SignError, fmt.Errorf("key invalid") } return this.decodeUserData(msg) @@ -139,7 +139,7 @@ func (this *Agent) decodeUserData(msg *pb.UserMessage) (code pb.ErrorCode, err e base64Str := msg.Sec dec, err := base64.StdEncoding.DecodeString(base64Str[35:]) if err != nil { - log.Errorf("base64 decode err %v", err) + this.gateway.Errorf("base64 decode err %v", err) return pb.ErrorCode_DecodeError, nil } now := configure.Now().Unix() @@ -147,7 +147,7 @@ func (this *Agent) decodeUserData(msg *pb.UserMessage) (code pb.ErrorCode, err e timestamp := jsonRet.Get("timestamp").Int() //秘钥30秒失效 if now-time.Unix(timestamp, 0).Unix() > 30 { - log.Errorf("last timestamp:%v more than 30s", timestamp) + this.gateway.Errorf("last timestamp:%v more than 30s", timestamp) return pb.ErrorCode_TimestampTimeout, fmt.Errorf("sec key expire") } @@ -161,13 +161,13 @@ func (this *Agent) decodeUserData(msg *pb.UserMessage) (code pb.ErrorCode, err e } ad, err := anypb.New(req) if err != nil { - log.Errorf("decodeUserData pb err:%v", err) + this.gateway.Errorf("decodeUserData pb err:%v", err) return pb.ErrorCode_PbError, err } msg.Data = ad } else { if msg.MainType != string(comm.ModuleNotify) && this.UserId() == "" { - log.Errorf("[%v.%v] Agent UId empty", msg.MainType, msg.SubType) + this.gateway.Errorf("[%v.%v] Agent UId empty", msg.MainType, msg.SubType) return pb.ErrorCode_AgentUidEmpty, fmt.Errorf("no login") } } @@ -202,7 +202,20 @@ func (this *Agent) WriteMsg(msg *pb.UserMessage) (err error) { if atomic.LoadInt32(&this.state) != 1 { return } - this.writeChan <- msg + var ( + data []byte + ) + if data, err = proto.Marshal(msg); err == nil { + this.writeChan <- data + } + return +} + +func (this *Agent) WriteBytes(data []byte) (err error) { + if atomic.LoadInt32(&this.state) != 1 { + return + } + this.writeChan <- data return } diff --git a/modules/gateway/agentmgr_comp.go b/modules/gateway/agentmgr_comp.go index cf2a15027..a0caf7137 100644 --- a/modules/gateway/agentmgr_comp.go +++ b/modules/gateway/agentmgr_comp.go @@ -11,6 +11,8 @@ import ( "go_dreamfactory/lego/core" "go_dreamfactory/lego/core/cbase" "go_dreamfactory/lego/sys/log" + + "google.golang.org/protobuf/proto" ) /* @@ -128,18 +130,24 @@ func (this *AgentMgrComp) SendMsgToAgent(ctx context.Context, args *pb.AgentSend } // SendMsgToAgents 向多个户发送消息 -func (this *AgentMgrComp) SendMsgToAgents(ctx context.Context, args *pb.BatchMessageReq, reply *pb.RPCMessageReply) error { +func (this *AgentMgrComp) SendMsgToAgents(ctx context.Context, args *pb.BatchMessageReq, reply *pb.RPCMessageReply) (err error) { + var ( + data []byte + ) msg := &pb.UserMessage{ MainType: args.MainType, SubType: args.SubType, Data: args.Data, } this.module.Debugf("SendMsgToAgents: agents:%v msg:%v", args.UserSessionIds, msg) + if data, err = proto.Marshal(msg); err != nil { + return + } for _, v := range args.UserSessionIds { if a, ok := this.agents.Load(v); ok { agent := a.(IAgent) if agent.UserId() != "" { //自发送登录用户 - agent.WriteMsg(msg) + agent.WriteBytes(data) } } } @@ -147,21 +155,27 @@ func (this *AgentMgrComp) SendMsgToAgents(ctx context.Context, args *pb.BatchMes } // SendMsgToAllAgent 向所有户发送消息 -func (this *AgentMgrComp) SendMsgToAllAgent(ctx context.Context, args *pb.BroadCastMessageReq, reply *pb.RPCMessageReply) error { +func (this *AgentMgrComp) SendMsgToAllAgent(ctx context.Context, args *pb.BroadCastMessageReq, reply *pb.RPCMessageReply) (err error) { + var ( + data []byte + ) msg := &pb.UserMessage{ MainType: args.MainType, SubType: args.SubType, Data: args.Data, } this.module.Debugf("SendMsgToAllAgent: msg:%v", msg) + if data, err = proto.Marshal(msg); err != nil { + return + } this.agents.Range(func(key, value any) bool { agent := value.(IAgent) if agent.UserId() != "" { //只发送登录用户 - agent.WriteMsg(msg) + agent.WriteBytes(data) } return true }) - return nil + return } // CloseAgent 关闭某个用户 diff --git a/modules/gateway/core.go b/modules/gateway/core.go index 33a5f3575..fc511e6e8 100644 --- a/modules/gateway/core.go +++ b/modules/gateway/core.go @@ -19,6 +19,7 @@ type ( Bind(uId string, wId string) UnBind() WriteMsg(msg *pb.UserMessage) (err error) + WriteBytes(data []byte) (err error) Close() //主动关闭接口 } // IGateway 网关模块 接口定义