go_dreamfactory/modules/gateway/agent.go

251 lines
5.9 KiB
Go

package gateway
import (
"context"
"encoding/base64"
"fmt"
"go_dreamfactory/comm"
"go_dreamfactory/pb"
"go_dreamfactory/utils"
"sync"
"sync/atomic"
"time"
"go_dreamfactory/lego/sys/log"
"go_dreamfactory/lego/utils/container/id"
"github.com/gorilla/websocket"
"github.com/tidwall/gjson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
)
/*
用户代理对象
封装用户socket 对象 处理用户的消息读取 写入 关闭等操作
*/
func newAgent(gateway IGateway, conn *websocket.Conn) *Agent {
agent := &Agent{
gateway: gateway,
wsConn: conn,
sessionId: id.NewUUId(),
uId: "",
writeChan: make(chan *pb.UserMessage, 2),
closeSignal: make(chan bool),
state: 1,
}
agent.wg.Add(2)
go agent.readLoop()
go agent.writeLoop()
return agent
}
//用户代理
type Agent struct {
gateway IGateway
wsConn *websocket.Conn
sessionId string
uId string
wId string
writeChan chan *pb.UserMessage
closeSignal chan bool
state int32 //状态 0 关闭 1 运行 2 关闭中
wg sync.WaitGroup
}
func (this *Agent) readLoop() {
defer this.wg.Done()
var (
data []byte
msg *pb.UserMessage = &pb.UserMessage{}
err error
)
locp:
for {
if _, data, err = this.wsConn.ReadMessage(); err != nil {
log.Errorf("agent:%s uId:%s ReadMessage err:%v", this.sessionId, this.uId, err)
go this.Close()
break locp
}
if err = proto.Unmarshal(data, msg); err != nil {
log.Errorf("agent:%s uId:%s Unmarshal err:%v", this.sessionId, this.uId, err)
go this.Close()
break locp
} else {
err = this.secAuth(msg)
if err == nil {
if err := this.messageDistribution(msg); err != nil {
go this.Close()
break locp
}
} else {
data, _ := anypb.New(&pb.ErrorNotify{ReqMainType: msg.MainType, ReqSubType: msg.SubType, Code: pb.ErrorCode_SecKeyInvalid})
if err = this.WriteMsg(&pb.UserMessage{
MainType: comm.MainTypeNotify,
SubType: comm.SubTypeErrorNotify,
Data: data,
}); err != nil {
go this.Close()
break locp
}
}
}
}
log.Debugf("agent:%s uId:%s readLoop end!", this.sessionId, this.uId)
}
func (this *Agent) writeLoop() {
defer this.wg.Done()
var (
data []byte
err error
)
locp:
for {
select {
case <-this.closeSignal:
break locp
case msg, ok := <-this.writeChan:
if ok {
data, err = proto.Marshal(msg)
if err = this.wsConn.WriteMessage(websocket.BinaryMessage, data); err != nil {
log.Errorf("agent:%s uId:%d WriteMessage err:%v", this.sessionId, this.uId, err)
go this.Close()
}
} else {
go this.Close()
}
}
}
log.Debugf("agent:%s uId:%s writeLoop end!", this.sessionId, this.uId)
}
//安全认证 所有协议
func (this *Agent) secAuth(msg *pb.UserMessage) error {
if !utils.ValidSecretKey(msg.Sec) { //验证失败
return fmt.Errorf("key invalid")
}
return decodeUserData(msg)
}
//解码
func decodeUserData(msg *pb.UserMessage) error {
base64Str := msg.Sec
dec, err := base64.StdEncoding.DecodeString(base64Str[35:])
if err != nil {
log.Errorf("base64 decode err %v", err)
return nil
}
now := time.Now().Unix()
jsonRet := gjson.Parse(string(dec))
timestamp := jsonRet.Get("timestamp").Int()
//秘钥30秒失效
if now-time.Unix(timestamp, 0).Unix() > 30 {
return fmt.Errorf("sec key expire")
}
//只有login的时候才需要设置Data
if msg.MainType == "user" && msg.SubType == "login" {
serverId := jsonRet.Get("serverId").Int()
account := jsonRet.Get("account").String()
req := &pb.UserLoginReq{
Account: account,
Sid: int32(serverId),
}
ad, err := anypb.New(req)
if err != nil {
return err
}
msg.Data = ad
}
return nil
}
func (this *Agent) SessionId() string {
return this.sessionId
}
func (this *Agent) IP() string {
return this.wsConn.RemoteAddr().String()
}
func (this *Agent) UserId() string {
return this.uId
}
func (this *Agent) WorkerId() string {
return this.wId
}
func (this *Agent) Bind(uId string, wId string) {
this.uId = uId
this.wId = wId
}
func (this *Agent) UnBind() {
this.uId = ""
}
func (this *Agent) WriteMsg(msg *pb.UserMessage) (err error) {
if atomic.LoadInt32(&this.state) != 1 {
return
}
this.writeChan <- msg
return
}
//外部代用关闭
func (this *Agent) Close() {
if !atomic.CompareAndSwapInt32(&this.state, 1, 2) {
return
}
this.wsConn.Close()
this.closeSignal <- true
this.wg.Wait()
atomic.StoreInt32(&this.state, 0)
this.gateway.DisConnect(this)
}
//分发用户消息
func (this *Agent) messageDistribution(msg *pb.UserMessage) (err error) {
reply := &pb.RPCMessageReply{}
log.Debugf("agent:%s uId:%s MessageDistribution msg:%s.%s", this.sessionId, this.uId, msg.MainType, msg.SubType)
servicePath := comm.Service_Worker
if rule, ok := this.gateway.GetMsgDistribute(msg.MainType, msg.SubType); ok {
servicePath = rule
} else {
if len(this.wId) > 0 {
servicePath = fmt.Sprintf("%s/%s", comm.Service_Worker, this.wId)
}
}
if err = this.gateway.Service().RpcCall(context.Background(), servicePath, string(comm.Rpc_GatewayRoute), &pb.AgentMessage{
Ip: this.IP(),
UserSessionId: this.sessionId,
UserId: this.uId,
GatewayServiceId: this.gateway.Service().GetId(),
MainType: msg.MainType,
SubType: msg.SubType,
Message: msg.Data,
}, reply); err != nil {
log.Errorf("agent:%s uId:%s MessageDistribution err:%v", this.sessionId, this.uId, err)
return
}
if reply.Code != pb.ErrorCode_Success {
data, _ := anypb.New(&pb.ErrorNotify{ReqMainType: msg.MainType, ReqSubType: msg.SubType, Code: pb.ErrorCode(reply.Code.Number())})
err = this.WriteMsg(&pb.UserMessage{
MainType: comm.MainTypeNotify,
SubType: comm.SubTypeErrorNotify,
Data: data,
})
return
} else {
for _, v := range reply.Reply {
if err = this.WriteMsg(v); err != nil {
return
}
}
}
return nil
}