aboutsummaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/grpc/stream.go
diff options
context:
space:
mode:
authorNiall Sheridan <nsheridan@gmail.com>2016-12-28 21:18:36 +0000
committerNiall Sheridan <nsheridan@gmail.com>2016-12-28 21:18:36 +0000
commit73ef85bc5db590c22689e11be20737a3dd88168f (patch)
treefe393a6f0776bca1889b2113ab341a2922e25d10 /vendor/google.golang.org/grpc/stream.go
parent9e573e571fe878ed32947cae5a6d43cb5d72d3bb (diff)
Update dependencies
Diffstat (limited to 'vendor/google.golang.org/grpc/stream.go')
-rw-r--r--vendor/google.golang.org/grpc/stream.go146
1 files changed, 120 insertions, 26 deletions
diff --git a/vendor/google.golang.org/grpc/stream.go b/vendor/google.golang.org/grpc/stream.go
index 4681054..d3a4deb 100644
--- a/vendor/google.golang.org/grpc/stream.go
+++ b/vendor/google.golang.org/grpc/stream.go
@@ -45,6 +45,7 @@ import (
"golang.org/x/net/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/stats"
"google.golang.org/grpc/transport"
)
@@ -97,7 +98,7 @@ type ClientStream interface {
// NewClientStream creates a new Stream for the client side. This is called
// by generated code.
-func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
+func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
if cc.dopts.streamInt != nil {
return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...)
}
@@ -106,11 +107,18 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
var (
- t transport.ClientTransport
- s *transport.Stream
- put func()
+ t transport.ClientTransport
+ s *transport.Stream
+ put func()
+ cancel context.CancelFunc
)
c := defaultCallInfo
+ if mc, ok := cc.getMethodConfig(method); ok {
+ c.failFast = !mc.WaitForReady
+ if mc.Timeout > 0 {
+ ctx, cancel = context.WithTimeout(ctx, mc.Timeout)
+ }
+ }
for _, o := range opts {
if err := o.before(&c); err != nil {
return nil, toRPCErr(err)
@@ -143,6 +151,25 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}
}()
}
+ if stats.On() {
+ ctx = stats.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
+ begin := &stats.Begin{
+ Client: true,
+ BeginTime: time.Now(),
+ FailFast: c.failFast,
+ }
+ stats.HandleRPC(ctx, begin)
+ }
+ defer func() {
+ if err != nil && stats.On() {
+ // Only handle end stats if err != nil.
+ end := &stats.End{
+ Client: true,
+ Error: err,
+ }
+ stats.HandleRPC(ctx, end)
+ }
+ }()
gopts := BalancerGetOptions{
BlockingWait: !c.failFast,
}
@@ -180,12 +207,13 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
break
}
cs := &clientStream{
- opts: opts,
- c: c,
- desc: desc,
- codec: cc.dopts.codec,
- cp: cc.dopts.cp,
- dc: cc.dopts.dc,
+ opts: opts,
+ c: c,
+ desc: desc,
+ codec: cc.dopts.codec,
+ cp: cc.dopts.cp,
+ dc: cc.dopts.dc,
+ cancel: cancel,
put: put,
t: t,
@@ -194,6 +222,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
tracing: EnableTracing,
trInfo: trInfo,
+
+ statsCtx: ctx,
}
if cc.dopts.cp != nil {
cs.cbuf = new(bytes.Buffer)
@@ -227,16 +257,17 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
// clientStream implements a client side Stream.
type clientStream struct {
- opts []CallOption
- c callInfo
- t transport.ClientTransport
- s *transport.Stream
- p *parser
- desc *StreamDesc
- codec Codec
- cp Compressor
- cbuf *bytes.Buffer
- dc Decompressor
+ opts []CallOption
+ c callInfo
+ t transport.ClientTransport
+ s *transport.Stream
+ p *parser
+ desc *StreamDesc
+ codec Codec
+ cp Compressor
+ cbuf *bytes.Buffer
+ dc Decompressor
+ cancel context.CancelFunc
tracing bool // set to EnableTracing when the clientStream is created.
@@ -246,6 +277,11 @@ type clientStream struct {
// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
// and is set to nil when the clientStream's finish method is called.
trInfo traceInfo
+
+ // statsCtx keeps the user context for stats handling.
+ // All stats collection should use the statsCtx (instead of the stream context)
+ // so that all the generated stats for a particular RPC can be associated in the processing phase.
+ statsCtx context.Context
}
func (cs *clientStream) Context() context.Context {
@@ -274,6 +310,8 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
}
cs.mu.Unlock()
}
+ // TODO Investigate how to signal the stats handling party.
+ // generate error stats if err != nil && err != io.EOF?
defer func() {
if err != nil {
cs.finish(err)
@@ -296,7 +334,13 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
}
err = toRPCErr(err)
}()
- out, err := encode(cs.codec, m, cs.cp, cs.cbuf)
+ var outPayload *stats.OutPayload
+ if stats.On() {
+ outPayload = &stats.OutPayload{
+ Client: true,
+ }
+ }
+ out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload)
defer func() {
if cs.cbuf != nil {
cs.cbuf.Reset()
@@ -305,11 +349,37 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
if err != nil {
return Errorf(codes.Internal, "grpc: %v", err)
}
- return cs.t.Write(cs.s, out, &transport.Options{Last: false})
+ err = cs.t.Write(cs.s, out, &transport.Options{Last: false})
+ if err == nil && outPayload != nil {
+ outPayload.SentTime = time.Now()
+ stats.HandleRPC(cs.statsCtx, outPayload)
+ }
+ return err
}
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
- err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
+ defer func() {
+ if err != nil && stats.On() {
+ // Only generate End if err != nil.
+ // If err == nil, it's not the last RecvMsg.
+ // The last RecvMsg gets either an RPC error or io.EOF.
+ end := &stats.End{
+ Client: true,
+ EndTime: time.Now(),
+ }
+ if err != io.EOF {
+ end.Error = toRPCErr(err)
+ }
+ stats.HandleRPC(cs.statsCtx, end)
+ }
+ }()
+ var inPayload *stats.InPayload
+ if stats.On() {
+ inPayload = &stats.InPayload{
+ Client: true,
+ }
+ }
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, inPayload)
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
@@ -324,11 +394,15 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
}
cs.mu.Unlock()
}
+ if inPayload != nil {
+ stats.HandleRPC(cs.statsCtx, inPayload)
+ }
if !cs.desc.ClientStreams || cs.desc.ServerStreams {
return
}
// Special handling for client streaming rpc.
- err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
+ // This recv expects EOF or errors, so we don't collect inPayload.
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, nil)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@@ -384,6 +458,11 @@ func (cs *clientStream) closeTransportStream(err error) {
}
func (cs *clientStream) finish(err error) {
+ defer func() {
+ if cs.cancel != nil {
+ cs.cancel()
+ }
+ }()
cs.mu.Lock()
defer cs.mu.Unlock()
for _, o := range cs.opts {
@@ -482,7 +561,11 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
ss.mu.Unlock()
}
}()
- out, err := encode(ss.codec, m, ss.cp, ss.cbuf)
+ var outPayload *stats.OutPayload
+ if stats.On() {
+ outPayload = &stats.OutPayload{}
+ }
+ out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload)
defer func() {
if ss.cbuf != nil {
ss.cbuf.Reset()
@@ -495,6 +578,10 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil {
return toRPCErr(err)
}
+ if outPayload != nil {
+ outPayload.SentTime = time.Now()
+ stats.HandleRPC(ss.s.Context(), outPayload)
+ }
return nil
}
@@ -513,7 +600,11 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
ss.mu.Unlock()
}
}()
- if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize); err != nil {
+ var inPayload *stats.InPayload
+ if stats.On() {
+ inPayload = &stats.InPayload{}
+ }
+ if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize, inPayload); err != nil {
if err == io.EOF {
return err
}
@@ -522,5 +613,8 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
}
return toRPCErr(err)
}
+ if inPayload != nil {
+ stats.HandleRPC(ss.s.Context(), inPayload)
+ }
return nil
}