From c9849d667ab55c23d343332a11afb3eb8ede3f2d Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Sun, 17 Jul 2016 17:16:14 +0100 Subject: Update vendor libs --- vendor/google.golang.org/grpc/Makefile | 5 +- vendor/google.golang.org/grpc/balancer.go | 175 +++++++++++++-------- vendor/google.golang.org/grpc/call.go | 18 +-- vendor/google.golang.org/grpc/clientconn.go | 158 +++++++++---------- .../grpc/credentials/credentials.go | 40 ++--- .../grpc/credentials/oauth/oauth.go | 33 ++-- vendor/google.golang.org/grpc/rpc_util.go | 56 +++++-- vendor/google.golang.org/grpc/server.go | 68 ++++++-- vendor/google.golang.org/grpc/stream.go | 70 +++++++-- .../grpc/transport/handler_server.go | 6 +- .../grpc/transport/http2_client.go | 28 ++-- .../grpc/transport/http2_server.go | 9 +- .../google.golang.org/grpc/transport/http_util.go | 15 +- .../google.golang.org/grpc/transport/transport.go | 10 +- 14 files changed, 439 insertions(+), 252 deletions(-) (limited to 'vendor/google.golang.org/grpc') diff --git a/vendor/google.golang.org/grpc/Makefile b/vendor/google.golang.org/grpc/Makefile index d26eb90..03bb01f 100644 --- a/vendor/google.golang.org/grpc/Makefile +++ b/vendor/google.golang.org/grpc/Makefile @@ -21,8 +21,9 @@ proto: exit 1; \ fi go get -u -v github.com/golang/protobuf/protoc-gen-go - for file in $$(git ls-files '*.proto'); do \ - protoc -I $$(dirname $$file) --go_out=plugins=grpc:$$(dirname $$file) $$file; \ + # use $$dir as the root for all proto files in the same directory + for dir in $$(git ls-files '*.proto' | xargs -n1 dirname | uniq); do \ + protoc -I $$dir --go_out=plugins=grpc:$$dir $$dir/*.proto; \ done test: testdeps diff --git a/vendor/google.golang.org/grpc/balancer.go b/vendor/google.golang.org/grpc/balancer.go index 348bf97..419e214 100644 --- a/vendor/google.golang.org/grpc/balancer.go +++ b/vendor/google.golang.org/grpc/balancer.go @@ -40,7 +40,6 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/naming" - "google.golang.org/grpc/transport" ) // Address represents a server the client connects to. @@ -94,10 +93,10 @@ type Balancer interface { // instead of blocking. // // The function returns put which is called once the rpc has completed or failed. - // put can collect and report RPC stats to a remote load balancer. gRPC internals - // will try to call this again if err is non-nil (unless err is ErrClientConnClosing). + // put can collect and report RPC stats to a remote load balancer. // - // TODO: Add other non-recoverable errors? + // This function should only return the errors Balancer cannot recover by itself. + // gRPC internals will fail the RPC if an error is returned. Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) // Notify returns a channel that is used by gRPC internals to watch the addresses // gRPC needs to connect. The addresses might be from a name resolver or remote @@ -139,35 +138,40 @@ func RoundRobin(r naming.Resolver) Balancer { return &roundRobin{r: r} } +type addrInfo struct { + addr Address + connected bool +} + type roundRobin struct { - r naming.Resolver - w naming.Watcher - open []Address // all the addresses the client should potentially connect - mu sync.Mutex - addrCh chan []Address // the channel to notify gRPC internals the list of addresses the client should connect to. - connected []Address // all the connected addresses - next int // index of the next address to return for Get() - waitCh chan struct{} // the channel to block when there is no connected address available - done bool // The Balancer is closed. + r naming.Resolver + w naming.Watcher + addrs []*addrInfo // all the addresses the client should potentially connect + mu sync.Mutex + addrCh chan []Address // the channel to notify gRPC internals the list of addresses the client should connect to. + next int // index of the next address to return for Get() + waitCh chan struct{} // the channel to block when there is no connected address available + done bool // The Balancer is closed. } func (rr *roundRobin) watchAddrUpdates() error { updates, err := rr.w.Next() if err != nil { - grpclog.Println("grpc: the naming watcher stops working due to %v.", err) + grpclog.Printf("grpc: the naming watcher stops working due to %v.\n", err) return err } rr.mu.Lock() defer rr.mu.Unlock() for _, update := range updates { addr := Address{ - Addr: update.Addr, + Addr: update.Addr, + Metadata: update.Metadata, } switch update.Op { case naming.Add: var exist bool - for _, v := range rr.open { - if addr == v { + for _, v := range rr.addrs { + if addr == v.addr { exist = true grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr) break @@ -176,12 +180,12 @@ func (rr *roundRobin) watchAddrUpdates() error { if exist { continue } - rr.open = append(rr.open, addr) + rr.addrs = append(rr.addrs, &addrInfo{addr: addr}) case naming.Delete: - for i, v := range rr.open { - if v == addr { - copy(rr.open[i:], rr.open[i+1:]) - rr.open = rr.open[:len(rr.open)-1] + for i, v := range rr.addrs { + if addr == v.addr { + copy(rr.addrs[i:], rr.addrs[i+1:]) + rr.addrs = rr.addrs[:len(rr.addrs)-1] break } } @@ -189,9 +193,11 @@ func (rr *roundRobin) watchAddrUpdates() error { grpclog.Println("Unknown update.Op ", update.Op) } } - // Make a copy of rr.open and write it onto rr.addrCh so that gRPC internals gets notified. - open := make([]Address, len(rr.open), len(rr.open)) - copy(open, rr.open) + // Make a copy of rr.addrs and write it onto rr.addrCh so that gRPC internals gets notified. + open := make([]Address, len(rr.addrs)) + for i, v := range rr.addrs { + open[i] = v.addr + } if rr.done { return ErrClientConnClosing } @@ -202,7 +208,9 @@ func (rr *roundRobin) watchAddrUpdates() error { func (rr *roundRobin) Start(target string) error { if rr.r == nil { // If there is no name resolver installed, it is not needed to - // do name resolution. In this case, rr.addrCh stays nil. + // do name resolution. In this case, target is added into rr.addrs + // as the only address available and rr.addrCh stays nil. + rr.addrs = append(rr.addrs, &addrInfo{addr: Address{Addr: target}}) return nil } w, err := rr.r.Resolve(target) @@ -221,38 +229,41 @@ func (rr *roundRobin) Start(target string) error { return nil } -// Up appends addr to the end of rr.connected and sends notification if there -// are pending Get() calls. +// Up sets the connected state of addr and sends notification if there are pending +// Get() calls. func (rr *roundRobin) Up(addr Address) func(error) { rr.mu.Lock() defer rr.mu.Unlock() - for _, a := range rr.connected { - if a == addr { - return nil + var cnt int + for _, a := range rr.addrs { + if a.addr == addr { + if a.connected { + return nil + } + a.connected = true } - } - rr.connected = append(rr.connected, addr) - if len(rr.connected) == 1 { - // addr is only one available. Notify the Get() callers who are blocking. - if rr.waitCh != nil { - close(rr.waitCh) - rr.waitCh = nil + if a.connected { + cnt++ } } + // addr is only one which is connected. Notify the Get() callers who are blocking. + if cnt == 1 && rr.waitCh != nil { + close(rr.waitCh) + rr.waitCh = nil + } return func(err error) { rr.down(addr, err) } } -// down removes addr from rr.connected and moves the remaining addrs forward. +// down unsets the connected state of addr. func (rr *roundRobin) down(addr Address, err error) { rr.mu.Lock() defer rr.mu.Unlock() - for i, a := range rr.connected { - if a == addr { - copy(rr.connected[i:], rr.connected[i+1:]) - rr.connected = rr.connected[:len(rr.connected)-1] - return + for _, a := range rr.addrs { + if addr == a.addr { + a.connected = false + break } } } @@ -266,17 +277,40 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad err = ErrClientConnClosing return } - if rr.next >= len(rr.connected) { - rr.next = 0 + + if len(rr.addrs) > 0 { + if rr.next >= len(rr.addrs) { + rr.next = 0 + } + next := rr.next + for { + a := rr.addrs[next] + next = (next + 1) % len(rr.addrs) + if a.connected { + addr = a.addr + rr.next = next + rr.mu.Unlock() + return + } + if next == rr.next { + // Has iterated all the possible address but none is connected. + break + } + } } - if len(rr.connected) > 0 { - addr = rr.connected[rr.next] + if !opts.BlockingWait { + if len(rr.addrs) == 0 { + rr.mu.Unlock() + err = fmt.Errorf("there is no address available") + return + } + // Returns the next addr on rr.addrs for failfast RPCs. + addr = rr.addrs[rr.next].addr rr.next++ rr.mu.Unlock() return } - // There is no address available. Wait on rr.waitCh. - // TODO(zhaoq): Handle the case when opts.BlockingWait is false. + // Wait on rr.waitCh for non-failfast RPCs. if rr.waitCh == nil { ch = make(chan struct{}) rr.waitCh = ch @@ -287,7 +321,7 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad for { select { case <-ctx.Done(): - err = transport.ContextErr(ctx.Err()) + err = ctx.Err() return case <-ch: rr.mu.Lock() @@ -296,24 +330,35 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad err = ErrClientConnClosing return } - if len(rr.connected) == 0 { - // The newly added addr got removed by Down() again. - if rr.waitCh == nil { - ch = make(chan struct{}) - rr.waitCh = ch - } else { - ch = rr.waitCh + + if len(rr.addrs) > 0 { + if rr.next >= len(rr.addrs) { + rr.next = 0 + } + next := rr.next + for { + a := rr.addrs[next] + next = (next + 1) % len(rr.addrs) + if a.connected { + addr = a.addr + rr.next = next + rr.mu.Unlock() + return + } + if next == rr.next { + // Has iterated all the possible address but none is connected. + break + } } - rr.mu.Unlock() - continue } - if rr.next >= len(rr.connected) { - rr.next = 0 + // The newly added addr got removed by Down() again. + if rr.waitCh == nil { + ch = make(chan struct{}) + rr.waitCh = ch + } else { + ch = rr.waitCh } - addr = rr.connected[rr.next] - rr.next++ rr.mu.Unlock() - return } } } diff --git a/vendor/google.golang.org/grpc/call.go b/vendor/google.golang.org/grpc/call.go index d6d993b..84ac178 100644 --- a/vendor/google.golang.org/grpc/call.go +++ b/vendor/google.golang.org/grpc/call.go @@ -101,7 +101,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd // Invoke is called by generated code. Also users can call Invoke directly when it // is really needed in their use cases. func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) { - var c callInfo + c := defaultCallInfo for _, o := range opts { if err := o.before(&c); err != nil { return toRPCErr(err) @@ -155,19 +155,17 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli t, put, err = cc.getTransport(ctx, gopts) if err != nil { // TODO(zhaoq): Probably revisit the error handling. - if err == ErrClientConnClosing { - return Errorf(codes.FailedPrecondition, "%v", err) + if _, ok := err.(*rpcError); ok { + return err } - if _, ok := err.(transport.StreamError); ok { - return toRPCErr(err) - } - if _, ok := err.(transport.ConnectionError); ok { + if err == errConnClosing { if c.failFast { - return toRPCErr(err) + return Errorf(codes.Unavailable, "%v", errConnClosing) } + continue } - // All the remaining cases are treated as retryable. - continue + // All the other errors are treated as Internal errors. + return Errorf(codes.Internal, "%v", err) } if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) diff --git a/vendor/google.golang.org/grpc/clientconn.go b/vendor/google.golang.org/grpc/clientconn.go index 53a1212..c3c7691 100644 --- a/vendor/google.golang.org/grpc/clientconn.go +++ b/vendor/google.golang.org/grpc/clientconn.go @@ -53,24 +53,28 @@ var ( // ErrClientConnClosing indicates that the operation is illegal because // the ClientConn is closing. ErrClientConnClosing = errors.New("grpc: the client connection is closing") + // ErrClientConnTimeout indicates that the ClientConn cannot establish the + // underlying connections within the specified timeout. + ErrClientConnTimeout = errors.New("grpc: timed out when dialing") // errNoTransportSecurity indicates that there is no transport security // being set for ClientConn. Users should either set one or explicitly // call WithInsecure DialOption to disable security. errNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)") - // errCredentialsMisuse indicates that users want to transmit security information - // (e.g., oauth2 token) which requires secure connection on an insecure + // errTransportCredentialsMissing indicates that users want to transmit security + // information (e.g., oauth2 token) which requires secure connection on an insecure // connection. - errCredentialsMisuse = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportAuthenticator() to set)") - // errClientConnTimeout indicates that the connection could not be - // established or re-established within the specified timeout. - errClientConnTimeout = errors.New("grpc: timed out trying to connect") + errTransportCredentialsMissing = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportCredentials() to set)") + // errCredentialsConflict indicates that grpc.WithTransportCredentials() + // and grpc.WithInsecure() are both called for a connection. + errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)") // errNetworkIP indicates that the connection is down due to some network I/O error. errNetworkIO = errors.New("grpc: failed with network I/O error") // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. errConnDrain = errors.New("grpc: the connection is drained") // errConnClosing indicates that the connection is closing. errConnClosing = errors.New("grpc: the connection is closing") + errNoAddr = errors.New("grpc: there is no address available to dial") // minimum time to give a connection to complete minConnectTimeout = 20 * time.Second ) @@ -85,6 +89,7 @@ type dialOptions struct { balancer Balancer block bool insecure bool + timeout time.Duration copts transport.ConnectOptions } @@ -168,24 +173,25 @@ func WithInsecure() DialOption { // WithTransportCredentials returns a DialOption which configures a // connection level security credentials (e.g., TLS/SSL). -func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption { +func WithTransportCredentials(creds credentials.TransportCredentials) DialOption { return func(o *dialOptions) { - o.copts.AuthOptions = append(o.copts.AuthOptions, creds) + o.copts.TransportCredentials = creds } } // WithPerRPCCredentials returns a DialOption which sets // credentials which will place auth state on each outbound RPC. -func WithPerRPCCredentials(creds credentials.Credentials) DialOption { +func WithPerRPCCredentials(creds credentials.PerRPCCredentials) DialOption { return func(o *dialOptions) { - o.copts.AuthOptions = append(o.copts.AuthOptions, creds) + o.copts.PerRPCCredentials = append(o.copts.PerRPCCredentials, creds) } } -// WithTimeout returns a DialOption that configures a timeout for dialing a client connection. +// WithTimeout returns a DialOption that configures a timeout for dialing a ClientConn +// initially. This is valid if and only if WithBlock() is present. func WithTimeout(d time.Duration) DialOption { return func(o *dialOptions) { - o.copts.Timeout = d + o.timeout = d } } @@ -212,42 +218,62 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { for _, opt := range opts { opt(&cc.dopts) } + + // Set defaults. if cc.dopts.codec == nil { - // Set the default codec. cc.dopts.codec = protoCodec{} } - if cc.dopts.bs == nil { cc.dopts.bs = DefaultBackoffConfig } - - cc.balancer = cc.dopts.balancer - if cc.balancer == nil { - cc.balancer = RoundRobin(nil) + if cc.dopts.balancer == nil { + cc.dopts.balancer = RoundRobin(nil) } - if err := cc.balancer.Start(target); err != nil { + + if err := cc.dopts.balancer.Start(target); err != nil { return nil, err } - ch := cc.balancer.Notify() + var ( + ok bool + addrs []Address + ) + ch := cc.dopts.balancer.Notify() if ch == nil { // There is no name resolver installed. - addr := Address{Addr: target} - if err := cc.newAddrConn(addr, false); err != nil { - return nil, err - } + addrs = append(addrs, Address{Addr: target}) } else { - addrs, ok := <-ch + addrs, ok = <-ch if !ok || len(addrs) == 0 { - return nil, fmt.Errorf("grpc: there is no address available to dial") + return nil, errNoAddr } + } + waitC := make(chan error, 1) + go func() { for _, a := range addrs { if err := cc.newAddrConn(a, false); err != nil { - return nil, err + waitC <- err + return } } + close(waitC) + }() + var timeoutCh <-chan time.Time + if cc.dopts.timeout > 0 { + timeoutCh = time.After(cc.dopts.timeout) + } + select { + case err := <-waitC: + if err != nil { + cc.Close() + return nil, err + } + case <-timeoutCh: + cc.Close() + return nil, ErrClientConnTimeout + } + if ok { go cc.lbWatcher() } - colonPos := strings.LastIndex(target, ":") if colonPos == -1 { colonPos = len(target) @@ -292,7 +318,6 @@ func (s ConnectivityState) String() string { // ClientConn represents a client connection to an RPC server. type ClientConn struct { target string - balancer Balancer authority string dopts dialOptions @@ -301,7 +326,7 @@ type ClientConn struct { } func (cc *ClientConn) lbWatcher() { - for addrs := range cc.balancer.Notify() { + for addrs := range cc.dopts.balancer.Notify() { var ( add []Address // Addresses need to setup connections. del []*addrConn // Connections need to tear down. @@ -345,19 +370,16 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) } if !ac.dopts.insecure { - var ok bool - for _, cd := range ac.dopts.copts.AuthOptions { - if _, ok = cd.(credentials.TransportAuthenticator); ok { - break - } - } - if !ok { + if ac.dopts.copts.TransportCredentials == nil { return errNoTransportSecurity } } else { - for _, cd := range ac.dopts.copts.AuthOptions { + if ac.dopts.copts.TransportCredentials != nil { + return errCredentialsConflict + } + for _, cd := range ac.dopts.copts.PerRPCCredentials { if cd.RequireTransportSecurity() { - return errCredentialsMisuse + return errTransportCredentialsMissing } } } @@ -400,15 +422,14 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { } func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) { - // TODO(zhaoq): Implement fail-fast logic. - addr, put, err := cc.balancer.Get(ctx, opts) + addr, put, err := cc.dopts.balancer.Get(ctx, opts) if err != nil { - return nil, nil, err + return nil, nil, toRPCErr(err) } cc.mu.RLock() if cc.conns == nil { cc.mu.RUnlock() - return nil, nil, ErrClientConnClosing + return nil, nil, toRPCErr(ErrClientConnClosing) } ac, ok := cc.conns[addr] cc.mu.RUnlock() @@ -416,9 +437,9 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) if put != nil { put() } - return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc") + return nil, nil, Errorf(codes.Internal, "grpc: failed to find the transport to send the rpc") } - t, err := ac.wait(ctx) + t, err := ac.wait(ctx, !opts.BlockingWait) if err != nil { if put != nil { put() @@ -438,7 +459,7 @@ func (cc *ClientConn) Close() error { conns := cc.conns cc.conns = nil cc.mu.Unlock() - cc.balancer.Close() + cc.dopts.balancer.Close() for _, ac := range conns { ac.tearDown(ErrClientConnClosing) } @@ -517,7 +538,6 @@ func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState Connecti func (ac *addrConn) resetTransport(closeTransport bool) error { var retries int - start := time.Now() for { ac.mu.Lock() ac.printf("connecting") @@ -537,29 +557,13 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { if closeTransport && t != nil { t.Close() } - // Adjust timeout for the current try. - copts := ac.dopts.copts - if copts.Timeout < 0 { - ac.tearDown(errClientConnTimeout) - return errClientConnTimeout - } - if copts.Timeout > 0 { - copts.Timeout -= time.Since(start) - if copts.Timeout <= 0 { - ac.tearDown(errClientConnTimeout) - return errClientConnTimeout - } - } sleepTime := ac.dopts.bs.backoff(retries) - timeout := sleepTime - if timeout < minConnectTimeout { - timeout = minConnectTimeout - } - if copts.Timeout == 0 || copts.Timeout > timeout { - copts.Timeout = timeout + ac.dopts.copts.Timeout = sleepTime + if sleepTime < minConnectTimeout { + ac.dopts.copts.Timeout = minConnectTimeout } connectTime := time.Now() - newTransport, err := transport.NewClientTransport(ac.addr.Addr, &copts) + newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts) if err != nil { ac.mu.Lock() if ac.state == Shutdown { @@ -579,14 +583,6 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { if sleepTime < 0 { sleepTime = 0 } - // Fail early before falling into sleep. - if ac.dopts.copts.Timeout > 0 && ac.dopts.copts.Timeout < sleepTime+time.Since(start) { - ac.mu.Lock() - ac.errorf("connection timeout") - ac.mu.Unlock() - ac.tearDown(errClientConnTimeout) - return errClientConnTimeout - } closeTransport = false select { case <-time.After(sleepTime): @@ -611,7 +607,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { close(ac.ready) ac.ready = nil } - ac.down = ac.cc.balancer.Up(ac.addr) + ac.down = ac.cc.dopts.balancer.Up(ac.addr) ac.mu.Unlock() return nil } @@ -650,8 +646,9 @@ func (ac *addrConn) transportMonitor() { } } -// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed. -func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) { +// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or +// iv) transport is in TransientFailure and the RPC is fail-fast. +func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTransport, error) { for { ac.mu.Lock() switch { @@ -662,6 +659,9 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) ct := ac.transport ac.mu.Unlock() return ct, nil + case ac.state == TransientFailure && failFast: + ac.mu.Unlock() + return nil, Errorf(codes.Unavailable, "grpc: RPC failed fast due to transport failure") default: ready := ac.ready if ready == nil { @@ -671,7 +671,7 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) ac.mu.Unlock() select { case <-ctx.Done(): - return nil, transport.ContextErr(ctx.Err()) + return nil, toRPCErr(ctx.Err()) // Wait until the new transport is ready or failed. case <-ready: } diff --git a/vendor/google.golang.org/grpc/credentials/credentials.go b/vendor/google.golang.org/grpc/credentials/credentials.go index 681f64e..8d4c57c 100644 --- a/vendor/google.golang.org/grpc/credentials/credentials.go +++ b/vendor/google.golang.org/grpc/credentials/credentials.go @@ -54,9 +54,9 @@ var ( alpnProtoStr = []string{"h2"} ) -// Credentials defines the common interface all supported credentials must -// implement. -type Credentials interface { +// PerRPCCredentials defines the common interface for the credentials which need to +// attach security information to every RPC (e.g., oauth2). +type PerRPCCredentials interface { // GetRequestMetadata gets the current request metadata, refreshing // tokens if required. This should be called by the transport layer on // each request, and the data should be populated in headers or other @@ -66,7 +66,7 @@ type Credentials interface { // TODO(zhaoq): Define the set of the qualified keys instead of leaving // it as an arbitrary string. GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) - // RequireTransportSecurity indicates whether the credentails requires + // RequireTransportSecurity indicates whether the credentials requires // transport security. RequireTransportSecurity() bool } @@ -87,9 +87,9 @@ type AuthInfo interface { AuthType() string } -// TransportAuthenticator defines the common interface for all the live gRPC wire +// TransportCredentials defines the common interface for all the live gRPC wire // protocols and supported transport security protocols (e.g., TLS, SSL). -type TransportAuthenticator interface { +type TransportCredentials interface { // ClientHandshake does the authentication handshake specified by the corresponding // authentication protocol on rawConn for clients. It returns the authenticated // connection and the corresponding auth information about the connection. @@ -98,9 +98,8 @@ type TransportAuthenticator interface { // the authenticated connection and the corresponding auth information about // the connection. ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) - // Info provides the ProtocolInfo of this TransportAuthenticator. + // Info provides the ProtocolInfo of this TransportCredentials. Info() ProtocolInfo - Credentials } // TLSInfo contains the auth information for a TLS authenticated connection. @@ -109,6 +108,7 @@ type TLSInfo struct { State tls.ConnectionState } +// AuthType returns the type of TLSInfo as a string. func (t TLSInfo) AuthType() string { return "tls" } @@ -116,7 +116,7 @@ func (t TLSInfo) AuthType() string { // tlsCreds is the credentials required for authenticating a connection using TLS. type tlsCreds struct { // TLS configuration - config tls.Config + config *tls.Config } func (c tlsCreds) Info() ProtocolInfo { @@ -151,14 +151,16 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D errChannel <- timeoutError{} }) } + // use local cfg to avoid clobbering ServerName if using multiple endpoints + cfg := *c.config if c.config.ServerName == "" { colonPos := strings.LastIndex(addr, ":") if colonPos == -1 { colonPos = len(addr) } - c.config.ServerName = addr[:colonPos] + cfg.ServerName = addr[:colonPos] } - conn := tls.Client(rawConn, &c.config) + conn := tls.Client(rawConn, &cfg) if timeout == 0 { err = conn.Handshake() } else { @@ -177,7 +179,7 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D } func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) { - conn := tls.Server(rawConn, &c.config) + conn := tls.Server(rawConn, c.config) if err := conn.Handshake(); err != nil { rawConn.Close() return nil, nil, err @@ -185,20 +187,20 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) return conn, TLSInfo{conn.ConnectionState()}, nil } -// NewTLS uses c to construct a TransportAuthenticator based on TLS. -func NewTLS(c *tls.Config) TransportAuthenticator { - tc := &tlsCreds{*c} +// NewTLS uses c to construct a TransportCredentials based on TLS. +func NewTLS(c *tls.Config) TransportCredentials { + tc := &tlsCreds{c} tc.config.NextProtos = alpnProtoStr return tc } // NewClientTLSFromCert constructs a TLS from the input certificate for client. -func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportAuthenticator { +func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportCredentials { return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}) } // NewClientTLSFromFile constructs a TLS from the input certificate file for client. -func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator, error) { +func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, error) { b, err := ioutil.ReadFile(certFile) if err != nil { return nil, err @@ -211,13 +213,13 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator, } // NewServerTLSFromCert constructs a TLS from the input certificate for server. -func NewServerTLSFromCert(cert *tls.Certificate) TransportAuthenticator { +func NewServerTLSFromCert(cert *tls.Certificate) TransportCredentials { return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}}) } // NewServerTLSFromFile constructs a TLS from the input certificate file and key // file for server. -func NewServerTLSFromFile(certFile, keyFile string) (TransportAuthenticator, error) { +func NewServerTLSFromFile(certFile, keyFile string) (TransportCredentials, error) { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, err diff --git a/vendor/google.golang.org/grpc/credentials/oauth/oauth.go b/vendor/google.golang.org/grpc/credentials/oauth/oauth.go index 04943fd..8e68c4d 100644 --- a/vendor/google.golang.org/grpc/credentials/oauth/oauth.go +++ b/vendor/google.golang.org/grpc/credentials/oauth/oauth.go @@ -45,7 +45,7 @@ import ( "google.golang.org/grpc/credentials" ) -// TokenSource supplies credentials from an oauth2.TokenSource. +// TokenSource supplies PerRPCCredentials from an oauth2.TokenSource. type TokenSource struct { oauth2.TokenSource } @@ -57,10 +57,11 @@ func (ts TokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (ma return nil, err } return map[string]string{ - "authorization": token.TokenType + " " + token.AccessToken, + "authorization": token.Type() + " " + token.AccessToken, }, nil } +// RequireTransportSecurity indicates whether the credentails requires transport security. func (ts TokenSource) RequireTransportSecurity() bool { return true } @@ -69,7 +70,8 @@ type jwtAccess struct { jsonKey []byte } -func NewJWTAccessFromFile(keyFile string) (credentials.Credentials, error) { +// NewJWTAccessFromFile creates PerRPCCredentials from the given keyFile. +func NewJWTAccessFromFile(keyFile string) (credentials.PerRPCCredentials, error) { jsonKey, err := ioutil.ReadFile(keyFile) if err != nil { return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err) @@ -77,7 +79,8 @@ func NewJWTAccessFromFile(keyFile string) (credentials.Credentials, error) { return NewJWTAccessFromKey(jsonKey) } -func NewJWTAccessFromKey(jsonKey []byte) (credentials.Credentials, error) { +// NewJWTAccessFromKey creates PerRPCCredentials from the given jsonKey. +func NewJWTAccessFromKey(jsonKey []byte) (credentials.PerRPCCredentials, error) { return jwtAccess{jsonKey}, nil } @@ -99,13 +102,13 @@ func (j jwtAccess) RequireTransportSecurity() bool { return true } -// oauthAccess supplies credentials from a given token. +// oauthAccess supplies PerRPCCredentials from a given token. type oauthAccess struct { token oauth2.Token } -// NewOauthAccess constructs the credentials using a given token. -func NewOauthAccess(token *oauth2.Token) credentials.Credentials { +// NewOauthAccess constructs the PerRPCCredentials using a given token. +func NewOauthAccess(token *oauth2.Token) credentials.PerRPCCredentials { return oauthAccess{token: *token} } @@ -119,15 +122,15 @@ func (oa oauthAccess) RequireTransportSecurity() bool { return true } -// NewComputeEngine constructs the credentials that fetches access tokens from +// NewComputeEngine constructs the PerRPCCredentials that fetches access tokens from // Google Compute Engine (GCE)'s metadata server. It is only valid to use this // if your program is running on a GCE instance. // TODO(dsymonds): Deprecate and remove this. -func NewComputeEngine() credentials.Credentials { +func NewComputeEngine() credentials.PerRPCCredentials { return TokenSource{google.ComputeTokenSource("")} } -// serviceAccount represents credentials via JWT signing key. +// serviceAccount represents PerRPCCredentials via JWT signing key. type serviceAccount struct { config *jwt.Config } @@ -146,9 +149,9 @@ func (s serviceAccount) RequireTransportSecurity() bool { return true } -// NewServiceAccountFromKey constructs the credentials using the JSON key slice +// NewServiceAccountFromKey constructs the PerRPCCredentials using the JSON key slice // from a Google Developers service account. -func NewServiceAccountFromKey(jsonKey []byte, scope ...string) (credentials.Credentials, error) { +func NewServiceAccountFromKey(jsonKey []byte, scope ...string) (credentials.PerRPCCredentials, error) { config, err := google.JWTConfigFromJSON(jsonKey, scope...) if err != nil { return nil, err @@ -156,9 +159,9 @@ func NewServiceAccountFromKey(jsonKey []byte, scope ...string) (credentials.Cred return serviceAccount{config: config}, nil } -// NewServiceAccountFromFile constructs the credentials using the JSON key file +// NewServiceAccountFromFile constructs the PerRPCCredentials using the JSON key file // of a Google Developers service account. -func NewServiceAccountFromFile(keyFile string, scope ...string) (credentials.Credentials, error) { +func NewServiceAccountFromFile(keyFile string, scope ...string) (credentials.PerRPCCredentials, error) { jsonKey, err := ioutil.ReadFile(keyFile) if err != nil { return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err) @@ -168,7 +171,7 @@ func NewServiceAccountFromFile(keyFile string, scope ...string) (credentials.Cre // NewApplicationDefault returns "Application Default Credentials". For more // detail, see https://developers.google.com/accounts/docs/application-default-credentials. -func NewApplicationDefault(ctx context.Context, scope ...string) (credentials.Credentials, error) { +func NewApplicationDefault(ctx context.Context, scope ...string) (credentials.PerRPCCredentials, error) { t, err := google.DefaultTokenSource(ctx, scope...) if err != nil { return nil, err diff --git a/vendor/google.golang.org/grpc/rpc_util.go b/vendor/google.golang.org/grpc/rpc_util.go index 06544ad..d628717 100644 --- a/vendor/google.golang.org/grpc/rpc_util.go +++ b/vendor/google.golang.org/grpc/rpc_util.go @@ -61,7 +61,7 @@ type Codec interface { String() string } -// protoCodec is a Codec implemetation with protobuf. It is the default codec for gRPC. +// protoCodec is a Codec implementation with protobuf. It is the default codec for gRPC. type protoCodec struct{} func (protoCodec) Marshal(v interface{}) ([]byte, error) { @@ -141,6 +141,8 @@ type callInfo struct { traceInfo traceInfo // in trace.go } +var defaultCallInfo = callInfo{failFast: true} + // CallOption configures a Call before it starts or extracts information from // a Call after it completes. type CallOption interface { @@ -179,6 +181,19 @@ func Trailer(md *metadata.MD) CallOption { }) } +// FailFast configures the action to take when an RPC is attempted on broken +// connections or unreachable servers. If failfast is true, the RPC will fail +// immediately. Otherwise, the RPC client will block the call until a +// connection is available (or the call is canceled or times out) and will retry +// the call if it fails due to a transient error. Please refer to +// https://github.com/grpc/grpc/blob/master/doc/fail_fast.md +func FailFast(failFast bool) CallOption { + return beforeCall(func(c *callInfo) error { + c.failFast = failFast + return nil + }) +} + // The format of the payload: compressed or not? type payloadFormat uint8 @@ -187,7 +202,7 @@ const ( compressionMade ) -// parser reads complelete gRPC messages from the underlying reader. +// parser reads complete gRPC messages from the underlying reader. type parser struct { // r is the underlying reader. // See the comment on recvMsg for the permissible @@ -319,7 +334,7 @@ type rpcError struct { desc string } -func (e rpcError) Error() string { +func (e *rpcError) Error() string { return fmt.Sprintf("rpc error: code = %d desc = %s", e.code, e.desc) } @@ -329,7 +344,7 @@ func Code(err error) codes.Code { if err == nil { return codes.OK } - if e, ok := err.(rpcError); ok { + if e, ok := err.(*rpcError); ok { return e.code } return codes.Unknown @@ -341,7 +356,7 @@ func ErrorDesc(err error) string { if err == nil { return "" } - if e, ok := err.(rpcError); ok { + if e, ok := err.(*rpcError); ok { return e.desc } return err.Error() @@ -353,7 +368,7 @@ func Errorf(c codes.Code, format string, a ...interface{}) error { if c == codes.OK { return nil } - return rpcError{ + return &rpcError{ code: c, desc: fmt.Sprintf(format, a...), } @@ -362,18 +377,37 @@ func Errorf(c codes.Code, format string, a ...interface{}) error { // toRPCErr converts an error into a rpcError. func toRPCErr(err error) error { switch e := err.(type) { - case rpcError: + case *rpcError: return err case transport.StreamError: - return rpcError{ + return &rpcError{ code: e.Code, desc: e.Desc, } case transport.ConnectionError: - return rpcError{ + return &rpcError{ code: codes.Internal, desc: e.Desc, } + default: + switch err { + case context.DeadlineExceeded: + return &rpcError{ + code: codes.DeadlineExceeded, + desc: err.Error(), + } + case context.Canceled: + return &rpcError{ + code: codes.Canceled, + desc: err.Error(), + } + case ErrClientConnClosing: + return &rpcError{ + code: codes.FailedPrecondition, + desc: err.Error(), + } + } + } return Errorf(codes.Unknown, "%v", err) } @@ -406,10 +440,10 @@ func convertCode(err error) codes.Code { return codes.Unknown } -// SupportPackageIsVersion2 is referenced from generated protocol buffer files +// SupportPackageIsVersion3 is referenced from generated protocol buffer files // to assert that that code is compatible with this version of the grpc package. // // This constant may be renamed in the future if a change in the generated code // requires a synchronised update of grpc-go and protoc-gen-go. This constant // should not be referenced from any other code. -const SupportPackageIsVersion2 = true +const SupportPackageIsVersion3 = true diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go index bfb9c60..a2b2b94 100644 --- a/vendor/google.golang.org/grpc/server.go +++ b/vendor/google.golang.org/grpc/server.go @@ -73,6 +73,7 @@ type ServiceDesc struct { HandlerType interface{} Methods []MethodDesc Streams []StreamDesc + Metadata interface{} } // service consists of the information of the server serving this service and @@ -81,6 +82,7 @@ type service struct { server interface{} // the server for service methods md map[string]*MethodDesc sd map[string]*StreamDesc + mdata interface{} } // Server is a gRPC server to serve RPC requests. @@ -95,7 +97,7 @@ type Server struct { } type options struct { - creds credentials.Credentials + creds credentials.TransportCredentials codec Codec cp Compressor dc Decompressor @@ -138,7 +140,7 @@ func MaxConcurrentStreams(n uint32) ServerOption { } // Creds returns a ServerOption that sets credentials for server connections. -func Creds(c credentials.Credentials) ServerOption { +func Creds(c credentials.TransportCredentials) ServerOption { return func(o *options) { o.creds = c } @@ -230,6 +232,7 @@ func (s *Server) register(sd *ServiceDesc, ss interface{}) { server: ss, md: make(map[string]*MethodDesc), sd: make(map[string]*StreamDesc), + mdata: sd.Metadata, } for i := range sd.Methods { d := &sd.Methods[i] @@ -242,6 +245,52 @@ func (s *Server) register(sd *ServiceDesc, ss interface{}) { s.m[sd.ServiceName] = srv } +// MethodInfo contains the information of an RPC including its method name and type. +type MethodInfo struct { + // Name is the method name only, without the service name or package name. + Name string + // IsClientStream indicates whether the RPC is a client streaming RPC. + IsClientStream bool + // IsServerStream indicates whether the RPC is a server streaming RPC. + IsServerStream bool +} + +// ServiceInfo contains unary RPC method info, streaming RPC methid info and metadata for a service. +type ServiceInfo struct { + Methods []MethodInfo + // Metadata is the metadata specified in ServiceDesc when registering service. + Metadata interface{} +} + +// GetServiceInfo returns a map from service names to ServiceInfo. +// Service names include the package names, in the form of .. +func (s *Server) GetServiceInfo() map[string]*ServiceInfo { + ret := make(map[string]*ServiceInfo) + for n, srv := range s.m { + methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd)) + for m := range srv.md { + methods = append(methods, MethodInfo{ + Name: m, + IsClientStream: false, + IsServerStream: false, + }) + } + for m, d := range srv.sd { + methods = append(methods, MethodInfo{ + Name: m, + IsClientStream: d.ClientStreams, + IsServerStream: d.ServerStreams, + }) + } + + ret[n] = &ServiceInfo{ + Methods: methods, + Metadata: srv.mdata, + } + } + return ret +} + var ( // ErrServerStopped indicates that the operation is now illegal because of // the server being stopped. @@ -249,11 +298,10 @@ var ( ) func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - creds, ok := s.opts.creds.(credentials.TransportAuthenticator) - if !ok { + if s.opts.creds == nil { return rawConn, nil, nil } - return creds.ServerHandshake(rawConn) + return s.opts.creds.ServerHandshake(rawConn) } // Serve accepts incoming connections on the listener lis, creating a new @@ -272,9 +320,11 @@ func (s *Server) Serve(lis net.Listener) error { s.lis[lis] = true s.mu.Unlock() defer func() { - lis.Close() s.mu.Lock() - delete(s.lis, lis) + if s.lis != nil && s.lis[lis] { + lis.Close() + delete(s.lis, lis) + } s.mu.Unlock() }() for { @@ -529,7 +579,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) if appErr != nil { - if err, ok := appErr.(rpcError); ok { + if err, ok := appErr.(*rpcError); ok { statusCode = err.code statusDesc = err.desc } else { @@ -614,7 +664,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler) } if appErr != nil { - if err, ok := appErr.(rpcError); ok { + if err, ok := appErr.(*rpcError); ok { ss.statusCode = err.code ss.statusDesc = err.desc } else if err, ok := appErr.(transport.StreamError); ok { diff --git a/vendor/google.golang.org/grpc/stream.go b/vendor/google.golang.org/grpc/stream.go index de125d5..7a3bef5 100644 --- a/vendor/google.golang.org/grpc/stream.go +++ b/vendor/google.golang.org/grpc/stream.go @@ -79,7 +79,7 @@ type Stream interface { RecvMsg(m interface{}) error } -// ClientStream defines the interface a client stream has to satify. +// ClientStream defines the interface a client stream has to satisfy. type ClientStream interface { // Header returns the header metadata received from the server if there // is any. It blocks if the metadata is not ready to read. @@ -102,16 +102,15 @@ type ClientStream interface { func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { var ( t transport.ClientTransport + s *transport.Stream err error put func() ) - // TODO(zhaoq): CallOption is omitted. Add support when it is needed. - gopts := BalancerGetOptions{ - BlockingWait: false, - } - t, put, err = cc.getTransport(ctx, gopts) - if err != nil { - return nil, toRPCErr(err) + c := defaultCallInfo + for _, o := range opts { + if err := o.before(&c); err != nil { + return nil, toRPCErr(err) + } } callHdr := &transport.CallHdr{ Host: cc.authority, @@ -122,8 +121,9 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth callHdr.SendCompress = cc.dopts.cp.Type() } cs := &clientStream{ + opts: opts, + c: c, desc: desc, - put: put, codec: cc.dopts.codec, cp: cc.dopts.cp, dc: cc.dopts.dc, @@ -142,11 +142,44 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth cs.trInfo.tr.LazyLog(&cs.trInfo.firstLine, false) ctx = trace.NewContext(ctx, cs.trInfo.tr) } - s, err := t.NewStream(ctx, callHdr) - if err != nil { - cs.finish(err) - return nil, toRPCErr(err) + gopts := BalancerGetOptions{ + BlockingWait: !c.failFast, } + for { + t, put, err = cc.getTransport(ctx, gopts) + if err != nil { + // TODO(zhaoq): Probably revisit the error handling. + if _, ok := err.(*rpcError); ok { + return nil, err + } + if err == errConnClosing { + if c.failFast { + return nil, Errorf(codes.Unavailable, "%v", errConnClosing) + } + continue + } + // All the other errors are treated as Internal errors. + return nil, Errorf(codes.Internal, "%v", err) + } + + s, err = t.NewStream(ctx, callHdr) + if err != nil { + if put != nil { + put() + put = nil + } + if _, ok := err.(transport.ConnectionError); ok { + if c.failFast { + cs.finish(err) + return nil, toRPCErr(err) + } + continue + } + return nil, toRPCErr(err) + } + break + } + cs.put = put cs.t = t cs.s = s cs.p = &parser{r: s} @@ -167,6 +200,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // clientStream implements a client side Stream. type clientStream struct { + opts []CallOption + c callInfo t transport.ClientTransport s *transport.Stream p *parser @@ -312,15 +347,18 @@ func (cs *clientStream) closeTransportStream(err error) { } func (cs *clientStream) finish(err error) { - if !cs.tracing { - return - } cs.mu.Lock() defer cs.mu.Unlock() + for _, o := range cs.opts { + o.after(&cs.c) + } if cs.put != nil { cs.put() cs.put = nil } + if !cs.tracing { + return + } if cs.trInfo.tr != nil { if err == nil || err == io.EOF { cs.trInfo.tr.LazyPrintf("RPC: [OK]") diff --git a/vendor/google.golang.org/grpc/transport/handler_server.go b/vendor/google.golang.org/grpc/transport/handler_server.go index 7a4ae07..4b0d525 100644 --- a/vendor/google.golang.org/grpc/transport/handler_server.go +++ b/vendor/google.golang.org/grpc/transport/handler_server.go @@ -65,7 +65,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr if r.Method != "POST" { return nil, errors.New("invalid gRPC request method") } - if !strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { + if !validContentType(r.Header.Get("Content-Type")) { return nil, errors.New("invalid gRPC request content-type") } if _, ok := w.(http.Flusher); !ok { @@ -97,7 +97,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr } for k, vv := range r.Header { k = strings.ToLower(k) - if isReservedHeader(k) && !isWhitelistedPseudoHeader(k){ + if isReservedHeader(k) && !isWhitelistedPseudoHeader(k) { continue } for _, v := range vv { @@ -312,7 +312,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) { Addr: ht.RemoteAddr(), } if req.TLS != nil { - pr.AuthInfo = credentials.TLSInfo{*req.TLS} + pr.AuthInfo = credentials.TLSInfo{State: *req.TLS} } ctx = metadata.NewContext(ctx, ht.headerMD) ctx = peer.NewContext(ctx, pr) diff --git a/vendor/google.golang.org/grpc/transport/http2_client.go b/vendor/google.golang.org/grpc/transport/http2_client.go index e624f8d..f66435f 100644 --- a/vendor/google.golang.org/grpc/transport/http2_client.go +++ b/vendor/google.golang.org/grpc/transport/http2_client.go @@ -88,7 +88,7 @@ type http2Client struct { // The scheme used: https if TLS is on, http otherwise. scheme string - authCreds []credentials.Credentials + creds []credentials.PerRPCCredentials mu sync.Mutex // guard the following variables state transportState // the state of underlying connection @@ -117,19 +117,12 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e return nil, ConnectionErrorf("transport: %v", connErr) } var authInfo credentials.AuthInfo - for _, c := range opts.AuthOptions { - if ccreds, ok := c.(credentials.TransportAuthenticator); ok { - scheme = "https" - // TODO(zhaoq): Now the first TransportAuthenticator is used if there are - // multiple ones provided. Revisit this if it is not appropriate. Probably - // place the ClientTransport construction into a separate function to make - // things clear. - if timeout > 0 { - timeout -= time.Since(startT) - } - conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout) - break + if opts.TransportCredentials != nil { + scheme = "https" + if timeout > 0 { + timeout -= time.Since(startT) } + conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout) } if connErr != nil { return nil, ConnectionErrorf("transport: %v", connErr) @@ -163,7 +156,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e scheme: scheme, state: reachable, activeStreams: make(map[uint32]*Stream), - authCreds: opts.AuthOptions, + creds: opts.PerRPCCredentials, maxStreams: math.MaxInt32, streamSendQuota: defaultWindowSize, } @@ -182,7 +175,10 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) } if initialWindowSize != defaultWindowSize { - err = t.framer.writeSettings(true, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)}) + err = t.framer.writeSettings(true, http2.Setting{ + ID: http2.SettingInitialWindowSize, + Val: uint32(initialWindowSize), + }) } else { err = t.framer.writeSettings(true) } @@ -248,7 +244,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } ctx = peer.NewContext(ctx, pr) authData := make(map[string]string) - for _, c := range t.authCreds { + for _, c := range t.creds { // Construct URI required to get auth request metadata. var port string if pos := strings.LastIndex(t.target, ":"); pos != -1 { diff --git a/vendor/google.golang.org/grpc/transport/http2_server.go b/vendor/google.golang.org/grpc/transport/http2_server.go index 1c4d585..cee1542 100644 --- a/vendor/google.golang.org/grpc/transport/http2_server.go +++ b/vendor/google.golang.org/grpc/transport/http2_server.go @@ -100,10 +100,15 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI if maxStreams == 0 { maxStreams = math.MaxUint32 } else { - settings = append(settings, http2.Setting{http2.SettingMaxConcurrentStreams, maxStreams}) + settings = append(settings, http2.Setting{ + ID: http2.SettingMaxConcurrentStreams, + Val: maxStreams, + }) } if initialWindowSize != defaultWindowSize { - settings = append(settings, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)}) + settings = append(settings, http2.Setting{ + ID: http2.SettingInitialWindowSize, + Val: uint32(initialWindowSize)}) } if err := framer.writeSettings(true, settings...); err != nil { return nil, ConnectionErrorf("transport: %v", err) diff --git a/vendor/google.golang.org/grpc/transport/http_util.go b/vendor/google.golang.org/grpc/transport/http_util.go index a4b1b07..f2e23dc 100644 --- a/vendor/google.golang.org/grpc/transport/http_util.go +++ b/vendor/google.golang.org/grpc/transport/http_util.go @@ -144,10 +144,23 @@ func (d *decodeState) setErr(err error) { } } +func validContentType(t string) bool { + e := "application/grpc" + if !strings.HasPrefix(t, e) { + return false + } + // Support variations on the content-type + // (e.g. "application/grpc+blah", "application/grpc;blah"). + if len(t) > len(e) && t[len(e)] != '+' && t[len(e)] != ';' { + return false + } + return true +} + func (d *decodeState) processHeaderField(f hpack.HeaderField) { switch f.Name { case "content-type": - if !strings.Contains(f.Value, "application/grpc") { + if !validContentType(f.Value) { d.setErr(StreamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value)) return } diff --git a/vendor/google.golang.org/grpc/transport/transport.go b/vendor/google.golang.org/grpc/transport/transport.go index 1c9af54..d4c220a 100644 --- a/vendor/google.golang.org/grpc/transport/transport.go +++ b/vendor/google.golang.org/grpc/transport/transport.go @@ -336,9 +336,11 @@ type ConnectOptions struct { UserAgent string // Dialer specifies how to dial a network address. Dialer func(string, time.Duration) (net.Conn, error) - // AuthOptions stores the credentials required to setup a client connection and/or issue RPCs. - AuthOptions []credentials.Credentials - // Timeout specifies the timeout for dialing a client connection. + // PerRPCCredentials stores the PerRPCCredentials required to issue RPCs. + PerRPCCredentials []credentials.PerRPCCredentials + // TransportCredentials stores the Authenticator required to setup a client connection. + TransportCredentials credentials.TransportCredentials + // Timeout specifies the timeout for dialing a ClientTransport. Timeout time.Duration } @@ -473,7 +475,7 @@ func (e ConnectionError) Error() string { return fmt.Sprintf("connection error: desc = %q", e.Desc) } -// Define some common ConnectionErrors. +// ErrConnClosing indicates that the transport is closing. var ErrConnClosing = ConnectionError{Desc: "transport is closing"} // StreamError is an error that only affects one stream within a connection. -- cgit v1.2.3