aboutsummaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/grpc/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/server.go')
-rw-r--r--vendor/google.golang.org/grpc/server.go154
1 files changed, 126 insertions, 28 deletions
diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go
index e0bb187..b52a563 100644
--- a/vendor/google.golang.org/grpc/server.go
+++ b/vendor/google.golang.org/grpc/server.go
@@ -54,6 +54,8 @@ import (
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/stats"
+ "google.golang.org/grpc/tap"
"google.golang.org/grpc/transport"
)
@@ -110,6 +112,7 @@ type options struct {
maxMsgSize int
unaryInt UnaryServerInterceptor
streamInt StreamServerInterceptor
+ inTapHandle tap.ServerInHandle
maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server
}
@@ -186,6 +189,17 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
}
}
+// InTapHandle returns a ServerOption that sets the tap handle for all the server
+// transport to be created. Only one can be installed.
+func InTapHandle(h tap.ServerInHandle) ServerOption {
+ return func(o *options) {
+ if o.inTapHandle != nil {
+ panic("The tap handle has been set.")
+ }
+ o.inTapHandle = h
+ }
+}
+
// NewServer creates a gRPC server which has no service registered and has not
// started to accept requests yet.
func NewServer(opt ...ServerOption) *Server {
@@ -329,6 +343,7 @@ func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credenti
// read gRPC requests and then call the registered handlers to reply to them.
// Serve returns when lis.Accept fails with fatal errors. lis will be closed when
// this method returns.
+// Serve always returns non-nil error.
func (s *Server) Serve(lis net.Listener) error {
s.mu.Lock()
s.printf("serving")
@@ -412,17 +427,22 @@ func (s *Server) handleRawConn(rawConn net.Conn) {
if s.opts.useHandlerImpl {
s.serveUsingHandler(conn)
} else {
- s.serveNewHTTP2Transport(conn, authInfo)
+ s.serveHTTP2Transport(conn, authInfo)
}
}
-// serveNewHTTP2Transport sets up a new http/2 transport (using the
+// serveHTTP2Transport sets up a http/2 transport (using the
// gRPC http2 server transport in transport/http2_server.go) and
// serves streams on it.
// This is run in its own goroutine (it does network I/O in
// transport.NewServerTransport).
-func (s *Server) serveNewHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) {
- st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo)
+func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) {
+ config := &transport.ServerConfig{
+ MaxStreams: s.opts.maxConcurrentStreams,
+ AuthInfo: authInfo,
+ InTapHandle: s.opts.inTapHandle,
+ }
+ st, err := transport.NewServerTransport("http2", c, config)
if err != nil {
s.mu.Lock()
s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err)
@@ -448,6 +468,12 @@ func (s *Server) serveStreams(st transport.ServerTransport) {
defer wg.Done()
s.handleStream(st, stream, s.traceInfo(st, stream))
}()
+ }, func(ctx context.Context, method string) context.Context {
+ if !EnableTracing {
+ return ctx
+ }
+ tr := trace.New("grpc.Recv."+methodFamily(method), method)
+ return trace.NewContext(ctx, tr)
})
wg.Wait()
}
@@ -497,15 +523,17 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// traceInfo returns a traceInfo and associates it with stream, if tracing is enabled.
// If tracing is not enabled, it returns nil.
func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) {
- if !EnableTracing {
+ tr, ok := trace.FromContext(stream.Context())
+ if !ok {
return nil
}
+
trInfo = &traceInfo{
- tr: trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()),
+ tr: tr,
}
trInfo.firstLine.client = false
trInfo.firstLine.remoteAddr = st.RemoteAddr()
- stream.TraceContext(trInfo.tr)
+
if dl, ok := stream.Context().Deadline(); ok {
trInfo.firstLine.deadline = dl.Sub(time.Now())
}
@@ -532,11 +560,17 @@ func (s *Server) removeConn(c io.Closer) {
}
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error {
- var cbuf *bytes.Buffer
+ var (
+ cbuf *bytes.Buffer
+ outPayload *stats.OutPayload
+ )
if cp != nil {
cbuf = new(bytes.Buffer)
}
- p, err := encode(s.opts.codec, msg, cp, cbuf)
+ if stats.On() {
+ outPayload = &stats.OutPayload{}
+ }
+ p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload)
if err != nil {
// This typically indicates a fatal issue (e.g., memory
// corruption or hardware faults) the application program
@@ -547,10 +581,32 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
// the optimal option.
grpclog.Fatalf("grpc: Server failed to encode response %v", err)
}
- return t.Write(stream, p, opts)
+ err = t.Write(stream, p, opts)
+ if err == nil && outPayload != nil {
+ outPayload.SentTime = time.Now()
+ stats.HandleRPC(stream.Context(), outPayload)
+ }
+ return err
}
func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) {
+ if stats.On() {
+ begin := &stats.Begin{
+ BeginTime: time.Now(),
+ }
+ stats.HandleRPC(stream.Context(), begin)
+ }
+ defer func() {
+ if stats.On() {
+ end := &stats.End{
+ EndTime: time.Now(),
+ }
+ if err != nil && err != io.EOF {
+ end.Error = toRPCErr(err)
+ }
+ stats.HandleRPC(stream.Context(), end)
+ }
+ }()
if trInfo != nil {
defer trInfo.tr.Finish()
trInfo.firstLine.client = false
@@ -579,14 +635,14 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if err != nil {
switch err := err.(type) {
case *rpcError:
- if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
- grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
+ if e := t.WriteStatus(stream, err.code, err.desc); e != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
case transport.ConnectionError:
// Nothing to do here.
case transport.StreamError:
- if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil {
- grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
+ if e := t.WriteStatus(stream, err.Code, err.Desc); e != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
default:
panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err))
@@ -597,20 +653,29 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil {
switch err := err.(type) {
case *rpcError:
- if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
- grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
+ if e := t.WriteStatus(stream, err.code, err.desc); e != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
+ return err
default:
- if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil {
- grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
+ if e := t.WriteStatus(stream, codes.Internal, err.Error()); e != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
-
+ // TODO checkRecvPayload always return RPC error. Add a return here if necessary.
+ }
+ }
+ var inPayload *stats.InPayload
+ if stats.On() {
+ inPayload = &stats.InPayload{
+ RecvTime: time.Now(),
}
- return err
}
statusCode := codes.OK
statusDesc := ""
df := func(v interface{}) error {
+ if inPayload != nil {
+ inPayload.WireLength = len(req)
+ }
if pf == compressionMade {
var err error
req, err = s.opts.dc.Do(bytes.NewReader(req))
@@ -618,7 +683,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
- return err
+ return Errorf(codes.Internal, err.Error())
}
}
if len(req) > s.opts.maxMsgSize {
@@ -630,6 +695,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if err := s.opts.codec.Unmarshal(req, v); err != nil {
return err
}
+ if inPayload != nil {
+ inPayload.Payload = v
+ inPayload.Data = req
+ inPayload.Length = len(req)
+ stats.HandleRPC(stream.Context(), inPayload)
+ }
if trInfo != nil {
trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
}
@@ -650,9 +721,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err)
- return err
}
- return nil
+ return Errorf(statusCode, statusDesc)
}
if trInfo != nil {
trInfo.tr.LazyLog(stringer("OK"), false)
@@ -677,11 +747,32 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if trInfo != nil {
trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
}
- return t.WriteStatus(stream, statusCode, statusDesc)
+ errWrite := t.WriteStatus(stream, statusCode, statusDesc)
+ if statusCode != codes.OK {
+ return Errorf(statusCode, statusDesc)
+ }
+ return errWrite
}
}
func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
+ if stats.On() {
+ begin := &stats.Begin{
+ BeginTime: time.Now(),
+ }
+ stats.HandleRPC(stream.Context(), begin)
+ }
+ defer func() {
+ if stats.On() {
+ end := &stats.End{
+ EndTime: time.Now(),
+ }
+ if err != nil && err != io.EOF {
+ end.Error = toRPCErr(err)
+ }
+ stats.HandleRPC(stream.Context(), end)
+ }
+ }()
if s.opts.cp != nil {
stream.SetSendCompress(s.opts.cp.Type())
}
@@ -744,7 +835,11 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
}
ss.mu.Unlock()
}
- return t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc)
+ errWrite := t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc)
+ if ss.statusCode != codes.OK {
+ return Errorf(ss.statusCode, ss.statusDesc)
+ }
+ return errWrite
}
@@ -759,7 +854,8 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
trInfo.tr.LazyLog(&fmtStringer{"Malformed method name %q", []interface{}{sm}}, true)
trInfo.tr.SetError()
}
- if err := t.WriteStatus(stream, codes.InvalidArgument, fmt.Sprintf("malformed method name: %q", stream.Method())); err != nil {
+ errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
+ if err := t.WriteStatus(stream, codes.InvalidArgument, errDesc); err != nil {
if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError()
@@ -779,7 +875,8 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true)
trInfo.tr.SetError()
}
- if err := t.WriteStatus(stream, codes.Unimplemented, fmt.Sprintf("unknown service %v", service)); err != nil {
+ errDesc := fmt.Sprintf("unknown service %v", service)
+ if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil {
if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError()
@@ -804,7 +901,8 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
trInfo.tr.LazyLog(&fmtStringer{"Unknown method %v", []interface{}{method}}, true)
trInfo.tr.SetError()
}
- if err := t.WriteStatus(stream, codes.Unimplemented, fmt.Sprintf("unknown method %v", method)); err != nil {
+ errDesc := fmt.Sprintf("unknown method %v", method)
+ if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil {
if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError()