package gateway import ( "context" "encoding/base64" "fmt" "go_dreamfactory/comm" "go_dreamfactory/pb" "go_dreamfactory/utils" "sync" "sync/atomic" "time" "go_dreamfactory/lego/sys/log" "go_dreamfactory/lego/utils/container/id" "github.com/gorilla/websocket" "github.com/tidwall/gjson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" ) /* 用户代理对象 封装用户socket 对象 处理用户的消息读取 写入 关闭等操作 */ func newAgent(gateway IGateway, conn *websocket.Conn) *Agent { agent := &Agent{ gateway: gateway, wsConn: conn, sessionId: id.NewUUId(), uId: "", writeChan: make(chan *pb.UserMessage, 2), closeSignal: make(chan bool), state: 1, } agent.wg.Add(2) go agent.readLoop() go agent.writeLoop() return agent } //用户代理 type Agent struct { gateway IGateway wsConn *websocket.Conn sessionId string uId string wId string writeChan chan *pb.UserMessage closeSignal chan bool state int32 //状态 0 关闭 1 运行 2 关闭中 wg sync.WaitGroup } func (this *Agent) readLoop() { defer this.wg.Done() var ( data []byte msg *pb.UserMessage = &pb.UserMessage{} err error ) locp: for { if _, data, err = this.wsConn.ReadMessage(); err != nil { log.Errorf("agent:%s uId:%s ReadMessage err:%v", this.sessionId, this.uId, err) go this.Close() break locp } if err = proto.Unmarshal(data, msg); err != nil { log.Errorf("agent:%s uId:%s Unmarshal err:%v", this.sessionId, this.uId, err) go this.Close() break locp } else { err = this.secAuth(msg) if err == nil { if err := this.messageDistribution(msg); err != nil { go this.Close() break locp } } else { data, _ := anypb.New(&pb.ErrorNotify{ReqMainType: msg.MainType, ReqSubType: msg.SubType, Code: pb.ErrorCode_SecKeyInvalid}) if err = this.WriteMsg(&pb.UserMessage{ MainType: comm.MainTypeNotify, SubType: comm.SubTypeErrorNotify, Data: data, }); err != nil { go this.Close() break locp } } } } log.Debugf("agent:%s uId:%s readLoop end!", this.sessionId, this.uId) } func (this *Agent) writeLoop() { defer this.wg.Done() var ( data []byte err error ) locp: for { select { case <-this.closeSignal: break locp case msg, ok := <-this.writeChan: if ok { data, err = proto.Marshal(msg) if err = this.wsConn.WriteMessage(websocket.BinaryMessage, data); err != nil { log.Errorf("agent:%s uId:%d WriteMessage err:%v", this.sessionId, this.uId, err) go this.Close() } } else { go this.Close() } } } log.Debugf("agent:%s uId:%s writeLoop end!", this.sessionId, this.uId) } //安全认证 所有协议 func (this *Agent) secAuth(msg *pb.UserMessage) error { if !utils.ValidSecretKey(msg.Sec) { //验证失败 return fmt.Errorf("key invalid") } return decodeUserData(msg) } //解码 func decodeUserData(msg *pb.UserMessage) error { base64Str := msg.Sec dec, err := base64.StdEncoding.DecodeString(base64Str[35:]) if err != nil { log.Errorf("base64 decode err %v", err) return nil } now := time.Now().Unix() jsonRet := gjson.Parse(string(dec)) timestamp := jsonRet.Get("timestamp").Int() //秘钥30秒失效 if now-time.Unix(timestamp, 0).Unix() > 30 { return fmt.Errorf("sec key expire") } //只有login的时候才需要设置Data if msg.MainType == "user" && msg.SubType == "login" { serverId := jsonRet.Get("serverId").Int() account := jsonRet.Get("account").String() req := &pb.UserLoginReq{ Account: account, Sid: int32(serverId), } ad, err := anypb.New(req) if err != nil { return err } msg.Data = ad } return nil } func (this *Agent) SessionId() string { return this.sessionId } func (this *Agent) IP() string { return this.wsConn.RemoteAddr().String() } func (this *Agent) UserId() string { return this.uId } func (this *Agent) WorkerId() string { return this.wId } func (this *Agent) Bind(uId string, wId string) { this.uId = uId this.wId = wId } func (this *Agent) UnBind() { this.uId = "" } func (this *Agent) WriteMsg(msg *pb.UserMessage) (err error) { if atomic.LoadInt32(&this.state) != 1 { return } this.writeChan <- msg return } //外部代用关闭 func (this *Agent) Close() { if !atomic.CompareAndSwapInt32(&this.state, 1, 2) { return } this.wsConn.Close() this.closeSignal <- true this.wg.Wait() atomic.StoreInt32(&this.state, 0) this.gateway.DisConnect(this) } //分发用户消息 func (this *Agent) messageDistribution(msg *pb.UserMessage) (err error) { reply := &pb.RPCMessageReply{} log.Debugf("agent:%s uId:%s MessageDistribution msg:%s.%s", this.sessionId, this.uId, msg.MainType, msg.SubType) servicePath := comm.Service_Worker if rule, ok := this.gateway.GetMsgDistribute(msg.MainType, msg.SubType); ok { servicePath = rule } else { if len(this.wId) > 0 { servicePath = fmt.Sprintf("%s/%s", comm.Service_Worker, this.wId) } } if err = this.gateway.Service().RpcCall(context.Background(), servicePath, string(comm.Rpc_GatewayRoute), &pb.AgentMessage{ Ip: this.IP(), UserSessionId: this.sessionId, UserId: this.uId, GatewayServiceId: this.gateway.Service().GetId(), MainType: msg.MainType, SubType: msg.SubType, Message: msg.Data, }, reply); err != nil { log.Errorf("agent:%s uId:%s MessageDistribution err:%v", this.sessionId, this.uId, err) return } if reply.Code != pb.ErrorCode_Success { data, _ := anypb.New(&pb.ErrorNotify{ReqMainType: msg.MainType, ReqSubType: msg.SubType, Code: pb.ErrorCode(reply.Code.Number())}) err = this.WriteMsg(&pb.UserMessage{ MainType: comm.MainTypeNotify, SubType: comm.SubTypeErrorNotify, Data: data, }) return } else { for _, v := range reply.Reply { if err = this.WriteMsg(v); err != nil { return } } } return nil }