code that manages incoming/outgoing packet order

Makes sure that outgoing packets order matches incoming packets order.
This is not required by spec but some clients seem to require it (eg.
winscp).
This commit is contained in:
John Eikenberry 2017-03-14 17:49:31 -07:00
parent 243a742d21
commit ded0784f45
2 changed files with 205 additions and 0 deletions

124
packet-manager.go Normal file
View File

@ -0,0 +1,124 @@
package sftp
import (
"encoding"
"sort"
)
// --------------------------------------------------------------------
// Process with 2 branch select, listening to each channel.
// 0) start of loop
// Branch A
// 1) Wait for ids to come in and add them to id list.
// Branch B
// 1) Wait for a packet comes in.
// 2) Add it to the packet list.
// 3) The heads of each list are then compared and if they have the same ids
// the packet is sent out and the entries removed.
// 4) Goto step 2 Until the lists are emptied or the ids don't match.
// 5) Goto step 0.
// --------------------------------------------------------------------
type packetSender interface {
sendPacket(encoding.BinaryMarshaler) error
}
type packetManager struct {
requests chan requestPacket
responses chan responsePacket
fini chan struct{}
incoming []uint32
outgoing []responsePacket
sender packetSender // connection object
}
func newPktMgr(sender packetSender) packetManager {
s := packetManager{
requests: make(chan requestPacket, sftpServerWorkerCount),
responses: make(chan responsePacket, sftpServerWorkerCount),
fini: make(chan struct{}),
incoming: make([]uint32, 0, sftpServerWorkerCount),
outgoing: make([]responsePacket, 0, sftpServerWorkerCount),
sender: sender,
}
go s.worker()
return s
}
// register incoming packets to be handled
// send id of 0 for packets without id
func (s packetManager) incomingPacket(pkt requestPacket) {
s.requests <- pkt // buffer == sftpServerWorkerCount
}
// register outgoing packets as being ready
func (s packetManager) readyPacket(pkt responsePacket) {
s.responses <- pkt
}
// shut down packetManager worker
func (s packetManager) close() {
close(s.fini)
}
// process packets
func (s *packetManager) worker() {
for {
select {
case pkt := <-s.requests:
debug("incoming id: %v", pkt.id())
s.incoming = append(s.incoming, pkt.id())
if len(s.incoming) > 1 {
sort.Slice(s.incoming, func(i, j int) bool {
return s.incoming[i] < s.incoming[j]
})
}
case pkt := <-s.responses:
debug("outgoing pkt: %v", pkt.id())
s.outgoing = append(s.outgoing, pkt)
if len(s.outgoing) > 1 {
sort.Slice(s.outgoing, func(i, j int) bool {
return s.outgoing[i].id() < s.outgoing[j].id()
})
}
case <-s.fini:
return
}
s.maybeSendPackets()
}
}
// send as many packets as are ready
func (s *packetManager) maybeSendPackets() {
for {
if len(s.outgoing) == 0 || len(s.incoming) == 0 {
debug("break! -- outgoing: %v; incoming: %v",
len(s.outgoing), len(s.incoming))
break
}
out := s.outgoing[0]
in := s.incoming[0]
debug("incoming: %v", s.incoming)
debug("outgoing: %v", outfilter(s.outgoing))
if in == out.id() {
s.sender.sendPacket(out)
// pop off heads
copy(s.incoming, s.incoming[1:]) // shift left
s.incoming = s.incoming[:len(s.incoming)-1] // remove last
copy(s.outgoing, s.outgoing[1:]) // shift left
s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
} else {
break
}
}
}
func outfilter(o []responsePacket) []uint32 {
res := make([]uint32, 0, len(o))
for _, v := range o {
res = append(res, v.id())
}
return res
}

81
packet-manager_test.go Normal file
View File

@ -0,0 +1,81 @@
package sftp
import (
"encoding"
"testing"
"github.com/stretchr/testify/assert"
)
type _sender struct {
sent chan encoding.BinaryMarshaler
}
func newsender() *_sender {
return &_sender{make(chan encoding.BinaryMarshaler)}
}
func (s _sender) sendPacket(p encoding.BinaryMarshaler) error {
s.sent <- p
return nil
}
type fakepacket uint32
func (fakepacket) MarshalBinary() ([]byte, error) {
return []byte{}, nil
}
func (fakepacket) UnmarshalBinary([]byte) error {
return nil
}
func (f fakepacket) id() uint32 {
return uint32(f)
}
type pair struct {
in fakepacket
out fakepacket
}
var ttable1 = []pair{
pair{fakepacket(0), fakepacket(0)},
pair{fakepacket(1), fakepacket(1)},
pair{fakepacket(2), fakepacket(2)},
pair{fakepacket(3), fakepacket(3)},
}
var ttable2 = []pair{
pair{fakepacket(0), fakepacket(0)},
pair{fakepacket(1), fakepacket(4)},
pair{fakepacket(2), fakepacket(1)},
pair{fakepacket(3), fakepacket(3)},
pair{fakepacket(4), fakepacket(2)},
}
var tables = [][]pair{ttable1, ttable2}
func TestPacketManager(t *testing.T) {
sender := newsender()
s := newPktMgr(sender)
// go func() {
// for _ = range s.workers {
// }
// }()
for i := range tables {
table := tables[i]
for _, p := range table {
s.incomingPacket(p.in)
}
for _, p := range table {
s.readyPacket(p.out)
}
for _, p := range table {
pkt := <-sender.sent
id := pkt.(fakepacket).id()
assert.Equal(t, id, p.in.id())
}
}
s.close()
}