go_dreamfactory/modules/gateway/agent.go
2023-06-06 13:55:33 +08:00

425 lines
12 KiB
Go

package gateway
import (
"context"
"encoding/base64"
"fmt"
"go_dreamfactory/comm"
"go_dreamfactory/pb"
"go_dreamfactory/sys/configure"
"go_dreamfactory/sys/db"
"go_dreamfactory/utils"
"strings"
"sync"
"sync/atomic"
"time"
"go_dreamfactory/lego/sys/log"
"go_dreamfactory/lego/utils/container/id"
"github.com/gorilla/websocket"
"github.com/tidwall/gjson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
)
/*
用户代理对象
封装用户socket 对象 处理用户的消息读取 写入 关闭等操作
*/
func newAgent(gateway IGateway, conn *websocket.Conn) *Agent {
agent := &Agent{
gateway: gateway,
wsConn: conn,
sessionId: id.NewUUId(),
uId: "",
writeChan: make(chan []byte, 2),
closeSignal: make(chan bool),
state: 1,
protoMsg: make(map[string]int64, 0),
}
agent.wg.Add(2)
go agent.readLoop()
go agent.writeLoop()
return agent
}
// 用户代理
type Agent struct {
gateway IGateway
wsConn *websocket.Conn
sessionId string
uId string
wId string
writeChan chan []byte
closeSignal chan bool
state int32 //状态 0 关闭 1 运行 2 关闭中
wg sync.WaitGroup
protoMsg map[string]int64
}
func (this *Agent) readLoop() {
defer this.wg.Done()
var (
data []byte
msg *pb.UserMessage = &pb.UserMessage{}
err error
)
locp:
for {
if _, data, err = this.wsConn.ReadMessage(); err != nil {
this.gateway.Errorf("agent:%s uId:%s ReadMessage err:%v", this.sessionId, this.uId, err)
go this.Close()
break locp
}
if err = proto.Unmarshal(data, msg); err != nil {
this.gateway.Errorf("agent:%s uId:%s Unmarshal err:%v", this.sessionId, this.uId, err)
go this.Close()
break locp
} else {
this.wsConn.SetReadDeadline(time.Now().Add(time.Second * 60))
// this.gateway.Debugf("----------1 agent:%s uId:%s MainType:%s SubType:%s ", this.sessionId, this.uId, msg.MainType, msg.SubType)
if msg.MainType == string(comm.ModuleGate) { //心跳消息 无需校验秘钥
data, _ := anypb.New(&pb.GatewayHeartbeatResp{
Timestamp: configure.Now().Unix(),
})
this.WriteMsg(&pb.UserMessage{
MsgId: msg.MsgId,
MainType: string(comm.ModuleGate),
SubType: "heartbeat",
Data: data,
})
continue
}
var errdata *pb.ErrorData
errdata = this.secAuth(msg)
if errdata == nil {
// this.gateway.Debugf("----------2 agent:%s uId:%s MainType:%s SubType:%s ", this.sessionId, this.uId, msg.MainType, msg.SubType)
if err = this.messageDistribution(msg); err != nil {
this.gateway.Errorf("messageDistribution err:%v", err)
data, _ := anypb.New(&pb.NotifyErrorNotifyPush{
MsgId: msg.MsgId,
ReqMainType: msg.MainType,
ReqSubType: msg.SubType,
Arg: msg.Data,
Code: pb.ErrorCode_GatewayException,
Err: &pb.ErrorData{Title: "用户消息处理失败!", Datastring: err.Error()},
})
err = this.WriteMsg(&pb.UserMessage{
MsgId: msg.MsgId,
MainType: comm.MainTypeNotify,
SubType: comm.SubTypeErrorNotify,
Data: data,
})
go this.Close()
break locp
}
} else {
this.gateway.Errorf("agent:%s uId:%s 密钥无效 err:%v", this.sessionId, this.uId, err)
data, _ := anypb.New(&pb.NotifyErrorNotifyPush{
MsgId: msg.MsgId,
ReqMainType: msg.MainType,
ReqSubType: msg.SubType,
Code: errdata.Code,
Err: errdata,
})
if err = this.WriteMsg(&pb.UserMessage{
MsgId: msg.MsgId,
MainType: comm.MainTypeNotify,
SubType: comm.SubTypeErrorNotify,
Data: data,
}); err != nil {
go this.Close()
break locp
}
}
}
}
this.gateway.Debugf("agent:%s uId:%s readLoop end!", this.sessionId, this.uId)
}
func (this *Agent) writeLoop() {
defer this.wg.Done()
var (
// data []byte
err error
)
locp:
for {
select {
case <-this.closeSignal:
break locp
case msg, ok := <-this.writeChan:
if ok {
//data, err = proto.Marshal(msg)
if err = this.wsConn.WriteMessage(websocket.BinaryMessage, msg); err != nil {
this.gateway.Errorf("agent:%s uId:%d WriteMessage err:%v", this.sessionId, this.uId, err)
go this.Close()
}
} else {
go this.Close()
}
}
}
this.gateway.Debugf("agent:%s uId:%s writeLoop end!", this.sessionId, this.uId)
}
// 安全认证 所有协议
func (this *Agent) secAuth(msg *pb.UserMessage) (errdata *pb.ErrorData) {
if !utils.ValidSecretKey(msg.Sec) { //验证失败
this.gateway.Errorf("%v", msg.Sec)
errdata = &pb.ErrorData{
Code: pb.ErrorCode_SignError,
Title: pb.ErrorCode_SignError.ToString(),
Message: "key invalid",
}
return
}
return this.decodeUserData(msg)
}
// 解码
func (this *Agent) decodeUserData(msg *pb.UserMessage) (errdata *pb.ErrorData) {
base64Str := msg.Sec
dec, err := base64.StdEncoding.DecodeString(base64Str[35:])
if err != nil {
this.gateway.Errorf("base64 decode err %v", err)
errdata = &pb.ErrorData{
Code: pb.ErrorCode_DecodeError,
Title: pb.ErrorCode_DecodeError.ToString(),
}
return
}
now := configure.Now().Unix()
jsonRet := gjson.Parse(string(dec))
timestamp := jsonRet.Get("timestamp").Int()
//秘钥30秒失效
if now-time.Unix(timestamp, 0).Unix() > 30 {
this.gateway.Errorf("now:%v last timestamp:%v more than 30s", now, timestamp)
errdata = &pb.ErrorData{
Code: pb.ErrorCode_TimestampTimeout,
Title: pb.ErrorCode_TimestampTimeout.ToString(),
Message: "sec key expire",
}
return
}
//只有login的时候才需要设置Data
if msg.MainType == string(comm.ModuleUser) && msg.SubType == "login" {
serverId := jsonRet.Get("serverId").String()
account := jsonRet.Get("account").String()
req := &pb.UserLoginReq{
Account: account,
Sid: serverId,
}
ad, err := anypb.New(req)
if err != nil {
this.gateway.Errorf("decodeUserData pb err:%v", err)
errdata = &pb.ErrorData{
Code: pb.ErrorCode_PbError,
Title: pb.ErrorCode_PbError.ToString(),
Message: err.Error(),
}
return
}
msg.Data = ad
} else {
switch msg.MainType {
case string(comm.ModuleNotify), string(comm.ModuleGate):
return
default:
if this.UserId() == "" {
this.gateway.Errorf("[%v.%v] Agent UId empty", msg.MainType, msg.SubType)
errdata = &pb.ErrorData{
Code: pb.ErrorCode_AgentUidEmpty,
Title: pb.ErrorCode_AgentUidEmpty.ToString(),
Message: "no login",
}
return
}
}
}
return
}
func (this *Agent) SessionId() string {
return this.sessionId
}
func (this *Agent) IP() string {
return this.wsConn.RemoteAddr().String()
}
func (this *Agent) UserId() string {
return this.uId
}
func (this *Agent) WorkerId() string {
return this.wId
}
func (this *Agent) Bind(uId string, wId string) {
this.uId = uId
this.wId = wId
}
func (this *Agent) UnBind() {
this.uId = ""
}
func (this *Agent) WriteMsg(msg *pb.UserMessage) (err error) {
if atomic.LoadInt32(&this.state) != 1 {
return
}
var (
data []byte
)
if data, err = proto.Marshal(msg); err == nil {
this.writeChan <- data
}
return
}
func (this *Agent) WriteBytes(data []byte) (err error) {
if atomic.LoadInt32(&this.state) != 1 {
err = fmt.Errorf("Uid%s Staet:%d", this.uId, this.state)
return
}
this.writeChan <- data
return
}
// 外部代用关闭
func (this *Agent) Close() {
if !atomic.CompareAndSwapInt32(&this.state, 1, 2) {
return
}
this.wsConn.Close()
this.closeSignal <- true
this.wg.Wait()
atomic.StoreInt32(&this.state, 0)
this.gateway.DisConnect(this)
}
// 分发用户消息
func (this *Agent) messageDistribution(msg *pb.UserMessage) (err error) {
var (
req *pb.AgentMessage = getmsg()
reply *pb.RPCMessageReply = getmsgreply()
serviceTag string = ""
servicePath string = comm.Service_Worker
rule string
ok bool
)
defer func() {
putmsg(req)
putmsgreply(reply)
}()
req.Ip = this.IP()
req.UserSessionId = this.sessionId
req.UserId = this.uId
req.ServiceTag = this.gateway.Service().GetTag()
req.GatewayServiceId = this.gateway.Service().GetId()
req.MainType = msg.MainType
req.SubType = msg.SubType
req.Message = msg.Data
if rule, ok = this.gateway.GetMsgDistribute(req.MainType, req.SubType); ok {
paths := strings.Split(rule, "/")
if len(paths) == 3 {
if paths[0] == "~" {
serviceTag = db.CrossTag()
} else {
serviceTag = paths[0]
}
servicePath = fmt.Sprintf("%s/%s", paths[1], paths[2])
} else if len(paths) == 2 {
if paths[0] == "~" {
serviceTag = db.CrossTag()
servicePath = paths[1]
} else {
servicePath = rule
}
} else {
this.gateway.Errorf("messageDistribution mtype:%s stype:%s rule:%s is empty!", req.MainType, req.SubType, rule)
return
}
} else {
if len(this.wId) > 0 { //已经绑定worker 服务器
servicePath = fmt.Sprintf("%s/%s", comm.Service_Worker, this.wId)
}
}
stime := time.Now()
// this.gateway.Debugf("----------3 agent:%s uId:%s MainType:%s SubType:%s ", this.sessionId, this.uId, msg.MainType, msg.SubType)
// ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
ctx := context.Background()
if len(serviceTag) == 0 {
// this.gateway.Debugf("----------4 agent:%s uId:%s MainType:%s SubType:%s ", this.sessionId, this.uId, msg.MainType, msg.SubType)
if err = this.gateway.Service().RpcCall(ctx, servicePath, string(comm.Rpc_GatewayRoute), req, reply); err != nil {
this.gateway.Error("[UserResponse]",
log.Field{Key: "uid", Value: this.uId},
log.Field{Key: "serviceTag", Value: serviceTag},
log.Field{Key: "servicePath", Value: servicePath},
log.Field{Key: "req", Value: fmt.Sprintf("%s:%s %v", req.MainType, req.SubType, req.Message.String())},
log.Field{Key: "err", Value: err.Error()},
)
return
}
} else { //跨集群调用
// this.gateway.Debugf("----------5 agent:%s uId:%s servicePath:%s MainType:%s SubType:%s ", this.sessionId, this.uId, msg.ServicePath, msg.MainType, msg.SubType)
if msg.ServicePath != "" { //客户端是否制定目标服务器 /wroker/woker0
servicePath = msg.ServicePath
}
if err = this.gateway.Service().AcrossClusterRpcCall(ctx, serviceTag, servicePath, string(comm.Rpc_GatewayRoute), req, reply); err != nil {
this.gateway.Error("[UserResponse]",
log.Field{Key: "uid", Value: this.uId},
log.Field{Key: "serviceTag", Value: serviceTag},
log.Field{Key: "servicePath", Value: servicePath},
log.Field{Key: "req", Value: fmt.Sprintf("%s:%s %v", req.MainType, req.SubType, req.Message.String())},
log.Field{Key: "err", Value: err.Error()},
)
return
}
}
this.gateway.Debug("[UserResponse]",
log.Field{Key: "t", Value: time.Since(stime).Milliseconds()},
log.Field{Key: "uid", Value: this.uId},
log.Field{Key: "req", Value: fmt.Sprintf("%s:%s %v", req.MainType, req.SubType, req.Message.String())},
log.Field{Key: "reply", Value: reply.String()},
)
// key := msg.MainType + msg.SubType
// if v, ok := this.protoMsg[key]; ok && v != 0 { // 发送消息 协议解锁
// v = 0
// }
if reply.ErrorData != nil {
data, _ := anypb.New(&pb.NotifyErrorNotifyPush{
MsgId: msg.MsgId,
ReqMainType: msg.MainType,
ReqSubType: msg.SubType,
Arg: msg.Data,
Code: reply.ErrorData.Code,
Err: reply.ErrorData})
err = this.WriteMsg(&pb.UserMessage{
MsgId: msg.MsgId,
MainType: comm.MainTypeNotify,
SubType: comm.SubTypeErrorNotify,
Data: data,
})
return
} else {
for _, v := range reply.Reply {
if v.MainType == msg.MainType && v.SubType == msg.SubType {
v.MsgId = msg.MsgId
}
if err = this.WriteMsg(v); err != nil {
return
}
}
}
return nil
}