diff options
Diffstat (limited to 'vendor/google.golang.org/grpc/server.go')
-rw-r--r-- | vendor/google.golang.org/grpc/server.go | 575 |
1 files changed, 442 insertions, 133 deletions
diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go index 7c882db..014c72b 100644 --- a/vendor/google.golang.org/grpc/server.go +++ b/vendor/google.golang.org/grpc/server.go @@ -32,13 +32,19 @@ import ( "sync" "time" + "io/ioutil" + "golang.org/x/net/context" "golang.org/x/net/http2" "golang.org/x/net/trace" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/encoding" + "google.golang.org/grpc/encoding/proto" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" @@ -89,18 +95,28 @@ type Server struct { conns map[io.Closer]bool serve bool drain bool - ctx context.Context - cancel context.CancelFunc - // A CondVar to let GracefulStop() blocks until all the pending RPCs are finished - // and all the transport goes away. - cv *sync.Cond + cv *sync.Cond // signaled when connections close for GracefulStop m map[string]*service // service name -> service info events trace.EventLog + + quit chan struct{} + done chan struct{} + quitOnce sync.Once + doneOnce sync.Once + channelzRemoveOnce sync.Once + serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop + + channelzID int64 // channelz unique identification number + czmu sync.RWMutex + callsStarted int64 + callsFailed int64 + callsSucceeded int64 + lastCallStartedTime time.Time } type options struct { creds credentials.TransportCredentials - codec Codec + codec baseCodec cp Compressor dc Decompressor unaryInt UnaryServerInterceptor @@ -118,11 +134,13 @@ type options struct { initialConnWindowSize int32 writeBufferSize int readBufferSize int + connectionTimeout time.Duration } var defaultServerOptions = options{ maxReceiveMessageSize: defaultServerMaxReceiveMessageSize, maxSendMessageSize: defaultServerMaxSendMessageSize, + connectionTimeout: 120 * time.Second, } // A ServerOption sets options such as credentials, codec and keepalive parameters, etc. @@ -175,20 +193,32 @@ func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption { } // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling. +// +// This will override any lookups by content-subtype for Codecs registered with RegisterCodec. func CustomCodec(codec Codec) ServerOption { return func(o *options) { o.codec = codec } } -// RPCCompressor returns a ServerOption that sets a compressor for outbound messages. +// RPCCompressor returns a ServerOption that sets a compressor for outbound +// messages. For backward compatibility, all outbound messages will be sent +// using this compressor, regardless of incoming message compression. By +// default, server messages will be sent using the same compressor with which +// request messages were sent. +// +// Deprecated: use encoding.RegisterCompressor instead. func RPCCompressor(cp Compressor) ServerOption { return func(o *options) { o.cp = cp } } -// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages. +// RPCDecompressor returns a ServerOption that sets a decompressor for inbound +// messages. It has higher priority than decompressors registered via +// encoding.RegisterCompressor. +// +// Deprecated: use encoding.RegisterCompressor instead. func RPCDecompressor(dc Decompressor) ServerOption { return func(o *options) { o.dc = dc @@ -196,7 +226,9 @@ func RPCDecompressor(dc Decompressor) ServerOption { } // MaxMsgSize returns a ServerOption to set the max message size in bytes the server can receive. -// If this is not set, gRPC uses the default limit. Deprecated: use MaxRecvMsgSize instead. +// If this is not set, gRPC uses the default limit. +// +// Deprecated: use MaxRecvMsgSize instead. func MaxMsgSize(m int) ServerOption { return MaxRecvMsgSize(m) } @@ -291,6 +323,18 @@ func UnknownServiceHandler(streamHandler StreamHandler) ServerOption { } } +// ConnectionTimeout returns a ServerOption that sets the timeout for +// connection establishment (up to and including HTTP/2 handshaking) for all +// new connections. If this is not set, the default is 120 seconds. A zero or +// negative value will result in an immediate timeout. +// +// This API is EXPERIMENTAL. +func ConnectionTimeout(d time.Duration) ServerOption { + return func(o *options) { + o.connectionTimeout = d + } +} + // NewServer creates a gRPC server which has no service registered and has not // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { @@ -298,22 +342,23 @@ func NewServer(opt ...ServerOption) *Server { for _, o := range opt { o(&opts) } - if opts.codec == nil { - // Set the default codec. - opts.codec = protoCodec{} - } s := &Server{ lis: make(map[net.Listener]bool), opts: opts, conns: make(map[io.Closer]bool), m: make(map[string]*service), + quit: make(chan struct{}), + done: make(chan struct{}), } s.cv = sync.NewCond(&s.mu) - s.ctx, s.cancel = context.WithCancel(context.Background()) if EnableTracing { _, file, line, _ := runtime.Caller(1) s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) } + + if channelz.IsOn() { + s.channelzID = channelz.RegisterServer(s, "") + } return s } @@ -418,11 +463,9 @@ func (s *Server) GetServiceInfo() map[string]ServiceInfo { return ret } -var ( - // ErrServerStopped indicates that the operation is now illegal because of - // the server being stopped. - ErrServerStopped = errors.New("grpc: the server has been stopped") -) +// ErrServerStopped indicates that the operation is now illegal because of +// the server being stopped. +var ErrServerStopped = errors.New("grpc: the server has been stopped") func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { if s.opts.creds == nil { @@ -431,28 +474,66 @@ func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credenti return s.opts.creds.ServerHandshake(rawConn) } +type listenSocket struct { + net.Listener + channelzID int64 +} + +func (l *listenSocket) ChannelzMetric() *channelz.SocketInternalMetric { + return &channelz.SocketInternalMetric{ + LocalAddr: l.Listener.Addr(), + } +} + +func (l *listenSocket) Close() error { + err := l.Listener.Close() + if channelz.IsOn() { + channelz.RemoveEntry(l.channelzID) + } + return err +} + // Serve accepts incoming connections on the listener lis, creating a new // ServerTransport and service goroutine for each. The service goroutines // read gRPC requests and then call the registered handlers to reply to them. // Serve returns when lis.Accept fails with fatal errors. lis will be closed when // this method returns. -// Serve always returns non-nil error. +// Serve will return a non-nil error unless Stop or GracefulStop is called. func (s *Server) Serve(lis net.Listener) error { s.mu.Lock() s.printf("serving") s.serve = true if s.lis == nil { + // Serve called after Stop or GracefulStop. s.mu.Unlock() lis.Close() return ErrServerStopped } - s.lis[lis] = true + + s.serveWG.Add(1) + defer func() { + s.serveWG.Done() + select { + // Stop or GracefulStop called; block until done and return nil. + case <-s.quit: + <-s.done + default: + } + }() + + ls := &listenSocket{Listener: lis} + s.lis[ls] = true + + if channelz.IsOn() { + ls.channelzID = channelz.RegisterListenSocket(ls, s.channelzID, "") + } s.mu.Unlock() + defer func() { s.mu.Lock() - if s.lis != nil && s.lis[lis] { - lis.Close() - delete(s.lis, lis) + if s.lis != nil && s.lis[ls] { + ls.Close() + delete(s.lis, ls) } s.mu.Unlock() }() @@ -479,36 +560,52 @@ func (s *Server) Serve(lis net.Listener) error { timer := time.NewTimer(tempDelay) select { case <-timer.C: - case <-s.ctx.Done(): + case <-s.quit: + timer.Stop() + return nil } - timer.Stop() continue } s.mu.Lock() s.printf("done serving; Accept = %v", err) s.mu.Unlock() + + select { + case <-s.quit: + return nil + default: + } return err } tempDelay = 0 - // Start a new goroutine to deal with rawConn - // so we don't stall this Accept loop goroutine. - go s.handleRawConn(rawConn) + // Start a new goroutine to deal with rawConn so we don't stall this Accept + // loop goroutine. + // + // Make sure we account for the goroutine so GracefulStop doesn't nil out + // s.conns before this conn can be added. + s.serveWG.Add(1) + go func() { + s.handleRawConn(rawConn) + s.serveWG.Done() + }() } } -// handleRawConn is run in its own goroutine and handles a just-accepted -// connection that has not had any I/O performed on it yet. +// handleRawConn forks a goroutine to handle a just-accepted connection that +// has not had any I/O performed on it yet. func (s *Server) handleRawConn(rawConn net.Conn) { + rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout)) conn, authInfo, err := s.useTransportAuthenticator(rawConn) if err != nil { s.mu.Lock() s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) s.mu.Unlock() grpclog.Warningf("grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err) - // If serverHandShake returns ErrConnDispatched, keep rawConn open. + // If serverHandshake returns ErrConnDispatched, keep rawConn open. if err != credentials.ErrConnDispatched { rawConn.Close() } + rawConn.SetDeadline(time.Time{}) return } @@ -520,19 +617,33 @@ func (s *Server) handleRawConn(rawConn net.Conn) { } s.mu.Unlock() + var serve func() + c := conn.(io.Closer) if s.opts.useHandlerImpl { - s.serveUsingHandler(conn) + serve = func() { s.serveUsingHandler(conn) } } else { - s.serveHTTP2Transport(conn, authInfo) + // Finish handshaking (HTTP2) + st := s.newHTTP2Transport(conn, authInfo) + if st == nil { + return + } + c = st + serve = func() { s.serveStreams(st) } } + + rawConn.SetDeadline(time.Time{}) + if !s.addConn(c) { + return + } + go func() { + serve() + s.removeConn(c) + }() } -// serveHTTP2Transport sets up a http/2 transport (using the -// gRPC http2 server transport in transport/http2_server.go) and -// serves streams on it. -// This is run in its own goroutine (it does network I/O in -// transport.NewServerTransport). -func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) { +// newHTTP2Transport sets up a http/2 transport (using the +// gRPC http2 server transport in transport/http2_server.go). +func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) transport.ServerTransport { config := &transport.ServerConfig{ MaxStreams: s.opts.maxConcurrentStreams, AuthInfo: authInfo, @@ -544,6 +655,7 @@ func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) InitialConnWindowSize: s.opts.initialConnWindowSize, WriteBufferSize: s.opts.writeBufferSize, ReadBufferSize: s.opts.readBufferSize, + ChannelzParentID: s.channelzID, } st, err := transport.NewServerTransport("http2", c, config) if err != nil { @@ -552,17 +664,13 @@ func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) s.mu.Unlock() c.Close() grpclog.Warningln("grpc: Server.Serve failed to create ServerTransport: ", err) - return - } - if !s.addConn(st) { - st.Close() - return + return nil } - s.serveStreams(st) + + return st } func (s *Server) serveStreams(st transport.ServerTransport) { - defer s.removeConn(st) defer st.Close() var wg sync.WaitGroup st.HandleStreams(func(stream *transport.Stream) { @@ -596,11 +704,6 @@ var _ http.Handler = (*Server)(nil) // // conn is the *tls.Conn that's already been authenticated. func (s *Server) serveUsingHandler(conn net.Conn) { - if !s.addConn(conn) { - conn.Close() - return - } - defer s.removeConn(conn) h2s := &http2.Server{ MaxConcurrentStreams: s.opts.maxConcurrentStreams, } @@ -634,13 +737,12 @@ func (s *Server) serveUsingHandler(conn net.Conn) { // available through grpc-go's HTTP/2 server, and it is currently EXPERIMENTAL // and subject to change. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - st, err := transport.NewServerHandlerTransport(w, r) + st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandler) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } if !s.addConn(st) { - st.Close() return } defer s.removeConn(st) @@ -670,9 +772,15 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea func (s *Server) addConn(c io.Closer) bool { s.mu.Lock() defer s.mu.Unlock() - if s.conns == nil || s.drain { + if s.conns == nil { + c.Close() return false } + if s.drain { + // Transport added after we drained our existing conns: drain it + // immediately. + c.(transport.ServerTransport).Drain() + } s.conns[c] = true return true } @@ -686,43 +794,83 @@ func (s *Server) removeConn(c io.Closer) { } } -func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error { - var ( - cbuf *bytes.Buffer - outPayload *stats.OutPayload - ) - if cp != nil { - cbuf = new(bytes.Buffer) - } - if s.opts.statsHandler != nil { - outPayload = &stats.OutPayload{} +// ChannelzMetric returns ServerInternalMetric of current server. +// This is an EXPERIMENTAL API. +func (s *Server) ChannelzMetric() *channelz.ServerInternalMetric { + s.czmu.RLock() + defer s.czmu.RUnlock() + return &channelz.ServerInternalMetric{ + CallsStarted: s.callsStarted, + CallsSucceeded: s.callsSucceeded, + CallsFailed: s.callsFailed, + LastCallStartedTimestamp: s.lastCallStartedTime, } - hdr, data, err := encode(s.opts.codec, msg, cp, cbuf, outPayload) +} + +func (s *Server) incrCallsStarted() { + s.czmu.Lock() + s.callsStarted++ + s.lastCallStartedTime = time.Now() + s.czmu.Unlock() +} + +func (s *Server) incrCallsSucceeded() { + s.czmu.Lock() + s.callsSucceeded++ + s.czmu.Unlock() +} + +func (s *Server) incrCallsFailed() { + s.czmu.Lock() + s.callsFailed++ + s.czmu.Unlock() +} + +func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error { + data, err := encode(s.getCodec(stream.ContentSubtype()), msg) if err != nil { grpclog.Errorln("grpc: server failed to encode response: ", err) return err } - if len(data) > s.opts.maxSendMessageSize { - return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize) + compData, err := compress(data, cp, comp) + if err != nil { + grpclog.Errorln("grpc: server failed to compress response: ", err) + return err } - err = t.Write(stream, hdr, data, opts) - if err == nil && outPayload != nil { - outPayload.SentTime = time.Now() - s.opts.statsHandler.HandleRPC(stream.Context(), outPayload) + hdr, payload := msgHeader(data, compData) + // TODO(dfawley): should we be checking len(data) instead? + if len(payload) > s.opts.maxSendMessageSize { + return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize) + } + err = t.Write(stream, hdr, payload, opts) + if err == nil && s.opts.statsHandler != nil { + s.opts.statsHandler.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now())) } return err } func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) { + if channelz.IsOn() { + s.incrCallsStarted() + defer func() { + if err != nil && err != io.EOF { + s.incrCallsFailed() + } else { + s.incrCallsSucceeded() + } + }() + } sh := s.opts.statsHandler if sh != nil { + beginTime := time.Now() begin := &stats.Begin{ - BeginTime: time.Now(), + BeginTime: beginTime, } sh.HandleRPC(stream.Context(), begin) defer func() { end := &stats.End{ - EndTime: time.Now(), + BeginTime: beginTime, + EndTime: time.Now(), } if err != nil && err != io.EOF { end.Error = toRPCErr(err) @@ -741,10 +889,43 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } }() } + + // comp and cp are used for compression. decomp and dc are used for + // decompression. If comp and decomp are both set, they are the same; + // however they are kept separate to ensure that at most one of the + // compressor/decompressor variable pairs are set for use later. + var comp, decomp encoding.Compressor + var cp Compressor + var dc Decompressor + + // If dc is set and matches the stream's compression, use it. Otherwise, try + // to find a matching registered compressor for decomp. + if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { + dc = s.opts.dc + } else if rc != "" && rc != encoding.Identity { + decomp = encoding.GetCompressor(rc) + if decomp == nil { + st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc) + t.WriteStatus(stream, st) + return st.Err() + } + } + + // If cp is set, use it. Otherwise, attempt to compress the response using + // the incoming message compression method. + // + // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. if s.opts.cp != nil { - // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. - stream.SetSendCompress(s.opts.cp.Type()) + cp = s.opts.cp + stream.SetSendCompress(cp.Type()) + } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { + // Legacy compressor not specified; attempt to respond with same encoding. + comp = encoding.GetCompressor(rc) + if comp != nil { + stream.SetSendCompress(rc) + } } + p := &parser{r: stream} pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize) if err == io.EOF { @@ -752,7 +933,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return err } if err == io.ErrUnexpectedEOF { - err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) + err = status.Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) } if err != nil { if st, ok := status.FromError(err); ok { @@ -773,19 +954,14 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } return err } - - if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { - if st, ok := status.FromError(err); ok { - if e := t.WriteStatus(stream, st); e != nil { - grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) - } - return err - } - if e := t.WriteStatus(stream, status.New(codes.Internal, err.Error())); e != nil { + if channelz.IsOn() { + t.IncrMsgRecv() + } + if st := checkRecvPayload(pf, stream.RecvCompress(), dc != nil || decomp != nil); st != nil { + if e := t.WriteStatus(stream, st); e != nil { grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) } - - // TODO checkRecvPayload always return RPC error. Add a return here if necessary. + return st.Err() } var inPayload *stats.InPayload if sh != nil { @@ -799,9 +975,17 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } if pf == compressionMade { var err error - req, err = s.opts.dc.Do(bytes.NewReader(req)) - if err != nil { - return Errorf(codes.Internal, err.Error()) + if dc != nil { + req, err = dc.Do(bytes.NewReader(req)) + if err != nil { + return status.Errorf(codes.Internal, err.Error()) + } + } else { + tmp, _ := decomp.Decompress(bytes.NewReader(req)) + req, err = ioutil.ReadAll(tmp) + if err != nil { + return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + } } } if len(req) > s.opts.maxReceiveMessageSize { @@ -809,7 +993,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. // java implementation. return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize) } - if err := s.opts.codec.Unmarshal(req, v); err != nil { + if err := s.getCodec(stream.ContentSubtype()).Unmarshal(req, v); err != nil { return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) } if inPayload != nil { @@ -823,12 +1007,13 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } return nil } - reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) + ctx := NewContextWithServerTransportStream(stream.Context(), stream) + reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt) if appErr != nil { appStatus, ok := status.FromError(appErr) if !ok { // Convert appErr if it is not a grpc status error. - appErr = status.Error(convertCode(appErr), appErr.Error()) + appErr = status.Error(codes.Unknown, appErr.Error()) appStatus, _ = status.FromError(appErr) } if trInfo != nil { @@ -847,7 +1032,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Last: true, Delay: false, } - if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil { + + if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil { if err == io.EOF { // The entire stream is done (for unary RPC only). return err @@ -870,6 +1056,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } return err } + if channelz.IsOn() { + t.IncrMsgSent() + } if trInfo != nil { trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) } @@ -880,15 +1069,27 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) { + if channelz.IsOn() { + s.incrCallsStarted() + defer func() { + if err != nil && err != io.EOF { + s.incrCallsFailed() + } else { + s.incrCallsSucceeded() + } + }() + } sh := s.opts.statsHandler if sh != nil { + beginTime := time.Now() begin := &stats.Begin{ - BeginTime: time.Now(), + BeginTime: beginTime, } sh.HandleRPC(stream.Context(), begin) defer func() { end := &stats.End{ - EndTime: time.Now(), + BeginTime: beginTime, + EndTime: time.Now(), } if err != nil && err != io.EOF { end.Error = toRPCErr(err) @@ -896,21 +1097,47 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp sh.HandleRPC(stream.Context(), end) }() } - if s.opts.cp != nil { - stream.SetSendCompress(s.opts.cp.Type()) - } + ctx := NewContextWithServerTransportStream(stream.Context(), stream) ss := &serverStream{ + ctx: ctx, t: t, s: stream, p: &parser{r: stream}, - codec: s.opts.codec, - cp: s.opts.cp, - dc: s.opts.dc, + codec: s.getCodec(stream.ContentSubtype()), maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, trInfo: trInfo, statsHandler: sh, } + + // If dc is set and matches the stream's compression, use it. Otherwise, try + // to find a matching registered compressor for decomp. + if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { + ss.dc = s.opts.dc + } else if rc != "" && rc != encoding.Identity { + ss.decomp = encoding.GetCompressor(rc) + if ss.decomp == nil { + st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc) + t.WriteStatus(ss.s, st) + return st.Err() + } + } + + // If cp is set, use it. Otherwise, attempt to compress the response using + // the incoming message compression method. + // + // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. + if s.opts.cp != nil { + ss.cp = s.opts.cp + stream.SetSendCompress(s.opts.cp.Type()) + } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { + // Legacy compressor not specified; attempt to respond with same encoding. + ss.comp = encoding.GetCompressor(rc) + if ss.comp != nil { + stream.SetSendCompress(rc) + } + } + if trInfo != nil { trInfo.tr.LazyLog(&trInfo.firstLine, false) defer func() { @@ -946,7 +1173,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp case transport.StreamError: appStatus = status.New(err.Code, err.Desc) default: - appStatus = status.New(convertCode(appErr), appErr.Error()) + appStatus = status.New(codes.Unknown, appErr.Error()) } appErr = appStatus.Err() } @@ -966,7 +1193,6 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.mu.Unlock() } return t.WriteStatus(ss.s, status.New(codes.OK, "")) - } func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) { @@ -1048,12 +1274,65 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str } } +// The key to save ServerTransportStream in the context. +type streamKey struct{} + +// NewContextWithServerTransportStream creates a new context from ctx and +// attaches stream to it. +// +// This API is EXPERIMENTAL. +func NewContextWithServerTransportStream(ctx context.Context, stream ServerTransportStream) context.Context { + return context.WithValue(ctx, streamKey{}, stream) +} + +// ServerTransportStream is a minimal interface that a transport stream must +// implement. This can be used to mock an actual transport stream for tests of +// handler code that use, for example, grpc.SetHeader (which requires some +// stream to be in context). +// +// See also NewContextWithServerTransportStream. +// +// This API is EXPERIMENTAL. +type ServerTransportStream interface { + Method() string + SetHeader(md metadata.MD) error + SendHeader(md metadata.MD) error + SetTrailer(md metadata.MD) error +} + +// ServerTransportStreamFromContext returns the ServerTransportStream saved in +// ctx. Returns nil if the given context has no stream associated with it +// (which implies it is not an RPC invocation context). +// +// This API is EXPERIMENTAL. +func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream { + s, _ := ctx.Value(streamKey{}).(ServerTransportStream) + return s +} + // Stop stops the gRPC server. It immediately closes all open // connections and listeners. // It cancels all active RPCs on the server side and the corresponding // pending RPCs on the client side will get notified by connection // errors. func (s *Server) Stop() { + s.quitOnce.Do(func() { + close(s.quit) + }) + + defer func() { + s.serveWG.Wait() + s.doneOnce.Do(func() { + close(s.done) + }) + }() + + s.channelzRemoveOnce.Do(func() { + if channelz.IsOn() { + channelz.RemoveEntry(s.channelzID) + } + }) + s.mu.Lock() listeners := s.lis s.lis = nil @@ -1071,7 +1350,6 @@ func (s *Server) Stop() { } s.mu.Lock() - s.cancel() if s.events != nil { s.events.Finish() s.events = nil @@ -1083,22 +1361,44 @@ func (s *Server) Stop() { // accepting new connections and RPCs and blocks until all the pending RPCs are // finished. func (s *Server) GracefulStop() { + s.quitOnce.Do(func() { + close(s.quit) + }) + + defer func() { + s.doneOnce.Do(func() { + close(s.done) + }) + }() + + s.channelzRemoveOnce.Do(func() { + if channelz.IsOn() { + channelz.RemoveEntry(s.channelzID) + } + }) s.mu.Lock() - defer s.mu.Unlock() if s.conns == nil { + s.mu.Unlock() return } + for lis := range s.lis { lis.Close() } s.lis = nil - s.cancel() if !s.drain { for c := range s.conns { c.(transport.ServerTransport).Drain() } s.drain = true } + + // Wait for serving threads to be ready to exit. Only then can we be sure no + // new conns will be created. + s.mu.Unlock() + s.serveWG.Wait() + s.mu.Lock() + for len(s.conns) != 0 { s.cv.Wait() } @@ -1107,26 +1407,29 @@ func (s *Server) GracefulStop() { s.events.Finish() s.events = nil } + s.mu.Unlock() } func init() { - internal.TestingCloseConns = func(arg interface{}) { - arg.(*Server).testingCloseConns() - } internal.TestingUseHandlerImpl = func(arg interface{}) { arg.(*Server).opts.useHandlerImpl = true } } -// testingCloseConns closes all existing transports but keeps s.lis -// accepting new connections. -func (s *Server) testingCloseConns() { - s.mu.Lock() - for c := range s.conns { - c.Close() - delete(s.conns, c) +// contentSubtype must be lowercase +// cannot return nil +func (s *Server) getCodec(contentSubtype string) baseCodec { + if s.opts.codec != nil { + return s.opts.codec } - s.mu.Unlock() + if contentSubtype == "" { + return encoding.GetCodec(proto.Name) + } + codec := encoding.GetCodec(contentSubtype) + if codec == nil { + return encoding.GetCodec(proto.Name) + } + return codec } // SetHeader sets the header metadata. @@ -1139,9 +1442,9 @@ func SetHeader(ctx context.Context, md metadata.MD) error { if md.Len() == 0 { return nil } - stream, ok := transport.StreamFromContext(ctx) - if !ok { - return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) + stream := ServerTransportStreamFromContext(ctx) + if stream == nil { + return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } return stream.SetHeader(md) } @@ -1149,15 +1452,11 @@ func SetHeader(ctx context.Context, md metadata.MD) error { // SendHeader sends header metadata. It may be called at most once. // The provided md and headers set by SetHeader() will be sent. func SendHeader(ctx context.Context, md metadata.MD) error { - stream, ok := transport.StreamFromContext(ctx) - if !ok { - return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) + stream := ServerTransportStreamFromContext(ctx) + if stream == nil { + return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } - t := stream.ServerTransport() - if t == nil { - grpclog.Fatalf("grpc: SendHeader: %v has no ServerTransport to send header metadata.", stream) - } - if err := t.WriteHeader(stream, md); err != nil { + if err := stream.SendHeader(md); err != nil { return toRPCErr(err) } return nil @@ -1169,9 +1468,19 @@ func SetTrailer(ctx context.Context, md metadata.MD) error { if md.Len() == 0 { return nil } - stream, ok := transport.StreamFromContext(ctx) - if !ok { - return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) + stream := ServerTransportStreamFromContext(ctx) + if stream == nil { + return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } return stream.SetTrailer(md) } + +// Method returns the method string for the server context. The returned +// string is in the format of "/service/method". +func Method(ctx context.Context) (string, bool) { + s := ServerTransportStreamFromContext(ctx) + if s == nil { + return "", false + } + return s.Method(), true +} |