package gateway import ( "context" "encoding/base64" "fmt" "go_dreamfactory/comm" "go_dreamfactory/pb" "go_dreamfactory/sys/configure" "go_dreamfactory/sys/db" "go_dreamfactory/utils" "strings" "sync" "sync/atomic" "time" "go_dreamfactory/lego/sys/log" "go_dreamfactory/lego/utils/container/id" "github.com/gorilla/websocket" "github.com/tidwall/gjson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" ) /* 用户代理对象 封装用户socket 对象 处理用户的消息读取 写入 关闭等操作 */ func newAgent(gateway IGateway, conn *websocket.Conn) *Agent { agent := &Agent{ gateway: gateway, wsConn: conn, sessionId: id.NewUUId(), uId: "", writeChan: make(chan [][]byte, 2), closeSignal: make(chan bool), state: 1, } agent.wg.Add(2) go agent.readLoop() go agent.writeLoop() return agent } // 用户代理 type Agent struct { gateway IGateway wsConn *websocket.Conn sessionId string uId string wId string cwid string group int32 writeChan chan [][]byte closeSignal chan bool state int32 //状态 0 关闭 1 运行 2 关闭中 wg sync.WaitGroup queueIndex int32 //排队编号 lastpushtime time.Time //上次推送时间 } func (this *Agent) readLoop() { defer this.wg.Done() var ( data []byte msg *pb.UserMessage = &pb.UserMessage{} errdata *pb.ErrorData err error ) locp: for { if _, data, err = this.wsConn.ReadMessage(); err != nil { this.gateway.Errorf("agent:%s uId:%s ReadMessage err:%v", this.sessionId, this.uId, err) go this.Close() break locp } if err = proto.Unmarshal(data, msg); err != nil { this.gateway.Errorf("agent:%s uId:%s Unmarshal err:%v", this.sessionId, this.uId, err) go this.Close() break locp } else { this.wsConn.SetReadDeadline(time.Now().Add(time.Second * 60)) // this.gateway.Debugf("----------1 agent:%s uId:%s MainType:%s SubType:%s ", this.sessionId, this.uId, msg.MainType, msg.SubType) if msg.MainType == string(comm.ModuleGate) { //心跳消息 无需校验秘钥 data, _ := anypb.New(&pb.GatewayHeartbeatResp{ Timestamp: configure.Now().Unix(), }) this.WriteMsg(&pb.UserMessage{ MsgId: msg.MsgId, MainType: string(comm.ModuleGate), SubType: "heartbeat", Data: data, }) continue } errdata = this.secAuth(msg) if errdata == nil { // this.gateway.Debugf("----------2 agent:%s uId:%s MainType:%s SubType:%s ", this.sessionId, this.uId, msg.MainType, msg.SubType) if this.gateway.IsOpenLoginQueue() && msg.MainType == string(comm.ModuleUser) && msg.SubType == "login" { //登录排队 if this.uId == "" { if this.queueIndex, err = this.gateway.InLoginQueue(this.sessionId, msg); err != nil { this.gateway.Errorf("messageDistribution err:%v", err) data, _ := anypb.New(&pb.NotifyErrorNotifyPush{ MsgId: msg.MsgId, ReqMainType: msg.MainType, ReqSubType: msg.SubType, Arg: msg.Data, Code: pb.ErrorCode_GatewayException, Err: &pb.ErrorData{Code: pb.ErrorCode_GatewayException, Title: pb.ErrorCode_GatewayException.String(), Datastring: err.Error()}, }) err = this.WriteMsg(&pb.UserMessage{ MsgId: msg.MsgId, MainType: comm.MainTypeNotify, SubType: comm.SubTypeErrorNotify, Data: data, }) go this.Close() break locp } else { if this.queueIndex > 0 { this.lastpushtime = time.Now() data, _ := anypb.New(&pb.UserLoginQueueChangePush{ Index: this.queueIndex, }) err = this.WriteMsg(&pb.UserMessage{ MainType: string(comm.ModuleUser), SubType: "loginqueuechange", Data: data, }) } } } else { data, _ := anypb.New(&pb.NotifyErrorNotifyPush{ MsgId: msg.MsgId, ReqMainType: msg.MainType, ReqSubType: msg.SubType, Arg: msg.Data, Code: pb.ErrorCode_ReqParameterError, Err: &pb.ErrorData{Code: pb.ErrorCode_ReqParameterError, Title: "Repeat login!"}, }) err = this.WriteMsg(&pb.UserMessage{ MsgId: msg.MsgId, MainType: comm.MainTypeNotify, SubType: comm.SubTypeErrorNotify, Data: data, }) } continue } if err = this.messageDistribution(msg); err != nil { this.gateway.Errorf("messageDistribution err:%v", err) data, _ := anypb.New(&pb.NotifyErrorNotifyPush{ MsgId: msg.MsgId, ReqMainType: msg.MainType, ReqSubType: msg.SubType, Arg: msg.Data, Code: pb.ErrorCode_GatewayException, Err: &pb.ErrorData{Code: pb.ErrorCode_GatewayException, Title: pb.ErrorCode_GatewayException.String(), Message: err.Error()}, }) err = this.WriteMsg(&pb.UserMessage{ MsgId: msg.MsgId, MainType: comm.MainTypeNotify, SubType: comm.SubTypeErrorNotify, Data: data, }) go this.Close() break locp } } else { this.gateway.Errorf("agent:%s uId:%s 密钥无效 err:%v", this.sessionId, this.uId, err) data, _ := anypb.New(&pb.NotifyErrorNotifyPush{ MsgId: msg.MsgId, ReqMainType: msg.MainType, ReqSubType: msg.SubType, Code: errdata.Code, Err: errdata, }) if err = this.WriteMsg(&pb.UserMessage{ MsgId: msg.MsgId, MainType: comm.MainTypeNotify, SubType: comm.SubTypeErrorNotify, Data: data, }); err != nil { go this.Close() break locp } } } } this.gateway.Debugf("agent:%s uId:%s readLoop end!", this.sessionId, this.uId) } func (this *Agent) writeLoop() { defer this.wg.Done() var ( // data []byte err error ) locp: for { select { case <-this.closeSignal: break locp case msgs, ok := <-this.writeChan: if ok { for _, msg := range msgs { //data, err = proto.Marshal(msg) if err = this.wsConn.WriteMessage(websocket.BinaryMessage, msg); err != nil { this.gateway.Errorf("agent:%s uId:%s WriteMessage err:%v", this.sessionId, this.uId, err) go this.Close() } } } else { go this.Close() } } } this.gateway.Debugf("agent:%s uId:%s writeLoop end!", this.sessionId, this.uId) } // 安全认证 所有协议 func (this *Agent) secAuth(msg *pb.UserMessage) (errdata *pb.ErrorData) { if !utils.ValidSecretKey(msg.Sec) { //验证失败 this.gateway.Errorf("%v", msg.Sec) errdata = &pb.ErrorData{ Code: pb.ErrorCode_SignError, Title: pb.ErrorCode_SignError.ToString(), Message: "key invalid", } return } return this.decodeUserData(msg) } // 解码 func (this *Agent) decodeUserData(msg *pb.UserMessage) (errdata *pb.ErrorData) { base64Str := msg.Sec dec, err := base64.StdEncoding.DecodeString(base64Str[35:]) if err != nil { this.gateway.Errorf("base64 decode err %v", err) errdata = &pb.ErrorData{ Code: pb.ErrorCode_DecodeError, Title: pb.ErrorCode_DecodeError.ToString(), } return } now := configure.Now().Unix() jsonRet := gjson.Parse(string(dec)) timestamp := jsonRet.Get("timestamp").Int() //秘钥30秒失效 if now-time.Unix(timestamp, 0).Unix() > 30 { this.gateway.Errorf("now:%v last timestamp:%v more than 30s", now, timestamp) errdata = &pb.ErrorData{ Code: pb.ErrorCode_TimestampTimeout, Title: pb.ErrorCode_TimestampTimeout.ToString(), Message: "sec key expire", } return } //只有login的时候才需要设置Data if msg.MainType == string(comm.ModuleUser) && msg.SubType == "login" { serverId := jsonRet.Get("serverId").String() account := jsonRet.Get("account").String() req := &pb.UserLoginReq{ Account: account, Sid: serverId, } ad, err := anypb.New(req) if err != nil { this.gateway.Errorf("decodeUserData pb err:%v", err) errdata = &pb.ErrorData{ Code: pb.ErrorCode_PbError, Title: pb.ErrorCode_PbError.ToString(), Message: err.Error(), } return } msg.Data = ad } else { switch msg.MainType { case string(comm.ModuleNotify), string(comm.ModuleGate): return default: if this.UserId() == "" { this.gateway.Errorf("[%v.%v] Agent UId empty", msg.MainType, msg.SubType) errdata = &pb.ErrorData{ Code: pb.ErrorCode_AgentUidEmpty, Title: pb.ErrorCode_AgentUidEmpty.ToString(), Message: "no login", } return } } } return } func (this *Agent) SessionId() string { return this.sessionId } func (this *Agent) IP() string { return this.wsConn.RemoteAddr().String() } func (this *Agent) UserId() string { return this.uId } func (this *Agent) Group() int32 { return this.group } func (this *Agent) WorkerId() string { return this.wId } func (this *Agent) Bind(uId string, wId string) { this.uId = uId this.wId = wId } func (this *Agent) UnBind() { this.uId = "" } func (this *Agent) SetCrosssId(wId string) { this.cwid = wId } func (this *Agent) CrosssWorkerId() string { return this.cwid } func (this *Agent) WriteMsg(msg *pb.UserMessage) (err error) { if atomic.LoadInt32(&this.state) != 1 { return } var ( data []byte ) if data, err = proto.Marshal(msg); err == nil { this.writeChan <- [][]byte{data} } return } func (this *Agent) WriteMsgs(msgs []*pb.UserMessage) (err error) { if atomic.LoadInt32(&this.state) != 1 { return } var ( datas [][]byte = make([][]byte, 0) data []byte ) for _, msg := range msgs { if data, err = proto.Marshal(msg); err == nil { datas = append(datas, data) } } this.writeChan <- datas return } func (this *Agent) WriteBytes(data []byte) (err error) { if atomic.LoadInt32(&this.state) != 1 { err = fmt.Errorf("Uid%s Staet:%d", this.uId, this.state) return } this.writeChan <- [][]byte{data} return } // 外部代用关闭 func (this *Agent) Close() { if !atomic.CompareAndSwapInt32(&this.state, 1, 2) { return } this.wsConn.Close() this.closeSignal <- true this.wg.Wait() atomic.StoreInt32(&this.state, 0) this.gateway.DisConnect(this) } //处理用户消息 提供给外部使用 比如 登录等待逻辑 func (this *Agent) HandleMessage(msg *pb.UserMessage) (err error) { if err = this.messageDistribution(msg); err != nil { this.gateway.Errorf("messageDistribution err:%v", err) data, _ := anypb.New(&pb.NotifyErrorNotifyPush{ MsgId: msg.MsgId, ReqMainType: msg.MainType, ReqSubType: msg.SubType, Arg: msg.Data, Code: pb.ErrorCode_GatewayException, Err: &pb.ErrorData{Code: pb.ErrorCode_GatewayException, Title: pb.ErrorCode_GatewayException.String(), Message: err.Error()}, }) err = this.WriteMsg(&pb.UserMessage{ MsgId: msg.MsgId, MainType: comm.MainTypeNotify, SubType: comm.SubTypeErrorNotify, Data: data, }) } return } // 分发用户消息 func (this *Agent) messageDistribution(msg *pb.UserMessage) (err error) { var ( req *pb.AgentMessage = getmsg() reply *pb.RPCMessageReply = getmsgreply() serviceTag string = "" servicePath string = comm.Service_Worker rule string ok bool ) defer func() { putmsg(req) putmsgreply(reply) }() req.Ip = this.IP() req.UserSessionId = this.sessionId req.UserId = this.uId req.Group = this.group req.ServiceTag = this.gateway.Service().GetTag() req.GatewayServiceId = this.gateway.Service().GetId() req.MainType = msg.MainType req.SubType = msg.SubType req.Message = msg.Data if rule, ok = this.gateway.GetMsgDistribute(req.MainType, req.SubType); ok { paths := strings.Split(rule, "/") if len(paths) == 3 { if paths[0] == "~" { serviceTag = db.CrossTag() } else { serviceTag = paths[0] } servicePath = fmt.Sprintf("%s/%s", paths[1], paths[2]) } else if len(paths) == 2 { if paths[0] == "~" { serviceTag = db.CrossTag() if len(this.cwid) > 0 { //已经绑定 跨服服务器 servicePath = fmt.Sprintf("%s/%s", paths[1], this.cwid) } else { servicePath = paths[1] } } else { servicePath = rule } } else { this.gateway.Errorf("messageDistribution mtype:%s stype:%s rule:%s is empty!", req.MainType, req.SubType, rule) return } } else { if len(this.wId) > 0 { //已经绑定worker 服务器 servicePath = fmt.Sprintf("%s/%s", comm.Service_Worker, this.wId) } } stime := time.Now() // this.gateway.Debugf("----------3 agent:%s uId:%s MainType:%s SubType:%s ", this.sessionId, this.uId, msg.MainType, msg.SubType) if len(serviceTag) == 0 { // this.gateway.Debugf("----------4 agent:%s uId:%s MainType:%s SubType:%s ", this.sessionId, this.uId, msg.MainType, msg.SubType) if err = this.gateway.Service().RpcCall(context.Background(), servicePath, string(comm.Rpc_GatewayRoute), req, reply); err != nil { this.gateway.Error("[UserResponse]", log.Field{Key: "uid", Value: this.uId}, log.Field{Key: "serviceTag", Value: serviceTag}, log.Field{Key: "servicePath", Value: servicePath}, log.Field{Key: "req", Value: fmt.Sprintf("%s:%s %v", req.MainType, req.SubType, req.Message.String())}, log.Field{Key: "err", Value: err.Error()}, ) return } } else { //跨集群调用 // this.gateway.Debugf("----------5 agent:%s uId:%s servicePath:%s MainType:%s SubType:%s ", this.sessionId, this.uId, msg.ServicePath, msg.MainType, msg.SubType) if msg.ServicePath != "" { //客户端是否制定目标服务器 /wroker/woker0 servicePath = msg.ServicePath } if err = this.gateway.Service().AcrossClusterRpcCall(context.Background(), serviceTag, servicePath, string(comm.Rpc_GatewayRoute), req, reply); err != nil { this.gateway.Error("[UserResponse]", log.Field{Key: "uid", Value: this.uId}, log.Field{Key: "serviceTag", Value: serviceTag}, log.Field{Key: "servicePath", Value: servicePath}, log.Field{Key: "req", Value: fmt.Sprintf("%s:%s %v", req.MainType, req.SubType, req.Message.String())}, log.Field{Key: "err", Value: err.Error()}, ) return } } this.gateway.Debug("[UserResponse]", log.Field{Key: "t", Value: time.Since(stime).Milliseconds()}, log.Field{Key: "uid", Value: this.uId}, log.Field{Key: "req", Value: fmt.Sprintf("%s:%s %v", req.MainType, req.SubType, req.Message.String())}, log.Field{Key: "reply", Value: reply.String()}, ) // key := msg.MainType + msg.SubType // if v, ok := this.protoMsg[key]; ok && v != 0 { // 发送消息 协议解锁 // v = 0 // } if reply.ErrorData != nil { data, _ := anypb.New(&pb.NotifyErrorNotifyPush{ MsgId: msg.MsgId, ReqMainType: msg.MainType, ReqSubType: msg.SubType, Arg: msg.Data, Code: reply.ErrorData.Code, Err: reply.ErrorData}) err = this.WriteMsg(&pb.UserMessage{ MsgId: msg.MsgId, MainType: comm.MainTypeNotify, SubType: comm.SubTypeErrorNotify, Data: data, }) return } else { for _, v := range reply.Reply { if v.MainType == msg.MainType && v.SubType == msg.SubType { v.MsgId = msg.MsgId if msg.MainType == string(comm.ModuleUser) && msg.SubType == "login" { //登录回包 var ( resp proto.Message loginresp *pb.UserLoginResp ) if resp, err = v.Data.UnmarshalNew(); err != nil { return } loginresp = resp.(*pb.UserLoginResp) this.uId = loginresp.Data.Uid this.wId = reply.ServiceId this.group = loginresp.Data.Group this.gateway.LoginNotice(this) } } } if err = this.WriteMsgs(reply.Reply); err != nil { return } } return nil } //推送排队变化消息 func (this *Agent) PushQueueChange() { this.queueIndex-- if time.Now().Sub(this.lastpushtime).Seconds() < 1 { //间隔少于1秒 不发送 避免io爆炸 return } data, _ := anypb.New(&pb.UserLoginQueueChangePush{ Index: this.queueIndex, }) if err := this.WriteMsg(&pb.UserMessage{ MainType: string(comm.ModuleUser), SubType: "loginqueuechange", Data: data, }); err != nil { this.gateway.Errorf("pushQueueChange err:%v", err) } this.lastpushtime = time.Now() return }