From 921818bca208f0c70e85ec670074cb3905cbbc82 Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Sat, 27 Aug 2016 01:32:30 +0100 Subject: Update dependencies --- .../grpc/transport/http2_client.go | 237 +++++++++++++++------ 1 file changed, 172 insertions(+), 65 deletions(-) (limited to 'vendor/google.golang.org/grpc/transport/http2_client.go') diff --git a/vendor/google.golang.org/grpc/transport/http2_client.go b/vendor/google.golang.org/grpc/transport/http2_client.go index f66435f..afbba45 100644 --- a/vendor/google.golang.org/grpc/transport/http2_client.go +++ b/vendor/google.golang.org/grpc/transport/http2_client.go @@ -35,6 +35,7 @@ package transport import ( "bytes" + "fmt" "io" "math" "net" @@ -71,6 +72,9 @@ type http2Client struct { shutdownChan chan struct{} // errorChan is closed to notify the I/O error to the caller. errorChan chan struct{} + // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) + // that the server sent GoAway on this transport. + goAway chan struct{} framer *framer hBuf *bytes.Buffer // the buffer for HPACK encoding @@ -97,41 +101,73 @@ type http2Client struct { maxStreams int // the per-stream outbound flow control window size set by the peer. streamSendQuota uint32 + // goAwayID records the Last-Stream-ID in the GoAway frame from the server. + goAwayID uint32 + // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame. + prevGoAwayID uint32 +} + +func dial(fn func(context.Context, string) (net.Conn, error), ctx context.Context, addr string) (net.Conn, error) { + if fn != nil { + return fn(ctx, addr) + } + return dialContext(ctx, "tcp", addr) +} + +func isTemporary(err error) bool { + switch err { + case io.EOF: + // Connection closures may be resolved upon retry, and are thus + // treated as temporary. + return true + case context.DeadlineExceeded: + // In Go 1.7, context.DeadlineExceeded implements Timeout(), and this + // special case is not needed. Until then, we need to keep this + // clause. + return true + } + + switch err := err.(type) { + case interface { + Temporary() bool + }: + return err.Temporary() + case interface { + Timeout() bool + }: + // Timeouts may be resolved upon retry, and are thus treated as + // temporary. + return err.Timeout() + } + return false } // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // and starts to receive messages on it. Non-nil error returns if construction // fails. -func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) { - if opts.Dialer == nil { - // Set the default Dialer. - opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout("tcp", addr, timeout) - } - } +func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) { scheme := "http" - startT := time.Now() - timeout := opts.Timeout - conn, connErr := opts.Dialer(addr, timeout) - if connErr != nil { - return nil, ConnectionErrorf("transport: %v", connErr) + conn, err := dial(opts.Dialer, ctx, addr) + if err != nil { + return nil, ConnectionErrorf(true, err, "transport: %v", err) } + // Any further errors will close the underlying connection + defer func(conn net.Conn) { + if err != nil { + conn.Close() + } + }(conn) var authInfo credentials.AuthInfo - if opts.TransportCredentials != nil { + if creds := opts.TransportCredentials; creds != nil { scheme = "https" - if timeout > 0 { - timeout -= time.Since(startT) - } - conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout) - } - if connErr != nil { - return nil, ConnectionErrorf("transport: %v", connErr) - } - defer func() { + conn, authInfo, err = creds.ClientHandshake(ctx, addr, conn) if err != nil { - conn.Close() + // Credentials handshake errors are typically considered permanent + // to avoid retrying on e.g. bad certificates. + temp := isTemporary(err) + return nil, ConnectionErrorf(temp, err, "transport: %v", err) } - }() + } ua := primaryUA if opts.UserAgent != "" { ua = opts.UserAgent + " " + ua @@ -147,6 +183,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e writableChan: make(chan int, 1), shutdownChan: make(chan struct{}), errorChan: make(chan struct{}), + goAway: make(chan struct{}), framer: newFramer(conn), hBuf: &buf, hEnc: hpack.NewEncoder(&buf), @@ -168,11 +205,11 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e n, err := t.conn.Write(clientPreface) if err != nil { t.Close() - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } if n != len(clientPreface) { t.Close() - return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) + return nil, ConnectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) } if initialWindowSize != defaultWindowSize { err = t.framer.writeSettings(true, http2.Setting{ @@ -184,13 +221,13 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e } if err != nil { t.Close() - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } // Adjust the connection flow control window if needed. if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil { t.Close() - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } } go t.controller() @@ -202,6 +239,8 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { // TODO(zhaoq): Handle uint32 overflow of Stream.id. s := &Stream{ id: t.nextID, + done: make(chan struct{}), + goAway: make(chan struct{}), method: callHdr.Method, sendCompress: callHdr.SendCompress, buf: newRecvBuffer(), @@ -216,8 +255,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { // Make a stream be able to cancel the pending operations by itself. s.ctx, s.cancel = context.WithCancel(ctx) s.dec = &recvBufferReader{ - ctx: s.ctx, - recv: s.buf, + ctx: s.ctx, + goAway: s.goAway, + recv: s.buf, } return s } @@ -271,6 +311,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea t.mu.Unlock() return nil, ErrConnClosing } + if t.state == draining { + t.mu.Unlock() + return nil, ErrStreamDrain + } if t.state != reachable { t.mu.Unlock() return nil, ErrConnClosing @@ -278,7 +322,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea checkStreamsQuota := t.streamsQuota != nil t.mu.Unlock() if checkStreamsQuota { - sq, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire()) + sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire()) if err != nil { return nil, err } @@ -287,7 +331,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea t.streamsQuota.add(sq - 1) } } - if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { // Return the quota back now because there is no stream returned to the caller. if _, ok := err.(StreamError); ok && checkStreamsQuota { t.streamsQuota.add(1) @@ -295,6 +339,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return nil, err } t.mu.Lock() + if t.state == draining { + t.mu.Unlock() + if checkStreamsQuota { + t.streamsQuota.add(1) + } + // Need to make t writable again so that the rpc in flight can still proceed. + t.writableChan <- 0 + return nil, ErrStreamDrain + } if t.state != reachable { t.mu.Unlock() return nil, ErrConnClosing @@ -329,7 +382,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) } if timeout > 0 { - t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)}) + t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)}) } for k, v := range authData { // Capital header names are illegal in HTTP/2. @@ -384,7 +437,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } if err != nil { t.notifyError(err) - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } } t.writableChan <- 0 @@ -403,22 +456,17 @@ func (t *http2Client) CloseStream(s *Stream, err error) { if t.streamsQuota != nil { updateStreams = true } - if t.state == draining && len(t.activeStreams) == 1 { + delete(t.activeStreams, s.id) + if t.state == draining && len(t.activeStreams) == 0 { // The transport is draining and s is the last live stream on t. t.mu.Unlock() t.Close() return } - delete(t.activeStreams, s.id) t.mu.Unlock() if updateStreams { t.streamsQuota.add(1) } - // In case stream sending and receiving are invoked in separate - // goroutines (e.g., bi-directional streaming), the caller needs - // to call cancel on the stream to interrupt the blocking on - // other goroutines. - s.cancel() s.mu.Lock() if q := s.fc.resetPendingData(); q > 0 { if n := t.fc.onRead(q); n > 0 { @@ -445,13 +493,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) { // accessed any more. func (t *http2Client) Close() (err error) { t.mu.Lock() - if t.state == reachable { - close(t.errorChan) - } if t.state == closing { t.mu.Unlock() return } + if t.state == reachable || t.state == draining { + close(t.errorChan) + } t.state = closing t.mu.Unlock() close(t.shutdownChan) @@ -475,10 +523,35 @@ func (t *http2Client) Close() (err error) { func (t *http2Client) GracefulClose() error { t.mu.Lock() - if t.state == closing { + switch t.state { + case unreachable: + // The server may close the connection concurrently. t is not available for + // any streams. Close it now. + t.mu.Unlock() + t.Close() + return nil + case closing: t.mu.Unlock() return nil } + // Notify the streams which were initiated after the server sent GOAWAY. + select { + case <-t.goAway: + n := t.prevGoAwayID + if n == 0 && t.nextID > 1 { + n = t.nextID - 2 + } + m := t.goAwayID + 2 + if m == 2 { + m = 1 + } + for i := m; i <= n; i += 2 { + if s, ok := t.activeStreams[i]; ok { + close(s.goAway) + } + } + default: + } if t.state == draining { t.mu.Unlock() return nil @@ -504,15 +577,15 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { size := http2MaxFrameLen s.sendQuotaPool.add(0) // Wait until the stream has some quota to send the data. - sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire()) + sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire()) if err != nil { return err } t.sendQuotaPool.add(0) // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire()) + tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire()) if err != nil { - if _, ok := err.(StreamError); ok { + if _, ok := err.(StreamError); ok || err == io.EOF { t.sendQuotaPool.cancel() } return err @@ -544,8 +617,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { // Indicate there is a writer who is about to write a data frame. t.framer.adjustNumWriters(1) // Got some quota. Try to acquire writing privilege on the transport. - if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { - if _, ok := err.(StreamError); ok { + if _, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.writableChan); err != nil { + if _, ok := err.(StreamError); ok || err == io.EOF { // Return the connection quota back. t.sendQuotaPool.add(len(p)) } @@ -578,7 +651,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { // invoked. if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil { t.notifyError(err) - return ConnectionErrorf("transport: %v", err) + return ConnectionErrorf(true, err, "transport: %v", err) } if t.framer.adjustNumWriters(-1) == 0 { t.framer.flushWrite() @@ -593,11 +666,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { } s.mu.Lock() if s.state != streamDone { - if s.state == streamReadDone { - s.state = streamDone - } else { - s.state = streamWriteDone - } + s.state = streamWriteDone } s.mu.Unlock() return nil @@ -630,7 +699,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) { func (t *http2Client) handleData(f *http2.DataFrame) { size := len(f.Data()) if err := t.fc.onData(uint32(size)); err != nil { - t.notifyError(ConnectionErrorf("%v", err)) + t.notifyError(ConnectionErrorf(true, err, "%v", err)) return } // Select the right stream to dispatch. @@ -655,6 +724,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { s.state = streamDone s.statusCode = codes.Internal s.statusDesc = err.Error() + close(s.done) s.mu.Unlock() s.write(recvMsg{err: io.EOF}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl}) @@ -672,13 +742,14 @@ func (t *http2Client) handleData(f *http2.DataFrame) { // the read direction is closed, and set the status appropriately. if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { s.mu.Lock() - if s.state == streamWriteDone { - s.state = streamDone - } else { - s.state = streamReadDone + if s.state == streamDone { + s.mu.Unlock() + return } + s.state = streamDone s.statusCode = codes.Internal s.statusDesc = "server closed the stream without sending trailers" + close(s.done) s.mu.Unlock() s.write(recvMsg{err: io.EOF}) } @@ -704,6 +775,8 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode) s.statusCode = codes.Unknown } + s.statusDesc = fmt.Sprintf("stream terminated by RST_STREAM with error code: %d", f.ErrCode) + close(s.done) s.mu.Unlock() s.write(recvMsg{err: io.EOF}) } @@ -728,7 +801,32 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { } func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { - // TODO(zhaoq): GoAwayFrame handler to be implemented + t.mu.Lock() + if t.state == reachable || t.state == draining { + if f.LastStreamID > 0 && f.LastStreamID%2 != 1 { + t.mu.Unlock() + t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID)) + return + } + select { + case <-t.goAway: + id := t.goAwayID + // t.goAway has been closed (i.e.,multiple GoAways). + if id < f.LastStreamID { + t.mu.Unlock() + t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) + return + } + t.prevGoAwayID = id + t.goAwayID = f.LastStreamID + t.mu.Unlock() + return + default: + } + t.goAwayID = f.LastStreamID + close(t.goAway) + } + t.mu.Unlock() } func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { @@ -780,11 +878,11 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { if len(state.mdata) > 0 { s.trailer = state.mdata } - s.state = streamDone s.statusCode = state.statusCode s.statusDesc = state.statusDesc + close(s.done) + s.state = streamDone s.mu.Unlock() - s.write(recvMsg{err: io.EOF}) } @@ -937,13 +1035,22 @@ func (t *http2Client) Error() <-chan struct{} { return t.errorChan } +func (t *http2Client) GoAway() <-chan struct{} { + return t.goAway +} + func (t *http2Client) notifyError(err error) { t.mu.Lock() - defer t.mu.Unlock() // make sure t.errorChan is closed only once. + if t.state == draining { + t.mu.Unlock() + t.Close() + return + } if t.state == reachable { t.state = unreachable close(t.errorChan) grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err) } + t.mu.Unlock() } -- cgit v1.2.3