aboutsummaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/grpc/credentials/credentials.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/credentials/credentials.go')
-rw-r--r--vendor/google.golang.org/grpc/credentials/credentials.go60
1 files changed, 26 insertions, 34 deletions
diff --git a/vendor/google.golang.org/grpc/credentials/credentials.go b/vendor/google.golang.org/grpc/credentials/credentials.go
index 8d4c57c..13be457 100644
--- a/vendor/google.golang.org/grpc/credentials/credentials.go
+++ b/vendor/google.golang.org/grpc/credentials/credentials.go
@@ -40,11 +40,11 @@ package credentials // import "google.golang.org/grpc/credentials"
import (
"crypto/tls"
"crypto/x509"
+ "errors"
"fmt"
"io/ioutil"
"net"
"strings"
- "time"
"golang.org/x/net/context"
)
@@ -87,17 +87,24 @@ type AuthInfo interface {
AuthType() string
}
+var (
+ // ErrConnDispatched indicates that rawConn has been dispatched out of gRPC
+ // and the caller should not close rawConn.
+ ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC")
+)
+
// TransportCredentials defines the common interface for all the live gRPC wire
// protocols and supported transport security protocols (e.g., TLS, SSL).
type TransportCredentials interface {
// ClientHandshake does the authentication handshake specified by the corresponding
// authentication protocol on rawConn for clients. It returns the authenticated
// connection and the corresponding auth information about the connection.
- ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, AuthInfo, error)
+ // Implementations must use the provided context to implement timely cancellation.
+ ClientHandshake(context.Context, string, net.Conn) (net.Conn, AuthInfo, error)
// ServerHandshake does the authentication handshake for servers. It returns
// the authenticated connection and the corresponding auth information about
// the connection.
- ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
+ ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
// Info provides the ProtocolInfo of this TransportCredentials.
Info() ProtocolInfo
}
@@ -136,42 +143,28 @@ func (c *tlsCreds) RequireTransportSecurity() bool {
return true
}
-type timeoutError struct{}
-
-func (timeoutError) Error() string { return "credentials: Dial timed out" }
-func (timeoutError) Timeout() bool { return true }
-func (timeoutError) Temporary() bool { return true }
-
-func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ AuthInfo, err error) {
- // borrow some code from tls.DialWithDialer
- var errChannel chan error
- if timeout != 0 {
- errChannel = make(chan error, 2)
- time.AfterFunc(timeout, func() {
- errChannel <- timeoutError{}
- })
- }
+func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
// use local cfg to avoid clobbering ServerName if using multiple endpoints
- cfg := *c.config
- if c.config.ServerName == "" {
+ cfg := cloneTLSConfig(c.config)
+ if cfg.ServerName == "" {
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
cfg.ServerName = addr[:colonPos]
}
- conn := tls.Client(rawConn, &cfg)
- if timeout == 0 {
- err = conn.Handshake()
- } else {
- go func() {
- errChannel <- conn.Handshake()
- }()
- err = <-errChannel
- }
- if err != nil {
- rawConn.Close()
- return nil, nil, err
+ conn := tls.Client(rawConn, cfg)
+ errChannel := make(chan error, 1)
+ go func() {
+ errChannel <- conn.Handshake()
+ }()
+ select {
+ case err := <-errChannel:
+ if err != nil {
+ return nil, nil, err
+ }
+ case <-ctx.Done():
+ return nil, nil, ctx.Err()
}
// TODO(zhaoq): Omit the auth info for client now. It is more for
// information than anything else.
@@ -181,7 +174,6 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
conn := tls.Server(rawConn, c.config)
if err := conn.Handshake(); err != nil {
- rawConn.Close()
return nil, nil, err
}
return conn, TLSInfo{conn.ConnectionState()}, nil
@@ -189,7 +181,7 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
// NewTLS uses c to construct a TransportCredentials based on TLS.
func NewTLS(c *tls.Config) TransportCredentials {
- tc := &tlsCreds{c}
+ tc := &tlsCreds{cloneTLSConfig(c)}
tc.config.NextProtos = alpnProtoStr
return tc
}