go_dreamfactory/lego/sys/rpcx/client.go
2022-11-10 20:05:35 +08:00

606 lines
17 KiB
Go

package rpcx
import (
"context"
"errors"
"fmt"
"net"
"reflect"
"runtime"
"strings"
"sync"
"time"
"unicode"
"unicode/utf8"
"github.com/smallnest/rpcx/client"
"github.com/smallnest/rpcx/protocol"
"github.com/smallnest/rpcx/share"
)
func newClient(options *Options) (sys *Client, err error) {
sys = &Client{
options: options,
metadata: fmt.Sprintf("stag=%s&stype=%s&sid=%s&version=%s&addr=%s", options.ServiceTag, options.ServiceType, options.ServiceId, options.ServiceVersion, "tcp@"+options.ServiceAddr),
clusterClients: make(map[string]*clusterClients),
conns: make(map[string]net.Conn),
serviceMap: make(map[string]*service),
msgChan: make(chan *protocol.Message, 1000),
}
return
}
type clusterClients struct {
Mu sync.RWMutex
clients map[string]client.XClient //其他集群客户端
}
type Client struct {
options *Options
metadata string
writeTimeout time.Duration
AsyncWrite bool
clusterMu sync.RWMutex
clusterClients map[string]*clusterClients //其他集群客户端
connsMapMu sync.RWMutex
conns map[string]net.Conn
// connectMapMu sync.RWMutex
// connecting map[string]struct{}
serviceMapMu sync.RWMutex
serviceMap map[string]*service
msgChan chan *protocol.Message // 接收rpcXServer推送消息
}
// DoMessage 服务端消息处理
func (this *Client) DoMessage() {
for msg := range this.msgChan {
go func(req *protocol.Message) {
if req.ServicePath != "" && req.ServiceMethod != "" {
this.options.Log.Debugf("DoMessage :%v", req)
addr, ok := req.Metadata[ServiceAddrKey]
if !ok {
this.options.Log.Errorf("Metadata no found ServiceAddrKey!")
return
}
conn, ok := this.conns[addr]
if !ok {
this.options.Log.Errorf("no found conn addr:%s", addr)
return
}
res, _ := this.handleRequest(context.Background(), req)
this.sendResponse(conn, req, res)
}
}(msg)
}
}
//启动RPC 服务 接收消息处理
func (this *Client) Start() (err error) {
go this.DoMessage()
return
}
//停止RPC 服务
func (this *Client) Stop() (err error) {
this.clusterMu.Lock()
for _, v := range this.clusterClients {
v.Mu.Lock()
for _, v1 := range v.clients {
v1.Close()
}
v.Mu.Unlock()
}
this.clusterMu.RUnlock()
close(this.msgChan) //关闭消息处理
return
}
//获取服务集群列表
func (this *Client) GetServiceTags() []string {
this.clusterMu.RLock()
tags := make([]string, len(this.clusterClients))
n := 0
for k, _ := range this.clusterClients {
tags[n] = k
n++
}
this.clusterMu.RUnlock()
return tags
}
//注册Rpc 服务
func (this *Client) RegisterFunction(fn interface{}) (err error) {
_, err = this.registerFunction(this.options.ServiceType, fn, "", false)
if err != nil {
return err
}
return
}
//注册Rpc 服务
func (this *Client) RegisterFunctionName(name string, fn interface{}) (err error) {
_, err = this.registerFunction(this.options.ServiceType, fn, name, true)
if err != nil {
return err
}
return
}
//注销 暂不处理
func (this *Client) UnregisterAll() (err error) {
return nil
}
//同步调用
func (this *Client) Call(ctx context.Context, servicePath string, serviceMethod string, args interface{}, reply interface{}) (err error) {
var (
_client client.XClient
)
if _client, err = this.getclient(&ctx, this.options.ServiceTag, servicePath); err != nil {
return
}
err = _client.Call(ctx, serviceMethod, args, reply)
return
}
//异步调用
func (this *Client) Go(ctx context.Context, servicePath string, serviceMethod string, args interface{}, reply interface{}, done chan *client.Call) (call *client.Call, err error) {
var (
_client client.XClient
)
if _client, err = this.getclient(&ctx, this.options.ServiceTag, servicePath); err != nil {
return
}
return _client.Go(ctx, string(serviceMethod), args, reply, done)
}
//异步调用
func (this *Client) Broadcast(ctx context.Context, servicePath string, serviceMethod string, args interface{}, reply interface{}) (err error) {
var (
_client client.XClient
)
if _client, err = this.getclient(&ctx, this.options.ServiceTag, servicePath); err != nil {
return
}
err = _client.Broadcast(ctx, serviceMethod, args, reply)
return
}
//跨集群 同步调用
func (this *Client) AcrossClusterCall(ctx context.Context, clusterTag string, servicePath string, serviceMethod string, args interface{}, reply interface{}) (err error) {
var (
_client client.XClient
)
if _client, err = this.getclient(&ctx, clusterTag, servicePath); err != nil {
return
}
err = _client.Call(ctx, serviceMethod, args, reply)
return
}
//跨集群 异步调用
func (this *Client) AcrossClusterGo(ctx context.Context, clusterTag string, servicePath string, serviceMethod string, args interface{}, reply interface{}, done chan *client.Call) (call *client.Call, err error) {
var (
_client client.XClient
)
if _client, err = this.getclient(&ctx, clusterTag, servicePath); err != nil {
return
}
return _client.Go(ctx, string(serviceMethod), args, reply, done)
}
//跨集群 广播
func (this *Client) AcrossClusterBroadcast(ctx context.Context, clusterTag string, servicePath string, serviceMethod string, args interface{}, reply interface{}) (err error) {
var (
_client client.XClient
)
if _client, err = this.getclient(&ctx, clusterTag, servicePath); err != nil {
return
}
return _client.Broadcast(ctx, serviceMethod, args, reply)
}
func (this *Client) ClusterBroadcast(ctx context.Context, servicePath string, serviceMethod string, args interface{}, reply interface{}) (err error) {
if servicePath == "" {
err = errors.New("servicePath no cant null")
return
}
var (
spath []string
clients []client.XClient
)
spath = strings.Split(servicePath, "/")
ctx = context.WithValue(ctx, share.ReqMetaDataKey, map[string]string{
ServiceClusterTag: this.options.ServiceTag,
CallRoutRulesKey: servicePath,
ServiceAddrKey: "tcp@" + this.options.ServiceAddr,
ServiceMetaKey: this.metadata,
})
clients = make([]client.XClient, 0)
this.clusterMu.RLock()
for _, v := range this.clusterClients {
v.Mu.RLock()
if _client, ok := v.clients[spath[0]]; ok {
clients = append(clients, _client)
}
v.Mu.RUnlock()
}
this.clusterMu.RUnlock()
l := len(clients)
if l > 0 {
done := make(chan error, l)
for _, v := range clients {
go func(c client.XClient) {
done <- c.Broadcast(ctx, serviceMethod, args, reply)
}(v)
}
timeout := time.NewTimer(time.Minute)
check:
for {
select {
case err = <-done:
l--
if l == 0 || err != nil { // all returns or some one returns an error
break check
}
case <-timeout.C:
err = errors.New(("timeout"))
break check
}
}
timeout.Stop()
} else {
err = errors.New("on found any service")
}
return
}
//监控服务发现,发现没有连接上的额服务端 就连接上去
func (this *Client) UpdateServer(servers map[string]*ServiceNode) {
for _, v := range servers {
this.clusterMu.RLock()
cluster, ok := this.clusterClients[v.ServiceTag]
this.clusterMu.RUnlock()
if ok {
cluster.Mu.RLock()
_, ok = cluster.clients[v.ServiceType]
cluster.Mu.RUnlock()
if ok {
continue
}
}
//没有建立客户端 主动发起握手
if err := this.Call(context.Background(), fmt.Sprintf("%s/%s", v.ServiceType, v.ServiceId), RpcX_ShakeHands, &ServiceNode{
ServiceTag: this.options.ServiceTag,
ServiceId: this.options.ServiceId,
ServiceType: this.options.ServiceType,
ServiceAddr: this.options.ServiceAddr},
&ServiceNode{}); err != nil {
this.options.Log.Errorf("ShakeHands new node addr:%s err:%v", v.ServiceAddr, err)
} else {
this.options.Log.Debugf("UpdateServer addr:%s ", v.ServiceAddr)
}
}
}
//监控连接建立
func (this *Client) ClientConnected(conn net.Conn) (net.Conn, error) {
addr := "tcp@" + conn.RemoteAddr().String()
this.connsMapMu.Lock()
this.conns[addr] = conn
this.connsMapMu.Unlock()
this.options.Log.Debugf("ClientConnected addr:%v", addr)
return conn, nil
}
//监听连接关闭
func (this *Client) ClientConnectionClose(conn net.Conn) error {
addr := "tcp@" + conn.RemoteAddr().String()
this.connsMapMu.Lock()
delete(this.conns, addr)
this.connsMapMu.Unlock()
this.options.Log.Debugf("ClientConnectionClose addr:%v", addr)
return nil
}
//获取目标客户端
func (this *Client) getclient(ctx *context.Context, clusterTag string, servicePath string) (c client.XClient, err error) {
if servicePath == "" {
err = errors.New("servicePath no cant null")
return
}
var (
spath []string
cluster *clusterClients
d *client.ConsulDiscovery
ok bool
)
spath = strings.Split(servicePath, "/")
this.clusterMu.RLock()
cluster, ok = this.clusterClients[clusterTag]
this.clusterMu.RUnlock()
if !ok {
cluster = &clusterClients{clients: make(map[string]client.XClient)}
this.clusterMu.Lock()
this.clusterClients[clusterTag] = cluster
this.clusterMu.Unlock()
}
cluster.Mu.RLock()
c, ok = cluster.clients[spath[0]]
cluster.Mu.RUnlock()
if !ok {
if d, err = client.NewConsulDiscovery(clusterTag, spath[0], this.options.ConsulServers, nil); err != nil {
return
}
c = client.NewBidirectionalXClient(spath[0], client.Failfast, client.RandomSelect, d, client.DefaultOption, this.msgChan)
cluster.Mu.Lock()
cluster.clients[spath[0]] = c
cluster.Mu.Unlock()
c.GetPlugins().Add(this)
if this.options.RpcxStartType == RpcxStartByClient && this.options.AutoConnect {
c.SetSelector(newSelector(this.options.Log, clusterTag, this.UpdateServer))
} else {
c.SetSelector(newSelector(this.options.Log, clusterTag, nil))
}
}
*ctx = context.WithValue(*ctx, share.ReqMetaDataKey, map[string]string{
ServiceClusterTag: this.options.ServiceTag,
CallRoutRulesKey: servicePath,
ServiceAddrKey: "tcp@" + this.options.ServiceAddr,
ServiceMetaKey: this.metadata,
})
return
}
//注册服务方法
func (this *Client) registerFunction(servicePath string, fn interface{}, name string, useName bool) (string, error) {
this.serviceMapMu.Lock()
defer this.serviceMapMu.Unlock()
ss := this.serviceMap[servicePath]
if ss == nil {
ss = new(service)
ss.name = servicePath
ss.function = make(map[string]*functionType)
}
f, ok := fn.(reflect.Value)
if !ok {
f = reflect.ValueOf(fn)
}
if f.Kind() != reflect.Func {
return "", errors.New("function must be func or bound method")
}
fname := runtime.FuncForPC(reflect.Indirect(f).Pointer()).Name()
if fname != "" {
i := strings.LastIndex(fname, ".")
if i >= 0 {
fname = fname[i+1:]
}
}
if useName {
fname = name
}
if fname == "" {
errorStr := "rpcx.registerFunction: no func name for type " + f.Type().String()
this.options.Log.Errorf(errorStr)
return fname, errors.New(errorStr)
}
t := f.Type()
if t.NumIn() != 3 {
return fname, fmt.Errorf("rpcx.registerFunction: has wrong number of ins: %s", f.Type().String())
}
if t.NumOut() != 1 {
return fname, fmt.Errorf("rpcx.registerFunction: has wrong number of outs: %s", f.Type().String())
}
// First arg must be context.Context
ctxType := t.In(0)
if !ctxType.Implements(typeOfContext) {
return fname, fmt.Errorf("function %s must use context as the first parameter", f.Type().String())
}
argType := t.In(1)
if !isExportedOrBuiltinType(argType) {
return fname, fmt.Errorf("function %s parameter type not exported: %v", f.Type().String(), argType)
}
replyType := t.In(2)
if replyType.Kind() != reflect.Ptr {
return fname, fmt.Errorf("function %s reply type not a pointer: %s", f.Type().String(), replyType)
}
if !isExportedOrBuiltinType(replyType) {
return fname, fmt.Errorf("function %s reply type not exported: %v", f.Type().String(), replyType)
}
// The return type of the method must be error.
if returnType := t.Out(0); returnType != typeOfError {
return fname, fmt.Errorf("function %s returns %s, not error", f.Type().String(), returnType.String())
}
// Install the methods
ss.function[fname] = &functionType{fn: f, ArgType: argType, ReplyType: replyType}
this.serviceMap[servicePath] = ss
// init pool for reflect.Type of args and reply
reflectTypePools.Init(argType)
reflectTypePools.Init(replyType)
return fname, nil
}
//处理远程服务请求
func (this *Client) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
serviceName := req.ServicePath
methodName := req.ServiceMethod
res = req.Clone()
res.SetMessageType(protocol.Response)
this.serviceMapMu.RLock()
service := this.serviceMap[serviceName]
this.serviceMapMu.RUnlock()
if service == nil {
err = errors.New("rpcx: can't find service " + serviceName)
return handleError(res, err)
}
mtype := service.method[methodName]
if mtype == nil {
if service.function[methodName] != nil { // check raw functions
return this.handleRequestForFunction(ctx, req)
}
err = errors.New("rpcx: can't find method " + methodName)
return handleError(res, err)
}
// get a argv object from object pool
argv := reflectTypePools.Get(mtype.ArgType)
codec := share.Codecs[req.SerializeType()]
if codec == nil {
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
return handleError(res, err)
}
err = codec.Decode(req.Payload, argv)
if err != nil {
return handleError(res, err)
}
// and get a reply object from object pool
replyv := reflectTypePools.Get(mtype.ReplyType)
if mtype.ArgType.Kind() != reflect.Ptr {
err = service.call(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv))
} else {
err = service.call(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv))
}
reflectTypePools.Put(mtype.ArgType, argv)
if err != nil {
if replyv != nil {
data, err := codec.Encode(replyv)
// return reply to object pool
reflectTypePools.Put(mtype.ReplyType, replyv)
if err != nil {
return handleError(res, err)
}
res.Payload = data
}
return handleError(res, err)
}
if !req.IsOneway() {
data, err := codec.Encode(replyv)
// return reply to object pool
reflectTypePools.Put(mtype.ReplyType, replyv)
if err != nil {
return handleError(res, err)
}
res.Payload = data
} else if replyv != nil {
reflectTypePools.Put(mtype.ReplyType, replyv)
}
this.options.Log.Debugf("server called service %+v for an request %+v", service, req)
return res, nil
}
//处理远程服务请求 for 方法
func (this *Client) handleRequestForFunction(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
res = req.Clone()
res.SetMessageType(protocol.Response)
serviceName := req.ServicePath
methodName := req.ServiceMethod
this.serviceMapMu.RLock()
service := this.serviceMap[serviceName]
this.serviceMapMu.RUnlock()
if service == nil {
err = errors.New("rpcx: can't find service for func raw function")
return handleError(res, err)
}
mtype := service.function[methodName]
if mtype == nil {
err = errors.New("rpcx: can't find method " + methodName)
return handleError(res, err)
}
argv := reflectTypePools.Get(mtype.ArgType)
codec := share.Codecs[req.SerializeType()]
if codec == nil {
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
return handleError(res, err)
}
err = codec.Decode(req.Payload, argv)
if err != nil {
return handleError(res, err)
}
replyv := reflectTypePools.Get(mtype.ReplyType)
if mtype.ArgType.Kind() != reflect.Ptr {
err = service.callForFunction(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv))
} else {
err = service.callForFunction(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv))
}
reflectTypePools.Put(mtype.ArgType, argv)
if err != nil {
reflectTypePools.Put(mtype.ReplyType, replyv)
return handleError(res, err)
}
if !req.IsOneway() {
data, err := codec.Encode(replyv)
reflectTypePools.Put(mtype.ReplyType, replyv)
if err != nil {
return handleError(res, err)
}
res.Payload = data
} else if replyv != nil {
reflectTypePools.Put(mtype.ReplyType, replyv)
}
return res, nil
}
//发送远程服务调用 回应消息
func (this *Client) sendResponse(conn net.Conn, req, res *protocol.Message) {
if len(res.Payload) > 1024 && req.CompressType() != protocol.None {
res.SetCompressType(req.CompressType())
}
data := res.EncodeSlicePointer()
if this.writeTimeout != 0 {
conn.SetWriteDeadline(time.Now().Add(this.writeTimeout))
}
conn.Write(*data)
protocol.PutData(data)
}
//请求错误 封装回应消息
func handleError(res *protocol.Message, err error) (*protocol.Message, error) {
res.SetMessageStatusType(protocol.Error)
if res.Metadata == nil {
res.Metadata = make(map[string]string)
}
res.Metadata[protocol.ServiceError] = err.Error()
return res, err
}
//服务注册 类型判断
func isExportedOrBuiltinType(t reflect.Type) bool {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
// PkgPath will be non-empty even for an exported type,
// so we need to check the type name as well.
return isExported(t.Name()) || t.PkgPath() == ""
}
func isExported(name string) bool {
rune, _ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(rune)
}