From ba4840c52becf73c2749c9ef0f2f09ed0b9d5c7f Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Sun, 12 Feb 2017 22:24:33 +0000 Subject: Update dependencies --- vendor/golang.org/x/crypto/ssh/handshake.go | 423 +++++++++++++++++++--------- 1 file changed, 294 insertions(+), 129 deletions(-) (limited to 'vendor/golang.org/x/crypto/ssh/handshake.go') diff --git a/vendor/golang.org/x/crypto/ssh/handshake.go b/vendor/golang.org/x/crypto/ssh/handshake.go index 37d42e4..8de6506 100644 --- a/vendor/golang.org/x/crypto/ssh/handshake.go +++ b/vendor/golang.org/x/crypto/ssh/handshake.go @@ -19,6 +19,11 @@ import ( // messages are wrong when using ECDH. const debugHandshake = false +// chanSize sets the amount of buffering SSH connections. This is +// primarily for testing: setting chanSize=0 uncovers deadlocks more +// quickly. +const chanSize = 16 + // keyingTransport is a packet based transport that supports key // changes. It need not be thread-safe. It should pass through // msgNewKeys in both directions. @@ -53,34 +58,58 @@ type handshakeTransport struct { incoming chan []byte readError error + mu sync.Mutex + writeError error + sentInitPacket []byte + sentInitMsg *kexInitMsg + pendingPackets [][]byte // Used when a key exchange is in progress. + + // If the read loop wants to schedule a kex, it pings this + // channel, and the write loop will send out a kex + // message. + requestKex chan struct{} + + // If the other side requests or confirms a kex, its kexInit + // packet is sent here for the write loop to find it. + startKex chan *pendingKex + // data for host key checking hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error dialAddress string remoteAddr net.Addr - readSinceKex uint64 + // Algorithms agreed in the last key exchange. + algorithms *algorithms + + readPacketsLeft uint32 + readBytesLeft int64 - // Protects the writing side of the connection - mu sync.Mutex - cond *sync.Cond - sentInitPacket []byte - sentInitMsg *kexInitMsg - writtenSinceKex uint64 - writeError error + writePacketsLeft uint32 + writeBytesLeft int64 // The session ID or nil if first kex did not complete yet. sessionID []byte } +type pendingKex struct { + otherInit []byte + done chan error +} + func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { t := &handshakeTransport{ conn: conn, serverVersion: serverVersion, clientVersion: clientVersion, - incoming: make(chan []byte, 16), - config: config, + incoming: make(chan []byte, chanSize), + requestKex: make(chan struct{}, 1), + startKex: make(chan *pendingKex, 1), + + config: config, } - t.cond = sync.NewCond(&t.mu) + + // We always start with a mandatory key exchange. + t.requestKex <- struct{}{} return t } @@ -95,6 +124,7 @@ func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byt t.hostKeyAlgorithms = supportedHostKeyAlgos } go t.readLoop() + go t.kexLoop() return t } @@ -102,6 +132,7 @@ func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byt t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) t.hostKeys = config.hostKeys go t.readLoop() + go t.kexLoop() return t } @@ -109,6 +140,20 @@ func (t *handshakeTransport) getSessionID() []byte { return t.sessionID } +// waitSession waits for the session to be established. This should be +// the first thing to call after instantiating handshakeTransport. +func (t *handshakeTransport) waitSession() error { + p, err := t.readPacket() + if err != nil { + return err + } + if p[0] != msgNewKeys { + return fmt.Errorf("ssh: first packet should be msgNewKeys") + } + + return nil +} + func (t *handshakeTransport) id() string { if len(t.hostKeys) > 0 { return "server" @@ -116,6 +161,20 @@ func (t *handshakeTransport) id() string { return "client" } +func (t *handshakeTransport) printPacket(p []byte, write bool) { + action := "got" + if write { + action = "sent" + } + + if p[0] == msgChannelData || p[0] == msgChannelExtendedData { + log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p)) + } else { + msg, err := decode(p) + log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err) + } +} + func (t *handshakeTransport) readPacket() ([]byte, error) { p, ok := <-t.incoming if !ok { @@ -125,8 +184,10 @@ func (t *handshakeTransport) readPacket() ([]byte, error) { } func (t *handshakeTransport) readLoop() { + first := true for { - p, err := t.readOnePacket() + p, err := t.readOnePacket(first) + first = false if err != nil { t.readError = err close(t.incoming) @@ -138,67 +199,204 @@ func (t *handshakeTransport) readLoop() { t.incoming <- p } - // If we can't read, declare the writing part dead too. + // Stop writers too. + t.recordWriteError(t.readError) + + // Unblock the writer should it wait for this. + close(t.startKex) + + // Don't close t.requestKex; it's also written to from writePacket. +} + +func (t *handshakeTransport) pushPacket(p []byte) error { + if debugHandshake { + t.printPacket(p, true) + } + return t.conn.writePacket(p) +} + +func (t *handshakeTransport) getWriteError() error { t.mu.Lock() defer t.mu.Unlock() - if t.writeError == nil { - t.writeError = t.readError + return t.writeError +} + +func (t *handshakeTransport) recordWriteError(err error) { + t.mu.Lock() + defer t.mu.Unlock() + if t.writeError == nil && err != nil { + t.writeError = err } - t.cond.Broadcast() } -func (t *handshakeTransport) readOnePacket() ([]byte, error) { - if t.readSinceKex > t.config.RekeyThreshold { - if err := t.requestKeyChange(); err != nil { - return nil, err +func (t *handshakeTransport) requestKeyExchange() { + select { + case t.requestKex <- struct{}{}: + default: + // something already requested a kex, so do nothing. + } +} + +func (t *handshakeTransport) kexLoop() { + +write: + for t.getWriteError() == nil { + var request *pendingKex + var sent bool + + for request == nil || !sent { + var ok bool + select { + case request, ok = <-t.startKex: + if !ok { + break write + } + case <-t.requestKex: + break + } + + if !sent { + if err := t.sendKexInit(); err != nil { + t.recordWriteError(err) + break + } + sent = true + } + } + + if err := t.getWriteError(); err != nil { + if request != nil { + request.done <- err + } + break + } + + // We're not servicing t.requestKex, but that is OK: + // we never block on sending to t.requestKex. + + // We're not servicing t.startKex, but the remote end + // has just sent us a kexInitMsg, so it can't send + // another key change request, until we close the done + // channel on the pendingKex request. + + err := t.enterKeyExchange(request.otherInit) + + t.mu.Lock() + t.writeError = err + t.sentInitPacket = nil + t.sentInitMsg = nil + t.writePacketsLeft = packetRekeyThreshold + if t.config.RekeyThreshold > 0 { + t.writeBytesLeft = int64(t.config.RekeyThreshold) + } else if t.algorithms != nil { + t.writeBytesLeft = t.algorithms.w.rekeyBytes() + } + + // we have completed the key exchange. Since the + // reader is still blocked, it is safe to clear out + // the requestKex channel. This avoids the situation + // where: 1) we consumed our own request for the + // initial kex, and 2) the kex from the remote side + // caused another send on the requestKex channel, + clear: + for { + select { + case <-t.requestKex: + // + default: + break clear + } + } + + request.done <- t.writeError + + // kex finished. Push packets that we received while + // the kex was in progress. Don't look at t.startKex + // and don't increment writtenSinceKex: if we trigger + // another kex while we are still busy with the last + // one, things will become very confusing. + for _, p := range t.pendingPackets { + t.writeError = t.pushPacket(p) + if t.writeError != nil { + break + } } + t.pendingPackets = t.pendingPackets[:0] + t.mu.Unlock() } + // drain startKex channel. We don't service t.requestKex + // because nobody does blocking sends there. + go func() { + for init := range t.startKex { + init.done <- t.writeError + } + }() + + // Unblock reader. + t.conn.Close() +} + +// The protocol uses uint32 for packet counters, so we can't let them +// reach 1<<32. We will actually read and write more packets than +// this, though: the other side may send more packets, and after we +// hit this limit on writing we will send a few more packets for the +// key exchange itself. +const packetRekeyThreshold = (1 << 31) + +func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) { p, err := t.conn.readPacket() if err != nil { return nil, err } - t.readSinceKex += uint64(len(p)) + if t.readPacketsLeft > 0 { + t.readPacketsLeft-- + } else { + t.requestKeyExchange() + } + + if t.readBytesLeft > 0 { + t.readBytesLeft -= int64(len(p)) + } else { + t.requestKeyExchange() + } + if debugHandshake { - if p[0] == msgChannelData || p[0] == msgChannelExtendedData { - log.Printf("%s got data (packet %d bytes)", t.id(), len(p)) - } else { - msg, err := decode(p) - log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err) - } + t.printPacket(p, false) + } + + if first && p[0] != msgKexInit { + return nil, fmt.Errorf("ssh: first packet should be msgKexInit") } + if p[0] != msgKexInit { return p, nil } - t.mu.Lock() - firstKex := t.sessionID == nil - err = t.enterKeyExchangeLocked(p) - if err != nil { - // drop connection - t.conn.Close() - t.writeError = err + kex := pendingKex{ + done: make(chan error, 1), + otherInit: p, } + t.startKex <- &kex + err = <-kex.done if debugHandshake { log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err) } - // Unblock writers. - t.sentInitMsg = nil - t.sentInitPacket = nil - t.cond.Broadcast() - t.writtenSinceKex = 0 - t.mu.Unlock() - if err != nil { return nil, err } - t.readSinceKex = 0 + t.readPacketsLeft = packetRekeyThreshold + if t.config.RekeyThreshold > 0 { + t.readBytesLeft = int64(t.config.RekeyThreshold) + } else { + t.readBytesLeft = t.algorithms.r.rekeyBytes() + } // By default, a key exchange is hidden from higher layers by // translating it into msgIgnore. @@ -213,61 +411,16 @@ func (t *handshakeTransport) readOnePacket() ([]byte, error) { return successPacket, nil } -// keyChangeCategory describes whether a key exchange is the first on a -// connection, or a subsequent one. -type keyChangeCategory bool - -const ( - firstKeyExchange keyChangeCategory = true - subsequentKeyExchange keyChangeCategory = false -) - -// sendKexInit sends a key change message, and returns the message -// that was sent. After initiating the key change, all writes will be -// blocked until the change is done, and a failed key change will -// close the underlying transport. This function is safe for -// concurrent use by multiple goroutines. -func (t *handshakeTransport) sendKexInit(isFirst keyChangeCategory) error { - var err error - +// sendKexInit sends a key change message. +func (t *handshakeTransport) sendKexInit() error { t.mu.Lock() - // If this is the initial key change, but we already have a sessionID, - // then do nothing because the key exchange has already completed - // asynchronously. - if !isFirst || t.sessionID == nil { - _, _, err = t.sendKexInitLocked(isFirst) - } - t.mu.Unlock() - if err != nil { - return err - } - if isFirst { - if packet, err := t.readPacket(); err != nil { - return err - } else if packet[0] != msgNewKeys { - return unexpectedMessageError(msgNewKeys, packet[0]) - } - } - return nil -} - -func (t *handshakeTransport) requestInitialKeyChange() error { - return t.sendKexInit(firstKeyExchange) -} - -func (t *handshakeTransport) requestKeyChange() error { - return t.sendKexInit(subsequentKeyExchange) -} - -// sendKexInitLocked sends a key change message. t.mu must be locked -// while this happens. -func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) { - // kexInits may be sent either in response to the other side, - // or because our side wants to initiate a key change, so we - // may have already sent a kexInit. In that case, don't send a - // second kexInit. + defer t.mu.Unlock() if t.sentInitMsg != nil { - return t.sentInitMsg, t.sentInitPacket, nil + // kexInits may be sent either in response to the other side, + // or because our side wants to initiate a key change, so we + // may have already sent a kexInit. In that case, don't send a + // second kexInit. + return nil } msg := &kexInitMsg{ @@ -295,53 +448,65 @@ func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexI packetCopy := make([]byte, len(packet)) copy(packetCopy, packet) - if err := t.conn.writePacket(packetCopy); err != nil { - return nil, nil, err + if err := t.pushPacket(packetCopy); err != nil { + return err } t.sentInitMsg = msg t.sentInitPacket = packet - return msg, packet, nil + + return nil } func (t *handshakeTransport) writePacket(p []byte) error { + switch p[0] { + case msgKexInit: + return errors.New("ssh: only handshakeTransport can send kexInit") + case msgNewKeys: + return errors.New("ssh: only handshakeTransport can send newKeys") + } + t.mu.Lock() defer t.mu.Unlock() + if t.writeError != nil { + return t.writeError + } - if t.writtenSinceKex > t.config.RekeyThreshold { - t.sendKexInitLocked(subsequentKeyExchange) + if t.sentInitMsg != nil { + // Copy the packet so the writer can reuse the buffer. + cp := make([]byte, len(p)) + copy(cp, p) + t.pendingPackets = append(t.pendingPackets, cp) + return nil } - for t.sentInitMsg != nil && t.writeError == nil { - t.cond.Wait() + + if t.writeBytesLeft > 0 { + t.writeBytesLeft -= int64(len(p)) + } else { + t.requestKeyExchange() } - if t.writeError != nil { - return t.writeError + + if t.writePacketsLeft > 0 { + t.writePacketsLeft-- + } else { + t.requestKeyExchange() } - t.writtenSinceKex += uint64(len(p)) - switch p[0] { - case msgKexInit: - return errors.New("ssh: only handshakeTransport can send kexInit") - case msgNewKeys: - return errors.New("ssh: only handshakeTransport can send newKeys") - default: - return t.conn.writePacket(p) + if err := t.pushPacket(p); err != nil { + t.writeError = err } + + return nil } func (t *handshakeTransport) Close() error { return t.conn.Close() } -// enterKeyExchange runs the key exchange. t.mu must be held while running this. -func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) error { +func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { if debugHandshake { log.Printf("%s entered key exchange", t.id()) } - myInit, myInitPacket, err := t.sendKexInitLocked(subsequentKeyExchange) - if err != nil { - return err - } otherInit := &kexInitMsg{} if err := Unmarshal(otherInitPacket, otherInit); err != nil { @@ -352,20 +517,20 @@ func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) erro clientVersion: t.clientVersion, serverVersion: t.serverVersion, clientKexInit: otherInitPacket, - serverKexInit: myInitPacket, + serverKexInit: t.sentInitPacket, } clientInit := otherInit - serverInit := myInit + serverInit := t.sentInitMsg if len(t.hostKeys) == 0 { - clientInit = myInit - serverInit = otherInit + clientInit, serverInit = serverInit, clientInit - magics.clientKexInit = myInitPacket + magics.clientKexInit = t.sentInitPacket magics.serverKexInit = otherInitPacket } - algs, err := findAgreedAlgorithms(clientInit, serverInit) + var err error + t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit) if err != nil { return err } @@ -388,16 +553,16 @@ func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) erro } } - kex, ok := kexAlgoMap[algs.kex] + kex, ok := kexAlgoMap[t.algorithms.kex] if !ok { - return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex) + return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex) } var result *kexResult if len(t.hostKeys) > 0 { - result, err = t.server(kex, algs, &magics) + result, err = t.server(kex, t.algorithms, &magics) } else { - result, err = t.client(kex, algs, &magics) + result, err = t.client(kex, t.algorithms, &magics) } if err != nil { @@ -409,7 +574,7 @@ func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) erro } result.SessionID = t.sessionID - t.conn.prepareKeyChange(algs, result) + t.conn.prepareKeyChange(t.algorithms, result) if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { return err } -- cgit v1.2.3