factor out request channels

This commit is contained in:
John Eikenberry 2016-07-11 20:19:49 -07:00
parent a253a470f0
commit f3ebdef6de
2 changed files with 80 additions and 75 deletions

View File

@ -40,10 +40,10 @@ type Handlers struct {
// Server that abstracts the sftp protocol for a http request-like protocol
type RequestServer struct {
Handlers *Handlers
serverConn
Handlers Handlers
debugStream io.Writer
pktChan chan rxPacket
pktChan chan packet
openRequests map[string]*Request
openRequestLock sync.RWMutex
}
@ -59,7 +59,7 @@ func NewRequestServer(rwc io.ReadWriteCloser) (*RequestServer, error) {
},
},
debugStream: ioutil.Discard,
pktChan: make(chan rxPacket, sftpServerWorkerCount),
pktChan: make(chan packet, sftpServerWorkerCount),
openRequests: make(map[string]*Request),
}
@ -85,7 +85,6 @@ func (rs *RequestServer) closeRequest(handle string) {
defer rs.openRequestLock.Unlock()
if _, ok := rs.openRequests[handle]; ok {
delete(rs.openRequests, handle)
// Do Requests need cleanup?
}
}
@ -108,7 +107,9 @@ func (rs *RequestServer) Serve() error {
for {
pktType, pktBytes, err = rs.recvPacket()
if err != nil { break }
rs.pktChan <- rxPacket{fxp(pktType), pktBytes}
pkt, err := makePacket(rxPacket{fxp(pktType), pktBytes})
if err != nil { break }
rs.pktChan <- pkt
}
close(rs.pktChan) // shuts down sftpServerWorkers
@ -116,16 +117,8 @@ func (rs *RequestServer) Serve() error {
return err
}
// make packet
// handle special cases
// convert to request
// call RequestHandler
// send feedback
func (rs *RequestServer) packetWorker() error {
for p := range rs.pktChan {
pkt, err := makePacket(p)
if err != nil { return err }
for pkt := range rs.pktChan {
// handle packet specific pre-processing
var handle string
switch pkt := pkt.(type) {
@ -133,34 +126,33 @@ func (rs *RequestServer) packetWorker() error {
err := rs.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil})
if err != nil { return err }
continue
case *sshFxpOpenPacket:
handle = rs.nextRequest(newRequest(pkt.getPath()))
err := rs.sendPacket(sshFxpHandlePacket{pkt.id(), handle})
if err != nil { return err }
continue
case *sshFxpOpendirPacket:
handle = rs.nextRequest(newRequest(pkt.getPath()))
err := rs.sendPacket(sshFxpHandlePacket{pkt.id(), handle})
if err != nil { return err }
continue
case *sshFxpClosePacket:
handle = pkt.getHandle()
rs.closeRequest(handle)
err := rs.sendError(pkt, nil)
if err != nil { return err }
continue
case *sshFxpOpenPacket:
handle = rs.nextRequest(newRequest(pkt.getPath(), *rs.Handlers))
err := rs.sendPacket(sshFxpHandlePacket{pkt.id(), handle})
if err != nil { return err }
continue
case *sshFxpOpendirPacket:
handle = rs.nextRequest(newRequest(pkt.getPath(), *rs.Handlers))
err := rs.sendPacket(sshFxpHandlePacket{pkt.id(), handle})
if err != nil { return err }
continue
case hasHandle:
handle = pkt.getHandle()
case hasPath:
handle = rs.nextRequest(newRequest(pkt.getPath(), *rs.Handlers))
handle = rs.nextRequest(newRequest(pkt.getPath()))
}
request, ok := rs.getRequest(handle)
if !ok { return rs.sendError(pkt, syscall.EBADF) }
// send packet to request handler and wait for response
request.pktChan <- pkt
resp := <-request.rspChan
if resp.err != nil { rs.sendError(resp.pkt, err) }
resp := request.handleRequest(rs.Handlers, pkt)
if resp.err != nil { rs.sendError(resp.pkt, resp.err) }
rs.sendPacket(resp.pkt)
}
return nil

View File

@ -20,85 +20,87 @@ type Request struct {
Pflags uint32
Attrs []byte // convert to sub-struct
Target string // for renames and sym-links
data []byte
length uint32
pktChan chan packet
rspChan chan response
handlers Handlers
// packet data
pkt_id uint32
data []byte
length uint32
// reader/writer from handlers
put_writer io.Writer
get_reader io.Reader
}
func newRequest(path string, handlers Handlers) *Request {
request := &Request{Filepath: path, handlers: handlers}
go request.requestWorker()
func newRequest(path string) *Request {
request := &Request{Filepath: path}
return request
}
func (r *Request) close() {
close(r.pktChan)
close(r.rspChan)
}
func (r *Request) requestWorker() {
for pkt := range r.pktChan {
r.populate(pkt)
handlers := r.handlers
var err error
var rpkt resp_packet
switch r.Method {
case "Get":
rpkt, err = fileget(handlers.FileGet, r, pkt.id())
case "Put":
rpkt, err = fileput(handlers.FilePut, r, pkt.id())
case "SetStat", "Rename", "Rmdir", "Mkdir", "Symlink":
rpkt, err = filecmd(handlers.FileCmd, r, pkt.id())
case "List", "Stat", "Readlink":
rpkt, err = fileinfo(handlers.FileInfo, r, pkt.id())
}
if err != nil { r.rspChan <- response{nil, err} }
r.rspChan <- response{rpkt, nil}
func (r *Request) handleRequest(handlers Handlers, pkt packet) response {
r.populate(pkt)
var err error
var rpkt resp_packet
switch r.Method {
case "Get":
rpkt, err = fileget(handlers.FileGet, r)
case "Put":
rpkt, err = fileput(handlers.FilePut, r)
case "SetStat", "Rename", "Rmdir", "Mkdir", "Symlink":
rpkt, err = filecmd(handlers.FileCmd, r)
case "List", "Stat", "Readlink":
rpkt, err = fileinfo(handlers.FileInfo, r)
}
if err != nil { return response{nil, err} }
return response{rpkt, nil}
}
func fileget(h FileReader, r *Request, pkt_id uint32) (resp_packet, error) {
reader, err := h.Fileread(r)
if err != nil { return nil, syscall.EBADF }
func fileget(h FileReader, r *Request) (resp_packet, error) {
if r.get_reader == nil {
reader, err := h.Fileread(r)
if err != nil { return nil, syscall.EBADF }
r.get_reader = reader
}
reader := r.get_reader
data := make([]byte, clamp(r.length, maxTxPacket))
n, err := reader.Read(data)
if err != nil && (err != io.EOF || n == 0) { return nil, err }
return &sshFxpDataPacket{
ID: pkt_id,
ID: r.pkt_id,
Length: uint32(n),
Data: r.data[:n],
}, nil
}
func fileput(h FileWriter, r *Request, pkt_id uint32) (resp_packet, error) {
writer, err := h.Filewrite(r)
if err != nil { return nil, syscall.EBADF }
_, err = writer.Write(r.data)
func fileput(h FileWriter, r *Request) (resp_packet, error) {
if r.put_writer == nil {
writer, err := h.Filewrite(r)
if err != nil { return nil, syscall.EBADF }
r.put_writer = writer
}
writer := r.put_writer
_, err := writer.Write(r.data)
if err != nil { return nil, err }
return &sshFxpStatusPacket{
ID: pkt_id,
ID: r.pkt_id,
StatusError: StatusError{
Code: ssh_FX_OK,
}}, nil
}
func filecmd(h FileCmder, r *Request, pkt_id uint32) (resp_packet, error) {
func filecmd(h FileCmder, r *Request) (resp_packet, error) {
err := h.Filecmd(r)
if err != nil { return nil, err }
return sshFxpStatusPacket{
ID: pkt_id,
ID: r.pkt_id,
StatusError: StatusError{
Code: ssh_FX_OK,
}}, nil
}
func fileinfo(h FileInfoer, r *Request, pkt_id uint32) (resp_packet, error) {
func fileinfo(h FileInfoer, r *Request) (resp_packet, error) {
finfo, err := h.Fileinfo(r)
if err != nil { return nil, err }
switch r.Method {
case "List":
dirname := path.Base(r.Filepath)
ret := sshFxpNamePacket{ID: pkt_id}
ret := sshFxpNamePacket{ID: r.pkt_id}
for _, fi := range finfo {
ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{
Name: fi.Name(),
@ -112,7 +114,7 @@ func fileinfo(h FileInfoer, r *Request, pkt_id uint32) (resp_packet, error) {
return nil, err
}
return &sshFxpStatResponse{
ID: pkt_id,
ID: r.pkt_id,
info: finfo[0],
}, nil
case "Readlink":
@ -121,7 +123,7 @@ func fileinfo(h FileInfoer, r *Request, pkt_id uint32) (resp_packet, error) {
return nil, err
}
return sshFxpNamePacket{
ID: pkt_id,
ID: r.pkt_id,
NameAttrs: []sshFxpNameAttr{{
Name: finfo[0].Name(),
LongName: finfo[0].Name(),
@ -139,36 +141,47 @@ func (r *Request) populate(p interface{}) {
r.Method = "Setstat"
r.Pflags = p.Flags
r.Attrs = p.Attrs.([]byte)
r.pkt_id = p.id()
case *sshFxpFsetstatPacket:
r.Method = "Setstat"
r.Pflags = p.Flags
r.Attrs = p.Attrs.([]byte)
r.pkt_id = p.id()
case *sshFxpRenamePacket:
r.Method = "Rename"
r.Target = p.Newpath
r.pkt_id = p.id()
case *sshFxpSymlinkPacket:
r.Method = "Symlink"
r.Target = p.Linkpath
r.pkt_id = p.id()
case *sshFxpReadPacket:
r.Method = "Get"
r.length = p.Len
r.pkt_id = p.id()
case *sshFxpWritePacket:
r.Method = "Put"
r.data = p.Data
r.length = p.Length
r.pkt_id = p.id()
// below here method and path are all the data
case *sshFxpReaddirPacket:
r.Method = "List"
r.pkt_id = p.id()
case *sshFxpStatPacket, *sshFxpLstatPacket, *sshFxpFstatPacket,
*sshFxpRealpathPacket, *sshFxpRemovePacket:
r.Method = "Stat"
r.pkt_id = p.(packet).id()
case *sshFxpRmdirPacket:
r.Method = "Rmdir"
r.pkt_id = p.id()
case *sshFxpReadlinkPacket:
r.Method = "Readlink"
r.pkt_id = p.id()
// special cases
case *sshFxpMkdirPacket:
r.Method = "Mkdir"
r.pkt_id = p.id()
//r.Attrs are ignored in ./packet.go
}
}