mirror of https://github.com/pkg/sftp.git
				
				
				
			
		
			
				
	
	
		
			218 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			218 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Go
		
	
	
	
| package sftp
 | |
| 
 | |
| import (
 | |
| 	"encoding"
 | |
| 	"sort"
 | |
| 	"sync"
 | |
| )
 | |
| 
 | |
| // The goal of the packetManager is to keep the outgoing packets in the same
 | |
| // order as the incoming as is requires by section 7 of the RFC.
 | |
| 
 | |
| type packetManager struct {
 | |
| 	requests    chan orderedPacket
 | |
| 	responses   chan orderedPacket
 | |
| 	fini        chan struct{}
 | |
| 	incoming    orderedPackets
 | |
| 	outgoing    orderedPackets
 | |
| 	sender      packetSender // connection object
 | |
| 	working     *sync.WaitGroup
 | |
| 	packetCount uint32
 | |
| 	// it is not nil if the allocator is enabled
 | |
| 	alloc *allocator
 | |
| }
 | |
| 
 | |
| type packetSender interface {
 | |
| 	sendPacket(encoding.BinaryMarshaler) error
 | |
| }
 | |
| 
 | |
| func newPktMgr(sender packetSender) *packetManager {
 | |
| 	s := &packetManager{
 | |
| 		requests:  make(chan orderedPacket, SftpServerWorkerCount),
 | |
| 		responses: make(chan orderedPacket, SftpServerWorkerCount),
 | |
| 		fini:      make(chan struct{}),
 | |
| 		incoming:  make([]orderedPacket, 0, SftpServerWorkerCount),
 | |
| 		outgoing:  make([]orderedPacket, 0, SftpServerWorkerCount),
 | |
| 		sender:    sender,
 | |
| 		working:   &sync.WaitGroup{},
 | |
| 	}
 | |
| 	go s.controller()
 | |
| 	return s
 | |
| }
 | |
| 
 | |
| //// packet ordering
 | |
| func (s *packetManager) newOrderID() uint32 {
 | |
| 	s.packetCount++
 | |
| 	return s.packetCount
 | |
| }
 | |
| 
 | |
| // returns the next orderID without incrementing it.
 | |
| // This is used before receiving a new packet, with the allocator enabled, to associate
 | |
| // the slice allocated for the received packet with the orderID that will be used to mark
 | |
| // the allocated slices for reuse once the request is served
 | |
| func (s *packetManager) getNextOrderID() uint32 {
 | |
| 	return s.packetCount + 1
 | |
| }
 | |
| 
 | |
| type orderedRequest struct {
 | |
| 	requestPacket
 | |
| 	orderid uint32
 | |
| }
 | |
| 
 | |
| func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest {
 | |
| 	return orderedRequest{requestPacket: p, orderid: s.newOrderID()}
 | |
| }
 | |
| func (p orderedRequest) orderID() uint32       { return p.orderid }
 | |
| func (p orderedRequest) setOrderID(oid uint32) { p.orderid = oid }
 | |
| 
 | |
| type orderedResponse struct {
 | |
| 	responsePacket
 | |
| 	orderid uint32
 | |
| }
 | |
| 
 | |
| func (s *packetManager) newOrderedResponse(p responsePacket, id uint32,
 | |
| ) orderedResponse {
 | |
| 	return orderedResponse{responsePacket: p, orderid: id}
 | |
| }
 | |
| func (p orderedResponse) orderID() uint32       { return p.orderid }
 | |
| func (p orderedResponse) setOrderID(oid uint32) { p.orderid = oid }
 | |
| 
 | |
| type orderedPacket interface {
 | |
| 	id() uint32
 | |
| 	orderID() uint32
 | |
| }
 | |
| type orderedPackets []orderedPacket
 | |
| 
 | |
| func (o orderedPackets) Sort() {
 | |
| 	sort.Slice(o, func(i, j int) bool {
 | |
| 		return o[i].orderID() < o[j].orderID()
 | |
| 	})
 | |
| }
 | |
| 
 | |
| //// packet registry
 | |
| // register incoming packets to be handled
 | |
| func (s *packetManager) incomingPacket(pkt orderedRequest) {
 | |
| 	s.working.Add(1)
 | |
| 	s.requests <- pkt
 | |
| }
 | |
| 
 | |
| // register outgoing packets as being ready
 | |
| func (s *packetManager) readyPacket(pkt orderedResponse) {
 | |
| 	s.responses <- pkt
 | |
| 	s.working.Done()
 | |
| }
 | |
| 
 | |
| // shut down packetManager controller
 | |
| func (s *packetManager) close() {
 | |
| 	// pause until current packets are processed
 | |
| 	s.working.Wait()
 | |
| 	close(s.fini)
 | |
| }
 | |
