// Copyright 2016 The CMux Authors. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or // implied. See the License for the specific language governing // permissions and limitations under the License. package cmux import ( "fmt" "io" "net" "sync" "time" ) // Matcher matches a connection based on its content. type Matcher func(io.Reader) bool // MatchWriter is a match that can also write response (say to do handshake). type MatchWriter func(io.Writer, io.Reader) bool // ErrorHandler handles an error and returns whether // the mux should continue serving the listener. type ErrorHandler func(error) bool var _ net.Error = ErrNotMatched{} // ErrNotMatched is returned whenever a connection is not matched by any of // the matchers registered in the multiplexer. type ErrNotMatched struct { c net.Conn } func (e ErrNotMatched) Error() string { return fmt.Sprintf("mux: connection %v not matched by an matcher", e.c.RemoteAddr()) } // Temporary implements the net.Error interface. func (e ErrNotMatched) Temporary() bool { return true } // Timeout implements the net.Error interface. func (e ErrNotMatched) Timeout() bool { return false } type errListenerClosed string func (e errListenerClosed) Error() string { return string(e) } func (e errListenerClosed) Temporary() bool { return false } func (e errListenerClosed) Timeout() bool { return false } // ErrListenerClosed is returned from muxListener.Accept when the underlying // listener is closed. var ErrListenerClosed = errListenerClosed("mux: listener closed") // for readability of readTimeout var noTimeout time.Duration // New instantiates a new connection multiplexer. func New(l net.Listener) CMux { return &cMux{ root: l, bufLen: 1024, errh: func(_ error) bool { return true }, donec: make(chan struct{}), readTimeout: noTimeout, } } // CMux is a multiplexer for network connections. type CMux interface { // Match returns a net.Listener that sees (i.e., accepts) only // the connections matched by at least one of the matcher. // // The order used to call Match determines the priority of matchers. Match(...Matcher) net.Listener // MatchWithWriters returns a net.Listener that accepts only the // connections that matched by at least of the matcher writers. // // Prefer Matchers over MatchWriters, since the latter can write on the // connection before the actual handler. // // The order used to call Match determines the priority of matchers. MatchWithWriters(...MatchWriter) net.Listener // Serve starts multiplexing the listener. Serve blocks and perhaps // should be invoked concurrently within a go routine. Serve() error // HandleError registers an error handler that handles listener errors. HandleError(ErrorHandler) // sets a timeout for the read of matchers SetReadTimeout(time.Duration) } type matchersListener struct { ss []MatchWriter l muxListener } type cMux struct { root net.Listener bufLen int errh ErrorHandler donec chan struct{} sls []matchersListener readTimeout time.Duration } func matchersToMatchWriters(matchers []Matcher) []MatchWriter { mws := make([]MatchWriter, 0, len(matchers)) for _, m := range matchers { mws = append(mws, func(w io.Writer, r io.Reader) bool { return m(r) }) } return mws } func (m *cMux) Match(matchers ...Matcher) net.Listener { mws := matchersToMatchWriters(matchers) return m.MatchWithWriters(mws...) } func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener { ml := muxListener{ Listener: m.root, connc: make(chan net.Conn, m.bufLen), } m.sls = append(m.sls, matchersListener{ss: matchers, l: ml}) return ml } func (m *cMux) SetReadTimeout(t time.Duration) { m.readTimeout = t } func (m *cMux) Serve() error { var wg sync.WaitGroup defer func() { close(m.donec) wg.Wait() for _, sl := range m.sls { close(sl.l.connc) // Drain the connections enqueued for the listener. for c := range sl.l.connc { _ = c.Close() } } }() for { c, err := m.root.Accept() if err != nil { if !m.handleErr(err) { return err } continue } wg.Add(1) go m.serve(c, m.donec, &wg) } } func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { defer wg.Done() muc := newMuxConn(c) if m.readTimeout > noTimeout { _ = c.SetReadDeadline(time.Now().Add(m.readTimeout)) } for _, sl := range m.sls { for _, s := range sl.ss { matched := s(muc.Conn, muc.startSniffing()) if matched { muc.doneSniffing() if m.readTimeout > noTimeout { _ = c.SetReadDeadline(time.Time{}) } select { case sl.l.connc <- muc: case <-donec: _ = c.Close() } return } } } _ = c.Close() err := ErrNotMatched{c: c} if !m.handleErr(err) { _ = m.root.Close() } } func (m *cMux) HandleError(h ErrorHandler) { m.errh = h } func (m *cMux) handleErr(err error) bool { if !m.errh(err) { return false } if ne, ok := err.(net.Error); ok { return ne.Temporary() } return false } type muxListener struct { net.Listener connc chan net.Conn } func (l muxListener) Accept() (net.Conn, error) { c, ok := <-l.connc if !ok { return nil, ErrListenerClosed } return c, nil } // MuxConn wraps a net.Conn and provides transparent sniffing of connection data. type MuxConn struct { net.Conn buf bufferedReader } func newMuxConn(c net.Conn) *MuxConn { return &MuxConn{ Conn: c, buf: bufferedReader{source: c}, } } // From the io.Reader documentation: // // When Read encounters an error or end-of-file condition after // successfully reading n > 0 bytes, it returns the number of // bytes read. It may return the (non-nil) error from the same call // or return the error (and n == 0) from a subsequent call. // An instance of this general case is that a Reader returning // a non-zero number of bytes at the end of the input stream may // return either err == EOF or err == nil. The next Read should // return 0, EOF. func (m *MuxConn) Read(p []byte) (int, error) { return m.buf.Read(p) } func (m *MuxConn) startSniffing() io.Reader { m.buf.reset(true) return &m.buf } func (m *MuxConn) doneSniffing() { m.buf.reset(false) }