package gateway import ( "context" "encoding/base64" "fmt" "go_dreamfactory/comm" "go_dreamfactory/pb" "go_dreamfactory/sys/configure" "go_dreamfactory/utils" "strings" "sync" "sync/atomic" "time" "go_dreamfactory/lego/sys/log" "go_dreamfactory/lego/utils/container/id" "github.com/gorilla/websocket" "github.com/tidwall/gjson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" ) /* 用户代理对象 封装用户socket 对象 处理用户的消息读取 写入 关闭等操作 */ func newAgent(gateway IGateway, conn *websocket.Conn) *Agent { agent := &Agent{ gateway: gateway, wsConn: conn, sessionId: id.NewUUId(), uId: "", writeChan: make(chan []byte, 2), closeSignal: make(chan bool), state: 1, } agent.wg.Add(2) go agent.readLoop() go agent.writeLoop() return agent } //用户代理 type Agent struct { gateway IGateway wsConn *websocket.Conn sessionId string uId string wId string writeChan chan []byte closeSignal chan bool state int32 //状态 0 关闭 1 运行 2 关闭中 wg sync.WaitGroup } func (this *Agent) readLoop() { defer this.wg.Done() var ( data []byte msg *pb.UserMessage = &pb.UserMessage{} err error ) locp: for { if _, data, err = this.wsConn.ReadMessage(); err != nil { this.gateway.Errorf("agent:%s uId:%s ReadMessage err:%v", this.sessionId, this.uId, err) go this.Close() break locp } if err = proto.Unmarshal(data, msg); err != nil { this.gateway.Errorf("agent:%s uId:%s Unmarshal err:%v", this.sessionId, this.uId, err) go this.Close() break locp } else { var code pb.ErrorCode code, err = this.secAuth(msg) if err == nil { if msg.MainType == string(comm.ModuleGate) { //心跳消息 data, _ := anypb.New(&pb.GatewayHeartbeatResp{ Timestamp: configure.Now().Unix(), }) this.WriteMsg(&pb.UserMessage{ MainType: string(comm.ModuleGate), SubType: "heartbeat", Data: data, }) this.wsConn.SetReadDeadline(time.Now().Add(time.Second * 30)) continue } if err := this.messageDistribution(msg); err != nil { this.gateway.Errorf("messageDistribution err:%v", err) go this.Close() break locp } } else { this.gateway.Errorf("agent:%s uId:%s 密钥无效 err:%v", this.sessionId, this.uId, err) data, _ := anypb.New(&pb.NotifyErrorNotifyPush{ReqMainType: msg.MainType, ReqSubType: msg.SubType, Code: code, Message: err.Error()}) if err = this.WriteMsg(&pb.UserMessage{ MainType: comm.MainTypeNotify, SubType: comm.SubTypeErrorNotify, Data: data, }); err != nil { go this.Close() break locp } } } } this.gateway.Debugf("agent:%s uId:%s readLoop end!", this.sessionId, this.uId) } func (this *Agent) writeLoop() { defer this.wg.Done() var ( // data []byte err error ) locp: for { select { case <-this.closeSignal: break locp case msg, ok := <-this.writeChan: if ok { // data, err = proto.Marshal(msg) if err = this.wsConn.WriteMessage(websocket.BinaryMessage, msg); err != nil { this.gateway.Errorf("agent:%s uId:%d WriteMessage err:%v", this.sessionId, this.uId, err) go this.Close() } } else { go this.Close() } } } this.gateway.Debugf("agent:%s uId:%s writeLoop end!", this.sessionId, this.uId) } //安全认证 所有协议 func (this *Agent) secAuth(msg *pb.UserMessage) (code pb.ErrorCode, err error) { if !utils.ValidSecretKey(msg.Sec) { //验证失败 this.gateway.Errorf("%v", msg.Sec) return pb.ErrorCode_SignError, fmt.Errorf("key invalid") } return this.decodeUserData(msg) } //解码 func (this *Agent) decodeUserData(msg *pb.UserMessage) (code pb.ErrorCode, err error) { base64Str := msg.Sec dec, err := base64.StdEncoding.DecodeString(base64Str[35:]) if err != nil { this.gateway.Errorf("base64 decode err %v", err) return pb.ErrorCode_DecodeError, nil } now := configure.Now().Unix() jsonRet := gjson.Parse(string(dec)) timestamp := jsonRet.Get("timestamp").Int() //秘钥30秒失效 if now-time.Unix(timestamp, 0).Unix() > 30 { this.gateway.Errorf("last timestamp:%v more than 30s", timestamp) return pb.ErrorCode_TimestampTimeout, fmt.Errorf("sec key expire") } //只有login的时候才需要设置Data if msg.MainType == string(comm.ModuleUser) && msg.SubType == "login" { serverId := jsonRet.Get("serverId").String() account := jsonRet.Get("account").String() req := &pb.UserLoginReq{ Account: account, Sid: serverId, } ad, err := anypb.New(req) if err != nil { this.gateway.Errorf("decodeUserData pb err:%v", err) return pb.ErrorCode_PbError, err } msg.Data = ad } else { switch msg.MainType { case string(comm.ModuleNotify), string(comm.ModuleGate): return pb.ErrorCode_Success, nil default: if this.UserId() == "" { this.gateway.Errorf("[%v.%v] Agent UId empty", msg.MainType, msg.SubType) return pb.ErrorCode_AgentUidEmpty, fmt.Errorf("no login") } } } return pb.ErrorCode_Success, nil } func (this *Agent) SessionId() string { return this.sessionId } func (this *Agent) IP() string { return this.wsConn.RemoteAddr().String() } func (this *Agent) UserId() string { return this.uId } func (this *Agent) WorkerId() string { return this.wId } func (this *Agent) Bind(uId string, wId string) { this.uId = uId this.wId = wId } func (this *Agent) UnBind() { this.uId = "" } func (this *Agent) WriteMsg(msg *pb.UserMessage) (err error) { if atomic.LoadInt32(&this.state) != 1 { return } var ( data []byte ) if data, err = proto.Marshal(msg); err == nil { this.writeChan <- data } return } func (this *Agent) WriteBytes(data []byte) (err error) { if atomic.LoadInt32(&this.state) != 1 { return } this.writeChan <- data return } //外部代用关闭 func (this *Agent) Close() { if !atomic.CompareAndSwapInt32(&this.state, 1, 2) { return } this.wsConn.Close() this.closeSignal <- true this.wg.Wait() atomic.StoreInt32(&this.state, 0) this.gateway.DisConnect(this) } //分发用户消息 func (this *Agent) messageDistribution(msg *pb.UserMessage) (err error) { var ( req *pb.AgentMessage = getmsg() reply *pb.RPCMessageReply = getmsgreply() serviceTag string = "" servicePath string = comm.Service_Worker rule string ok bool ) defer func() { putmsg(req) putmsgreply(reply) }() req.Ip = this.IP() req.UserSessionId = this.sessionId req.UserId = this.uId req.ServiceTag = this.gateway.Service().GetTag() req.GatewayServiceId = this.gateway.Service().GetId() req.MainType = msg.MainType req.SubType = msg.SubType req.Message = msg.Data if rule, ok = this.gateway.GetMsgDistribute(req.MainType, req.SubType); ok { paths := strings.Split(rule, "/") if len(paths) == 3 { if paths[0] == "~" { serviceTag = this.gateway.CrossServiceTag() } else { serviceTag = paths[0] } servicePath = fmt.Sprintf("%s/%s", paths[1], paths[2]) } else if len(paths) == 2 { if paths[0] == "~" { serviceTag = this.gateway.CrossServiceTag() servicePath = paths[1] } else { servicePath = rule } } else { this.gateway.Errorf("messageDistribution mtype:%s stype:%s rule:%s is empty!", req.MainType, req.SubType, rule) return } } else { if len(this.wId) > 0 { //已经绑定worker 服务器 servicePath = fmt.Sprintf("%s/%s", comm.Service_Worker, this.wId) } } stime := time.Now() if len(serviceTag) == 0 { if err = this.gateway.Service().RpcCall(context.Background(), servicePath, string(comm.Rpc_GatewayRoute), req, reply); err != nil { this.gateway.Error("[UserResponse]", log.Field{Key: "uid", Value: this.uId}, log.Field{Key: "serviceTag", Value: serviceTag}, log.Field{Key: "servicePath", Value: servicePath}, log.Field{Key: "servicePath", Value: servicePath}, log.Field{Key: "req", Value: fmt.Sprintf("%s:%s %v", req.MainType, req.SubType, req.Message.String())}, log.Field{Key: "err", Value: err.Error()}, ) return } } else { //跨集群调用 if err = this.gateway.Service().AcrossClusterRpcCall(context.Background(), serviceTag, servicePath, string(comm.Rpc_GatewayRoute), req, reply); err != nil { this.gateway.Error("[UserResponse]", log.Field{Key: "uid", Value: this.uId}, log.Field{Key: "serviceTag", Value: serviceTag}, log.Field{Key: "servicePath", Value: servicePath}, log.Field{Key: "req", Value: fmt.Sprintf("%s:%s %v", req.MainType, req.SubType, req.Message.String())}, log.Field{Key: "err", Value: err.Error()}, ) return } } this.gateway.Debug("[UserResponse]", log.Field{Key: "t", Value: time.Since(stime).Milliseconds()}, log.Field{Key: "uid", Value: this.uId}, log.Field{Key: "req", Value: fmt.Sprintf("%s:%s %v", req.MainType, req.SubType, req.Message.String())}, log.Field{Key: "reply", Value: reply.String()}, ) if reply.Code != pb.ErrorCode_Success { data, _ := anypb.New(&pb.NotifyErrorNotifyPush{ ReqMainType: msg.MainType, ReqSubType: msg.SubType, Arg: msg.Data, Code: pb.ErrorCode(reply.Code.Number())}) err = this.WriteMsg(&pb.UserMessage{ MainType: comm.MainTypeNotify, SubType: comm.SubTypeErrorNotify, Data: data, }) return } else { for _, v := range reply.Reply { if err = this.WriteMsg(v); err != nil { return } } } return nil }