| 
 | |
| // Passed a worker function, returns a channel for incoming packets.
 | |
| // Keep process packet responses in the order they are received while
 | |
| // maximizing throughput of file transfers.
 | |
| func (s *packetManager) workerChan(runWorker func(chan orderedRequest),
 | |
| ) chan orderedRequest {
 | |
| 
 | |
| 	// multiple workers for faster read/writes
 | |
| 	rwChan := make(chan orderedRequest, SftpServerWorkerCount)
 | |
| 	for i := 0; i < SftpServerWorkerCount; i++ {
 | |
| 		runWorker(rwChan)
 | |
| 	}
 | |
| 
 | |
| 	// single worker to enforce sequential processing of everything else
 | |
| 	cmdChan := make(chan orderedRequest)
 | |
| 	runWorker(cmdChan)
 | |
| 
 | |
| 	pktChan := make(chan orderedRequest, SftpServerWorkerCount)
 | |
| 	go func() {
 | |
| 		for pkt := range pktChan {
 | |
| 			switch pkt.requestPacket.(type) {
 | |
| 			case *sshFxpReadPacket, *sshFxpWritePacket:
 | |
| 				s.incomingPacket(pkt)
 | |
| 				rwChan <- pkt
 | |
| 				continue
 | |
| 			case *sshFxpClosePacket:
 | |
| 				// wait for reads/writes to finish when file is closed
 | |
| 				// incomingPacket() call must occur after this
 | |
| 				s.working.Wait()
 | |
| 			}
 | |
| 			s.incomingPacket(pkt)
 | |
| 			// all non-RW use sequential cmdChan
 | |
| 			cmdChan <- pkt
 | |
| 		}
 | |
| 		close(rwChan)
 | |
| 		close(cmdChan)
 | |
| 		s.close()
 | |
| 	}()
 | |
| 
 | |
| 	return pktChan
 | |
| }
 | |
| 
 | |
| // process packets
 | |
| func (s *packetManager) controller() {
 | |
| 	for {
 | |
| 		select {
 | |
| 		case pkt := <-s.requests:
 | |
| 			debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderID())
 | |
| 			s.incoming = append(s.incoming, pkt)
 | |
| 			s.incoming.Sort()
 | |
| 		case pkt := <-s.responses:
 | |
| 			debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderID())
 | |
| 			s.outgoing = append(s.outgoing, pkt)
 | |
| 			s.outgoing.Sort()
 | |
| 		case <-s.fini:
 | |
| 			return
 | |
| 		}
 | |
| 		s.maybeSendPackets()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // send as many packets as are ready
 | |
| func (s *packetManager) maybeSendPackets() {
 | |
| 	for {
 | |
| 		if len(s.outgoing) == 0 || len(s.incoming) == 0 {
 | |
| 			debug("break! -- outgoing: %v; incoming: %v",
 | |
| 				len(s.outgoing), len(s.incoming))
 | |
| 			break
 | |
| 		}
 | |
| 		out := s.outgoing[0]
 | |
| 		in := s.incoming[0]
 | |
| 		// debug("incoming: %v", ids(s.incoming))
 | |
| 		// debug("outgoing: %v", ids(s.outgoing))
 | |
| 		if in.orderID() == out.orderID() {
 | |
| 			debug("Sending packet: %v", out.id())
 | |
| 			s.sender.sendPacket(out.(encoding.BinaryMarshaler))
 | |
| 			if s.alloc != nil {
 | |
| 				// mark for reuse the slices allocated for this request
 | |
| 				s.alloc.ReleasePages(in.orderID())
 | |
| 			}
 | |
| 			// pop off heads
 | |
| 			copy(s.incoming, s.incoming[1:])            // shift left
 | |
| 			s.incoming[len(s.incoming)-1] = nil         // clear last
 | |
| 			s.incoming = s.incoming[:len(s.incoming)-1] // remove last
 | |
| 			copy(s.outgoing, s.outgoing[1:])            // shift left
 | |
| 			s.outgoing[len(s.outgoing)-1] = nil         // clear last
 | |
| 			s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
 | |
| 		} else {
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // func oids(o []orderedPacket) []uint32 {
 | |
| // 	res := make([]uint32, 0, len(o))
 | |
| // 	for _, v := range o {
 | |
| // 		res = append(res, v.orderId())
 | |
| // 	}
 | |
| // 	return res
 | |
| // }
 | |
| // func ids(o []orderedPacket) []uint32 {
 | |
| // 	res := make([]uint32, 0, len(o))
 | |
| // 	for _, v := range o {
 | |
| // 		res = append(res, v.id())
 | |
| // 	}
 | |
| // 	return res
 | |
| // }
 |