sftp/request-server.go

153 lines
3.5 KiB
Go
Raw Normal View History

package sftp
import (
"io"
"io/ioutil"
"sync"
2016-07-09 03:38:35 +08:00
"syscall"
)
// Server takes the dataHandler and openHandler as arguments
// starts up packet handlers
// packet handlers convert packets to datas
// call dataHandler with data
// is done with packet/data
//
// dataHandler should call Handler() on data to process data and
// reply to client
//
// tricky bit about reading/writing spinning up workers to handle all packets
// datas using Id for switch
// + only 1 type + const
// - duplicates sftp prot Id
// datas using data-type for switch
// + types as types
// + type.Handle could enforce type of arg
// - requires dummy interface only for typing
2016-07-09 03:38:35 +08:00
var maxTxPacket uint32 = 1 << 15
2016-07-09 03:38:35 +08:00
type handleHandler func(string) string
2016-07-09 03:38:35 +08:00
type Handlers struct {
FileGet FileReader
FilePut FileWriter
FileCmd FileCmder
FileInfo FileInfoer
2016-07-09 03:38:35 +08:00
}
// Server that abstracts the sftp protocol for a http request-like protocol
type RequestServer struct {
2016-07-12 02:06:08 +08:00
Handlers *Handlers
serverConn
2016-07-09 03:38:35 +08:00
debugStream io.Writer
pktChan chan rxPacket
openRequests map[string]*Request
openRequestLock sync.RWMutex
}
// simple factory function
// one server per user-session
2016-07-12 01:58:51 +08:00
func NewRequestServer(rwc io.ReadWriteCloser) (*RequestServer, error) {
s := &RequestServer{
serverConn: serverConn{
conn: conn{
Reader: rwc,
WriteCloser: rwc,
},
},
2016-07-09 03:38:35 +08:00
debugStream: ioutil.Discard,
pktChan: make(chan rxPacket, sftpServerWorkerCount),
openRequests: make(map[string]*Request),
}
return s, nil
}
2016-07-09 03:38:35 +08:00
func (rs *RequestServer) nextRequest(r *Request) string {
rs.openRequestLock.Lock()
defer rs.openRequestLock.Unlock()
rs.openRequests[r.Filepath] = r
return r.Filepath
}
func (rs *RequestServer) getRequest(handle string) (*Request, bool) {
rs.openRequestLock.Lock()
defer rs.openRequestLock.Unlock()
r, ok := rs.openRequests[handle]
return r, ok
}
func (rs *RequestServer) closeRequest(handle string) {
rs.openRequestLock.Lock()
defer rs.openRequestLock.Unlock()
if _, ok := rs.openRequests[handle]; ok {
delete(rs.openRequests, handle)
// Do Requests need cleanup?
}
}
// start serving requests from user session
2016-07-09 03:38:35 +08:00
func (rs *RequestServer) Serve() error {
var wg sync.WaitGroup
wg.Add(sftpServerWorkerCount)
for i := 0; i < sftpServerWorkerCount; i++ {
go func() {
defer wg.Done()
2016-07-09 03:38:35 +08:00
if err := rs.packetWorker(); err != nil {
rs.conn.Close() // shuts down recvPacket
}
}()
}
var err error
var pktType uint8
var pktBytes []byte
for {
2016-07-09 03:38:35 +08:00
pktType, pktBytes, err = rs.recvPacket()
if err != nil { break }
2016-07-09 03:38:35 +08:00
rs.pktChan <- rxPacket{fxp(pktType), pktBytes}
}
2016-07-09 03:38:35 +08:00
close(rs.pktChan) // shuts down sftpServerWorkers
wg.Wait() // wait for all workers to exit
return err
}
// make packet
// handle special cases
// convert to request
// call RequestHandler
// send feedback
2016-07-09 03:38:35 +08:00
func (rs *RequestServer) packetWorker() error {
for p := range rs.pktChan {
pkt, err := makePacket(p)
if err != nil { return err }
2016-07-09 03:38:35 +08:00
// handle packet specific pre-processing
var handle string
switch pkt := pkt.(type) {
case *sshFxInitPacket:
2016-07-09 03:38:35 +08:00
err := rs.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil})
if err != nil { return err }
continue
case *sshFxpClosePacket:
handle = pkt.getHandle()
rs.closeRequest(handle)
err := rs.sendError(pkt, nil)
if err != nil { return err }
continue
case hasPath:
2016-07-09 08:22:52 +08:00
handle = rs.nextRequest(newRequest(pkt.getPath(), rs))
case hasHandle:
2016-07-09 03:38:35 +08:00
handle = pkt.getHandle()
}
2016-07-09 03:38:35 +08:00
request, ok := rs.getRequest(handle)
if !ok { return rs.sendError(pkt, syscall.EBADF) }
request.pktChan <- pkt
}
return nil
}