mirror of https://github.com/pkg/sftp.git
125 lines
3.1 KiB
Go
125 lines
3.1 KiB
Go
|
package sftp
|
||
|
|
||
|
import (
|
||
|
"encoding"
|
||
|
"sort"
|
||
|
)
|
||
|
|
||
|
// --------------------------------------------------------------------
|
||
|
// Process with 2 branch select, listening to each channel.
|
||
|
// 0) start of loop
|
||
|
|
||
|
// Branch A
|
||
|
// 1) Wait for ids to come in and add them to id list.
|
||
|
|
||
|
// Branch B
|
||
|
// 1) Wait for a packet comes in.
|
||
|
// 2) Add it to the packet list.
|
||
|
// 3) The heads of each list are then compared and if they have the same ids
|
||
|
// the packet is sent out and the entries removed.
|
||
|
// 4) Goto step 2 Until the lists are emptied or the ids don't match.
|
||
|
// 5) Goto step 0.
|
||
|
// --------------------------------------------------------------------
|
||
|
|
||
|
type packetSender interface {
|
||
|
sendPacket(encoding.BinaryMarshaler) error
|
||
|
}
|
||
|
|
||
|
type packetManager struct {
|
||
|
requests chan requestPacket
|
||
|
responses chan responsePacket
|
||
|
fini chan struct{}
|
||
|
incoming []uint32
|
||
|
outgoing []responsePacket
|
||
|
sender packetSender // connection object
|
||
|
}
|
||
|
|
||
|
func newPktMgr(sender packetSender) packetManager {
|
||
|
s := packetManager{
|
||
|
requests: make(chan requestPacket, sftpServerWorkerCount),
|
||
|
responses: make(chan responsePacket, sftpServerWorkerCount),
|
||
|
fini: make(chan struct{}),
|
||
|
incoming: make([]uint32, 0, sftpServerWorkerCount),
|
||
|
outgoing: make([]responsePacket, 0, sftpServerWorkerCount),
|
||
|
sender: sender,
|
||
|
}
|
||
|
go s.worker()
|
||
|
return s
|
||
|
}
|
||
|
|
||
|
// register incoming packets to be handled
|
||
|
// send id of 0 for packets without id
|
||
|
func (s packetManager) incomingPacket(pkt requestPacket) {
|
||
|
s.requests <- pkt // buffer == sftpServerWorkerCount
|
||
|
}
|
||
|
|
||
|
// register outgoing packets as being ready
|
||
|
func (s packetManager) readyPacket(pkt responsePacket) {
|
||
|
s.responses <- pkt
|
||
|
}
|
||
|
|
||
|
// shut down packetManager worker
|
||
|
func (s packetManager) close() {
|
||
|
close(s.fini)
|
||
|
}
|
||
|
|
||
|
// process packets
|
||
|
func (s *packetManager) worker() {
|
||
|
for {
|
||
|
select {
|
||
|
case pkt := <-s.requests:
|
||
|
debug("incoming id: %v", pkt.id())
|
||
|
s.incoming = append(s.incoming, pkt.id())
|
||
|
if len(s.incoming) > 1 {
|
||
|
sort.Slice(s.incoming, func(i, j int) bool {
|
||
|
return s.incoming[i] < s.incoming[j]
|
||
|
})
|
||
|
}
|
||
|
case pkt := <-s.responses:
|
||
|
debug("outgoing pkt: %v", pkt.id())
|
||
|
s.outgoing = append(s.outgoing, pkt)
|
||
|
if len(s.outgoing) > 1 {
|
||
|
sort.Slice(s.outgoing, func(i, j int) bool {
|
||
|
return s.outgoing[i].id() < s.outgoing[j].id()
|
||
|
})
|
||
|
}
|
||
|
case <-s.fini:
|
||
|
return
|
||
|
}
|
||
|
s.maybeSendPackets()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// send as many packets as are ready
|
||
|
func (s *packetManager) maybeSendPackets() {
|
||
|
for {
|
||
|
if len(s.outgoing) == 0 || len(s.incoming) == 0 {
|
||
|
debug("break! -- outgoing: %v; incoming: %v",
|
||
|
len(s.outgoing), len(s.incoming))
|
||
|
break
|
||
|
}
|
||
|
out := s.outgoing[0]
|
||
|
in := s.incoming[0]
|
||
|
debug("incoming: %v", s.incoming)
|
||
|
debug("outgoing: %v", outfilter(s.outgoing))
|
||
|
if in == out.id() {
|
||
|
s.sender.sendPacket(out)
|
||
|
// pop off heads
|
||
|
copy(s.incoming, s.incoming[1:]) // shift left
|
||
|
s.incoming = s.incoming[:len(s.incoming)-1] // remove last
|
||
|
copy(s.outgoing, s.outgoing[1:]) // shift left
|
||
|
s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
|
||
|
} else {
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func outfilter(o []responsePacket) []uint32 {
|
||
|
res := make([]uint32, 0, len(o))
|
||
|
for _, v := range o {
|
||
|
res = append(res, v.id())
|
||
|
}
|
||
|
return res
|
||
|
}
|