aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/pkg/sftp/conn.go
blob: 9b03d112f3818674e33c56345130240d54f3afe4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
package sftp

import (
	"encoding"
	"io"
	"sync"

	"github.com/pkg/errors"
)

// conn implements a bidirectional channel on which client and server
// connections are multiplexed.
type conn struct {
	io.Reader
	io.WriteCloser
	sync.Mutex // used to serialise writes to sendPacket
}

func (c *conn) recvPacket() (uint8, []byte, error) {
	return recvPacket(c)
}

func (c *conn) sendPacket(m encoding.BinaryMarshaler) error {
	c.Lock()
	defer c.Unlock()
	return sendPacket(c, m)
}

type clientConn struct {
	conn
	wg         sync.WaitGroup
	sync.Mutex                          // protects inflight
	inflight   map[uint32]chan<- result // outstanding requests
}

// Close closes the SFTP session.
func (c *clientConn) Close() error {
	defer c.wg.Wait()
	return c.conn.Close()
}

func (c *clientConn) loop() {
	defer c.wg.Done()
	err := c.recv()
	if err != nil {
		c.broadcastErr(err)
	}
}

// recv continuously reads from the server and forwards responses to the
// appropriate channel.
func (c *clientConn) recv() error {
	defer c.conn.Close()
	for {
		typ, data, err := c.recvPacket()
		if err != nil {
			return err
		}
		sid, _ := unmarshalUint32(data)
		c.Lock()
		ch, ok := c.inflight[sid]
		delete(c.inflight, sid)
		c.Unlock()
		if !ok {
			// This is an unexpected occurrence. Send the error
			// back to all listeners so that they terminate
			// gracefully.
			return errors.Errorf("sid: %v not fond", sid)
		}
		ch <- result{typ: typ, data: data}
	}
}

// result captures the result of receiving the a packet from the server
type result struct {
	typ  byte
	data []byte
	err  error
}

type idmarshaler interface {
	id() uint32
	encoding.BinaryMarshaler
}

func (c *clientConn) sendPacket(p idmarshaler) (byte, []byte, error) {
	ch := make(chan result, 1)
	c.dispatchRequest(ch, p)
	s := <-ch
	return s.typ, s.data, s.err
}

func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) {
	c.Lock()
	c.inflight[p.id()] = ch
	if err := c.conn.sendPacket(p); err != nil {
		delete(c.inflight, p.id())
		ch <- result{err: err}
	}
	c.Unlock()
}

// broadcastErr sends an error to all goroutines waiting for a response.
func (c *clientConn) broadcastErr(err error) {
	c.Lock()
	listeners := make([]chan<- result, 0, len(c.inflight))
	for _, ch := range c.inflight {
		listeners = append(listeners, ch)
	}
	c.Unlock()
	for _, ch := range listeners {
		ch <- result{err: err}
	}
}

type serverConn struct {
	conn
}

func (s *serverConn) sendError(p id, err error) error {
	return s.sendPacket(statusFromError(p, err))
}