aboutsummaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/grpc
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc')
-rw-r--r--vendor/google.golang.org/grpc/README.md2
-rw-r--r--vendor/google.golang.org/grpc/call.go29
-rw-r--r--vendor/google.golang.org/grpc/clientconn.go359
-rw-r--r--vendor/google.golang.org/grpc/credentials/credentials.go60
-rw-r--r--vendor/google.golang.org/grpc/credentials/credentials_util_go17.go76
-rw-r--r--vendor/google.golang.org/grpc/credentials/credentials_util_pre_go17.go74
-rw-r--r--vendor/google.golang.org/grpc/metadata/metadata.go14
-rw-r--r--vendor/google.golang.org/grpc/rpc_util.go18
-rw-r--r--vendor/google.golang.org/grpc/server.go100
-rw-r--r--vendor/google.golang.org/grpc/stream.go115
-rw-r--r--vendor/google.golang.org/grpc/transport/control.go5
-rw-r--r--vendor/google.golang.org/grpc/transport/go16.go46
-rw-r--r--vendor/google.golang.org/grpc/transport/go17.go46
-rw-r--r--vendor/google.golang.org/grpc/transport/handler_server.go10
-rw-r--r--vendor/google.golang.org/grpc/transport/http2_client.go237
-rw-r--r--vendor/google.golang.org/grpc/transport/http2_server.go76
-rw-r--r--vendor/google.golang.org/grpc/transport/http_util.go85
-rw-r--r--vendor/google.golang.org/grpc/transport/pre_go16.go51
-rw-r--r--vendor/google.golang.org/grpc/transport/transport.go93
19 files changed, 1156 insertions, 340 deletions
diff --git a/vendor/google.golang.org/grpc/README.md b/vendor/google.golang.org/grpc/README.md
index 90e9453..660658b 100644
--- a/vendor/google.golang.org/grpc/README.md
+++ b/vendor/google.golang.org/grpc/README.md
@@ -28,5 +28,5 @@ See [API documentation](https://godoc.org/google.golang.org/grpc) for package an
Status
------
-Beta release
+GA
diff --git a/vendor/google.golang.org/grpc/call.go b/vendor/google.golang.org/grpc/call.go
index 84ac178..fea0799 100644
--- a/vendor/google.golang.org/grpc/call.go
+++ b/vendor/google.golang.org/grpc/call.go
@@ -36,6 +36,7 @@ package grpc
import (
"bytes"
"io"
+ "math"
"time"
"golang.org/x/net/context"
@@ -51,13 +52,20 @@ import (
func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
// Try to acquire header metadata from the server if there is any.
var err error
+ defer func() {
+ if err != nil {
+ if _, ok := err.(transport.ConnectionError); !ok {
+ t.CloseStream(stream, err)
+ }
+ }
+ }()
c.headerMD, err = stream.Header()
if err != nil {
return err
}
p := &parser{r: stream}
for {
- if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil {
+ if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil {
if err == io.EOF {
break
}
@@ -76,6 +84,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
}
defer func() {
if err != nil {
+ // If err is connection error, t will be closed, no need to close stream here.
if _, ok := err.(transport.ConnectionError); !ok {
t.CloseStream(stream, err)
}
@@ -90,7 +99,10 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
}
err = t.Write(stream, outBuf, opts)
- if err != nil {
+ // t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
+ // does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
+ // recvResponse to get the final status.
+ if err != nil && err != io.EOF {
return nil, err
}
// Sent successfully.
@@ -158,9 +170,9 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if _, ok := err.(*rpcError); ok {
return err
}
- if err == errConnClosing {
+ if err == errConnClosing || err == errConnUnavailable {
if c.failFast {
- return Errorf(codes.Unavailable, "%v", errConnClosing)
+ return Errorf(codes.Unavailable, "%v", err)
}
continue
}
@@ -176,7 +188,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
put()
put = nil
}
- if _, ok := err.(transport.ConnectionError); ok {
+ // Retry a non-failfast RPC when
+ // i) there is a connection error; or
+ // ii) the server started to drain before this RPC was initiated.
+ if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast {
return toRPCErr(err)
}
@@ -184,20 +199,18 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
}
return toRPCErr(err)
}
- // Receive the response
err = recvResponse(cc.dopts, t, &c, stream, reply)
if err != nil {
if put != nil {
put()
put = nil
}
- if _, ok := err.(transport.ConnectionError); ok {
+ if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast {
return toRPCErr(err)
}
continue
}
- t.CloseStream(stream, err)
return toRPCErr(err)
}
if c.traceInfo.tr != nil {
diff --git a/vendor/google.golang.org/grpc/clientconn.go b/vendor/google.golang.org/grpc/clientconn.go
index c3c7691..1d3b46c 100644
--- a/vendor/google.golang.org/grpc/clientconn.go
+++ b/vendor/google.golang.org/grpc/clientconn.go
@@ -43,7 +43,6 @@ import (
"golang.org/x/net/context"
"golang.org/x/net/trace"
- "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/transport"
@@ -68,13 +67,15 @@ var (
// errCredentialsConflict indicates that grpc.WithTransportCredentials()
// and grpc.WithInsecure() are both called for a connection.
errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)")
- // errNetworkIP indicates that the connection is down due to some network I/O error.
+ // errNetworkIO indicates that the connection is down due to some network I/O error.
errNetworkIO = errors.New("grpc: failed with network I/O error")
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
errConnDrain = errors.New("grpc: the connection is drained")
// errConnClosing indicates that the connection is closing.
errConnClosing = errors.New("grpc: the connection is closing")
- errNoAddr = errors.New("grpc: there is no address available to dial")
+ // errConnUnavailable indicates that the connection is unavailable.
+ errConnUnavailable = errors.New("grpc: the connection is unavailable")
+ errNoAddr = errors.New("grpc: there is no address available to dial")
// minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second
)
@@ -196,9 +197,14 @@ func WithTimeout(d time.Duration) DialOption {
}
// WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
-func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption {
+func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
return func(o *dialOptions) {
- o.copts.Dialer = f
+ o.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) {
+ if deadline, ok := ctx.Deadline(); ok {
+ return f(addr, deadline.Sub(time.Now()))
+ }
+ return f(addr, 0)
+ }
}
}
@@ -209,12 +215,34 @@ func WithUserAgent(s string) DialOption {
}
}
-// Dial creates a client connection the given target.
+// Dial creates a client connection to the given target.
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
+ return DialContext(context.Background(), target, opts...)
+}
+
+// DialContext creates a client connection to the given target. ctx can be used to
+// cancel or expire the pending connecting. Once this function returns, the
+// cancellation and expiration of ctx will be noop. Users should call ClientConn.Close
+// to terminate all the pending operations after this function returns.
+// This is the EXPERIMENTAL API.
+func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
cc := &ClientConn{
target: target,
conns: make(map[Address]*addrConn),
}
+ cc.ctx, cc.cancel = context.WithCancel(context.Background())
+ defer func() {
+ select {
+ case <-ctx.Done():
+ conn, err = nil, ctx.Err()
+ default:
+ }
+
+ if err != nil {
+ cc.Close()
+ }
+ }()
+
for _, opt := range opts {
opt(&cc.dopts)
}
@@ -226,31 +254,33 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
if cc.dopts.bs == nil {
cc.dopts.bs = DefaultBackoffConfig
}
- if cc.dopts.balancer == nil {
- cc.dopts.balancer = RoundRobin(nil)
- }
- if err := cc.dopts.balancer.Start(target); err != nil {
- return nil, err
- }
var (
ok bool
addrs []Address
)
- ch := cc.dopts.balancer.Notify()
- if ch == nil {
- // There is no name resolver installed.
+ if cc.dopts.balancer == nil {
+ // Connect to target directly if balancer is nil.
addrs = append(addrs, Address{Addr: target})
} else {
- addrs, ok = <-ch
- if !ok || len(addrs) == 0 {
- return nil, errNoAddr
+ if err := cc.dopts.balancer.Start(target); err != nil {
+ return nil, err
+ }
+ ch := cc.dopts.balancer.Notify()
+ if ch == nil {
+ // There is no name resolver installed.
+ addrs = append(addrs, Address{Addr: target})
+ } else {
+ addrs, ok = <-ch
+ if !ok || len(addrs) == 0 {
+ return nil, errNoAddr
+ }
}
}
waitC := make(chan error, 1)
go func() {
for _, a := range addrs {
- if err := cc.newAddrConn(a, false); err != nil {
+ if err := cc.resetAddrConn(a, false, nil); err != nil {
waitC <- err
return
}
@@ -262,15 +292,17 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
timeoutCh = time.After(cc.dopts.timeout)
}
select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
case err := <-waitC:
if err != nil {
- cc.Close()
return nil, err
}
case <-timeoutCh:
- cc.Close()
return nil, ErrClientConnTimeout
}
+ // If balancer is nil or balancer.Notify() is nil, ok will be false here.
+ // The lbWatcher goroutine will not be created.
if ok {
go cc.lbWatcher()
}
@@ -317,6 +349,9 @@ func (s ConnectivityState) String() string {
// ClientConn represents a client connection to an RPC server.
type ClientConn struct {
+ ctx context.Context
+ cancel context.CancelFunc
+
target string
authority string
dopts dialOptions
@@ -347,11 +382,12 @@ func (cc *ClientConn) lbWatcher() {
}
if !keep {
del = append(del, c)
+ delete(cc.conns, c.addr)
}
}
cc.mu.Unlock()
for _, a := range add {
- cc.newAddrConn(a, true)
+ cc.resetAddrConn(a, true, nil)
}
for _, c := range del {
c.tearDown(errConnDrain)
@@ -359,13 +395,17 @@ func (cc *ClientConn) lbWatcher() {
}
}
-func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
+// resetAddrConn creates an addrConn for addr and adds it to cc.conns.
+// If there is an old addrConn for addr, it will be torn down, using tearDownErr as the reason.
+// If tearDownErr is nil, errConnDrain will be used instead.
+func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr error) error {
ac := &addrConn{
- cc: cc,
- addr: addr,
- dopts: cc.dopts,
- shutdownChan: make(chan struct{}),
+ cc: cc,
+ addr: addr,
+ dopts: cc.dopts,
}
+ ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
+ ac.stateCV = sync.NewCond(&ac.mu)
if EnableTracing {
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
}
@@ -383,26 +423,44 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
}
}
}
- // Insert ac into ac.cc.conns. This needs to be done before any getTransport(...) is called.
- ac.cc.mu.Lock()
- if ac.cc.conns == nil {
- ac.cc.mu.Unlock()
+ // Track ac in cc. This needs to be done before any getTransport(...) is called.
+ cc.mu.Lock()
+ if cc.conns == nil {
+ cc.mu.Unlock()
return ErrClientConnClosing
}
- stale := ac.cc.conns[ac.addr]
- ac.cc.conns[ac.addr] = ac
- ac.cc.mu.Unlock()
+ stale := cc.conns[ac.addr]
+ cc.conns[ac.addr] = ac
+ cc.mu.Unlock()
if stale != nil {
// There is an addrConn alive on ac.addr already. This could be due to
- // i) stale's Close is undergoing;
- // ii) a buggy Balancer notifies duplicated Addresses.
- stale.tearDown(errConnDrain)
+ // 1) a buggy Balancer notifies duplicated Addresses;
+ // 2) goaway was received, a new ac will replace the old ac.
+ // The old ac should be deleted from cc.conns, but the
+ // underlying transport should drain rather than close.
+ if tearDownErr == nil {
+ // tearDownErr is nil if resetAddrConn is called by
+ // 1) Dial
+ // 2) lbWatcher
+ // In both cases, the stale ac should drain, not close.
+ stale.tearDown(errConnDrain)
+ } else {
+ stale.tearDown(tearDownErr)
+ }
}
- ac.stateCV = sync.NewCond(&ac.mu)
// skipWait may overwrite the decision in ac.dopts.block.
if ac.dopts.block && !skipWait {
if err := ac.resetTransport(false); err != nil {
- ac.tearDown(err)
+ if err != errConnClosing {
+ // Tear down ac and delete it from cc.conns.
+ cc.mu.Lock()
+ delete(cc.conns, ac.addr)
+ cc.mu.Unlock()
+ ac.tearDown(err)
+ }
+ if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
+ return e.Origin()
+ }
return err
}
// Start to monitor the error status of transport.
@@ -412,7 +470,10 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
go func() {
if err := ac.resetTransport(false); err != nil {
grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err)
- ac.tearDown(err)
+ if err != errConnClosing {
+ // Keep this ac in cc.conns, to get the reason it's torn down.
+ ac.tearDown(err)
+ }
return
}
ac.transportMonitor()
@@ -422,24 +483,48 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
}
func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
- addr, put, err := cc.dopts.balancer.Get(ctx, opts)
- if err != nil {
- return nil, nil, toRPCErr(err)
- }
- cc.mu.RLock()
- if cc.conns == nil {
+ var (
+ ac *addrConn
+ ok bool
+ put func()
+ )
+ if cc.dopts.balancer == nil {
+ // If balancer is nil, there should be only one addrConn available.
+ cc.mu.RLock()
+ if cc.conns == nil {
+ cc.mu.RUnlock()
+ return nil, nil, toRPCErr(ErrClientConnClosing)
+ }
+ for _, ac = range cc.conns {
+ // Break after the first iteration to get the first addrConn.
+ ok = true
+ break
+ }
+ cc.mu.RUnlock()
+ } else {
+ var (
+ addr Address
+ err error
+ )
+ addr, put, err = cc.dopts.balancer.Get(ctx, opts)
+ if err != nil {
+ return nil, nil, toRPCErr(err)
+ }
+ cc.mu.RLock()
+ if cc.conns == nil {
+ cc.mu.RUnlock()
+ return nil, nil, toRPCErr(ErrClientConnClosing)
+ }
+ ac, ok = cc.conns[addr]
cc.mu.RUnlock()
- return nil, nil, toRPCErr(ErrClientConnClosing)
}
- ac, ok := cc.conns[addr]
- cc.mu.RUnlock()
if !ok {
if put != nil {
put()
}
- return nil, nil, Errorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
+ return nil, nil, errConnClosing
}
- t, err := ac.wait(ctx, !opts.BlockingWait)
+ t, err := ac.wait(ctx, cc.dopts.balancer != nil, !opts.BlockingWait)
if err != nil {
if put != nil {
put()
@@ -451,6 +536,8 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
// Close tears down the ClientConn and all underlying connections.
func (cc *ClientConn) Close() error {
+ cc.cancel()
+
cc.mu.Lock()
if cc.conns == nil {
cc.mu.Unlock()
@@ -459,7 +546,9 @@ func (cc *ClientConn) Close() error {
conns := cc.conns
cc.conns = nil
cc.mu.Unlock()
- cc.dopts.balancer.Close()
+ if cc.dopts.balancer != nil {
+ cc.dopts.balancer.Close()
+ }
for _, ac := range conns {
ac.tearDown(ErrClientConnClosing)
}
@@ -468,11 +557,13 @@ func (cc *ClientConn) Close() error {
// addrConn is a network connection to a given address.
type addrConn struct {
- cc *ClientConn
- addr Address
- dopts dialOptions
- shutdownChan chan struct{}
- events trace.EventLog
+ ctx context.Context
+ cancel context.CancelFunc
+
+ cc *ClientConn
+ addr Address
+ dopts dialOptions
+ events trace.EventLog
mu sync.Mutex
state ConnectivityState
@@ -482,6 +573,9 @@ type addrConn struct {
// due to timeout.
ready chan struct{}
transport transport.ClientTransport
+
+ // The reason this addrConn is torn down.
+ tearDownErr error
}
// printf records an event in ac's event log, unless ac has been closed.
@@ -537,8 +631,7 @@ func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState Connecti
}
func (ac *addrConn) resetTransport(closeTransport bool) error {
- var retries int
- for {
+ for retries := 0; ; retries++ {
ac.mu.Lock()
ac.printf("connecting")
if ac.state == Shutdown {
@@ -558,13 +651,20 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
t.Close()
}
sleepTime := ac.dopts.bs.backoff(retries)
- ac.dopts.copts.Timeout = sleepTime
- if sleepTime < minConnectTimeout {
- ac.dopts.copts.Timeout = minConnectTimeout
+ timeout := minConnectTimeout
+ if timeout < sleepTime {
+ timeout = sleepTime
}
+ ctx, cancel := context.WithTimeout(ac.ctx, timeout)
connectTime := time.Now()
- newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts)
+ newTransport, err := transport.NewClientTransport(ctx, ac.addr.Addr, ac.dopts.copts)
if err != nil {
+ cancel()
+
+ if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
+ return err
+ }
+ grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
ac.mu.Lock()
if ac.state == Shutdown {
// ac.tearDown(...) has been invoked.
@@ -579,17 +679,12 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
ac.ready = nil
}
ac.mu.Unlock()
- sleepTime -= time.Since(connectTime)
- if sleepTime < 0 {
- sleepTime = 0
- }
closeTransport = false
select {
- case <-time.After(sleepTime):
- case <-ac.shutdownChan:
+ case <-time.After(sleepTime - time.Since(connectTime)):
+ case <-ac.ctx.Done():
+ return ac.ctx.Err()
}
- retries++
- grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
continue
}
ac.mu.Lock()
@@ -607,7 +702,9 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
close(ac.ready)
ac.ready = nil
}
- ac.down = ac.cc.dopts.balancer.Up(ac.addr)
+ if ac.cc.dopts.balancer != nil {
+ ac.down = ac.cc.dopts.balancer.Up(ac.addr)
+ }
ac.mu.Unlock()
return nil
}
@@ -621,14 +718,42 @@ func (ac *addrConn) transportMonitor() {
t := ac.transport
ac.mu.Unlock()
select {
- // shutdownChan is needed to detect the teardown when
+ // This is needed to detect the teardown when
// the addrConn is idle (i.e., no RPC in flight).
- case <-ac.shutdownChan:
+ case <-ac.ctx.Done():
+ select {
+ case <-t.Error():
+ t.Close()
+ default:
+ }
+ return
+ case <-t.GoAway():
+ // If GoAway happens without any network I/O error, ac is closed without shutting down the
+ // underlying transport (the transport will be closed when all the pending RPCs finished or
+ // failed.).
+ // If GoAway and some network I/O error happen concurrently, ac and its underlying transport
+ // are closed.
+ // In both cases, a new ac is created.
+ select {
+ case <-t.Error():
+ ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
+ default:
+ ac.cc.resetAddrConn(ac.addr, true, errConnDrain)
+ }
return
case <-t.Error():
+ select {
+ case <-ac.ctx.Done():
+ t.Close()
+ return
+ case <-t.GoAway():
+ ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
+ return
+ default:
+ }
ac.mu.Lock()
if ac.state == Shutdown {
- // ac.tearDown(...) has been invoked.
+ // ac has been shutdown.
ac.mu.Unlock()
return
}
@@ -640,6 +765,10 @@ func (ac *addrConn) transportMonitor() {
ac.printf("transport exiting: %v", err)
ac.mu.Unlock()
grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
+ if err != errConnClosing {
+ // Keep this ac in cc.conns, to get the reason it's torn down.
+ ac.tearDown(err)
+ }
return
}
}
@@ -647,35 +776,42 @@ func (ac *addrConn) transportMonitor() {
}
// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or
-// iv) transport is in TransientFailure and the RPC is fail-fast.
-func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTransport, error) {
+// iv) transport is in TransientFailure and there's no balancer/failfast is true.
+func (ac *addrConn) wait(ctx context.Context, hasBalancer, failfast bool) (transport.ClientTransport, error) {
for {
ac.mu.Lock()
switch {
case ac.state == Shutdown:
+ if failfast || !hasBalancer {
+ // RPC is failfast or balancer is nil. This RPC should fail with ac.tearDownErr.
+ err := ac.tearDownErr
+ ac.mu.Unlock()
+ return nil, err
+ }
ac.mu.Unlock()
return nil, errConnClosing
case ac.state == Ready:
ct := ac.transport
ac.mu.Unlock()
return ct, nil
- case ac.state == TransientFailure && failFast:
- ac.mu.Unlock()
- return nil, Errorf(codes.Unavailable, "grpc: RPC failed fast due to transport failure")
- default:
- ready := ac.ready
- if ready == nil {
- ready = make(chan struct{})
- ac.ready = ready
- }
- ac.mu.Unlock()
- select {
- case <-ctx.Done():
- return nil, toRPCErr(ctx.Err())
- // Wait until the new transport is ready or failed.
- case <-ready:
+ case ac.state == TransientFailure:
+ if failfast || hasBalancer {
+ ac.mu.Unlock()
+ return nil, errConnUnavailable
}
}
+ ready := ac.ready
+ if ready == nil {
+ ready = make(chan struct{})
+ ac.ready = ready
+ }
+ ac.mu.Unlock()
+ select {
+ case <-ctx.Done():
+ return nil, toRPCErr(ctx.Err())
+ // Wait until the new transport is ready or failed.
+ case <-ready:
+ }
}
}
@@ -683,24 +819,28 @@ func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTr
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
// some edge cases (e.g., the caller opens and closes many addrConn's in a
// tight loop.
+// tearDown doesn't remove ac from ac.cc.conns.
func (ac *addrConn) tearDown(err error) {
+ ac.cancel()
+
ac.mu.Lock()
- defer func() {
- ac.mu.Unlock()
- ac.cc.mu.Lock()
- if ac.cc.conns != nil {
- delete(ac.cc.conns, ac.addr)
- }
- ac.cc.mu.Unlock()
- }()
- if ac.state == Shutdown {
- return
- }
- ac.state = Shutdown
+ defer ac.mu.Unlock()
if ac.down != nil {
ac.down(downErrorf(false, false, "%v", err))
ac.down = nil
}
+ if err == errConnDrain && ac.transport != nil {
+ // GracefulClose(...) may be executed multiple times when
+ // i) receiving multiple GoAway frames from the server; or
+ // ii) there are concurrent name resolver/Balancer triggered
+ // address removal and GoAway.
+ ac.transport.GracefulClose()
+ }
+ if ac.state == Shutdown {
+ return
+ }
+ ac.state = Shutdown
+ ac.tearDownErr = err
ac.stateCV.Broadcast()
if ac.events != nil {
ac.events.Finish()
@@ -710,15 +850,8 @@ func (ac *addrConn) tearDown(err error) {
close(ac.ready)
ac.ready = nil
}
- if ac.transport != nil {
- if err == errConnDrain {
- ac.transport.GracefulClose()
- } else {
- ac.transport.Close()
- }
- }
- if ac.shutdownChan != nil {
- close(ac.shutdownChan)
+ if ac.transport != nil && err != errConnDrain {
+ ac.transport.Close()
}
return
}
diff --git a/vendor/google.golang.org/grpc/credentials/credentials.go b/vendor/google.golang.org/grpc/credentials/credentials.go
index 8d4c57c..13be457 100644
--- a/vendor/google.golang.org/grpc/credentials/credentials.go
+++ b/vendor/google.golang.org/grpc/credentials/credentials.go
@@ -40,11 +40,11 @@ package credentials // import "google.golang.org/grpc/credentials"
import (
"crypto/tls"
"crypto/x509"
+ "errors"
"fmt"
"io/ioutil"
"net"
"strings"
- "time"
"golang.org/x/net/context"
)
@@ -87,17 +87,24 @@ type AuthInfo interface {
AuthType() string
}
+var (
+ // ErrConnDispatched indicates that rawConn has been dispatched out of gRPC
+ // and the caller should not close rawConn.
+ ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC")
+)
+
// TransportCredentials defines the common interface for all the live gRPC wire
// protocols and supported transport security protocols (e.g., TLS, SSL).
type TransportCredentials interface {
// ClientHandshake does the authentication handshake specified by the corresponding
// authentication protocol on rawConn for clients. It returns the authenticated
// connection and the corresponding auth information about the connection.
- ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, AuthInfo, error)
+ // Implementations must use the provided context to implement timely cancellation.
+ ClientHandshake(context.Context, string, net.Conn) (net.Conn, AuthInfo, error)
// ServerHandshake does the authentication handshake for servers. It returns
// the authenticated connection and the corresponding auth information about
// the connection.
- ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
+ ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
// Info provides the ProtocolInfo of this TransportCredentials.
Info() ProtocolInfo
}
@@ -136,42 +143,28 @@ func (c *tlsCreds) RequireTransportSecurity() bool {
return true
}
-type timeoutError struct{}
-
-func (timeoutError) Error() string { return "credentials: Dial timed out" }
-func (timeoutError) Timeout() bool { return true }
-func (timeoutError) Temporary() bool { return true }
-
-func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ AuthInfo, err error) {
- // borrow some code from tls.DialWithDialer
- var errChannel chan error
- if timeout != 0 {
- errChannel = make(chan error, 2)
- time.AfterFunc(timeout, func() {
- errChannel <- timeoutError{}
- })
- }
+func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
// use local cfg to avoid clobbering ServerName if using multiple endpoints
- cfg := *c.config
- if c.config.ServerName == "" {
+ cfg := cloneTLSConfig(c.config)
+ if cfg.ServerName == "" {
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
cfg.ServerName = addr[:colonPos]
}
- conn := tls.Client(rawConn, &cfg)
- if timeout == 0 {
- err = conn.Handshake()
- } else {
- go func() {
- errChannel <- conn.Handshake()
- }()
- err = <-errChannel
- }
- if err != nil {
- rawConn.Close()
- return nil, nil, err
+ conn := tls.Client(rawConn, cfg)
+ errChannel := make(chan error, 1)
+ go func() {
+ errChannel <- conn.Handshake()
+ }()
+ select {
+ case err := <-errChannel:
+ if err != nil {
+ return nil, nil, err
+ }
+ case <-ctx.Done():
+ return nil, nil, ctx.Err()
}
// TODO(zhaoq): Omit the auth info for client now. It is more for
// information than anything else.
@@ -181,7 +174,6 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
conn := tls.Server(rawConn, c.config)
if err := conn.Handshake(); err != nil {
- rawConn.Close()
return nil, nil, err
}
return conn, TLSInfo{conn.ConnectionState()}, nil
@@ -189,7 +181,7 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
// NewTLS uses c to construct a TransportCredentials based on TLS.
func NewTLS(c *tls.Config) TransportCredentials {
- tc := &tlsCreds{c}
+ tc := &tlsCreds{cloneTLSConfig(c)}
tc.config.NextProtos = alpnProtoStr
return tc
}
diff --git a/vendor/google.golang.org/grpc/credentials/credentials_util_go17.go b/vendor/google.golang.org/grpc/credentials/credentials_util_go17.go
new file mode 100644
index 0000000..9647b9e
--- /dev/null
+++ b/vendor/google.golang.org/grpc/credentials/credentials_util_go17.go
@@ -0,0 +1,76 @@
+// +build go1.7
+
+/*
+ *
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+package credentials
+
+import (
+ "crypto/tls"
+)
+
+// cloneTLSConfig returns a shallow clone of the exported
+// fields of cfg, ignoring the unexported sync.Once, which
+// contains a mutex and must not be copied.
+//
+// If cfg is nil, a new zero tls.Config is returned.
+//
+// TODO replace this function with official clone function.
+func cloneTLSConfig(cfg *tls.Config) *tls.Config {
+ if cfg == nil {
+ return &tls.Config{}
+ }
+ return &tls.Config{
+ Rand: cfg.Rand,
+ Time: cfg.Time,
+ Certificates: cfg.Certificates,
+ NameToCertificate: cfg.NameToCertificate,
+ GetCertificate: cfg.GetCertificate,
+ RootCAs: cfg.RootCAs,
+ NextProtos: cfg.NextProtos,
+ ServerName: cfg.ServerName,
+ ClientAuth: cfg.ClientAuth,
+ ClientCAs: cfg.ClientCAs,
+ InsecureSkipVerify: cfg.InsecureSkipVerify,
+ CipherSuites: cfg.CipherSuites,
+ PreferServerCipherSuites: cfg.PreferServerCipherSuites,
+ SessionTicketsDisabled: cfg.SessionTicketsDisabled,
+ SessionTicketKey: cfg.SessionTicketKey,
+ ClientSessionCache: cfg.ClientSessionCache,
+ MinVersion: cfg.MinVersion,
+ MaxVersion: cfg.MaxVersion,
+ CurvePreferences: cfg.CurvePreferences,
+ DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
+ Renegotiation: cfg.Renegotiation,
+ }
+}
diff --git a/vendor/google.golang.org/grpc/credentials/credentials_util_pre_go17.go b/vendor/google.golang.org/grpc/credentials/credentials_util_pre_go17.go
new file mode 100644
index 0000000..09b8d12
--- /dev/null
+++ b/vendor/google.golang.org/grpc/credentials/credentials_util_pre_go17.go
@@ -0,0 +1,74 @@
+// +build !go1.7
+
+/*
+ *
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+package credentials
+
+import (
+ "crypto/tls"
+)
+
+// cloneTLSConfig returns a shallow clone of the exported
+// fields of cfg, ignoring the unexported sync.Once, which
+// contains a mutex and must not be copied.
+//
+// If cfg is nil, a new zero tls.Config is returned.
+//
+// TODO replace this function with official clone function.
+func cloneTLSConfig(cfg *tls.Config) *tls.Config {
+ if cfg == nil {
+ return &tls.Config{}
+ }
+ return &tls.Config{
+ Rand: cfg.Rand,
+ Time: cfg.Time,
+ Certificates: cfg.Certificates,
+ NameToCertificate: cfg.NameToCertificate,
+ GetCertificate: cfg.GetCertificate,
+ RootCAs: cfg.RootCAs,
+ NextProtos: cfg.NextProtos,
+ ServerName: cfg.ServerName,
+ ClientAuth: cfg.ClientAuth,
+ ClientCAs: cfg.ClientCAs,
+ InsecureSkipVerify: cfg.InsecureSkipVerify,
+ CipherSuites: cfg.CipherSuites,
+ PreferServerCipherSuites: cfg.PreferServerCipherSuites,
+ SessionTicketsDisabled: cfg.SessionTicketsDisabled,
+ SessionTicketKey: cfg.SessionTicketKey,
+ ClientSessionCache: cfg.ClientSessionCache,
+ MinVersion: cfg.MinVersion,
+ MaxVersion: cfg.MaxVersion,
+ CurvePreferences: cfg.CurvePreferences,
+ }
+}
diff --git a/vendor/google.golang.org/grpc/metadata/metadata.go b/vendor/google.golang.org/grpc/metadata/metadata.go
index 52070db..954c0f7 100644
--- a/vendor/google.golang.org/grpc/metadata/metadata.go
+++ b/vendor/google.golang.org/grpc/metadata/metadata.go
@@ -60,15 +60,21 @@ func encodeKeyValue(k, v string) (string, string) {
// DecodeKeyValue returns the original key and value corresponding to the
// encoded data in k, v.
+// If k is a binary header and v contains comma, v is split on comma before decoded,
+// and the decoded v will be joined with comma before returned.
func DecodeKeyValue(k, v string) (string, string, error) {
if !strings.HasSuffix(k, binHdrSuffix) {
return k, v, nil
}
- val, err := base64.StdEncoding.DecodeString(v)
- if err != nil {
- return "", "", err
+ vvs := strings.Split(v, ",")
+ for i, vv := range vvs {
+ val, err := base64.StdEncoding.DecodeString(vv)
+ if err != nil {
+ return "", "", err
+ }
+ vvs[i] = string(val)
}
- return k, string(val), nil
+ return k, strings.Join(vvs, ","), nil
}
// MD is a mapping from metadata keys to values. Users should use the following
diff --git a/vendor/google.golang.org/grpc/rpc_util.go b/vendor/google.golang.org/grpc/rpc_util.go
index d628717..35ac9cc 100644
--- a/vendor/google.golang.org/grpc/rpc_util.go
+++ b/vendor/google.golang.org/grpc/rpc_util.go
@@ -227,7 +227,7 @@ type parser struct {
// No other error values or types must be returned, which also means
// that the underlying io.Reader must not return an incompatible
// error.
-func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
+func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) {
if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
return 0, nil, err
}
@@ -238,6 +238,9 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
if length == 0 {
return pf, nil, nil
}
+ if length > uint32(maxMsgSize) {
+ return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize)
+ }
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message:
msg = make([]byte, int(length))
@@ -308,8 +311,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
return nil
}
-func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error {
- pf, d, err := p.recvMsg()
+func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error {
+ pf, d, err := p.recvMsg(maxMsgSize)
if err != nil {
return err
}
@@ -319,11 +322,16 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
if pf == compressionMade {
d, err = dc.Do(bytes.NewReader(d))
if err != nil {
- return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
+ return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
}
}
+ if len(d) > maxMsgSize {
+ // TODO: Revisit the error code. Currently keep it consistent with java
+ // implementation.
+ return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize)
+ }
if err := c.Unmarshal(d, m); err != nil {
- return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
+ return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
}
return nil
}
diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go
index a2b2b94..b2a825a 100644
--- a/vendor/google.golang.org/grpc/server.go
+++ b/vendor/google.golang.org/grpc/server.go
@@ -89,9 +89,13 @@ type service struct {
type Server struct {
opts options
- mu sync.Mutex // guards following
- lis map[net.Listener]bool
- conns map[io.Closer]bool
+ mu sync.Mutex // guards following
+ lis map[net.Listener]bool
+ conns map[io.Closer]bool
+ drain bool
+ // A CondVar to let GracefulStop() blocks until all the pending RPCs are finished
+ // and all the transport goes away.
+ cv *sync.Cond
m map[string]*service // service name -> service info
events trace.EventLog
}
@@ -101,12 +105,15 @@ type options struct {
codec Codec
cp Compressor
dc Decompressor
+ maxMsgSize int
unaryInt UnaryServerInterceptor
streamInt StreamServerInterceptor
maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server
}
+var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
+
// A ServerOption sets options.
type ServerOption func(*options)
@@ -117,20 +124,28 @@ func CustomCodec(codec Codec) ServerOption {
}
}
-// RPCCompressor returns a ServerOption that sets a compressor for outbound message.
+// RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
func RPCCompressor(cp Compressor) ServerOption {
return func(o *options) {
o.cp = cp
}
}
-// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message.
+// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
func RPCDecompressor(dc Decompressor) ServerOption {
return func(o *options) {
o.dc = dc
}
}
+// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages.
+// If this is not set, gRPC uses the default 4MB.
+func MaxMsgSize(m int) ServerOption {
+ return func(o *options) {
+ o.maxMsgSize = m
+ }
+}
+
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption {
@@ -173,6 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
// started to accept requests yet.
func NewServer(opt ...ServerOption) *Server {
var opts options
+ opts.maxMsgSize = defaultMaxMsgSize
for _, o := range opt {
o(&opts)
}
@@ -186,6 +202,7 @@ func NewServer(opt ...ServerOption) *Server {
conns: make(map[io.Closer]bool),
m: make(map[string]*service),
}
+ s.cv = sync.NewCond(&s.mu)
if EnableTracing {
_, file, line, _ := runtime.Caller(1)
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
@@ -264,8 +281,8 @@ type ServiceInfo struct {
// GetServiceInfo returns a map from service names to ServiceInfo.
// Service names include the package names, in the form of <package>.<service>.
-func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
- ret := make(map[string]*ServiceInfo)
+func (s *Server) GetServiceInfo() map[string]ServiceInfo {
+ ret := make(map[string]ServiceInfo)
for n, srv := range s.m {
methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd))
for m := range srv.md {
@@ -283,7 +300,7 @@ func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
})
}
- ret[n] = &ServiceInfo{
+ ret[n] = ServiceInfo{
Methods: methods,
Metadata: srv.mdata,
}
@@ -350,7 +367,10 @@ func (s *Server) handleRawConn(rawConn net.Conn) {
s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
s.mu.Unlock()
grpclog.Printf("grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err)
- rawConn.Close()
+ // If serverHandShake returns ErrConnDispatched, keep rawConn open.
+ if err != credentials.ErrConnDispatched {
+ rawConn.Close()
+ }
return
}
@@ -468,7 +488,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
func (s *Server) addConn(c io.Closer) bool {
s.mu.Lock()
defer s.mu.Unlock()
- if s.conns == nil {
+ if s.conns == nil || s.drain {
return false
}
s.conns[c] = true
@@ -480,6 +500,7 @@ func (s *Server) removeConn(c io.Closer) {
defer s.mu.Unlock()
if s.conns != nil {
delete(s.conns, c)
+ s.cv.Signal()
}
}
@@ -520,7 +541,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
p := &parser{r: stream}
for {
- pf, req, err := p.recvMsg()
+ pf, req, err := p.recvMsg(s.opts.maxMsgSize)
if err == io.EOF {
// The entire stream is done (for unary RPC only).
return err
@@ -530,6 +551,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
if err != nil {
switch err := err.(type) {
+ case *rpcError:
+ if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
+ }
case transport.ConnectionError:
// Nothing to do here.
case transport.StreamError:
@@ -569,6 +594,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
return err
}
}
+ if len(req) > s.opts.maxMsgSize {
+ // TODO: Revisit the error code. Currently keep it consistent with
+ // java implementation.
+ statusCode = codes.Internal
+ statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
+ }
if err := s.opts.codec.Unmarshal(req, v); err != nil {
return err
}
@@ -628,13 +659,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
stream.SetSendCompress(s.opts.cp.Type())
}
ss := &serverStream{
- t: t,
- s: stream,
- p: &parser{r: stream},
- codec: s.opts.codec,
- cp: s.opts.cp,
- dc: s.opts.dc,
- trInfo: trInfo,
+ t: t,
+ s: stream,
+ p: &parser{r: stream},
+ codec: s.opts.codec,
+ cp: s.opts.cp,
+ dc: s.opts.dc,
+ maxMsgSize: s.opts.maxMsgSize,
+ trInfo: trInfo,
}
if ss.cp != nil {
ss.cbuf = new(bytes.Buffer)
@@ -766,14 +798,16 @@ func (s *Server) Stop() {
s.mu.Lock()
listeners := s.lis
s.lis = nil
- cs := s.conns
+ st := s.conns
s.conns = nil
+ // interrupt GracefulStop if Stop and GracefulStop are called concurrently.
+ s.cv.Signal()
s.mu.Unlock()
for lis := range listeners {
lis.Close()
}
- for c := range cs {
+ for c := range st {
c.Close()
}
@@ -785,6 +819,32 @@ func (s *Server) Stop() {
s.mu.Unlock()
}
+// GracefulStop stops the gRPC server gracefully. It stops the server to accept new
+// connections and RPCs and blocks until all the pending RPCs are finished.
+func (s *Server) GracefulStop() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.drain == true || s.conns == nil {
+ return
+ }
+ s.drain = true
+ for lis := range s.lis {
+ lis.Close()
+ }
+ s.lis = nil
+ for c := range s.conns {
+ c.(transport.ServerTransport).Drain()
+ }
+ for len(s.conns) != 0 {
+ s.cv.Wait()
+ }
+ s.conns = nil
+ if s.events != nil {
+ s.events.Finish()
+ s.events = nil
+ }
+}
+
func init() {
internal.TestingCloseConns = func(arg interface{}) {
arg.(*Server).testingCloseConns()
diff --git a/vendor/google.golang.org/grpc/stream.go b/vendor/google.golang.org/grpc/stream.go
index 7a3bef5..51df3f0 100644
--- a/vendor/google.golang.org/grpc/stream.go
+++ b/vendor/google.golang.org/grpc/stream.go
@@ -37,6 +37,7 @@ import (
"bytes"
"errors"
"io"
+ "math"
"sync"
"time"
@@ -84,12 +85,9 @@ type ClientStream interface {
// Header returns the header metadata received from the server if there
// is any. It blocks if the metadata is not ready to read.
Header() (metadata.MD, error)
- // Trailer returns the trailer metadata from the server. It must be called
- // after stream.Recv() returns non-nil error (including io.EOF) for
- // bi-directional streaming and server streaming or stream.CloseAndRecv()
- // returns for client streaming in order to receive trailer metadata if
- // present. Otherwise, it could returns an empty MD even though trailer
- // is present.
+ // Trailer returns the trailer metadata from the server, if there is any.
+ // It must only be called after stream.CloseAndRecv has returned, or
+ // stream.Recv has returned a non-nil error (including io.EOF).
Trailer() metadata.MD
// CloseSend closes the send direction of the stream. It closes the stream
// when non-nil error is met.
@@ -99,11 +97,10 @@ type ClientStream interface {
// NewClientStream creates a new Stream for the client side. This is called
// by generated code.
-func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
+func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
var (
t transport.ClientTransport
s *transport.Stream
- err error
put func()
)
c := defaultCallInfo
@@ -120,27 +117,24 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
}
- cs := &clientStream{
- opts: opts,
- c: c,
- desc: desc,
- codec: cc.dopts.codec,
- cp: cc.dopts.cp,
- dc: cc.dopts.dc,
- tracing: EnableTracing,
- }
- if cc.dopts.cp != nil {
- callHdr.SendCompress = cc.dopts.cp.Type()
- cs.cbuf = new(bytes.Buffer)
- }
- if cs.tracing {
- cs.trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
- cs.trInfo.firstLine.client = true
+ var trInfo traceInfo
+ if EnableTracing {
+ trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
+ trInfo.firstLine.client = true
if deadline, ok := ctx.Deadline(); ok {
- cs.trInfo.firstLine.deadline = deadline.Sub(time.Now())
+ trInfo.firstLine.deadline = deadline.Sub(time.Now())
}
- cs.trInfo.tr.LazyLog(&cs.trInfo.firstLine, false)
- ctx = trace.NewContext(ctx, cs.trInfo.tr)
+ trInfo.tr.LazyLog(&trInfo.firstLine, false)
+ ctx = trace.NewContext(ctx, trInfo.tr)
+ defer func() {
+ if err != nil {
+ // Need to call tr.finish() if error is returned.
+ // Because tr will not be returned to caller.
+ trInfo.tr.LazyPrintf("RPC: [%v]", err)
+ trInfo.tr.SetError()
+ trInfo.tr.Finish()
+ }
+ }()
}
gopts := BalancerGetOptions{
BlockingWait: !c.failFast,
@@ -152,9 +146,9 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if _, ok := err.(*rpcError); ok {
return nil, err
}
- if err == errConnClosing {
+ if err == errConnClosing || err == errConnUnavailable {
if c.failFast {
- return nil, Errorf(codes.Unavailable, "%v", errConnClosing)
+ return nil, Errorf(codes.Unavailable, "%v", err)
}
continue
}
@@ -168,9 +162,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
put()
put = nil
}
- if _, ok := err.(transport.ConnectionError); ok {
+ if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast {
- cs.finish(err)
return nil, toRPCErr(err)
}
continue
@@ -179,16 +172,43 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}
break
}
- cs.put = put
- cs.t = t
- cs.s = s
- cs.p = &parser{r: s}
- // Listen on ctx.Done() to detect cancellation when there is no pending
- // I/O operations on this stream.
+ cs := &clientStream{
+ opts: opts,
+ c: c,
+ desc: desc,
+ codec: cc.dopts.codec,
+ cp: cc.dopts.cp,
+ dc: cc.dopts.dc,
+
+ put: put,
+ t: t,
+ s: s,
+ p: &parser{r: s},
+
+ tracing: EnableTracing,
+ trInfo: trInfo,
+ }
+ if cc.dopts.cp != nil {
+ cs.cbuf = new(bytes.Buffer)
+ }
+ // Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination
+ // when there is no pending I/O operations on this stream.
go func() {
select {
case <-t.Error():
// Incur transport error, simply exit.
+ case <-s.Done():
+ // TODO: The trace of the RPC is terminated here when there is no pending
+ // I/O, which is probably not the optimal solution.
+ if s.StatusCode() == codes.OK {
+ cs.finish(nil)
+ } else {
+ cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc()))
+ }
+ cs.closeTransportStream(nil)
+ case <-s.GoAway():
+ cs.finish(errConnDrain)
+ cs.closeTransportStream(errConnDrain)
case <-s.Context().Done():
err := s.Context().Err()
cs.finish(err)
@@ -251,7 +271,17 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
if err != nil {
cs.finish(err)
}
- if err == nil || err == io.EOF {
+ if err == nil {
+ return
+ }
+ if err == io.EOF {
+ // Specialize the process for server streaming. SendMesg is only called
+ // once when creating the stream object. io.EOF needs to be skipped when
+ // the rpc is early finished (before the stream object is created.).
+ // TODO: It is probably better to move this into the generated code.
+ if !cs.desc.ClientStreams && cs.desc.ServerStreams {
+ err = nil
+ }
return
}
if _, ok := err.(transport.ConnectionError); !ok {
@@ -272,7 +302,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
}
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
- err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
@@ -291,7 +321,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
return
}
// Special handling for client streaming rpc.
- err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@@ -326,7 +356,7 @@ func (cs *clientStream) CloseSend() (err error) {
}
}()
if err == nil || err == io.EOF {
- return
+ return nil
}
if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err)
@@ -392,6 +422,7 @@ type serverStream struct {
cp Compressor
dc Decompressor
cbuf *bytes.Buffer
+ maxMsgSize int
statusCode codes.Code
statusDesc string
trInfo *traceInfo
@@ -458,5 +489,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
ss.mu.Unlock()
}
}()
- return recv(ss.p, ss.codec, ss.s, ss.dc, m)
+ return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize)
}
diff --git a/vendor/google.golang.org/grpc/transport/control.go b/vendor/google.golang.org/grpc/transport/control.go
index 7e9bdf3..4ef0830 100644
--- a/vendor/google.golang.org/grpc/transport/control.go
+++ b/vendor/google.golang.org/grpc/transport/control.go
@@ -72,6 +72,11 @@ type resetStream struct {
func (*resetStream) item() {}
+type goAway struct {
+}
+
+func (*goAway) item() {}
+
type flushIO struct {
}
diff --git a/vendor/google.golang.org/grpc/transport/go16.go b/vendor/google.golang.org/grpc/transport/go16.go
new file mode 100644
index 0000000..ee1c46b
--- /dev/null
+++ b/vendor/google.golang.org/grpc/transport/go16.go
@@ -0,0 +1,46 @@
+// +build go1.6,!go1.7
+
+/*
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+package transport
+
+import (
+ "net"
+
+ "golang.org/x/net/context"
+)
+
+// dialContext connects to the address on the named network.
+func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
+}
diff --git a/vendor/google.golang.org/grpc/transport/go17.go b/vendor/google.golang.org/grpc/transport/go17.go
new file mode 100644
index 0000000..356f13f
--- /dev/null
+++ b/vendor/google.golang.org/grpc/transport/go17.go
@@ -0,0 +1,46 @@
+// +build go1.7
+
+/*
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+package transport
+
+import (
+ "net"
+
+ "golang.org/x/net/context"
+)
+
+// dialContext connects to the address on the named network.
+func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ return (&net.Dialer{}).DialContext(ctx, network, address)
+}
diff --git a/vendor/google.golang.org/grpc/transport/handler_server.go b/vendor/google.golang.org/grpc/transport/handler_server.go
index 4b0d525..f23b2da 100644
--- a/vendor/google.golang.org/grpc/transport/handler_server.go
+++ b/vendor/google.golang.org/grpc/transport/handler_server.go
@@ -83,7 +83,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
}
if v := r.Header.Get("grpc-timeout"); v != "" {
- to, err := timeoutDecode(v)
+ to, err := decodeTimeout(v)
if err != nil {
return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err)
}
@@ -194,7 +194,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code,
h := ht.rw.Header()
h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
if statusDesc != "" {
- h.Set("Grpc-Message", statusDesc)
+ h.Set("Grpc-Message", encodeGrpcMessage(statusDesc))
}
if md := s.Trailer(); len(md) > 0 {
for k, vv := range md {
@@ -370,6 +370,10 @@ func (ht *serverHandlerTransport) runStream() {
}
}
+func (ht *serverHandlerTransport) Drain() {
+ panic("Drain() is not implemented")
+}
+
// mapRecvMsgError returns the non-nil err into the appropriate
// error value as expected by callers of *grpc.parser.recvMsg.
// In particular, in can only be:
@@ -389,5 +393,5 @@ func mapRecvMsgError(err error) error {
}
}
}
- return ConnectionError{Desc: err.Error()}
+ return ConnectionErrorf(true, err, err.Error())
}
diff --git a/vendor/google.golang.org/grpc/transport/http2_client.go b/vendor/google.golang.org/grpc/transport/http2_client.go
index f66435f..afbba45 100644
--- a/vendor/google.golang.org/grpc/transport/http2_client.go
+++ b/vendor/google.golang.org/grpc/transport/http2_client.go
@@ -35,6 +35,7 @@ package transport
import (
"bytes"
+ "fmt"
"io"
"math"
"net"
@@ -71,6 +72,9 @@ type http2Client struct {
shutdownChan chan struct{}
// errorChan is closed to notify the I/O error to the caller.
errorChan chan struct{}
+ // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
+ // that the server sent GoAway on this transport.
+ goAway chan struct{}
framer *framer
hBuf *bytes.Buffer // the buffer for HPACK encoding
@@ -97,41 +101,73 @@ type http2Client struct {
maxStreams int
// the per-stream outbound flow control window size set by the peer.
streamSendQuota uint32
+ // goAwayID records the Last-Stream-ID in the GoAway frame from the server.
+ goAwayID uint32
+ // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
+ prevGoAwayID uint32
+}
+
+func dial(fn func(context.Context, string) (net.Conn, error), ctx context.Context, addr string) (net.Conn, error) {
+ if fn != nil {
+ return fn(ctx, addr)
+ }
+ return dialContext(ctx, "tcp", addr)
+}
+
+func isTemporary(err error) bool {
+ switch err {
+ case io.EOF:
+ // Connection closures may be resolved upon retry, and are thus
+ // treated as temporary.
+ return true
+ case context.DeadlineExceeded:
+ // In Go 1.7, context.DeadlineExceeded implements Timeout(), and this
+ // special case is not needed. Until then, we need to keep this
+ // clause.
+ return true
+ }
+
+ switch err := err.(type) {
+ case interface {
+ Temporary() bool
+ }:
+ return err.Temporary()
+ case interface {
+ Timeout() bool
+ }:
+ // Timeouts may be resolved upon retry, and are thus treated as
+ // temporary.
+ return err.Timeout()
+ }
+ return false
}
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction
// fails.
-func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) {
- if opts.Dialer == nil {
- // Set the default Dialer.
- opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) {
- return net.DialTimeout("tcp", addr, timeout)
- }
- }
+func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) {
scheme := "http"
- startT := time.Now()
- timeout := opts.Timeout
- conn, connErr := opts.Dialer(addr, timeout)
- if connErr != nil {
- return nil, ConnectionErrorf("transport: %v", connErr)
+ conn, err := dial(opts.Dialer, ctx, addr)
+ if err != nil {
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
+ // Any further errors will close the underlying connection
+ defer func(conn net.Conn) {
+ if err != nil {
+ conn.Close()
+ }
+ }(conn)
var authInfo credentials.AuthInfo
- if opts.TransportCredentials != nil {
+ if creds := opts.TransportCredentials; creds != nil {
scheme = "https"
- if timeout > 0 {
- timeout -= time.Since(startT)
- }
- conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout)
- }
- if connErr != nil {
- return nil, ConnectionErrorf("transport: %v", connErr)
- }
- defer func() {
+ conn, authInfo, err = creds.ClientHandshake(ctx, addr, conn)
if err != nil {
- conn.Close()
+ // Credentials handshake errors are typically considered permanent
+ // to avoid retrying on e.g. bad certificates.
+ temp := isTemporary(err)
+ return nil, ConnectionErrorf(temp, err, "transport: %v", err)
}
- }()
+ }
ua := primaryUA
if opts.UserAgent != "" {
ua = opts.UserAgent + " " + ua
@@ -147,6 +183,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
writableChan: make(chan int, 1),
shutdownChan: make(chan struct{}),
errorChan: make(chan struct{}),
+ goAway: make(chan struct{}),
framer: newFramer(conn),
hBuf: &buf,
hEnc: hpack.NewEncoder(&buf),
@@ -168,11 +205,11 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
n, err := t.conn.Write(clientPreface)
if err != nil {
t.Close()
- return nil, ConnectionErrorf("transport: %v", err)
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
if n != len(clientPreface) {
t.Close()
- return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
+ return nil, ConnectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
}
if initialWindowSize != defaultWindowSize {
err = t.framer.writeSettings(true, http2.Setting{
@@ -184,13 +221,13 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
}
if err != nil {
t.Close()
- return nil, ConnectionErrorf("transport: %v", err)
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
// Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
t.Close()
- return nil, ConnectionErrorf("transport: %v", err)
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
}
go t.controller()
@@ -202,6 +239,8 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{
id: t.nextID,
+ done: make(chan struct{}),
+ goAway: make(chan struct{}),
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
@@ -216,8 +255,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// Make a stream be able to cancel the pending operations by itself.
s.ctx, s.cancel = context.WithCancel(ctx)
s.dec = &recvBufferReader{
- ctx: s.ctx,
- recv: s.buf,
+ ctx: s.ctx,
+ goAway: s.goAway,
+ recv: s.buf,
}
return s
}
@@ -271,6 +311,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.mu.Unlock()
return nil, ErrConnClosing
}
+ if t.state == draining {
+ t.mu.Unlock()
+ return nil, ErrStreamDrain
+ }
if t.state != reachable {
t.mu.Unlock()
return nil, ErrConnClosing
@@ -278,7 +322,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
checkStreamsQuota := t.streamsQuota != nil
t.mu.Unlock()
if checkStreamsQuota {
- sq, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire())
+ sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire())
if err != nil {
return nil, err
}
@@ -287,7 +331,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.streamsQuota.add(sq - 1)
}
}
- if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil {
+ if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
// Return the quota back now because there is no stream returned to the caller.
if _, ok := err.(StreamError); ok && checkStreamsQuota {
t.streamsQuota.add(1)
@@ -295,6 +339,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
return nil, err
}
t.mu.Lock()
+ if t.state == draining {
+ t.mu.Unlock()
+ if checkStreamsQuota {
+ t.streamsQuota.add(1)
+ }
+ // Need to make t writable again so that the rpc in flight can still proceed.
+ t.writableChan <- 0
+ return nil, ErrStreamDrain
+ }
if t.state != reachable {
t.mu.Unlock()
return nil, ErrConnClosing
@@ -329,7 +382,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
}
if timeout > 0 {
- t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)})
+ t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
}
for k, v := range authData {
// Capital header names are illegal in HTTP/2.
@@ -384,7 +437,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
}
if err != nil {
t.notifyError(err)
- return nil, ConnectionErrorf("transport: %v", err)
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
}
t.writableChan <- 0
@@ -403,22 +456,17 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
if t.streamsQuota != nil {
updateStreams = true
}
- if t.state == draining && len(t.activeStreams) == 1 {
+ delete(t.activeStreams, s.id)
+ if t.state == draining && len(t.activeStreams) == 0 {
// The transport is draining and s is the last live stream on t.
t.mu.Unlock()
t.Close()
return
}
- delete(t.activeStreams, s.id)
t.mu.Unlock()
if updateStreams {
t.streamsQuota.add(1)
}
- // In case stream sending and receiving are invoked in separate
- // goroutines (e.g., bi-directional streaming), the caller needs
- // to call cancel on the stream to interrupt the blocking on
- // other goroutines.
- s.cancel()
s.mu.Lock()
if q := s.fc.resetPendingData(); q > 0 {
if n := t.fc.onRead(q); n > 0 {
@@ -445,13 +493,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
// accessed any more.
func (t *http2Client) Close() (err error) {
t.mu.Lock()
- if t.state == reachable {
- close(t.errorChan)
- }
if t.state == closing {
t.mu.Unlock()
return
}
+ if t.state == reachable || t.state == draining {
+ close(t.errorChan)
+ }
t.state = closing
t.mu.Unlock()
close(t.shutdownChan)
@@ -475,10 +523,35 @@ func (t *http2Client) Close() (err error) {
func (t *http2Client) GracefulClose() error {
t.mu.Lock()
- if t.state == closing {
+ switch t.state {
+ case unreachable:
+ // The server may close the connection concurrently. t is not available for
+ // any streams. Close it now.
+ t.mu.Unlock()
+ t.Close()
+ return nil
+ case closing:
t.mu.Unlock()
return nil
}
+ // Notify the streams which were initiated after the server sent GOAWAY.
+ select {
+ case <-t.goAway:
+ n := t.prevGoAwayID
+ if n == 0 && t.nextID > 1 {
+ n = t.nextID - 2
+ }
+ m := t.goAwayID + 2
+ if m == 2 {
+ m = 1
+ }
+ for i := m; i <= n; i += 2 {
+ if s, ok := t.activeStreams[i]; ok {
+ close(s.goAway)
+ }
+ }
+ default:
+ }
if t.state == draining {
t.mu.Unlock()
return nil
@@ -504,15 +577,15 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
size := http2MaxFrameLen
s.sendQuotaPool.add(0)
// Wait until the stream has some quota to send the data.
- sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire())
+ sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire())
if err != nil {
return err
}
t.sendQuotaPool.add(0)
// Wait until the transport has some quota to send the data.
- tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire())
+ tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire())
if err != nil {
- if _, ok := err.(StreamError); ok {
+ if _, ok := err.(StreamError); ok || err == io.EOF {
t.sendQuotaPool.cancel()
}
return err
@@ -544,8 +617,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
// Indicate there is a writer who is about to write a data frame.
t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the transport.
- if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
- if _, ok := err.(StreamError); ok {
+ if _, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.writableChan); err != nil {
+ if _, ok := err.(StreamError); ok || err == io.EOF {
// Return the connection quota back.
t.sendQuotaPool.add(len(p))
}
@@ -578,7 +651,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
// invoked.
if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil {
t.notifyError(err)
- return ConnectionErrorf("transport: %v", err)
+ return ConnectionErrorf(true, err, "transport: %v", err)
}
if t.framer.adjustNumWriters(-1) == 0 {
t.framer.flushWrite()
@@ -593,11 +666,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
}
s.mu.Lock()
if s.state != streamDone {
- if s.state == streamReadDone {
- s.state = streamDone
- } else {
- s.state = streamWriteDone
- }
+ s.state = streamWriteDone
}
s.mu.Unlock()
return nil
@@ -630,7 +699,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
func (t *http2Client) handleData(f *http2.DataFrame) {
size := len(f.Data())
if err := t.fc.onData(uint32(size)); err != nil {
- t.notifyError(ConnectionErrorf("%v", err))
+ t.notifyError(ConnectionErrorf(true, err, "%v", err))
return
}
// Select the right stream to dispatch.
@@ -655,6 +724,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
s.state = streamDone
s.statusCode = codes.Internal
s.statusDesc = err.Error()
+ close(s.done)
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
@@ -672,13 +742,14 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// the read direction is closed, and set the status appropriately.
if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
s.mu.Lock()
- if s.state == streamWriteDone {
- s.state = streamDone
- } else {
- s.state = streamReadDone
+ if s.state == streamDone {
+ s.mu.Unlock()
+ return
}
+ s.state = streamDone
s.statusCode = codes.Internal
s.statusDesc = "server closed the stream without sending trailers"
+ close(s.done)
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
}
@@ -704,6 +775,8 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode)
s.statusCode = codes.Unknown
}
+ s.statusDesc = fmt.Sprintf("stream terminated by RST_STREAM with error code: %d", f.ErrCode)
+ close(s.done)
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
}
@@ -728,7 +801,32 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
}
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
- // TODO(zhaoq): GoAwayFrame handler to be implemented
+ t.mu.Lock()
+ if t.state == reachable || t.state == draining {
+ if f.LastStreamID > 0 && f.LastStreamID%2 != 1 {
+ t.mu.Unlock()
+ t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID))
+ return
+ }
+ select {
+ case <-t.goAway:
+ id := t.goAwayID
+ // t.goAway has been closed (i.e.,multiple GoAways).
+ if id < f.LastStreamID {
+ t.mu.Unlock()
+ t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID))
+ return
+ }
+ t.prevGoAwayID = id
+ t.goAwayID = f.LastStreamID
+ t.mu.Unlock()
+ return
+ default:
+ }
+ t.goAwayID = f.LastStreamID
+ close(t.goAway)
+ }
+ t.mu.Unlock()
}
func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) {
@@ -780,11 +878,11 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
if len(state.mdata) > 0 {
s.trailer = state.mdata
}
- s.state = streamDone
s.statusCode = state.statusCode
s.statusDesc = state.statusDesc
+ close(s.done)
+ s.state = streamDone
s.mu.Unlock()
-
s.write(recvMsg{err: io.EOF})
}
@@ -937,13 +1035,22 @@ func (t *http2Client) Error() <-chan struct{} {
return t.errorChan
}
+func (t *http2Client) GoAway() <-chan struct{} {
+ return t.goAway
+}
+
func (t *http2Client) notifyError(err error) {
t.mu.Lock()
- defer t.mu.Unlock()
// make sure t.errorChan is closed only once.
+ if t.state == draining {
+ t.mu.Unlock()
+ t.Close()
+ return
+ }
if t.state == reachable {
t.state = unreachable
close(t.errorChan)
grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err)
}
+ t.mu.Unlock()
}
diff --git a/vendor/google.golang.org/grpc/transport/http2_server.go b/vendor/google.golang.org/grpc/transport/http2_server.go
index cee1542..16010d5 100644
--- a/vendor/google.golang.org/grpc/transport/http2_server.go
+++ b/vendor/google.golang.org/grpc/transport/http2_server.go
@@ -111,12 +111,12 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
Val: uint32(initialWindowSize)})
}
if err := framer.writeSettings(true, settings...); err != nil {
- return nil, ConnectionErrorf("transport: %v", err)
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
// Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := framer.writeWindowUpdate(true, 0, delta); err != nil {
- return nil, ConnectionErrorf("transport: %v", err)
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
}
var buf bytes.Buffer
@@ -142,7 +142,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
}
// operateHeader takes action on the decoded headers.
-func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) {
+func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) (close bool) {
buf := newRecvBuffer()
s := &Stream{
id: frame.Header().StreamID,
@@ -205,6 +205,13 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
return
}
+ if s.id%2 != 1 || s.id <= t.maxStreamID {
+ t.mu.Unlock()
+ // illegal gRPC stream id.
+ grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", s.id)
+ return true
+ }
+ t.maxStreamID = s.id
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
t.activeStreams[s.id] = s
t.mu.Unlock()
@@ -212,6 +219,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.updateWindow(s, uint32(n))
}
handle(s)
+ return
}
// HandleStreams receives incoming streams using the given handler. This is
@@ -231,6 +239,10 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
}
frame, err := t.framer.readFrame()
+ if err == io.EOF || err == io.ErrUnexpectedEOF {
+ t.Close()
+ return
+ }
if err != nil {
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
t.Close()
@@ -257,20 +269,20 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
t.controlBuf.put(&resetStream{se.StreamID, se.Code})
continue
}
+ if err == io.EOF || err == io.ErrUnexpectedEOF {
+ t.Close()
+ return
+ }
+ grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
t.Close()
return
}
switch frame := frame.(type) {
case *http2.MetaHeadersFrame:
- id := frame.Header().StreamID
- if id%2 != 1 || id <= t.maxStreamID {
- // illegal gRPC stream id.
- grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id)
+ if t.operateHeaders(frame, handle) {
t.Close()
break
}
- t.maxStreamID = id
- t.operateHeaders(frame, handle)
case *http2.DataFrame:
t.handleData(frame)
case *http2.RSTStreamFrame:
@@ -282,7 +294,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
case *http2.WindowUpdateFrame:
t.handleWindowUpdate(frame)
case *http2.GoAwayFrame:
- break
+ // TODO: Handle GoAway from the client appropriately.
default:
grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
}
@@ -364,11 +376,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
// Received the end of stream from the client.
s.mu.Lock()
if s.state != streamDone {
- if s.state == streamWriteDone {
- s.state = streamDone
- } else {
- s.state = streamReadDone
- }
+ s.state = streamReadDone
}
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
@@ -440,7 +448,7 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e
}
if err != nil {
t.Close()
- return ConnectionErrorf("transport: %v", err)
+ return ConnectionErrorf(true, err, "transport: %v", err)
}
}
return nil
@@ -455,7 +463,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
}
s.headerOk = true
s.mu.Unlock()
- if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
+ if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err
}
t.hBuf.Reset()
@@ -495,7 +503,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
headersSent = true
}
s.mu.Unlock()
- if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
+ if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err
}
t.hBuf.Reset()
@@ -508,7 +516,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
Name: "grpc-status",
Value: strconv.Itoa(int(statusCode)),
})
- t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc})
+ t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)})
// Attach the trailer metadata.
for k, v := range s.trailer {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
@@ -544,7 +552,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
}
s.mu.Unlock()
if writeHeaderFrame {
- if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
+ if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err
}
t.hBuf.Reset()
@@ -560,7 +568,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
}
if err := t.framer.writeHeaders(false, p); err != nil {
t.Close()
- return ConnectionErrorf("transport: %v", err)
+ return ConnectionErrorf(true, err, "transport: %v", err)
}
t.writableChan <- 0
}
@@ -572,13 +580,13 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
size := http2MaxFrameLen
s.sendQuotaPool.add(0)
// Wait until the stream has some quota to send the data.
- sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire())
+ sq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.sendQuotaPool.acquire())
if err != nil {
return err
}
t.sendQuotaPool.add(0)
// Wait until the transport has some quota to send the data.
- tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire())
+ tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire())
if err != nil {
if _, ok := err.(StreamError); ok {
t.sendQuotaPool.cancel()
@@ -604,7 +612,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the
// transport.
- if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
+ if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
if _, ok := err.(StreamError); ok {
// Return the connection quota back.
t.sendQuotaPool.add(ps)
@@ -634,7 +642,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
}
if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil {
t.Close()
- return ConnectionErrorf("transport: %v", err)
+ return ConnectionErrorf(true, err, "transport: %v", err)
}
if t.framer.adjustNumWriters(-1) == 0 {
t.framer.flushWrite()
@@ -679,6 +687,17 @@ func (t *http2Server) controller() {
}
case *resetStream:
t.framer.writeRSTStream(true, i.streamID, i.code)
+ case *goAway:
+ t.mu.Lock()
+ if t.state == closing {
+ t.mu.Unlock()
+ // The transport is closing.
+ return
+ }
+ sid := t.maxStreamID
+ t.state = draining
+ t.mu.Unlock()
+ t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil)
case *flushIO:
t.framer.flushWrite()
case *ping:
@@ -724,6 +743,9 @@ func (t *http2Server) Close() (err error) {
func (t *http2Server) closeStream(s *Stream) {
t.mu.Lock()
delete(t.activeStreams, s.id)
+ if t.state == draining && len(t.activeStreams) == 0 {
+ defer t.Close()
+ }
t.mu.Unlock()
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be
@@ -746,3 +768,7 @@ func (t *http2Server) closeStream(s *Stream) {
func (t *http2Server) RemoteAddr() net.Addr {
return t.conn.RemoteAddr()
}
+
+func (t *http2Server) Drain() {
+ t.controlBuf.put(&goAway{})
+}
diff --git a/vendor/google.golang.org/grpc/transport/http_util.go b/vendor/google.golang.org/grpc/transport/http_util.go
index f2e23dc..79da512 100644
--- a/vendor/google.golang.org/grpc/transport/http_util.go
+++ b/vendor/google.golang.org/grpc/transport/http_util.go
@@ -35,6 +35,7 @@ package transport
import (
"bufio"
+ "bytes"
"fmt"
"io"
"net"
@@ -52,7 +53,7 @@ import (
const (
// The primary user agent
- primaryUA = "grpc-go/0.11"
+ primaryUA = "grpc-go/1.0"
// http2MaxFrameLen specifies the max length of a HTTP2 frame.
http2MaxFrameLen = 16384 // 16KB frame
// http://http2.github.io/http2-spec/#SettingValues
@@ -174,11 +175,11 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
}
d.statusCode = codes.Code(code)
case "grpc-message":
- d.statusDesc = f.Value
+ d.statusDesc = decodeGrpcMessage(f.Value)
case "grpc-timeout":
d.timeoutSet = true
var err error
- d.timeout, err = timeoutDecode(f.Value)
+ d.timeout, err = decodeTimeout(f.Value)
if err != nil {
d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
return
@@ -251,7 +252,7 @@ func div(d, r time.Duration) int64 {
}
// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
-func timeoutEncode(t time.Duration) string {
+func encodeTimeout(t time.Duration) string {
if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "n"
}
@@ -271,7 +272,7 @@ func timeoutEncode(t time.Duration) string {
return strconv.FormatInt(div(t, time.Hour), 10) + "H"
}
-func timeoutDecode(s string) (time.Duration, error) {
+func decodeTimeout(s string) (time.Duration, error) {
size := len(s)
if size < 2 {
return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
@@ -288,6 +289,80 @@ func timeoutDecode(s string) (time.Duration, error) {
return d * time.Duration(t), nil
}
+const (
+ spaceByte = ' '
+ tildaByte = '~'
+ percentByte = '%'
+)
+
+// encodeGrpcMessage is used to encode status code in header field
+// "grpc-message".
+// It checks to see if each individual byte in msg is an
+// allowable byte, and then either percent encoding or passing it through.
+// When percent encoding, the byte is converted into hexadecimal notation
+// with a '%' prepended.
+func encodeGrpcMessage(msg string) string {
+ if msg == "" {
+ return ""
+ }
+ lenMsg := len(msg)
+ for i := 0; i < lenMsg; i++ {
+ c := msg[i]
+ if !(c >= spaceByte && c < tildaByte && c != percentByte) {
+ return encodeGrpcMessageUnchecked(msg)
+ }
+ }
+ return msg
+}
+
+func encodeGrpcMessageUnchecked(msg string) string {
+ var buf bytes.Buffer
+ lenMsg := len(msg)
+ for i := 0; i < lenMsg; i++ {
+ c := msg[i]
+ if c >= spaceByte && c < tildaByte && c != percentByte {
+ buf.WriteByte(c)
+ } else {
+ buf.WriteString(fmt.Sprintf("%%%02X", c))
+ }
+ }
+ return buf.String()
+}
+
+// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
+func decodeGrpcMessage(msg string) string {
+ if msg == "" {
+ return ""
+ }
+ lenMsg := len(msg)
+ for i := 0; i < lenMsg; i++ {
+ if msg[i] == percentByte && i+2 < lenMsg {
+ return decodeGrpcMessageUnchecked(msg)
+ }
+ }
+ return msg
+}
+
+func decodeGrpcMessageUnchecked(msg string) string {
+ var buf bytes.Buffer
+ lenMsg := len(msg)
+ for i := 0; i < lenMsg; i++ {
+ c := msg[i]
+ if c == percentByte && i+2 < lenMsg {
+ parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8)
+ if err != nil {
+ buf.WriteByte(c)
+ } else {
+ buf.WriteByte(byte(parsed))
+ i += 2
+ }
+ } else {
+ buf.WriteByte(c)
+ }
+ }
+ return buf.String()
+}
+
type framer struct {
numWriters int32
reader io.Reader
diff --git a/vendor/google.golang.org/grpc/transport/pre_go16.go b/vendor/google.golang.org/grpc/transport/pre_go16.go
new file mode 100644
index 0000000..33d91c1
--- /dev/null
+++ b/vendor/google.golang.org/grpc/transport/pre_go16.go
@@ -0,0 +1,51 @@
+// +build !go1.6
+
+/*
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+package transport
+
+import (
+ "net"
+ "time"
+
+ "golang.org/x/net/context"
+)
+
+// dialContext connects to the address on the named network.
+func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ var dialer net.Dialer
+ if deadline, ok := ctx.Deadline(); ok {
+ dialer.Timeout = deadline.Sub(time.Now())
+ }
+ return dialer.Dial(network, address)
+}
diff --git a/vendor/google.golang.org/grpc/transport/transport.go b/vendor/google.golang.org/grpc/transport/transport.go
index d4c220a..b31769a 100644
--- a/vendor/google.golang.org/grpc/transport/transport.go
+++ b/vendor/google.golang.org/grpc/transport/transport.go
@@ -44,7 +44,6 @@ import (
"io"
"net"
"sync"
- "time"
"golang.org/x/net/context"
"golang.org/x/net/trace"
@@ -120,10 +119,11 @@ func (b *recvBuffer) get() <-chan item {
// recvBufferReader implements io.Reader interface to read the data from
// recvBuffer.
type recvBufferReader struct {
- ctx context.Context
- recv *recvBuffer
- last *bytes.Reader // Stores the remaining data in the previous calls.
- err error
+ ctx context.Context
+ goAway chan struct{}
+ recv *recvBuffer
+ last *bytes.Reader // Stores the remaining data in the previous calls.
+ err error
}
// Read reads the next len(p) bytes from last. If last is drained, it tries to
@@ -141,6 +141,8 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
select {
case <-r.ctx.Done():
return 0, ContextErr(r.ctx.Err())
+ case <-r.goAway:
+ return 0, ErrStreamDrain
case i := <-r.recv.get():
r.recv.load()
m := i.(*recvMsg)
@@ -158,7 +160,7 @@ const (
streamActive streamState = iota
streamWriteDone // EndStream sent
streamReadDone // EndStream received
- streamDone // sendDone and recvDone or RSTStreamFrame is sent or received.
+ streamDone // the entire stream is finished.
)
// Stream represents an RPC in the transport layer.
@@ -169,6 +171,10 @@ type Stream struct {
// ctx is the associated context of the stream.
ctx context.Context
cancel context.CancelFunc
+ // done is closed when the final status arrives.
+ done chan struct{}
+ // goAway is closed when the server sent GoAways signal before this stream was initiated.
+ goAway chan struct{}
// method records the associated RPC method of the stream.
method string
recvCompress string
@@ -214,6 +220,18 @@ func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str
}
+// Done returns a chanel which is closed when it receives the final status
+// from the server.
+func (s *Stream) Done() <-chan struct{} {
+ return s.done
+}
+
+// GoAway returns a channel which is closed when the server sent GoAways signal
+// before this stream was initiated.
+func (s *Stream) GoAway() <-chan struct{} {
+ return s.goAway
+}
+
// Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is cancelled/expired.
@@ -221,6 +239,8 @@ func (s *Stream) Header() (metadata.MD, error) {
select {
case <-s.ctx.Done():
return nil, ContextErr(s.ctx.Err())
+ case <-s.goAway:
+ return nil, ErrStreamDrain
case <-s.headerChan:
return s.header.Copy(), nil
}
@@ -335,19 +355,17 @@ type ConnectOptions struct {
// UserAgent is the application user agent.
UserAgent string
// Dialer specifies how to dial a network address.
- Dialer func(string, time.Duration) (net.Conn, error)
+ Dialer func(context.Context, string) (net.Conn, error)
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
PerRPCCredentials []credentials.PerRPCCredentials
// TransportCredentials stores the Authenticator required to setup a client connection.
TransportCredentials credentials.TransportCredentials
- // Timeout specifies the timeout for dialing a ClientTransport.
- Timeout time.Duration
}
// NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller.
-func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) {
- return newHTTP2Client(target, opts)
+func NewClientTransport(ctx context.Context, target string, opts ConnectOptions) (ClientTransport, error) {
+ return newHTTP2Client(ctx, target, opts)
}
// Options provides additional hints and information for message
@@ -417,6 +435,11 @@ type ClientTransport interface {
// and create a new one) in error case. It should not return nil
// once the transport is initiated.
Error() <-chan struct{}
+
+ // GoAway returns a channel that is closed when ClientTranspor
+ // receives the draining signal from the server (e.g., GOAWAY frame in
+ // HTTP/2).
+ GoAway() <-chan struct{}
}
// ServerTransport is the common interface for all gRPC server-side transport
@@ -448,6 +471,9 @@ type ServerTransport interface {
// RemoteAddr returns the remote network address.
RemoteAddr() net.Addr
+
+ // Drain notifies the client this ServerTransport stops accepting new RPCs.
+ Drain()
}
// StreamErrorf creates an StreamError with the specified error code and description.
@@ -459,9 +485,11 @@ func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError {
}
// ConnectionErrorf creates an ConnectionError with the specified error description.
-func ConnectionErrorf(format string, a ...interface{}) ConnectionError {
+func ConnectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError {
return ConnectionError{
Desc: fmt.Sprintf(format, a...),
+ temp: temp,
+ err: e,
}
}
@@ -469,14 +497,36 @@ func ConnectionErrorf(format string, a ...interface{}) ConnectionError {
// entire connection and the retry of all the active streams.
type ConnectionError struct {
Desc string
+ temp bool
+ err error
}
func (e ConnectionError) Error() string {
return fmt.Sprintf("connection error: desc = %q", e.Desc)
}
-// ErrConnClosing indicates that the transport is closing.
-var ErrConnClosing = ConnectionError{Desc: "transport is closing"}
+// Temporary indicates if this connection error is temporary or fatal.
+func (e ConnectionError) Temporary() bool {
+ return e.temp
+}
+
+// Origin returns the original error of this connection error.
+func (e ConnectionError) Origin() error {
+ // Never return nil error here.
+ // If the original error is nil, return itself.
+ if e.err == nil {
+ return e
+ }
+ return e.err
+}
+
+var (
+ // ErrConnClosing indicates that the transport is closing.
+ ErrConnClosing = ConnectionErrorf(true, nil, "transport is closing")
+ // ErrStreamDrain indicates that the stream is rejected by the server because
+ // the server stops accepting new RPCs.
+ ErrStreamDrain = StreamErrorf(codes.Unavailable, "the server stops accepting new RPCs")
+)
// StreamError is an error that only affects one stream within a connection.
type StreamError struct {
@@ -501,12 +551,25 @@ func ContextErr(err error) StreamError {
// wait blocks until it can receive from ctx.Done, closing, or proceed.
// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err.
+// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise
+// it return the StreamError for ctx.Err.
+// If it receives from goAway, it returns 0, ErrStreamDrain.
// If it receives from closing, it returns 0, ErrConnClosing.
// If it receives from proceed, it returns the received integer, nil.
-func wait(ctx context.Context, closing <-chan struct{}, proceed <-chan int) (int, error) {
+func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) {
select {
case <-ctx.Done():
return 0, ContextErr(ctx.Err())
+ case <-done:
+ // User cancellation has precedence.
+ select {
+ case <-ctx.Done():
+ return 0, ContextErr(ctx.Err())
+ default:
+ }
+ return 0, io.EOF
+ case <-goAway:
+ return 0, ErrStreamDrain
case <-closing:
return 0, ErrConnClosing
case i := <-proceed: