diff options
author | Niall Sheridan <nsheridan@gmail.com> | 2017-04-10 21:18:42 +0100 |
---|---|---|
committer | Niall Sheridan <nsheridan@gmail.com> | 2017-04-10 21:38:33 +0100 |
commit | 30802e07b2d84fbc213b490d3402707dffe60096 (patch) | |
tree | 934aecb8f3582325dfd1aa6652193adac87d00db /vendor/google.golang.org/grpc/server.go | |
parent | da7638dc112c4c106e8929601b642d2ca4596cba (diff) |
update dependencies
Diffstat (limited to 'vendor/google.golang.org/grpc/server.go')
-rw-r--r-- | vendor/google.golang.org/grpc/server.go | 216 |
1 files changed, 135 insertions, 81 deletions
diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go index 985226d..b15f71c 100644 --- a/vendor/google.golang.org/grpc/server.go +++ b/vendor/google.golang.org/grpc/server.go @@ -53,8 +53,10 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal" + "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" "google.golang.org/grpc/tap" "google.golang.org/grpc/transport" ) @@ -116,6 +118,9 @@ type options struct { statsHandler stats.Handler maxConcurrentStreams uint32 useHandlerImpl bool // use http.Handler-based server + unknownStreamDesc *StreamDesc + keepaliveParams keepalive.ServerParameters + keepalivePolicy keepalive.EnforcementPolicy } var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit @@ -123,6 +128,20 @@ var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size l // A ServerOption sets options. type ServerOption func(*options) +// KeepaliveParams returns a ServerOption that sets keepalive and max-age parameters for the server. +func KeepaliveParams(kp keepalive.ServerParameters) ServerOption { + return func(o *options) { + o.keepaliveParams = kp + } +} + +// KeepaliveEnforcementPolicy returns a ServerOption that sets keepalive enforcement policy for the server. +func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption { + return func(o *options) { + o.keepalivePolicy = kep + } +} + // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling. func CustomCodec(codec Codec) ServerOption { return func(o *options) { @@ -208,6 +227,24 @@ func StatsHandler(h stats.Handler) ServerOption { } } +// UnknownServiceHandler returns a ServerOption that allows for adding a custom +// unknown service handler. The provided method is a bidi-streaming RPC service +// handler that will be invoked instead of returning the the "unimplemented" gRPC +// error whenever a request is received for an unregistered service or method. +// The handling function has full access to the Context of the request and the +// stream, and the invocation passes through interceptors. +func UnknownServiceHandler(streamHandler StreamHandler) ServerOption { + return func(o *options) { + o.unknownStreamDesc = &StreamDesc{ + StreamName: "unknown_service_handler", + Handler: streamHandler, + // We need to assume that the users of the streamHandler will want to use both. + ClientStreams: true, + ServerStreams: true, + } + } +} + // NewServer creates a gRPC server which has no service registered and has not // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { @@ -446,10 +483,12 @@ func (s *Server) handleRawConn(rawConn net.Conn) { // transport.NewServerTransport). func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) { config := &transport.ServerConfig{ - MaxStreams: s.opts.maxConcurrentStreams, - AuthInfo: authInfo, - InTapHandle: s.opts.inTapHandle, - StatsHandler: s.opts.statsHandler, + MaxStreams: s.opts.maxConcurrentStreams, + AuthInfo: authInfo, + InTapHandle: s.opts.inTapHandle, + StatsHandler: s.opts.statsHandler, + KeepaliveParams: s.opts.keepaliveParams, + KeepalivePolicy: s.opts.keepalivePolicy, } st, err := transport.NewServerTransport("http2", c, config) if err != nil { @@ -633,7 +672,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. stream.SetSendCompress(s.opts.cp.Type()) } p := &parser{r: stream} - for { + for { // TODO: delete pf, req, err := p.recvMsg(s.opts.maxMsgSize) if err == io.EOF { // The entire stream is done (for unary RPC only). @@ -643,36 +682,37 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) } if err != nil { - switch err := err.(type) { - case *rpcError: - if e := t.WriteStatus(stream, err.code, err.desc); e != nil { + if st, ok := status.FromError(err); ok { + if e := t.WriteStatus(stream, st); e != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } - case transport.ConnectionError: - // Nothing to do here. - case transport.StreamError: - if e := t.WriteStatus(stream, err.Code, err.Desc); e != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) + } else { + switch st := err.(type) { + case transport.ConnectionError: + // Nothing to do here. + case transport.StreamError: + if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) + } + default: + panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st)) } - default: - panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err)) } return err } if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { - switch err := err.(type) { - case *rpcError: - if e := t.WriteStatus(stream, err.code, err.desc); e != nil { + if st, ok := status.FromError(err); ok { + if e := t.WriteStatus(stream, st); e != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } return err - default: - if e := t.WriteStatus(stream, codes.Internal, err.Error()); e != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) - } - // TODO checkRecvPayload always return RPC error. Add a return here if necessary. } + if e := t.WriteStatus(stream, status.New(codes.Internal, err.Error())); e != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) + } + + // TODO checkRecvPayload always return RPC error. Add a return here if necessary. } var inPayload *stats.InPayload if sh != nil { @@ -680,8 +720,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. RecvTime: time.Now(), } } - statusCode := codes.OK - statusDesc := "" df := func(v interface{}) error { if inPayload != nil { inPayload.WireLength = len(req) @@ -690,20 +728,16 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. var err error req, err = s.opts.dc.Do(bytes.NewReader(req)) if err != nil { - if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) - } return Errorf(codes.Internal, err.Error()) } } if len(req) > s.opts.maxMsgSize { // TODO: Revisit the error code. Currently keep it consistent with // java implementation. - statusCode = codes.Internal - statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize) + return status.Errorf(codes.Internal, "grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize) } if err := s.opts.codec.Unmarshal(req, v); err != nil { - return err + return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) } if inPayload != nil { inPayload.Payload = v @@ -718,21 +752,20 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) if appErr != nil { - if err, ok := appErr.(*rpcError); ok { - statusCode = err.code - statusDesc = err.desc - } else { - statusCode = convertCode(appErr) - statusDesc = appErr.Error() + appStatus, ok := status.FromError(appErr) + if !ok { + // Convert appErr if it is not a grpc status error. + appErr = status.Error(convertCode(appErr), appErr.Error()) + appStatus, _ = status.FromError(appErr) } - if trInfo != nil && statusCode != codes.OK { - trInfo.tr.LazyLog(stringer(statusDesc), true) + if trInfo != nil { + trInfo.tr.LazyLog(stringer(appStatus.Message()), true) trInfo.tr.SetError() } - if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) + if e := t.WriteStatus(stream, appStatus); e != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", e) } - return Errorf(statusCode, statusDesc) + return appErr } if trInfo != nil { trInfo.tr.LazyLog(stringer("OK"), false) @@ -742,26 +775,35 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Delay: false, } if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil { - switch err := err.(type) { - case transport.ConnectionError: - // Nothing to do here. - case transport.StreamError: - statusCode = err.Code - statusDesc = err.Desc - default: - statusCode = codes.Unknown - statusDesc = err.Error() + if err == io.EOF { + // The entire stream is done (for unary RPC only). + return err + } + if s, ok := status.FromError(err); ok { + if e := t.WriteStatus(stream, s); e != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", e) + } + } else { + switch st := err.(type) { + case transport.ConnectionError: + // Nothing to do here. + case transport.StreamError: + if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) + } + default: + panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st)) + } } return err } if trInfo != nil { trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) } - errWrite := t.WriteStatus(stream, statusCode, statusDesc) - if statusCode != codes.OK { - return Errorf(statusCode, statusDesc) - } - return errWrite + // TODO: Should we be logging if writing status failed here, like above? + // Should the logging be in WriteStatus? Should we ignore the WriteStatus + // error or allow the stats handler to see it? + return t.WriteStatus(stream, status.New(codes.OK, "")) } } @@ -815,43 +857,47 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp }() } var appErr error + var server interface{} + if srv != nil { + server = srv.server + } if s.opts.streamInt == nil { - appErr = sd.Handler(srv.server, ss) + appErr = sd.Handler(server, ss) } else { info := &StreamServerInfo{ FullMethod: stream.Method(), IsClientStream: sd.ClientStreams, IsServerStream: sd.ServerStreams, } - appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler) + appErr = s.opts.streamInt(server, ss, info, sd.Handler) } if appErr != nil { - if err, ok := appErr.(*rpcError); ok { - ss.statusCode = err.code - ss.statusDesc = err.desc - } else if err, ok := appErr.(transport.StreamError); ok { - ss.statusCode = err.Code - ss.statusDesc = err.Desc - } else { - ss.statusCode = convertCode(appErr) - ss.statusDesc = appErr.Error() + appStatus, ok := status.FromError(appErr) + if !ok { + switch err := appErr.(type) { + case transport.StreamError: + appStatus = status.New(err.Code, err.Desc) + default: + appStatus = status.New(convertCode(appErr), appErr.Error()) + } + appErr = appStatus.Err() + } + if trInfo != nil { + ss.mu.Lock() + ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true) + ss.trInfo.tr.SetError() + ss.mu.Unlock() } + t.WriteStatus(ss.s, appStatus) + // TODO: Should we log an error from WriteStatus here and below? + return appErr } if trInfo != nil { ss.mu.Lock() - if ss.statusCode != codes.OK { - ss.trInfo.tr.LazyLog(stringer(ss.statusDesc), true) - ss.trInfo.tr.SetError() - } else { - ss.trInfo.tr.LazyLog(stringer("OK"), false) - } + ss.trInfo.tr.LazyLog(stringer("OK"), false) ss.mu.Unlock() } - errWrite := t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc) - if ss.statusCode != codes.OK { - return Errorf(ss.statusCode, ss.statusDesc) - } - return errWrite + return t.WriteStatus(ss.s, status.New(codes.OK, "")) } @@ -867,7 +913,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str trInfo.tr.SetError() } errDesc := fmt.Sprintf("malformed method name: %q", stream.Method()) - if err := t.WriteStatus(stream, codes.InvalidArgument, errDesc); err != nil { + if err := t.WriteStatus(stream, status.New(codes.InvalidArgument, errDesc)); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() @@ -883,12 +929,16 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str method := sm[pos+1:] srv, ok := s.m[service] if !ok { + if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { + s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo) + return + } if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true) trInfo.tr.SetError() } errDesc := fmt.Sprintf("unknown service %v", service) - if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil { + if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() @@ -913,8 +963,12 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str trInfo.tr.LazyLog(&fmtStringer{"Unknown method %v", []interface{}{method}}, true) trInfo.tr.SetError() } + if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { + s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo) + return + } errDesc := fmt.Sprintf("unknown method %v", method) - if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil { + if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() |