From 921818bca208f0c70e85ec670074cb3905cbbc82 Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Sat, 27 Aug 2016 01:32:30 +0100 Subject: Update dependencies --- vendor/google.golang.org/grpc/server.go | 100 +++++++++++++++++++++++++------- 1 file changed, 80 insertions(+), 20 deletions(-) (limited to 'vendor/google.golang.org/grpc/server.go') diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go index a2b2b94..b2a825a 100644 --- a/vendor/google.golang.org/grpc/server.go +++ b/vendor/google.golang.org/grpc/server.go @@ -89,9 +89,13 @@ type service struct { type Server struct { opts options - mu sync.Mutex // guards following - lis map[net.Listener]bool - conns map[io.Closer]bool + mu sync.Mutex // guards following + lis map[net.Listener]bool + conns map[io.Closer]bool + drain bool + // A CondVar to let GracefulStop() blocks until all the pending RPCs are finished + // and all the transport goes away. + cv *sync.Cond m map[string]*service // service name -> service info events trace.EventLog } @@ -101,12 +105,15 @@ type options struct { codec Codec cp Compressor dc Decompressor + maxMsgSize int unaryInt UnaryServerInterceptor streamInt StreamServerInterceptor maxConcurrentStreams uint32 useHandlerImpl bool // use http.Handler-based server } +var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit + // A ServerOption sets options. type ServerOption func(*options) @@ -117,20 +124,28 @@ func CustomCodec(codec Codec) ServerOption { } } -// RPCCompressor returns a ServerOption that sets a compressor for outbound message. +// RPCCompressor returns a ServerOption that sets a compressor for outbound messages. func RPCCompressor(cp Compressor) ServerOption { return func(o *options) { o.cp = cp } } -// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message. +// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages. func RPCDecompressor(dc Decompressor) ServerOption { return func(o *options) { o.dc = dc } } +// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages. +// If this is not set, gRPC uses the default 4MB. +func MaxMsgSize(m int) ServerOption { + return func(o *options) { + o.maxMsgSize = m + } +} + // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number // of concurrent streams to each ServerTransport. func MaxConcurrentStreams(n uint32) ServerOption { @@ -173,6 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption { // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { var opts options + opts.maxMsgSize = defaultMaxMsgSize for _, o := range opt { o(&opts) } @@ -186,6 +202,7 @@ func NewServer(opt ...ServerOption) *Server { conns: make(map[io.Closer]bool), m: make(map[string]*service), } + s.cv = sync.NewCond(&s.mu) if EnableTracing { _, file, line, _ := runtime.Caller(1) s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) @@ -264,8 +281,8 @@ type ServiceInfo struct { // GetServiceInfo returns a map from service names to ServiceInfo. // Service names include the package names, in the form of .. -func (s *Server) GetServiceInfo() map[string]*ServiceInfo { - ret := make(map[string]*ServiceInfo) +func (s *Server) GetServiceInfo() map[string]ServiceInfo { + ret := make(map[string]ServiceInfo) for n, srv := range s.m { methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd)) for m := range srv.md { @@ -283,7 +300,7 @@ func (s *Server) GetServiceInfo() map[string]*ServiceInfo { }) } - ret[n] = &ServiceInfo{ + ret[n] = ServiceInfo{ Methods: methods, Metadata: srv.mdata, } @@ -350,7 +367,10 @@ func (s *Server) handleRawConn(rawConn net.Conn) { s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) s.mu.Unlock() grpclog.Printf("grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err) - rawConn.Close() + // If serverHandShake returns ErrConnDispatched, keep rawConn open. + if err != credentials.ErrConnDispatched { + rawConn.Close() + } return } @@ -468,7 +488,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea func (s *Server) addConn(c io.Closer) bool { s.mu.Lock() defer s.mu.Unlock() - if s.conns == nil { + if s.conns == nil || s.drain { return false } s.conns[c] = true @@ -480,6 +500,7 @@ func (s *Server) removeConn(c io.Closer) { defer s.mu.Unlock() if s.conns != nil { delete(s.conns, c) + s.cv.Signal() } } @@ -520,7 +541,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } p := &parser{r: stream} for { - pf, req, err := p.recvMsg() + pf, req, err := p.recvMsg(s.opts.maxMsgSize) if err == io.EOF { // The entire stream is done (for unary RPC only). return err @@ -530,6 +551,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } if err != nil { switch err := err.(type) { + case *rpcError: + if err := t.WriteStatus(stream, err.code, err.desc); err != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) + } case transport.ConnectionError: // Nothing to do here. case transport.StreamError: @@ -569,6 +594,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return err } } + if len(req) > s.opts.maxMsgSize { + // TODO: Revisit the error code. Currently keep it consistent with + // java implementation. + statusCode = codes.Internal + statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize) + } if err := s.opts.codec.Unmarshal(req, v); err != nil { return err } @@ -628,13 +659,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp stream.SetSendCompress(s.opts.cp.Type()) } ss := &serverStream{ - t: t, - s: stream, - p: &parser{r: stream}, - codec: s.opts.codec, - cp: s.opts.cp, - dc: s.opts.dc, - trInfo: trInfo, + t: t, + s: stream, + p: &parser{r: stream}, + codec: s.opts.codec, + cp: s.opts.cp, + dc: s.opts.dc, + maxMsgSize: s.opts.maxMsgSize, + trInfo: trInfo, } if ss.cp != nil { ss.cbuf = new(bytes.Buffer) @@ -766,14 +798,16 @@ func (s *Server) Stop() { s.mu.Lock() listeners := s.lis s.lis = nil - cs := s.conns + st := s.conns s.conns = nil + // interrupt GracefulStop if Stop and GracefulStop are called concurrently. + s.cv.Signal() s.mu.Unlock() for lis := range listeners { lis.Close() } - for c := range cs { + for c := range st { c.Close() } @@ -785,6 +819,32 @@ func (s *Server) Stop() { s.mu.Unlock() } +// GracefulStop stops the gRPC server gracefully. It stops the server to accept new +// connections and RPCs and blocks until all the pending RPCs are finished. +func (s *Server) GracefulStop() { + s.mu.Lock() + defer s.mu.Unlock() + if s.drain == true || s.conns == nil { + return + } + s.drain = true + for lis := range s.lis { + lis.Close() + } + s.lis = nil + for c := range s.conns { + c.(transport.ServerTransport).Drain() + } + for len(s.conns) != 0 { + s.cv.Wait() + } + s.conns = nil + if s.events != nil { + s.events.Finish() + s.events = nil + } +} + func init() { internal.TestingCloseConns = func(arg interface{}) { arg.(*Server).testingCloseConns() -- cgit v1.2.3