aboutsummaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/grpc/transport/http2_client.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/transport/http2_client.go')
-rw-r--r--vendor/google.golang.org/grpc/transport/http2_client.go237
1 files changed, 172 insertions, 65 deletions
diff --git a/vendor/google.golang.org/grpc/transport/http2_client.go b/vendor/google.golang.org/grpc/transport/http2_client.go
index f66435f..afbba45 100644
--- a/vendor/google.golang.org/grpc/transport/http2_client.go
+++ b/vendor/google.golang.org/grpc/transport/http2_client.go
@@ -35,6 +35,7 @@ package transport
import (
"bytes"
+ "fmt"
"io"
"math"
"net"
@@ -71,6 +72,9 @@ type http2Client struct {
shutdownChan chan struct{}
// errorChan is closed to notify the I/O error to the caller.
errorChan chan struct{}
+ // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
+ // that the server sent GoAway on this transport.
+ goAway chan struct{}
framer *framer
hBuf *bytes.Buffer // the buffer for HPACK encoding
@@ -97,41 +101,73 @@ type http2Client struct {
maxStreams int
// the per-stream outbound flow control window size set by the peer.
streamSendQuota uint32
+ // goAwayID records the Last-Stream-ID in the GoAway frame from the server.
+ goAwayID uint32
+ // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
+ prevGoAwayID uint32
+}
+
+func dial(fn func(context.Context, string) (net.Conn, error), ctx context.Context, addr string) (net.Conn, error) {
+ if fn != nil {
+ return fn(ctx, addr)
+ }
+ return dialContext(ctx, "tcp", addr)
+}
+
+func isTemporary(err error) bool {
+ switch err {
+ case io.EOF:
+ // Connection closures may be resolved upon retry, and are thus
+ // treated as temporary.
+ return true
+ case context.DeadlineExceeded:
+ // In Go 1.7, context.DeadlineExceeded implements Timeout(), and this
+ // special case is not needed. Until then, we need to keep this
+ // clause.
+ return true
+ }
+
+ switch err := err.(type) {
+ case interface {
+ Temporary() bool
+ }:
+ return err.Temporary()
+ case interface {
+ Timeout() bool
+ }:
+ // Timeouts may be resolved upon retry, and are thus treated as
+ // temporary.
+ return err.Timeout()
+ }
+ return false
}
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction
// fails.
-func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) {
- if opts.Dialer == nil {
- // Set the default Dialer.
- opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) {
- return net.DialTimeout("tcp", addr, timeout)
- }
- }
+func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) {
scheme := "http"
- startT := time.Now()
- timeout := opts.Timeout
- conn, connErr := opts.Dialer(addr, timeout)
- if connErr != nil {
- return nil, ConnectionErrorf("transport: %v", connErr)
+ conn, err := dial(opts.Dialer, ctx, addr)
+ if err != nil {
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
+ // Any further errors will close the underlying connection
+ defer func(conn net.Conn) {
+ if err != nil {
+ conn.Close()
+ }
+ }(conn)
var authInfo credentials.AuthInfo
- if opts.TransportCredentials != nil {
+ if creds := opts.TransportCredentials; creds != nil {
scheme = "https"
- if timeout > 0 {
- timeout -= time.Since(startT)
- }
- conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout)
- }
- if connErr != nil {
- return nil, ConnectionErrorf("transport: %v", connErr)
- }
- defer func() {
+ conn, authInfo, err = creds.ClientHandshake(ctx, addr, conn)
if err != nil {
- conn.Close()
+ // Credentials handshake errors are typically considered permanent
+ // to avoid retrying on e.g. bad certificates.
+ temp := isTemporary(err)
+ return nil, ConnectionErrorf(temp, err, "transport: %v", err)
}
- }()
+ }
ua := primaryUA
if opts.UserAgent != "" {
ua = opts.UserAgent + " " + ua
@@ -147,6 +183,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
writableChan: make(chan int, 1),
shutdownChan: make(chan struct{}),
errorChan: make(chan struct{}),
+ goAway: make(chan struct{}),
framer: newFramer(conn),
hBuf: &buf,
hEnc: hpack.NewEncoder(&buf),
@@ -168,11 +205,11 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
n, err := t.conn.Write(clientPreface)
if err != nil {
t.Close()
- return nil, ConnectionErrorf("transport: %v", err)
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
if n != len(clientPreface) {
t.Close()
- return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
+ return nil, ConnectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
}
if initialWindowSize != defaultWindowSize {
err = t.framer.writeSettings(true, http2.Setting{
@@ -184,13 +221,13 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
}
if err != nil {
t.Close()
- return nil, ConnectionErrorf("transport: %v", err)
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
// Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
t.Close()
- return nil, ConnectionErrorf("transport: %v", err)
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
}
go t.controller()
@@ -202,6 +239,8 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{
id: t.nextID,
+ done: make(chan struct{}),
+ goAway: make(chan struct{}),
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
@@ -216,8 +255,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// Make a stream be able to cancel the pending operations by itself.
s.ctx, s.cancel = context.WithCancel(ctx)
s.dec = &recvBufferReader{
- ctx: s.ctx,
- recv: s.buf,
+ ctx: s.ctx,
+ goAway: s.goAway,
+ recv: s.buf,
}
return s
}
@@ -271,6 +311,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.mu.Unlock()
return nil, ErrConnClosing
}
+ if t.state == draining {
+ t.mu.Unlock()
+ return nil, ErrStreamDrain
+ }
if t.state != reachable {
t.mu.Unlock()
return nil, ErrConnClosing
@@ -278,7 +322,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
checkStreamsQuota := t.streamsQuota != nil
t.mu.Unlock()
if checkStreamsQuota {
- sq, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire())
+ sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire())
if err != nil {
return nil, err
}
@@ -287,7 +331,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.streamsQuota.add(sq - 1)
}
}
- if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil {
+ if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
// Return the quota back now because there is no stream returned to the caller.
if _, ok := err.(StreamError); ok && checkStreamsQuota {
t.streamsQuota.add(1)
@@ -295,6 +339,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
return nil, err
}
t.mu.Lock()
+ if t.state == draining {
+ t.mu.Unlock()
+ if checkStreamsQuota {
+ t.streamsQuota.add(1)
+ }
+ // Need to make t writable again so that the rpc in flight can still proceed.
+ t.writableChan <- 0
+ return nil, ErrStreamDrain
+ }
if t.state != reachable {
t.mu.Unlock()
return nil, ErrConnClosing
@@ -329,7 +382,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
}
if timeout > 0 {
- t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)})
+ t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
}
for k, v := range authData {
// Capital header names are illegal in HTTP/2.
@@ -384,7 +437,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
}
if err != nil {
t.notifyError(err)
- return nil, ConnectionErrorf("transport: %v", err)
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
}
t.writableChan <- 0
@@ -403,22 +456,17 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
if t.streamsQuota != nil {
updateStreams = true
}
- if t.state == draining && len(t.activeStreams) == 1 {
+ delete(t.activeStreams, s.id)
+ if t.state == draining && len(t.activeStreams) == 0 {
// The transport is draining and s is the last live stream on t.
t.mu.Unlock()
t.Close()
return
}
- delete(t.activeStreams, s.id)
t.mu.Unlock()
if updateStreams {
t.streamsQuota.add(1)
}
- // In case stream sending and receiving are invoked in separate
- // goroutines (e.g., bi-directional streaming), the caller needs
- // to call cancel on the stream to interrupt the blocking on
- // other goroutines.
- s.cancel()
s.mu.Lock()
if q := s.fc.resetPendingData(); q > 0 {
if n := t.fc.onRead(q); n > 0 {
@@ -445,13 +493,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
// accessed any more.
func (t *http2Client) Close() (err error) {
t.mu.Lock()
- if t.state == reachable {
- close(t.errorChan)
- }
if t.state == closing {
t.mu.Unlock()
return
}
+ if t.state == reachable || t.state == draining {
+ close(t.errorChan)
+ }
t.state = closing
t.mu.Unlock()
close(t.shutdownChan)
@@ -475,10 +523,35 @@ func (t *http2Client) Close() (err error) {
func (t *http2Client) GracefulClose() error {
t.mu.Lock()
- if t.state == closing {
+ switch t.state {
+ case unreachable:
+ // The server may close the connection concurrently. t is not available for
+ // any streams. Close it now.
+ t.mu.Unlock()
+ t.Close()
+ return nil
+ case closing:
t.mu.Unlock()
return nil
}
+ // Notify the streams which were initiated after the server sent GOAWAY.
+ select {
+ case <-t.goAway:
+ n := t.prevGoAwayID
+ if n == 0 && t.nextID > 1 {
+ n = t.nextID - 2
+ }
+ m := t.goAwayID + 2
+ if m == 2 {
+ m = 1
+ }
+ for i := m; i <= n; i += 2 {
+ if s, ok := t.activeStreams[i]; ok {
+ close(s.goAway)
+ }
+ }
+ default:
+ }
if t.state == draining {
t.mu.Unlock()
return nil
@@ -504,15 +577,15 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
size := http2MaxFrameLen
s.sendQuotaPool.add(0)
// Wait until the stream has some quota to send the data.
- sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire())
+ sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire())
if err != nil {
return err
}
t.sendQuotaPool.add(0)
// Wait until the transport has some quota to send the data.
- tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire())
+ tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire())
if err != nil {
- if _, ok := err.(StreamError); ok {
+ if _, ok := err.(StreamError); ok || err == io.EOF {
t.sendQuotaPool.cancel()
}
return err
@@ -544,8 +617,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
// Indicate there is a writer who is about to write a data frame.
t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the transport.
- if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
- if _, ok := err.(StreamError); ok {
+ if _, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.writableChan); err != nil {
+ if _, ok := err.(StreamError); ok || err == io.EOF {
// Return the connection quota back.
t.sendQuotaPool.add(len(p))
}
@@ -578,7 +651,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
// invoked.
if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil {
t.notifyError(err)
- return ConnectionErrorf("transport: %v", err)
+ return ConnectionErrorf(true, err, "transport: %v", err)
}
if t.framer.adjustNumWriters(-1) == 0 {
t.framer.flushWrite()
@@ -593,11 +666,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
}
s.mu.Lock()
if s.state != streamDone {
- if s.state == streamReadDone {
- s.state = streamDone
- } else {
- s.state = streamWriteDone
- }
+ s.state = streamWriteDone
}
s.mu.Unlock()
return nil
@@ -630,7 +699,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
func (t *http2Client) handleData(f *http2.DataFrame) {
size := len(f.Data())
if err := t.fc.onData(uint32(size)); err != nil {
- t.notifyError(ConnectionErrorf("%v", err))
+ t.notifyError(ConnectionErrorf(true, err, "%v", err))
return
}
// Select the right stream to dispatch.
@@ -655,6 +724,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
s.state = streamDone
s.statusCode = codes.Internal
s.statusDesc = err.Error()
+ close(s.done)
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
@@ -672,13 +742,14 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// the read direction is closed, and set the status appropriately.
if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
s.mu.Lock()
- if s.state == streamWriteDone {
- s.state = streamDone
- } else {
- s.state = streamReadDone
+ if s.state == streamDone {
+ s.mu.Unlock()
+ return
}
+ s.state = streamDone
s.statusCode = codes.Internal
s.statusDesc = "server closed the stream without sending trailers"
+ close(s.done)
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
}
@@ -704,6 +775,8 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode)
s.statusCode = codes.Unknown
}
+ s.statusDesc = fmt.Sprintf("stream terminated by RST_STREAM with error code: %d", f.ErrCode)
+ close(s.done)
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
}
@@ -728,7 +801,32 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
}
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
- // TODO(zhaoq): GoAwayFrame handler to be implemented
+ t.mu.Lock()
+ if t.state == reachable || t.state == draining {
+ if f.LastStreamID > 0 && f.LastStreamID%2 != 1 {
+ t.mu.Unlock()
+ t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID))
+ return
+ }
+ select {
+ case <-t.goAway:
+ id := t.goAwayID
+ // t.goAway has been closed (i.e.,multiple GoAways).
+ if id < f.LastStreamID {
+ t.mu.Unlock()
+ t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID))
+ return
+ }
+ t.prevGoAwayID = id
+ t.goAwayID = f.LastStreamID
+ t.mu.Unlock()
+ return
+ default:
+ }
+ t.goAwayID = f.LastStreamID
+ close(t.goAway)
+ }
+ t.mu.Unlock()
}
func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) {
@@ -780,11 +878,11 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
if len(state.mdata) > 0 {
s.trailer = state.mdata
}
- s.state = streamDone
s.statusCode = state.statusCode
s.statusDesc = state.statusDesc
+ close(s.done)
+ s.state = streamDone
s.mu.Unlock()
-
s.write(recvMsg{err: io.EOF})
}
@@ -937,13 +1035,22 @@ func (t *http2Client) Error() <-chan struct{} {
return t.errorChan
}
+func (t *http2Client) GoAway() <-chan struct{} {
+ return t.goAway
+}
+
func (t *http2Client) notifyError(err error) {
t.mu.Lock()
- defer t.mu.Unlock()
// make sure t.errorChan is closed only once.
+ if t.state == draining {
+ t.mu.Unlock()
+ t.Close()
+ return
+ }
if t.state == reachable {
t.state = unreachable
close(t.errorChan)
grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err)
}
+ t.mu.Unlock()
}