405 lines
11 KiB
Go
405 lines
11 KiB
Go
package rpcx
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/rcrowley/go-metrics"
|
|
"github.com/smallnest/rpcx/client"
|
|
"github.com/smallnest/rpcx/protocol"
|
|
"github.com/smallnest/rpcx/server"
|
|
"github.com/smallnest/rpcx/serverplugin"
|
|
"github.com/smallnest/rpcx/share"
|
|
)
|
|
|
|
func newService(options *Options) (sys *Service, err error) {
|
|
sys = &Service{
|
|
options: options,
|
|
metadata: fmt.Sprintf("stag=%s&stype=%s&sid=%s&version=%s&addr=%s", options.ServiceTag, options.ServiceType, options.ServiceId, options.ServiceVersion, "tcp@"+options.ServiceAddr),
|
|
server: server.NewServer(),
|
|
selectors: make(map[string]client.Selector),
|
|
clients: make(map[string]net.Conn),
|
|
clientmeta: make(map[string]string),
|
|
pending: make(map[uint64]*client.Call),
|
|
}
|
|
|
|
r := &serverplugin.ConsulRegisterPlugin{
|
|
ServiceAddress: "tcp@" + options.ServiceAddr,
|
|
ConsulServers: options.ConsulServers,
|
|
BasePath: options.ServiceTag,
|
|
Metrics: metrics.NewRegistry(),
|
|
UpdateInterval: time.Minute,
|
|
}
|
|
if err = r.Start(); err != nil {
|
|
return
|
|
}
|
|
sys.server.Plugins.Add(r)
|
|
sys.server.Plugins.Add(sys)
|
|
sys.RegisterFunctionName(RpcX_ShakeHands, sys.RpcxShakeHands) //注册握手函数
|
|
return
|
|
}
|
|
|
|
type Service struct {
|
|
options *Options
|
|
metadata string
|
|
server *server.Server
|
|
selectors map[string]client.Selector
|
|
clientmutex sync.Mutex
|
|
clients map[string]net.Conn
|
|
clientmeta map[string]string
|
|
mutex sync.Mutex // protects following
|
|
seq uint64
|
|
pending map[uint64]*client.Call
|
|
}
|
|
|
|
//RPC 服务启动
|
|
func (this *Service) Start() (err error) {
|
|
go func() {
|
|
if err = this.server.Serve("tcp", this.options.ServiceAddr); err != nil {
|
|
this.Warnf("rpcx server exit:%v", err)
|
|
}
|
|
}()
|
|
return
|
|
}
|
|
|
|
//服务停止
|
|
func (this *Service) Stop() (err error) {
|
|
err = this.server.Close()
|
|
return
|
|
}
|
|
|
|
//注册RPC 服务
|
|
func (this *Service) RegisterFunction(fn interface{}) (err error) {
|
|
err = this.server.RegisterFunction(this.options.ServiceType, fn, this.metadata)
|
|
return
|
|
}
|
|
|
|
//注册RPC 服务
|
|
func (this *Service) RegisterFunctionName(name string, fn interface{}) (err error) {
|
|
err = this.server.RegisterFunctionName(this.options.ServiceType, name, fn, this.metadata)
|
|
return
|
|
}
|
|
|
|
//注销 暂时不处理
|
|
func (this *Service) UnregisterAll() (err error) {
|
|
// err = this.server.UnregisterAll()
|
|
return
|
|
}
|
|
|
|
//同步调用远程服务
|
|
func (this *Service) Call(ctx context.Context, servicePath string, serviceMethod string, args interface{}, reply interface{}) (err error) {
|
|
var (
|
|
done *client.Call
|
|
conn net.Conn
|
|
)
|
|
seq := new(uint64)
|
|
ctx = context.WithValue(ctx, seqKey{}, seq)
|
|
if conn, done, err = this.call(ctx, this.options.ServiceTag, servicePath, serviceMethod, args, reply, make(chan *client.Call, 1)); err != nil {
|
|
return
|
|
}
|
|
select {
|
|
case <-ctx.Done(): // cancel by context
|
|
this.mutex.Lock()
|
|
call := this.pending[*seq]
|
|
delete(this.pending, *seq)
|
|
this.mutex.Unlock()
|
|
if call != nil {
|
|
call.Error = ctx.Err()
|
|
call.Done <- call
|
|
}
|
|
return ctx.Err()
|
|
case call := <-done.Done:
|
|
err = call.Error
|
|
meta := ctx.Value(share.ResMetaDataKey)
|
|
if meta != nil && len(call.ResMetadata) > 0 {
|
|
resMeta := meta.(map[string]string)
|
|
for k, v := range call.ResMetadata {
|
|
resMeta[k] = v
|
|
}
|
|
resMeta[share.ServerAddress] = conn.RemoteAddr().String()
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
//异步调用 远程服务
|
|
func (this *Service) Go(ctx context.Context, servicePath string, serviceMethod string, args interface{}, reply interface{}, done chan *client.Call) (_call *client.Call, err error) {
|
|
_, _call, err = this.call(ctx, this.options.ServiceTag, servicePath, serviceMethod, args, reply, done)
|
|
return
|
|
}
|
|
|
|
//跨服 同步调用 远程服务
|
|
func (this *Service) AcrossClusterCall(ctx context.Context, clusterTag string, servicePath string, serviceMethod string, args interface{}, reply interface{}) (err error) {
|
|
var (
|
|
done *client.Call
|
|
conn net.Conn
|
|
)
|
|
seq := new(uint64)
|
|
ctx = context.WithValue(ctx, seqKey{}, seq)
|
|
if conn, done, err = this.call(ctx, clusterTag, servicePath, serviceMethod, args, reply, make(chan *client.Call, 1)); err != nil {
|
|
return
|
|
}
|
|
select {
|
|
case <-ctx.Done(): // cancel by context
|
|
this.mutex.Lock()
|
|
call := this.pending[*seq]
|
|
delete(this.pending, *seq)
|
|
this.mutex.Unlock()
|
|
if call != nil {
|
|
call.Error = ctx.Err()
|
|
call.Done <- call
|
|
}
|
|
return ctx.Err()
|
|
case call := <-done.Done:
|
|
err = call.Error
|
|
meta := ctx.Value(share.ResMetaDataKey)
|
|
if meta != nil && len(call.ResMetadata) > 0 {
|
|
resMeta := meta.(map[string]string)
|
|
for k, v := range call.ResMetadata {
|
|
resMeta[k] = v
|
|
}
|
|
resMeta[share.ServerAddress] = conn.RemoteAddr().String()
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
//跨服 异步调用 远程服务
|
|
func (this *Service) AcrossClusterGo(ctx context.Context, clusterTag, servicePath string, serviceMethod string, args interface{}, reply interface{}, done chan *client.Call) (_call *client.Call, err error) {
|
|
_, _call, err = this.call(ctx, clusterTag, servicePath, serviceMethod, args, reply, done)
|
|
return
|
|
}
|
|
|
|
//监听客户端链接到服务上 保存客户端的连接对象
|
|
func (this *Service) PreHandleRequest(ctx context.Context, r *protocol.Message) error {
|
|
var (
|
|
stag string
|
|
selector client.Selector
|
|
ok bool
|
|
)
|
|
req_metadata := ctx.Value(share.ReqMetaDataKey).(map[string]string)
|
|
if stag, ok = req_metadata[ServiceClusterTag]; ok {
|
|
if selector, ok = this.selectors[stag]; !ok {
|
|
this.selectors[stag] = newSelector(nil)
|
|
selector = this.selectors[stag]
|
|
}
|
|
if addr, ok := req_metadata[ServiceAddrKey]; ok {
|
|
if _, ok = this.clientmeta[addr]; !ok {
|
|
if smeta, ok := req_metadata[ServiceMetaKey]; ok {
|
|
servers := make(map[string]string)
|
|
this.clientmutex.Lock()
|
|
this.clientmeta[addr] = smeta
|
|
this.clients[addr] = ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
|
for k, v := range this.clientmeta {
|
|
servers[k] = v
|
|
}
|
|
this.clientmutex.Unlock()
|
|
selector.UpdateServer(servers)
|
|
this.Debugf("fond new node addr:%s smeta:%s \n", addr, smeta)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
//监控rpc连接收到的请求消息 处理消息回调请求
|
|
func (this *Service) PostReadRequest(ctx context.Context, r *protocol.Message, e error) error {
|
|
if isCallMessage := (r.MessageType() == protocol.Request); isCallMessage {
|
|
return nil
|
|
}
|
|
e = errors.New("is callMessage")
|
|
seq := r.Seq()
|
|
this.mutex.Lock()
|
|
call := this.pending[seq]
|
|
delete(this.pending, seq)
|
|
this.mutex.Unlock()
|
|
switch {
|
|
case call == nil:
|
|
this.Errorf("callmessage no found call:%d", seq)
|
|
case r.MessageStatusType() == protocol.Error:
|
|
if len(r.Metadata) > 0 {
|
|
call.ResMetadata = r.Metadata
|
|
call.Error = errors.New(r.Metadata[protocol.ServiceError])
|
|
}
|
|
if len(r.Payload) > 0 {
|
|
data := r.Payload
|
|
codec := share.Codecs[r.SerializeType()]
|
|
if codec != nil {
|
|
_ = codec.Decode(data, call.Reply)
|
|
}
|
|
}
|
|
call.Done <- call
|
|
default:
|
|
data := r.Payload
|
|
if len(data) > 0 {
|
|
codec := share.Codecs[r.SerializeType()]
|
|
if codec == nil {
|
|
call.Error = errors.New(client.ErrUnsupportedCodec.Error())
|
|
} else {
|
|
err := codec.Decode(data, call.Reply)
|
|
if err != nil {
|
|
call.Error = err
|
|
}
|
|
}
|
|
}
|
|
if len(r.Metadata) > 0 {
|
|
call.ResMetadata = r.Metadata
|
|
}
|
|
call.Done <- call
|
|
}
|
|
return nil
|
|
}
|
|
|
|
//客户端配置 AutoConnect 的默认连接函数
|
|
func (this *Service) RpcxShakeHands(ctx context.Context, args *ServiceNode, reply *ServiceNode) error {
|
|
// this.Debugf("RpcxShakeHands:%+v", ctx.Value(share.ReqMetaDataKey).(map[string]string))
|
|
return nil
|
|
}
|
|
|
|
///日志***********************************************************************
|
|
func (this *Service) Debug() bool {
|
|
return this.options.Debug
|
|
}
|
|
|
|
func (this *Service) Debugf(format string, a ...interface{}) {
|
|
if this.options.Debug {
|
|
this.options.Log.Debugf("[SYS RPCX] "+format, a...)
|
|
}
|
|
}
|
|
func (this *Service) Infof(format string, a ...interface{}) {
|
|
if this.options.Debug {
|
|
this.options.Log.Infof("[SYS RPCX] "+format, a...)
|
|
}
|
|
}
|
|
func (this *Service) Warnf(format string, a ...interface{}) {
|
|
if this.options.Debug {
|
|
this.options.Log.Warnf("[SYS RPCX] "+format, a...)
|
|
}
|
|
}
|
|
func (this *Service) Errorf(format string, a ...interface{}) {
|
|
if this.options.Debug {
|
|
this.options.Log.Errorf("[SYS RPCX] "+format, a...)
|
|
}
|
|
}
|
|
func (this *Service) Panicf(format string, a ...interface{}) {
|
|
if this.options.Debug {
|
|
this.options.Log.Panicf("[SYS RPCX] "+format, a...)
|
|
}
|
|
}
|
|
func (this *Service) Fatalf(format string, a ...interface{}) {
|
|
if this.options.Debug {
|
|
this.options.Log.Fatalf("[SYS RPCX] "+format, a...)
|
|
}
|
|
}
|
|
|
|
//执行远程调用
|
|
func (this *Service) call(ctx context.Context, clusterTag string, servicePath string, serviceMethod string, args interface{}, reply interface{}, done chan *client.Call) (conn net.Conn, _call *client.Call, err error) {
|
|
var (
|
|
spath []string
|
|
clientaddr string
|
|
metadata map[string]string
|
|
selector client.Selector
|
|
ok bool
|
|
)
|
|
if servicePath == "" {
|
|
err = errors.New("servicePath no cant null")
|
|
return
|
|
}
|
|
metadata = map[string]string{
|
|
ServiceClusterTag: clusterTag,
|
|
CallRoutRulesKey: servicePath,
|
|
ServiceAddrKey: "tcp@" + this.options.ServiceAddr,
|
|
ServiceMetaKey: this.metadata,
|
|
}
|
|
spath = strings.Split(servicePath, "/")
|
|
ctx = context.WithValue(ctx, share.ReqMetaDataKey, map[string]string{
|
|
CallRoutRulesKey: servicePath,
|
|
ServiceAddrKey: "tcp@" + this.options.ServiceAddr,
|
|
ServiceMetaKey: this.metadata,
|
|
})
|
|
if selector, ok = this.selectors[clusterTag]; !ok {
|
|
err = fmt.Errorf("on found serviceTag:%s", clusterTag)
|
|
}
|
|
if clientaddr = selector.Select(ctx, spath[0], serviceMethod, args); clientaddr == "" {
|
|
err = fmt.Errorf("on found servicePath:%s", servicePath)
|
|
return
|
|
}
|
|
if conn, ok = this.clients[clientaddr]; !ok {
|
|
err = fmt.Errorf("on found clientaddr:%s", clientaddr)
|
|
return
|
|
}
|
|
|
|
_call = new(client.Call)
|
|
_call.ServicePath = servicePath
|
|
_call.ServiceMethod = serviceMethod
|
|
_call.Args = args
|
|
_call.Reply = reply
|
|
if done == nil {
|
|
done = make(chan *client.Call, 10) // buffered.
|
|
} else {
|
|
if cap(done) == 0 {
|
|
log.Panic("rpc: done channel is unbuffered")
|
|
}
|
|
}
|
|
_call.Done = done
|
|
this.send(ctx, conn, spath[0], serviceMethod, metadata, _call)
|
|
return
|
|
}
|
|
|
|
//发送远程调用请求
|
|
func (this *Service) send(ctx context.Context, conn net.Conn, servicePath string, serviceMethod string, metadata map[string]string, call *client.Call) {
|
|
defer func() {
|
|
if call.Error != nil {
|
|
call.Done <- call
|
|
}
|
|
}()
|
|
serializeType := this.options.SerializeType
|
|
codec := share.Codecs[serializeType]
|
|
if codec == nil {
|
|
call.Error = client.ErrUnsupportedCodec
|
|
return
|
|
}
|
|
data, err := codec.Encode(call.Args)
|
|
if err != nil {
|
|
call.Error = err
|
|
return
|
|
}
|
|
|
|
this.mutex.Lock()
|
|
seq := this.seq
|
|
this.seq++
|
|
this.pending[seq] = call
|
|
this.mutex.Unlock()
|
|
if cseq, ok := ctx.Value(seqKey{}).(*uint64); ok {
|
|
*cseq = seq
|
|
}
|
|
req := protocol.GetPooledMsg()
|
|
req.SetMessageType(protocol.Request)
|
|
req.SetSeq(seq)
|
|
req.SetOneway(true)
|
|
req.SetSerializeType(this.options.SerializeType)
|
|
req.ServicePath = servicePath
|
|
req.ServiceMethod = serviceMethod
|
|
req.Metadata = metadata
|
|
req.Payload = data
|
|
|
|
b := req.EncodeSlicePointer()
|
|
if _, err = conn.Write(*b); err != nil {
|
|
call.Error = err
|
|
this.mutex.Lock()
|
|
delete(this.pending, seq)
|
|
this.mutex.Unlock()
|
|
return
|
|
}
|
|
protocol.PutData(b)
|
|
protocol.FreeMsg(req)
|
|
return
|
|
}
|