package rpcx import ( "context" "errors" "fmt" "log" "net" "strings" "sync" "time" "github.com/rcrowley/go-metrics" "github.com/smallnest/rpcx/client" "github.com/smallnest/rpcx/protocol" "github.com/smallnest/rpcx/server" "github.com/smallnest/rpcx/serverplugin" "github.com/smallnest/rpcx/share" ) func newService(options Options) (sys *Service, err error) { sys = &Service{ options: options, metadata: fmt.Sprintf("stype=%s&sid=%s&version=%s&addr=%s", options.ServiceType, options.ServiceId, options.ServiceVersion, "tcp@"+options.ServiceAddr), server: server.NewServer(), selector: newSelector(), clients: make(map[string]net.Conn), clientmeta: make(map[string]string), pending: make(map[uint64]*client.Call), } r := &serverplugin.ConsulRegisterPlugin{ ServiceAddress: "tcp@" + options.ServiceAddr, ConsulServers: options.ConsulServers, BasePath: options.ServiceTag, Metrics: metrics.NewRegistry(), UpdateInterval: time.Minute, } if err = r.Start(); err != nil { return } sys.server.Plugins.Add(r) sys.server.Plugins.Add(sys) return } type Service struct { options Options metadata string server *server.Server selector client.Selector clientmutex sync.Mutex clients map[string]net.Conn clientmeta map[string]string mutex sync.Mutex // protects following seq uint64 pending map[uint64]*client.Call } //RPC 服务启动 func (this *Service) Start() (err error) { go func() { if err = this.server.Serve("tcp", this.options.ServiceAddr); err != nil { this.Errorf("rpcx server exit!") } }() return } //服务停止 func (this *Service) Stop() (err error) { err = this.server.Close() return } //注册RPC 服务 func (this *Service) RegisterFunction(fn interface{}) (err error) { err = this.server.RegisterFunction(this.options.ServiceType, fn, this.metadata) return } //注册RPC 服务 func (this *Service) RegisterFunctionName(name string, fn interface{}) (err error) { err = this.server.RegisterFunctionName(this.options.ServiceType, name, fn, this.metadata) return } //注销 暂时不处理 func (this *Service) UnregisterAll() (err error) { // err = this.server.UnregisterAll() return } //同步调用远程服务 func (this *Service) Call(ctx context.Context, servicePath string, serviceMethod string, args interface{}, reply interface{}) (err error) { var ( done *client.Call conn net.Conn ) seq := new(uint64) ctx = context.WithValue(ctx, seqKey{}, seq) if conn, done, err = this.call(ctx, servicePath, serviceMethod, args, reply, make(chan *client.Call, 1)); err != nil { return } select { case <-ctx.Done(): // cancel by context this.mutex.Lock() call := this.pending[*seq] delete(this.pending, *seq) this.mutex.Unlock() if call != nil { call.Error = ctx.Err() call.Done <- call } return ctx.Err() case call := <-done.Done: err = call.Error meta := ctx.Value(share.ResMetaDataKey) if meta != nil && len(call.ResMetadata) > 0 { resMeta := meta.(map[string]string) for k, v := range call.ResMetadata { resMeta[k] = v } resMeta[share.ServerAddress] = conn.RemoteAddr().String() } } return } //异步调用 远程服务 func (this *Service) Go(ctx context.Context, servicePath string, serviceMethod string, args interface{}, reply interface{}, done chan *client.Call) (_call *client.Call, err error) { _, _call, err = this.call(ctx, servicePath, serviceMethod, args, reply, done) return } //监听客户端链接到服务上 保存客户端的连接对象 func (this *Service) PreHandleRequest(ctx context.Context, r *protocol.Message) error { req_metadata := ctx.Value(share.ReqMetaDataKey).(map[string]string) if addr, ok := req_metadata[ServiceAddrKey]; ok { if _, ok = this.clientmeta[addr]; !ok { if smeta, ok := req_metadata[ServiceMetaKey]; ok { servers := make(map[string]string) this.clientmutex.Lock() this.clientmeta[addr] = smeta this.clients[addr] = ctx.Value(server.RemoteConnContextKey).(net.Conn) for k, v := range this.clientmeta { servers[k] = v } this.clientmutex.Unlock() this.selector.UpdateServer(servers) this.Debugf("PreReadRequest addr:%s smeta:%s \n", addr, smeta) } } } return nil } //监控rpc连接收到的请求消息 处理消息回调请求 func (this *Service) PostReadRequest(ctx context.Context, r *protocol.Message, e error) error { if isCallMessage := (r.MessageType() == protocol.Request); isCallMessage { return nil } e = errors.New("is callMessage") seq := r.Seq() this.mutex.Lock() call := this.pending[seq] delete(this.pending, seq) this.mutex.Unlock() switch { case call == nil: this.Errorf("callmessage no found call:%d", seq) case r.MessageStatusType() == protocol.Error: if len(r.Metadata) > 0 { call.ResMetadata = r.Metadata call.Error = errors.New(r.Metadata[protocol.ServiceError]) } if len(r.Payload) > 0 { data := r.Payload codec := share.Codecs[r.SerializeType()] if codec != nil { _ = codec.Decode(data, call.Reply) } } call.Done <- call default: data := r.Payload if len(data) > 0 { codec := share.Codecs[r.SerializeType()] if codec == nil { call.Error = errors.New(client.ErrUnsupportedCodec.Error()) } else { err := codec.Decode(data, call.Reply) if err != nil { call.Error = err } } } if len(r.Metadata) > 0 { call.ResMetadata = r.Metadata } call.Done <- call } return nil } ///日志*********************************************************************** func (this *Service) Debug() bool { return this.options.Debug } func (this *Service) Debugf(format string, a ...interface{}) { if this.options.Debug { this.options.Log.Debugf("[SYS RPCX] "+format, a...) } } func (this *Service) Infof(format string, a ...interface{}) { if this.options.Debug { this.options.Log.Infof("[SYS RPCX] "+format, a...) } } func (this *Service) Warnf(format string, a ...interface{}) { if this.options.Debug { this.options.Log.Warnf("[SYS RPCX] "+format, a...) } } func (this *Service) Errorf(format string, a ...interface{}) { if this.options.Debug { this.options.Log.Errorf("[SYS RPCX] "+format, a...) } } func (this *Service) Panicf(format string, a ...interface{}) { if this.options.Debug { this.options.Log.Panicf("[SYS RPCX] "+format, a...) } } func (this *Service) Fatalf(format string, a ...interface{}) { if this.options.Debug { this.options.Log.Fatalf("[SYS RPCX] "+format, a...) } } //执行远程调用 func (this *Service) call(ctx context.Context, servicePath string, serviceMethod string, args interface{}, reply interface{}, done chan *client.Call) (conn net.Conn, _call *client.Call, err error) { var ( spath []string clientaddr string metadata map[string]string ok bool ) if servicePath == "" { err = errors.New("servicePath no cant null") return } metadata = map[string]string{ CallRoutRulesKey: servicePath, ServiceAddrKey: "tcp@" + this.options.ServiceAddr, ServiceMetaKey: this.metadata, } spath = strings.Split(servicePath, "/") ctx = context.WithValue(ctx, share.ReqMetaDataKey, map[string]string{ CallRoutRulesKey: servicePath, ServiceAddrKey: "tcp@" + this.options.ServiceAddr, ServiceMetaKey: this.metadata, }) if clientaddr = this.selector.Select(ctx, spath[0], serviceMethod, args); clientaddr == "" { err = fmt.Errorf("on found servicePath:%s", servicePath) return } if conn, ok = this.clients[clientaddr]; !ok { err = fmt.Errorf("on found clientaddr:%s", clientaddr) return } _call = new(client.Call) _call.ServicePath = servicePath _call.ServiceMethod = serviceMethod _call.Args = args _call.Reply = reply if done == nil { done = make(chan *client.Call, 10) // buffered. } else { if cap(done) == 0 { log.Panic("rpc: done channel is unbuffered") } } _call.Done = done this.send(ctx, conn, spath[0], serviceMethod, metadata, _call) return } //发送远程调用请求 func (this *Service) send(ctx context.Context, conn net.Conn, servicePath string, serviceMethod string, metadata map[string]string, call *client.Call) { defer func() { if call.Error != nil { call.Done <- call } }() serializeType := this.options.SerializeType codec := share.Codecs[serializeType] if codec == nil { call.Error = client.ErrUnsupportedCodec return } data, err := codec.Encode(call.Args) if err != nil { call.Error = err return } this.mutex.Lock() seq := this.seq this.seq++ this.pending[seq] = call this.mutex.Unlock() if cseq, ok := ctx.Value(seqKey{}).(*uint64); ok { *cseq = seq } req := protocol.GetPooledMsg() req.SetMessageType(protocol.Request) req.SetSeq(seq) req.SetOneway(true) req.SetSerializeType(this.options.SerializeType) req.ServicePath = servicePath req.ServiceMethod = serviceMethod req.Metadata = metadata req.Payload = data b := req.EncodeSlicePointer() if _, err = conn.Write(*b); err != nil { call.Error = err this.mutex.Lock() delete(this.pending, seq) this.mutex.Unlock() return } protocol.PutData(b) protocol.FreeMsg(req) return }