aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/soheilhy/cmux/cmux.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/soheilhy/cmux/cmux.go')
-rw-r--r--vendor/github.com/soheilhy/cmux/cmux.go269
1 files changed, 269 insertions, 0 deletions
diff --git a/vendor/github.com/soheilhy/cmux/cmux.go b/vendor/github.com/soheilhy/cmux/cmux.go
new file mode 100644
index 0000000..9de6b0a
--- /dev/null
+++ b/vendor/github.com/soheilhy/cmux/cmux.go
@@ -0,0 +1,269 @@
+// Copyright 2016 The CMux Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+// implied. See the License for the specific language governing
+// permissions and limitations under the License.
+
+package cmux
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "time"
+)
+
+// Matcher matches a connection based on its content.
+type Matcher func(io.Reader) bool
+
+// MatchWriter is a match that can also write response (say to do handshake).
+type MatchWriter func(io.Writer, io.Reader) bool
+
+// ErrorHandler handles an error and returns whether
+// the mux should continue serving the listener.
+type ErrorHandler func(error) bool
+
+var _ net.Error = ErrNotMatched{}
+
+// ErrNotMatched is returned whenever a connection is not matched by any of
+// the matchers registered in the multiplexer.
+type ErrNotMatched struct {
+ c net.Conn
+}
+
+func (e ErrNotMatched) Error() string {
+ return fmt.Sprintf("mux: connection %v not matched by an matcher",
+ e.c.RemoteAddr())
+}
+
+// Temporary implements the net.Error interface.
+func (e ErrNotMatched) Temporary() bool { return true }
+
+// Timeout implements the net.Error interface.
+func (e ErrNotMatched) Timeout() bool { return false }
+
+type errListenerClosed string
+
+func (e errListenerClosed) Error() string { return string(e) }
+func (e errListenerClosed) Temporary() bool { return false }
+func (e errListenerClosed) Timeout() bool { return false }
+
+// ErrListenerClosed is returned from muxListener.Accept when the underlying
+// listener is closed.
+var ErrListenerClosed = errListenerClosed("mux: listener closed")
+
+// for readability of readTimeout
+var noTimeout time.Duration
+
+// New instantiates a new connection multiplexer.
+func New(l net.Listener) CMux {
+ return &cMux{
+ root: l,
+ bufLen: 1024,
+ errh: func(_ error) bool { return true },
+ donec: make(chan struct{}),
+ readTimeout: noTimeout,
+ }
+}
+
+// CMux is a multiplexer for network connections.
+type CMux interface {
+ // Match returns a net.Listener that sees (i.e., accepts) only
+ // the connections matched by at least one of the matcher.
+ //
+ // The order used to call Match determines the priority of matchers.
+ Match(...Matcher) net.Listener
+ // MatchWithWriters returns a net.Listener that accepts only the
+ // connections that matched by at least of the matcher writers.
+ //
+ // Prefer Matchers over MatchWriters, since the latter can write on the
+ // connection before the actual handler.
+ //
+ // The order used to call Match determines the priority of matchers.
+ MatchWithWriters(...MatchWriter) net.Listener
+ // Serve starts multiplexing the listener. Serve blocks and perhaps
+ // should be invoked concurrently within a go routine.
+ Serve() error
+ // HandleError registers an error handler that handles listener errors.
+ HandleError(ErrorHandler)
+ // sets a timeout for the read of matchers
+ SetReadTimeout(time.Duration)
+}
+
+type matchersListener struct {
+ ss []MatchWriter
+ l muxListener
+}
+
+type cMux struct {
+ root net.Listener
+ bufLen int
+ errh ErrorHandler
+ donec chan struct{}
+ sls []matchersListener
+ readTimeout time.Duration
+}
+
+func matchersToMatchWriters(matchers []Matcher) []MatchWriter {
+ mws := make([]MatchWriter, 0, len(matchers))
+ for _, m := range matchers {
+ mws = append(mws, func(w io.Writer, r io.Reader) bool {
+ return m(r)
+ })
+ }
+ return mws
+}
+
+func (m *cMux) Match(matchers ...Matcher) net.Listener {
+ mws := matchersToMatchWriters(matchers)
+ return m.MatchWithWriters(mws...)
+}
+
+func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener {
+ ml := muxListener{
+ Listener: m.root,
+ connc: make(chan net.Conn, m.bufLen),
+ }
+ m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
+ return ml
+}
+
+func (m *cMux) SetReadTimeout(t time.Duration) {
+ m.readTimeout = t
+}
+
+func (m *cMux) Serve() error {
+ var wg sync.WaitGroup
+
+ defer func() {
+ close(m.donec)
+ wg.Wait()
+
+ for _, sl := range m.sls {
+ close(sl.l.connc)
+ // Drain the connections enqueued for the listener.
+ for c := range sl.l.connc {
+ _ = c.Close()
+ }
+ }
+ }()
+
+ for {
+ c, err := m.root.Accept()
+ if err != nil {
+ if !m.handleErr(err) {
+ return err
+ }
+ continue
+ }
+
+ wg.Add(1)
+ go m.serve(c, m.donec, &wg)
+ }
+}
+
+func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
+ defer wg.Done()
+
+ muc := newMuxConn(c)
+ if m.readTimeout > noTimeout {
+ _ = c.SetReadDeadline(time.Now().Add(m.readTimeout))
+ }
+ for _, sl := range m.sls {
+ for _, s := range sl.ss {
+ matched := s(muc.Conn, muc.startSniffing())
+ if matched {
+ muc.doneSniffing()
+ if m.readTimeout > noTimeout {
+ _ = c.SetReadDeadline(time.Time{})
+ }
+ select {
+ case sl.l.connc <- muc:
+ case <-donec:
+ _ = c.Close()
+ }
+ return
+ }
+ }
+ }
+
+ _ = c.Close()
+ err := ErrNotMatched{c: c}
+ if !m.handleErr(err) {
+ _ = m.root.Close()
+ }
+}
+
+func (m *cMux) HandleError(h ErrorHandler) {
+ m.errh = h
+}
+
+func (m *cMux) handleErr(err error) bool {
+ if !m.errh(err) {
+ return false
+ }
+
+ if ne, ok := err.(net.Error); ok {
+ return ne.Temporary()
+ }
+
+ return false
+}
+
+type muxListener struct {
+ net.Listener
+ connc chan net.Conn
+}
+
+func (l muxListener) Accept() (net.Conn, error) {
+ c, ok := <-l.connc
+ if !ok {
+ return nil, ErrListenerClosed
+ }
+ return c, nil
+}
+
+// MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
+type MuxConn struct {
+ net.Conn
+ buf bufferedReader
+}
+
+func newMuxConn(c net.Conn) *MuxConn {
+ return &MuxConn{
+ Conn: c,
+ buf: bufferedReader{source: c},
+ }
+}
+
+// From the io.Reader documentation:
+//
+// When Read encounters an error or end-of-file condition after
+// successfully reading n > 0 bytes, it returns the number of
+// bytes read. It may return the (non-nil) error from the same call
+// or return the error (and n == 0) from a subsequent call.
+// An instance of this general case is that a Reader returning
+// a non-zero number of bytes at the end of the input stream may
+// return either err == EOF or err == nil. The next Read should
+// return 0, EOF.
+func (m *MuxConn) Read(p []byte) (int, error) {
+ return m.buf.Read(p)
+}
+
+func (m *MuxConn) startSniffing() io.Reader {
+ m.buf.reset(true)
+ return &m.buf
+}
+
+func (m *MuxConn) doneSniffing() {
+ m.buf.reset(false)
+}