606 lines
17 KiB
Go
606 lines
17 KiB
Go
package rpcx
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"reflect"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
"unicode"
|
|
"unicode/utf8"
|
|
|
|
"github.com/smallnest/rpcx/client"
|
|
"github.com/smallnest/rpcx/protocol"
|
|
"github.com/smallnest/rpcx/share"
|
|
)
|
|
|
|
func newClient(options *Options) (sys *Client, err error) {
|
|
sys = &Client{
|
|
options: options,
|
|
metadata: fmt.Sprintf("stag=%s&stype=%s&sid=%s&version=%s&addr=%s", options.ServiceTag, options.ServiceType, options.ServiceId, options.ServiceVersion, "tcp@"+options.ServiceAddr),
|
|
clusterClients: make(map[string]*clusterClients),
|
|
conns: make(map[string]net.Conn),
|
|
serviceMap: make(map[string]*service),
|
|
msgChan: make(chan *protocol.Message, 1000),
|
|
}
|
|
return
|
|
}
|
|
|
|
type clusterClients struct {
|
|
Mu sync.RWMutex
|
|
clients map[string]client.XClient //其他集群客户端
|
|
}
|
|
|
|
type Client struct {
|
|
options *Options
|
|
metadata string
|
|
writeTimeout time.Duration
|
|
AsyncWrite bool
|
|
clusterMu sync.RWMutex
|
|
clusterClients map[string]*clusterClients //其他集群客户端
|
|
connsMapMu sync.RWMutex
|
|
conns map[string]net.Conn
|
|
// connectMapMu sync.RWMutex
|
|
// connecting map[string]struct{}
|
|
serviceMapMu sync.RWMutex
|
|
serviceMap map[string]*service
|
|
msgChan chan *protocol.Message // 接收rpcXServer推送消息
|
|
}
|
|
|
|
// DoMessage 服务端消息处理
|
|
func (this *Client) DoMessage() {
|
|
for msg := range this.msgChan {
|
|
go func(req *protocol.Message) {
|
|
if req.ServicePath != "" && req.ServiceMethod != "" {
|
|
this.options.Log.Debugf("DoMessage :%v", req)
|
|
addr, ok := req.Metadata[ServiceAddrKey]
|
|
if !ok {
|
|
this.options.Log.Errorf("Metadata no found ServiceAddrKey!")
|
|
return
|
|
}
|
|
conn, ok := this.conns[addr]
|
|
if !ok {
|
|
this.options.Log.Errorf("no found conn addr:%s", addr)
|
|
return
|
|
}
|
|
res, _ := this.handleRequest(context.Background(), req)
|
|
this.sendResponse(conn, req, res)
|
|
}
|
|
}(msg)
|
|
}
|
|
}
|
|
|
|
//启动RPC 服务 接收消息处理
|
|
func (this *Client) Start() (err error) {
|
|
go this.DoMessage()
|
|
return
|
|
}
|
|
|
|
//停止RPC 服务
|
|
func (this *Client) Stop() (err error) {
|
|
this.clusterMu.Lock()
|
|
for _, v := range this.clusterClients {
|
|
v.Mu.Lock()
|
|
for _, v1 := range v.clients {
|
|
v1.Close()
|
|
}
|
|
v.Mu.Unlock()
|
|
}
|
|
this.clusterMu.RUnlock()
|
|
close(this.msgChan) //关闭消息处理
|
|
return
|
|
}
|
|
|
|
//获取服务集群列表
|
|
func (this *Client) GetServiceTags() []string {
|
|
this.clusterMu.RLock()
|
|
tags := make([]string, len(this.clusterClients))
|
|
n := 0
|
|
for k, _ := range this.clusterClients {
|
|
tags[n] = k
|
|
n++
|
|
}
|
|
this.clusterMu.RUnlock()
|
|
return tags
|
|
}
|
|
|
|
//注册Rpc 服务
|
|
func (this *Client) RegisterFunction(fn interface{}) (err error) {
|
|
_, err = this.registerFunction(this.options.ServiceType, fn, "", false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return
|
|
}
|
|
|
|
//注册Rpc 服务
|
|
func (this *Client) RegisterFunctionName(name string, fn interface{}) (err error) {
|
|
_, err = this.registerFunction(this.options.ServiceType, fn, name, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return
|
|
}
|
|
|
|
//注销 暂不处理
|
|
func (this *Client) UnregisterAll() (err error) {
|
|
return nil
|
|
}
|
|
|
|
//同步调用
|
|
func (this *Client) Call(ctx context.Context, servicePath string, serviceMethod string, args interface{}, reply interface{}) (err error) {
|
|
var (
|
|
_client client.XClient
|
|
)
|
|
if _client, err = this.getclient(&ctx, this.options.ServiceTag, servicePath); err != nil {
|
|
return
|
|
}
|
|
err = _client.Call(ctx, serviceMethod, args, reply)
|
|
return
|
|
}
|
|
|
|
//异步调用
|
|
func (this *Client) Go(ctx context.Context, servicePath string, serviceMethod string, args interface{}, reply interface{}, done chan *client.Call) (call *client.Call, err error) {
|
|
var (
|
|
_client client.XClient
|
|
)
|
|
if _client, err = this.getclient(&ctx, this.options.ServiceTag, servicePath); err != nil {
|
|
return
|
|
}
|
|
return _client.Go(ctx, string(serviceMethod), args, reply, done)
|
|
}
|
|
|
|
//异步调用
|
|
func (this *Client) Broadcast(ctx context.Context, servicePath string, serviceMethod string, args interface{}, reply interface{}) (err error) {
|
|
var (
|
|
_client client.XClient
|
|
)
|
|
if _client, err = this.getclient(&ctx, this.options.ServiceTag, servicePath); err != nil {
|
|
return
|
|
}
|
|
err = _client.Broadcast(ctx, serviceMethod, args, reply)
|
|
return
|
|
}
|
|
|
|
//跨集群 同步调用
|
|
func (this *Client) AcrossClusterCall(ctx context.Context, clusterTag string, servicePath string, serviceMethod string, args interface{}, reply interface{}) (err error) {
|
|
var (
|
|
_client client.XClient
|
|
)
|
|
if _client, err = this.getclient(&ctx, clusterTag, servicePath); err != nil {
|
|
return
|
|
}
|
|
err = _client.Call(ctx, serviceMethod, args, reply)
|
|
return
|
|
}
|
|
|
|
//跨集群 异步调用
|
|
func (this *Client) AcrossClusterGo(ctx context.Context, clusterTag string, servicePath string, serviceMethod string, args interface{}, reply interface{}, done chan *client.Call) (call *client.Call, err error) {
|
|
var (
|
|
_client client.XClient
|
|
)
|
|
if _client, err = this.getclient(&ctx, clusterTag, servicePath); err != nil {
|
|
return
|
|
}
|
|
|
|
return _client.Go(ctx, string(serviceMethod), args, reply, done)
|
|
}
|
|
|
|
//跨集群 广播
|
|
func (this *Client) AcrossClusterBroadcast(ctx context.Context, clusterTag string, servicePath string, serviceMethod string, args interface{}, reply interface{}) (err error) {
|
|
var (
|
|
_client client.XClient
|
|
)
|
|
if _client, err = this.getclient(&ctx, clusterTag, servicePath); err != nil {
|
|
return
|
|
}
|
|
return _client.Broadcast(ctx, serviceMethod, args, reply)
|
|
}
|
|
|
|
func (this *Client) ClusterBroadcast(ctx context.Context, servicePath string, serviceMethod string, args interface{}, reply interface{}) (err error) {
|
|
if servicePath == "" {
|
|
err = errors.New("servicePath no cant null")
|
|
return
|
|
}
|
|
var (
|
|
spath []string
|
|
clients []client.XClient
|
|
)
|
|
spath = strings.Split(servicePath, "/")
|
|
ctx = context.WithValue(ctx, share.ReqMetaDataKey, map[string]string{
|
|
ServiceClusterTag: this.options.ServiceTag,
|
|
CallRoutRulesKey: servicePath,
|
|
ServiceAddrKey: "tcp@" + this.options.ServiceAddr,
|
|
ServiceMetaKey: this.metadata,
|
|
})
|
|
clients = make([]client.XClient, 0)
|
|
this.clusterMu.RLock()
|
|
for _, v := range this.clusterClients {
|
|
v.Mu.RLock()
|
|
if _client, ok := v.clients[spath[0]]; ok {
|
|
clients = append(clients, _client)
|
|
}
|
|
v.Mu.RUnlock()
|
|
}
|
|
this.clusterMu.RUnlock()
|
|
l := len(clients)
|
|
if l > 0 {
|
|
done := make(chan error, l)
|
|
for _, v := range clients {
|
|
go func(c client.XClient) {
|
|
done <- c.Broadcast(ctx, serviceMethod, args, reply)
|
|
}(v)
|
|
}
|
|
timeout := time.NewTimer(time.Minute)
|
|
check:
|
|
for {
|
|
select {
|
|
case err = <-done:
|
|
l--
|
|
if l == 0 || err != nil { // all returns or some one returns an error
|
|
break check
|
|
}
|
|
case <-timeout.C:
|
|
err = errors.New(("timeout"))
|
|
break check
|
|
}
|
|
}
|
|
timeout.Stop()
|
|
} else {
|
|
err = errors.New("on found any service")
|
|
}
|
|
return
|
|
}
|
|
|
|
//监控服务发现,发现没有连接上的额服务端 就连接上去
|
|
func (this *Client) UpdateServer(servers map[string]*ServiceNode) {
|
|
for _, v := range servers {
|
|
this.clusterMu.RLock()
|
|
cluster, ok := this.clusterClients[v.ServiceTag]
|
|
this.clusterMu.RUnlock()
|
|
if ok {
|
|
cluster.Mu.RLock()
|
|
_, ok = cluster.clients[v.ServiceType]
|
|
cluster.Mu.RUnlock()
|
|
if ok {
|
|
continue
|
|
}
|
|
}
|
|
//没有建立客户端 主动发起握手
|
|
if err := this.Call(context.Background(), fmt.Sprintf("%s/%s", v.ServiceType, v.ServiceId), RpcX_ShakeHands, &ServiceNode{
|
|
ServiceTag: this.options.ServiceTag,
|
|
ServiceId: this.options.ServiceId,
|
|
ServiceType: this.options.ServiceType,
|
|
ServiceAddr: this.options.ServiceAddr},
|
|
&ServiceNode{}); err != nil {
|
|
this.options.Log.Errorf("ShakeHands new node addr:%s err:%v", v.ServiceAddr, err)
|
|
} else {
|
|
this.options.Log.Debugf("UpdateServer addr:%s ", v.ServiceAddr)
|
|
}
|
|
}
|
|
}
|
|
|
|
//监控连接建立
|
|
func (this *Client) ClientConnected(conn net.Conn) (net.Conn, error) {
|
|
addr := "tcp@" + conn.RemoteAddr().String()
|
|
this.connsMapMu.Lock()
|
|
this.conns[addr] = conn
|
|
this.connsMapMu.Unlock()
|
|
this.options.Log.Debugf("ClientConnected addr:%v", addr)
|
|
return conn, nil
|
|
}
|
|
|
|
//监听连接关闭
|
|
func (this *Client) ClientConnectionClose(conn net.Conn) error {
|
|
addr := "tcp@" + conn.RemoteAddr().String()
|
|
this.connsMapMu.Lock()
|
|
delete(this.conns, addr)
|
|
this.connsMapMu.Unlock()
|
|
this.options.Log.Debugf("ClientConnectionClose addr:%v", addr)
|
|
return nil
|
|
}
|
|
|
|
//获取目标客户端
|
|
func (this *Client) getclient(ctx *context.Context, clusterTag string, servicePath string) (c client.XClient, err error) {
|
|
if servicePath == "" {
|
|
err = errors.New("servicePath no cant null")
|
|
return
|
|
}
|
|
var (
|
|
spath []string
|
|
cluster *clusterClients
|
|
d *client.ConsulDiscovery
|
|
ok bool
|
|
)
|
|
spath = strings.Split(servicePath, "/")
|
|
this.clusterMu.RLock()
|
|
cluster, ok = this.clusterClients[clusterTag]
|
|
this.clusterMu.RUnlock()
|
|
if !ok {
|
|
cluster = &clusterClients{clients: make(map[string]client.XClient)}
|
|
this.clusterMu.Lock()
|
|
this.clusterClients[clusterTag] = cluster
|
|
this.clusterMu.Unlock()
|
|
}
|
|
cluster.Mu.RLock()
|
|
c, ok = cluster.clients[spath[0]]
|
|
cluster.Mu.RUnlock()
|
|
if !ok {
|
|
if d, err = client.NewConsulDiscovery(clusterTag, spath[0], this.options.ConsulServers, nil); err != nil {
|
|
return
|
|
}
|
|
c = client.NewBidirectionalXClient(spath[0], client.Failfast, client.RandomSelect, d, client.DefaultOption, this.msgChan)
|
|
cluster.Mu.Lock()
|
|
cluster.clients[spath[0]] = c
|
|
cluster.Mu.Unlock()
|
|
c.GetPlugins().Add(this)
|
|
if this.options.RpcxStartType == RpcxStartByClient && this.options.AutoConnect {
|
|
c.SetSelector(newSelector(this.options.Log, clusterTag, this.UpdateServer))
|
|
} else {
|
|
c.SetSelector(newSelector(this.options.Log, clusterTag, nil))
|
|
}
|
|
}
|
|
|
|
*ctx = context.WithValue(*ctx, share.ReqMetaDataKey, map[string]string{
|
|
ServiceClusterTag: this.options.ServiceTag,
|
|
CallRoutRulesKey: servicePath,
|
|
ServiceAddrKey: "tcp@" + this.options.ServiceAddr,
|
|
ServiceMetaKey: this.metadata,
|
|
})
|
|
return
|
|
}
|
|
|
|
//注册服务方法
|
|
func (this *Client) registerFunction(servicePath string, fn interface{}, name string, useName bool) (string, error) {
|
|
this.serviceMapMu.Lock()
|
|
defer this.serviceMapMu.Unlock()
|
|
|
|
ss := this.serviceMap[servicePath]
|
|
if ss == nil {
|
|
ss = new(service)
|
|
ss.name = servicePath
|
|
ss.function = make(map[string]*functionType)
|
|
}
|
|
|
|
f, ok := fn.(reflect.Value)
|
|
if !ok {
|
|
f = reflect.ValueOf(fn)
|
|
}
|
|
if f.Kind() != reflect.Func {
|
|
return "", errors.New("function must be func or bound method")
|
|
}
|
|
|
|
fname := runtime.FuncForPC(reflect.Indirect(f).Pointer()).Name()
|
|
if fname != "" {
|
|
i := strings.LastIndex(fname, ".")
|
|
if i >= 0 {
|
|
fname = fname[i+1:]
|
|
}
|
|
}
|
|
if useName {
|
|
fname = name
|
|
}
|
|
if fname == "" {
|
|
errorStr := "rpcx.registerFunction: no func name for type " + f.Type().String()
|
|
this.options.Log.Errorf(errorStr)
|
|
return fname, errors.New(errorStr)
|
|
}
|
|
|
|
t := f.Type()
|
|
if t.NumIn() != 3 {
|
|
return fname, fmt.Errorf("rpcx.registerFunction: has wrong number of ins: %s", f.Type().String())
|
|
}
|
|
if t.NumOut() != 1 {
|
|
return fname, fmt.Errorf("rpcx.registerFunction: has wrong number of outs: %s", f.Type().String())
|
|
}
|
|
|
|
// First arg must be context.Context
|
|
ctxType := t.In(0)
|
|
if !ctxType.Implements(typeOfContext) {
|
|
return fname, fmt.Errorf("function %s must use context as the first parameter", f.Type().String())
|
|
}
|
|
|
|
argType := t.In(1)
|
|
if !isExportedOrBuiltinType(argType) {
|
|
return fname, fmt.Errorf("function %s parameter type not exported: %v", f.Type().String(), argType)
|
|
}
|
|
|
|
replyType := t.In(2)
|
|
if replyType.Kind() != reflect.Ptr {
|
|
return fname, fmt.Errorf("function %s reply type not a pointer: %s", f.Type().String(), replyType)
|
|
}
|
|
if !isExportedOrBuiltinType(replyType) {
|
|
return fname, fmt.Errorf("function %s reply type not exported: %v", f.Type().String(), replyType)
|
|
}
|
|
|
|
// The return type of the method must be error.
|
|
if returnType := t.Out(0); returnType != typeOfError {
|
|
return fname, fmt.Errorf("function %s returns %s, not error", f.Type().String(), returnType.String())
|
|
}
|
|
|
|
// Install the methods
|
|
ss.function[fname] = &functionType{fn: f, ArgType: argType, ReplyType: replyType}
|
|
this.serviceMap[servicePath] = ss
|
|
|
|
// init pool for reflect.Type of args and reply
|
|
reflectTypePools.Init(argType)
|
|
reflectTypePools.Init(replyType)
|
|
return fname, nil
|
|
}
|
|
|
|
//处理远程服务请求
|
|
func (this *Client) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
|
|
serviceName := req.ServicePath
|
|
methodName := req.ServiceMethod
|
|
|
|
res = req.Clone()
|
|
res.SetMessageType(protocol.Response)
|
|
this.serviceMapMu.RLock()
|
|
service := this.serviceMap[serviceName]
|
|
this.serviceMapMu.RUnlock()
|
|
if service == nil {
|
|
err = errors.New("rpcx: can't find service " + serviceName)
|
|
return handleError(res, err)
|
|
}
|
|
mtype := service.method[methodName]
|
|
if mtype == nil {
|
|
if service.function[methodName] != nil { // check raw functions
|
|
return this.handleRequestForFunction(ctx, req)
|
|
}
|
|
err = errors.New("rpcx: can't find method " + methodName)
|
|
return handleError(res, err)
|
|
}
|
|
// get a argv object from object pool
|
|
argv := reflectTypePools.Get(mtype.ArgType)
|
|
|
|
codec := share.Codecs[req.SerializeType()]
|
|
if codec == nil {
|
|
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
|
|
return handleError(res, err)
|
|
}
|
|
err = codec.Decode(req.Payload, argv)
|
|
if err != nil {
|
|
return handleError(res, err)
|
|
}
|
|
|
|
// and get a reply object from object pool
|
|
replyv := reflectTypePools.Get(mtype.ReplyType)
|
|
if mtype.ArgType.Kind() != reflect.Ptr {
|
|
err = service.call(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv))
|
|
} else {
|
|
err = service.call(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv))
|
|
}
|
|
reflectTypePools.Put(mtype.ArgType, argv)
|
|
if err != nil {
|
|
if replyv != nil {
|
|
data, err := codec.Encode(replyv)
|
|
// return reply to object pool
|
|
reflectTypePools.Put(mtype.ReplyType, replyv)
|
|
if err != nil {
|
|
return handleError(res, err)
|
|
}
|
|
res.Payload = data
|
|
}
|
|
return handleError(res, err)
|
|
}
|
|
|
|
if !req.IsOneway() {
|
|
data, err := codec.Encode(replyv)
|
|
// return reply to object pool
|
|
reflectTypePools.Put(mtype.ReplyType, replyv)
|
|
if err != nil {
|
|
return handleError(res, err)
|
|
}
|
|
res.Payload = data
|
|
} else if replyv != nil {
|
|
reflectTypePools.Put(mtype.ReplyType, replyv)
|
|
}
|
|
this.options.Log.Debugf("server called service %+v for an request %+v", service, req)
|
|
return res, nil
|
|
}
|
|
|
|
//处理远程服务请求 for 方法
|
|
func (this *Client) handleRequestForFunction(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
|
|
res = req.Clone()
|
|
|
|
res.SetMessageType(protocol.Response)
|
|
|
|
serviceName := req.ServicePath
|
|
methodName := req.ServiceMethod
|
|
this.serviceMapMu.RLock()
|
|
service := this.serviceMap[serviceName]
|
|
this.serviceMapMu.RUnlock()
|
|
if service == nil {
|
|
err = errors.New("rpcx: can't find service for func raw function")
|
|
return handleError(res, err)
|
|
}
|
|
mtype := service.function[methodName]
|
|
if mtype == nil {
|
|
err = errors.New("rpcx: can't find method " + methodName)
|
|
return handleError(res, err)
|
|
}
|
|
|
|
argv := reflectTypePools.Get(mtype.ArgType)
|
|
|
|
codec := share.Codecs[req.SerializeType()]
|
|
if codec == nil {
|
|
err = fmt.Errorf("can not find codec for %d", req.SerializeType())
|
|
return handleError(res, err)
|
|
}
|
|
|
|
err = codec.Decode(req.Payload, argv)
|
|
if err != nil {
|
|
return handleError(res, err)
|
|
}
|
|
|
|
replyv := reflectTypePools.Get(mtype.ReplyType)
|
|
|
|
if mtype.ArgType.Kind() != reflect.Ptr {
|
|
err = service.callForFunction(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv))
|
|
} else {
|
|
err = service.callForFunction(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv))
|
|
}
|
|
|
|
reflectTypePools.Put(mtype.ArgType, argv)
|
|
|
|
if err != nil {
|
|
reflectTypePools.Put(mtype.ReplyType, replyv)
|
|
return handleError(res, err)
|
|
}
|
|
|
|
if !req.IsOneway() {
|
|
data, err := codec.Encode(replyv)
|
|
reflectTypePools.Put(mtype.ReplyType, replyv)
|
|
if err != nil {
|
|
return handleError(res, err)
|
|
}
|
|
res.Payload = data
|
|
} else if replyv != nil {
|
|
reflectTypePools.Put(mtype.ReplyType, replyv)
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
//发送远程服务调用 回应消息
|
|
func (this *Client) sendResponse(conn net.Conn, req, res *protocol.Message) {
|
|
if len(res.Payload) > 1024 && req.CompressType() != protocol.None {
|
|
res.SetCompressType(req.CompressType())
|
|
}
|
|
data := res.EncodeSlicePointer()
|
|
if this.writeTimeout != 0 {
|
|
conn.SetWriteDeadline(time.Now().Add(this.writeTimeout))
|
|
}
|
|
conn.Write(*data)
|
|
protocol.PutData(data)
|
|
}
|
|
|
|
//请求错误 封装回应消息
|
|
func handleError(res *protocol.Message, err error) (*protocol.Message, error) {
|
|
res.SetMessageStatusType(protocol.Error)
|
|
if res.Metadata == nil {
|
|
res.Metadata = make(map[string]string)
|
|
}
|
|
res.Metadata[protocol.ServiceError] = err.Error()
|
|
return res, err
|
|
}
|
|
|
|
//服务注册 类型判断
|
|
func isExportedOrBuiltinType(t reflect.Type) bool {
|
|
for t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
// PkgPath will be non-empty even for an exported type,
|
|
// so we need to check the type name as well.
|
|
return isExported(t.Name()) || t.PkgPath() == ""
|
|
}
|
|
|
|
func isExported(name string) bool {
|
|
rune, _ := utf8.DecodeRuneInString(name)
|
|
return unicode.IsUpper(rune)
|
|
}
|