aboutsummaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/grpc/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/server.go')
-rw-r--r--vendor/google.golang.org/grpc/server.go100
1 files changed, 80 insertions, 20 deletions
diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go
index a2b2b94..b2a825a 100644
--- a/vendor/google.golang.org/grpc/server.go
+++ b/vendor/google.golang.org/grpc/server.go
@@ -89,9 +89,13 @@ type service struct {
type Server struct {
opts options
- mu sync.Mutex // guards following
- lis map[net.Listener]bool
- conns map[io.Closer]bool
+ mu sync.Mutex // guards following
+ lis map[net.Listener]bool
+ conns map[io.Closer]bool
+ drain bool
+ // A CondVar to let GracefulStop() blocks until all the pending RPCs are finished
+ // and all the transport goes away.
+ cv *sync.Cond
m map[string]*service // service name -> service info
events trace.EventLog
}
@@ -101,12 +105,15 @@ type options struct {
codec Codec
cp Compressor
dc Decompressor
+ maxMsgSize int
unaryInt UnaryServerInterceptor
streamInt StreamServerInterceptor
maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server
}
+var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
+
// A ServerOption sets options.
type ServerOption func(*options)
@@ -117,20 +124,28 @@ func CustomCodec(codec Codec) ServerOption {
}
}
-// RPCCompressor returns a ServerOption that sets a compressor for outbound message.
+// RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
func RPCCompressor(cp Compressor) ServerOption {
return func(o *options) {
o.cp = cp
}
}
-// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message.
+// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
func RPCDecompressor(dc Decompressor) ServerOption {
return func(o *options) {
o.dc = dc
}
}
+// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages.
+// If this is not set, gRPC uses the default 4MB.
+func MaxMsgSize(m int) ServerOption {
+ return func(o *options) {
+ o.maxMsgSize = m
+ }
+}
+
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption {
@@ -173,6 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
// started to accept requests yet.
func NewServer(opt ...ServerOption) *Server {
var opts options
+ opts.maxMsgSize = defaultMaxMsgSize
for _, o := range opt {
o(&opts)
}
@@ -186,6 +202,7 @@ func NewServer(opt ...ServerOption) *Server {
conns: make(map[io.Closer]bool),
m: make(map[string]*service),
}
+ s.cv = sync.NewCond(&s.mu)
if EnableTracing {
_, file, line, _ := runtime.Caller(1)
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
@@ -264,8 +281,8 @@ type ServiceInfo struct {
// GetServiceInfo returns a map from service names to ServiceInfo.
// Service names include the package names, in the form of <package>.<service>.
-func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
- ret := make(map[string]*ServiceInfo)
+func (s *Server) GetServiceInfo() map[string]ServiceInfo {
+ ret := make(map[string]ServiceInfo)
for n, srv := range s.m {
methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd))
for m := range srv.md {
@@ -283,7 +300,7 @@ func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
})
}
- ret[n] = &ServiceInfo{
+ ret[n] = ServiceInfo{
Methods: methods,
Metadata: srv.mdata,
}
@@ -350,7 +367,10 @@ func (s *Server) handleRawConn(rawConn net.Conn) {
s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
s.mu.Unlock()
grpclog.Printf("grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err)
- rawConn.Close()
+ // If serverHandShake returns ErrConnDispatched, keep rawConn open.
+ if err != credentials.ErrConnDispatched {
+ rawConn.Close()
+ }
return
}
@@ -468,7 +488,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
func (s *Server) addConn(c io.Closer) bool {
s.mu.Lock()
defer s.mu.Unlock()
- if s.conns == nil {
+ if s.conns == nil || s.drain {
return false
}
s.conns[c] = true
@@ -480,6 +500,7 @@ func (s *Server) removeConn(c io.Closer) {
defer s.mu.Unlock()
if s.conns != nil {
delete(s.conns, c)
+ s.cv.Signal()
}
}
@@ -520,7 +541,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
p := &parser{r: stream}
for {
- pf, req, err := p.recvMsg()
+ pf, req, err := p.recvMsg(s.opts.maxMsgSize)
if err == io.EOF {
// The entire stream is done (for unary RPC only).
return err
@@ -530,6 +551,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
if err != nil {
switch err := err.(type) {
+ case *rpcError:
+ if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
+ }
case transport.ConnectionError:
// Nothing to do here.
case transport.StreamError:
@@ -569,6 +594,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
return err
}
}
+ if len(req) > s.opts.maxMsgSize {
+ // TODO: Revisit the error code. Currently keep it consistent with
+ // java implementation.
+ statusCode = codes.Internal
+ statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
+ }
if err := s.opts.codec.Unmarshal(req, v); err != nil {
return err
}
@@ -628,13 +659,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
stream.SetSendCompress(s.opts.cp.Type())
}
ss := &serverStream{
- t: t,
- s: stream,
- p: &parser{r: stream},
- codec: s.opts.codec,
- cp: s.opts.cp,
- dc: s.opts.dc,
- trInfo: trInfo,
+ t: t,
+ s: stream,
+ p: &parser{r: stream},
+ codec: s.opts.codec,
+ cp: s.opts.cp,
+ dc: s.opts.dc,
+ maxMsgSize: s.opts.maxMsgSize,
+ trInfo: trInfo,
}
if ss.cp != nil {
ss.cbuf = new(bytes.Buffer)
@@ -766,14 +798,16 @@ func (s *Server) Stop() {
s.mu.Lock()
listeners := s.lis
s.lis = nil
- cs := s.conns
+ st := s.conns
s.conns = nil
+ // interrupt GracefulStop if Stop and GracefulStop are called concurrently.
+ s.cv.Signal()
s.mu.Unlock()
for lis := range listeners {
lis.Close()
}
- for c := range cs {
+ for c := range st {
c.Close()
}
@@ -785,6 +819,32 @@ func (s *Server) Stop() {
s.mu.Unlock()
}
+// GracefulStop stops the gRPC server gracefully. It stops the server to accept new
+// connections and RPCs and blocks until all the pending RPCs are finished.
+func (s *Server) GracefulStop() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.drain == true || s.conns == nil {
+ return
+ }
+ s.drain = true
+ for lis := range s.lis {
+ lis.Close()
+ }
+ s.lis = nil
+ for c := range s.conns {
+ c.(transport.ServerTransport).Drain()
+ }
+ for len(s.conns) != 0 {
+ s.cv.Wait()
+ }
+ s.conns = nil
+ if s.events != nil {
+ s.events.Finish()
+ s.events = nil
+ }
+}
+
func init() {
internal.TestingCloseConns = func(arg interface{}) {
arg.(*Server).testingCloseConns()