// Copyright 2013 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package ssh import ( "encoding/binary" "fmt" "io" "log" "sync" "sync/atomic" ) // debugMux, if set, causes messages in the connection protocol to be // logged. const debugMux = false // chanList is a thread safe channel list. type chanList struct { // protects concurrent access to chans sync.Mutex // chans are indexed by the local id of the channel, which the // other side should send in the PeersId field. chans []*channel // This is a debugging aid: it offsets all IDs by this // amount. This helps distinguish otherwise identical // server/client muxes offset uint32 } // Assigns a channel ID to the given channel. func (c *chanList) add(ch *channel) uint32 { c.Lock() defer c.Unlock() for i := range c.chans { if c.chans[i] == nil { c.chans[i] = ch return uint32(i) + c.offset } } c.chans = append(c.chans, ch) return uint32(len(c.chans)-1) + c.offset } // getChan returns the channel for the given ID. func (c *chanList) getChan(id uint32) *channel { id -= c.offset c.Lock() defer c.Unlock() if id < uint32(len(c.chans)) { return c.chans[id] } return nil } func (c *chanList) remove(id uint32) { id -= c.offset c.Lock() if id < uint32(len(c.chans)) { c.chans[id] = nil } c.Unlock() } // dropAll forgets all channels it knows, returning them in a slice. func (c *chanList) dropAll() []*channel { c.Lock() defer c.Unlock() var r []*channel for _, ch := range c.chans { if ch == nil { continue } r = append(r, ch) } c.chans = nil return r } // mux represents the state for the SSH connection protocol, which // multiplexes many channels onto a single packet transport. type mux struct { conn packetConn chanList chanList incomingChannels chan NewChannel globalSentMu sync.Mutex globalResponses chan interface{} incomingRequests chan *Request errCond *sync.Cond err error } // When debugging, each new chanList instantiation has a different // offset. var globalOff uint32 func (m *mux) Wait() error { m.errCond.L.Lock() defer m.errCond.L.Unlock() for m.err == nil { m.errCond.Wait() } return m.err } // newMux returns a mux that runs over the given connection. func newMux(p packetConn) *mux { m := &mux{ conn: p, incomingChannels: make(chan NewChannel, 16), globalResponses: make(chan interface{}, 1), incomingRequests: make(chan *Request, 16), errCond: newCond(), } if debugMux { m.chanList.offset = atomic.AddUint32(&globalOff, 1) } go m.loop() return m } func (m *mux) sendMessage(msg interface{}) error { p := Marshal(msg) if debugMux { log.Printf("send global(%d): %#v", m.chanList.offset, msg) } return m.conn.writePacket(p) } func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { if wantReply { m.globalSentMu.Lock() defer m.globalSentMu.Unlock() } if err := m.sendMessage(globalRequestMsg{ Type: name, WantReply: wantReply, Data: payload, }); err != nil { return false, nil, err } if !wantReply { return false, nil, nil } msg, ok := <-m.globalResponses if !ok { return false, nil, io.EOF } switch msg := msg.(type) { case *globalRequestFailureMsg: return false, msg.Data, nil case *globalRequestSuccessMsg: return true, msg.Data, nil default: return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) } } // ackRequest must be called after processing a global request that // has WantReply set. func (m *mux) ackRequest(ok bool, data []byte) error { if ok { return m.sendMessage(globalRequestSuccessMsg{Data: data}) } return m.sendMessage(globalRequestFailureMsg{Data: data}) } func (m *mux) Close() error { return m.conn.Close() } // loop runs the connection machine. It will process packets until an // error is encountered. To synchronize on loop exit, use mux.Wait. func (m *mux) loop() { var err error for err == nil { err = m.onePacket() } for _, ch := range m.chanList.dropAll() { ch.close() } close(m.incomingChannels) close(m.incomingRequests) close(m.globalResponses) m.conn.Close() m.errCond.L.Lock() m.err = err m.errCond.Broadcast() m.errCond.L.Unlock() if debugMux { log.Println("loop exit", err) } } // onePacket reads and processes one packet. func (m *mux) onePacket() error { packet, err := m.conn.readPacket() if err != nil { return err } if debugMux { if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) } else { p, _ := decode(packet) log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) } } switch packet[0] { case msgChannelOpen: return m.handleChannelOpen(packet) case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: return m.handleGlobalPacket(packet) } // assume a channel packet. if len(packet) < 5 { return parseError(packet[0]) } id := binary.BigEndian.Uint32(packet[1:]) ch := m.chanList.getChan(id) if ch == nil { return fmt.Errorf("ssh: invalid channel %d", id) } return ch.handlePacket(packet) } func (m *mux) handleGlobalPacket(packet []byte) error { msg, err := decode(packet) if err != nil { return err } switch msg := msg.(type) { case *globalRequestMsg: m.incomingRequests <- &Request{ Type: msg.Type, WantReply: msg.WantReply, Payload: msg.Data, mux: m, } case *globalRequestSuccessMsg, *globalRequestFailureMsg: m.globalResponses <- msg default: panic(fmt.Sprintf("not a global message %#v", msg)) } return nil } // handleChannelOpen schedules a channel to be Accept()ed. func (m *mux) handleChannelOpen(packet []byte) error { var msg channelOpenMsg if err := Unmarshal(packet, &msg); err != nil { return err } if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { failMsg := channelOpenFailureMsg{ PeersId: msg.PeersId, Reason: ConnectionFailed, Message: "invalid request", Language: "en_US.UTF-8", } return m.sendMessage(failMsg) } c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) c.remoteId = msg.PeersId c.maxRemotePayload = msg.MaxPacketSize c.remoteWin.add(msg.PeersWindow) m.incomingChannels <- c return nil } func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { ch, err := m.openChannel(chanType, extra) if err != nil { return nil, nil, err } return ch, ch.incomingRequests, nil } func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { ch := m.newChannel(chanType, channelOutbound, extra) ch.maxIncomingPayload = channelMaxPacket open := channelOpenMsg{ ChanType: chanType, PeersWindow: ch.myWindow, MaxPacketSize: ch.maxIncomingPayload, TypeSpecificData: extra, PeersId: ch.localId, } if err := m.sendMessage(open); err != nil { return nil, err } switch msg := (<-ch.msg).(type) { case *channelOpenConfirmMsg: return ch, nil case *channelOpenFailureMsg: return nil, &OpenChannelError{msg.Reason, msg.Message} default: return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) } }