mirror of https://github.com/pkg/sftp.git
ensure packets are processed in order
File operations that happen after the open packet has been received, like reading/writing, can be done with the pool as the order they are run in doesn't matter (the packets contain the file offsets). Command operations, on the other hand, need to be serialized. This flips between a pool of workers for file operations and a single worker for everything else. It flips on Open and Close packets.
This commit is contained in:
parent
5fd073bcc3
commit
d1bd7b3f9c
|
@ -1,6 +1,9 @@
|
|||
package sftp
|
||||
|
||||
import "encoding"
|
||||
import (
|
||||
"encoding"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// The goal of the packetManager is to keep the outgoing packets in the same
|
||||
// order as the incoming. This is due to some sftp clients requiring this
|
||||
|
@ -17,6 +20,7 @@ type packetManager struct {
|
|||
incoming requestPacketIDs
|
||||
outgoing responsePackets
|
||||
sender packetSender // connection object
|
||||
working *sync.WaitGroup
|
||||
}
|
||||
|
||||
func newPktMgr(sender packetSender) packetManager {
|
||||
|
@ -27,6 +31,7 @@ func newPktMgr(sender packetSender) packetManager {
|
|||
incoming: make([]uint32, 0, sftpServerWorkerCount),
|
||||
outgoing: make([]responsePacket, 0, sftpServerWorkerCount),
|
||||
sender: sender,
|
||||
working: &sync.WaitGroup{},
|
||||
}
|
||||
go s.worker()
|
||||
return s
|
||||
|
@ -35,12 +40,14 @@ func newPktMgr(sender packetSender) packetManager {
|
|||
// register incoming packets to be handled
|
||||
// send id of 0 for packets without id
|
||||
func (s packetManager) incomingPacket(pkt requestPacket) {
|
||||
s.working.Add(1)
|
||||
s.requests <- pkt // buffer == sftpServerWorkerCount
|
||||
}
|
||||
|
||||
// register outgoing packets as being ready
|
||||
func (s packetManager) readyPacket(pkt responsePacket) {
|
||||
s.responses <- pkt
|
||||
s.working.Done()
|
||||
}
|
||||
|
||||
// shut down packetManager worker
|
||||
|
|
58
server.go
58
server.go
|
@ -278,20 +278,58 @@ func handlePacket(s *Server, p interface{}) error {
|
|||
}
|
||||
}
|
||||
|
||||
type requestChan chan requestPacket
|
||||
|
||||
func (svr *Server) sftpServerWorkers(worker func(requestChan)) requestChan {
|
||||
|
||||
rwChan := make(chan requestPacket, sftpServerWorkerCount)
|
||||
for i := 0; i < sftpServerWorkerCount; i++ {
|
||||
go worker(rwChan)
|
||||
}
|
||||
|
||||
cmdChan := make(chan requestPacket)
|
||||
go worker(cmdChan)
|
||||
|
||||
pktChan := make(chan requestPacket, sftpServerWorkerCount)
|
||||
go func() {
|
||||
// start with cmdChan
|
||||
curChan := cmdChan
|
||||
for pkt := range pktChan {
|
||||
// on file open packet, switch to rwChan
|
||||
switch pkt.(type) {
|
||||
case *sshFxpOpenPacket:
|
||||
curChan = rwChan
|
||||
// on file close packet, switch back to cmdChan
|
||||
// after waiting for any reads/writes to finish
|
||||
case *sshFxpClosePacket:
|
||||
// wait for rwChan to finish
|
||||
svr.pktMgr.working.Wait()
|
||||
// stop using rwChan
|
||||
curChan = cmdChan
|
||||
}
|
||||
svr.pktMgr.incomingPacket(pkt)
|
||||
curChan <- pkt
|
||||
}
|
||||
close(rwChan)
|
||||
close(cmdChan)
|
||||
}()
|
||||
|
||||
return pktChan
|
||||
}
|
||||
|
||||
// Serve serves SFTP connections until the streams stop or the SFTP subsystem
|
||||
// is stopped.
|
||||
func (svr *Server) Serve() error {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(sftpServerWorkerCount)
|
||||
pktChan := make(chan requestPacket, sftpServerWorkerCount)
|
||||
for i := 0; i < sftpServerWorkerCount; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := svr.sftpServerWorker(pktChan); err != nil {
|
||||
svr.conn.Close() // shuts down recvPacket
|
||||
}
|
||||
}()
|
||||
wg.Add(1)
|
||||
workerFunc := func(ch requestChan) {
|
||||
wg.Add(1)
|
||||
defer wg.Done()
|
||||
if err := svr.sftpServerWorker(ch); err != nil {
|
||||
svr.conn.Close() // shuts down recvPacket
|
||||
}
|
||||
}
|
||||
pktChan := svr.sftpServerWorkers(workerFunc)
|
||||
|
||||
var err error
|
||||
var pkt requestPacket
|
||||
|
@ -310,9 +348,9 @@ func (svr *Server) Serve() error {
|
|||
break
|
||||
}
|
||||
|
||||
svr.pktMgr.incomingPacket(pkt)
|
||||
pktChan <- pkt
|
||||
}
|
||||
wg.Done()
|
||||
|
||||
close(pktChan) // shuts down sftpServerWorkers
|
||||
wg.Wait() // wait for all workers to exit
|
||||
|
|
Loading…
Reference in New Issue