mirror of https://github.com/pkg/sftp.git
factor out request channels
This commit is contained in:
parent
a253a470f0
commit
f3ebdef6de
|
@ -40,10 +40,10 @@ type Handlers struct {
|
|||
|
||||
// Server that abstracts the sftp protocol for a http request-like protocol
|
||||
type RequestServer struct {
|
||||
Handlers *Handlers
|
||||
serverConn
|
||||
Handlers Handlers
|
||||
debugStream io.Writer
|
||||
pktChan chan rxPacket
|
||||
pktChan chan packet
|
||||
openRequests map[string]*Request
|
||||
openRequestLock sync.RWMutex
|
||||
}
|
||||
|
@ -59,7 +59,7 @@ func NewRequestServer(rwc io.ReadWriteCloser) (*RequestServer, error) {
|
|||
},
|
||||
},
|
||||
debugStream: ioutil.Discard,
|
||||
pktChan: make(chan rxPacket, sftpServerWorkerCount),
|
||||
pktChan: make(chan packet, sftpServerWorkerCount),
|
||||
openRequests: make(map[string]*Request),
|
||||
}
|
||||
|
||||
|
@ -85,7 +85,6 @@ func (rs *RequestServer) closeRequest(handle string) {
|
|||
defer rs.openRequestLock.Unlock()
|
||||
if _, ok := rs.openRequests[handle]; ok {
|
||||
delete(rs.openRequests, handle)
|
||||
// Do Requests need cleanup?
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -108,7 +107,9 @@ func (rs *RequestServer) Serve() error {
|
|||
for {
|
||||
pktType, pktBytes, err = rs.recvPacket()
|
||||
if err != nil { break }
|
||||
rs.pktChan <- rxPacket{fxp(pktType), pktBytes}
|
||||
pkt, err := makePacket(rxPacket{fxp(pktType), pktBytes})
|
||||
if err != nil { break }
|
||||
rs.pktChan <- pkt
|
||||
}
|
||||
|
||||
close(rs.pktChan) // shuts down sftpServerWorkers
|
||||
|
@ -116,16 +117,8 @@ func (rs *RequestServer) Serve() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// make packet
|
||||
// handle special cases
|
||||
// convert to request
|
||||
// call RequestHandler
|
||||
// send feedback
|
||||
func (rs *RequestServer) packetWorker() error {
|
||||
for p := range rs.pktChan {
|
||||
pkt, err := makePacket(p)
|
||||
if err != nil { return err }
|
||||
|
||||
for pkt := range rs.pktChan {
|
||||
// handle packet specific pre-processing
|
||||
var handle string
|
||||
switch pkt := pkt.(type) {
|
||||
|
@ -133,34 +126,33 @@ func (rs *RequestServer) packetWorker() error {
|
|||
err := rs.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil})
|
||||
if err != nil { return err }
|
||||
continue
|
||||
case *sshFxpOpenPacket:
|
||||
handle = rs.nextRequest(newRequest(pkt.getPath()))
|
||||
err := rs.sendPacket(sshFxpHandlePacket{pkt.id(), handle})
|
||||
if err != nil { return err }
|
||||
continue
|
||||
case *sshFxpOpendirPacket:
|
||||
handle = rs.nextRequest(newRequest(pkt.getPath()))
|
||||
err := rs.sendPacket(sshFxpHandlePacket{pkt.id(), handle})
|
||||
if err != nil { return err }
|
||||
continue
|
||||
case *sshFxpClosePacket:
|
||||
handle = pkt.getHandle()
|
||||
rs.closeRequest(handle)
|
||||
err := rs.sendError(pkt, nil)
|
||||
if err != nil { return err }
|
||||
continue
|
||||
case *sshFxpOpenPacket:
|
||||
handle = rs.nextRequest(newRequest(pkt.getPath(), *rs.Handlers))
|
||||
err := rs.sendPacket(sshFxpHandlePacket{pkt.id(), handle})
|
||||
if err != nil { return err }
|
||||
continue
|
||||
case *sshFxpOpendirPacket:
|
||||
handle = rs.nextRequest(newRequest(pkt.getPath(), *rs.Handlers))
|
||||
err := rs.sendPacket(sshFxpHandlePacket{pkt.id(), handle})
|
||||
if err != nil { return err }
|
||||
continue
|
||||
case hasHandle:
|
||||
handle = pkt.getHandle()
|
||||
case hasPath:
|
||||
handle = rs.nextRequest(newRequest(pkt.getPath(), *rs.Handlers))
|
||||
handle = rs.nextRequest(newRequest(pkt.getPath()))
|
||||
}
|
||||
|
||||
request, ok := rs.getRequest(handle)
|
||||
if !ok { return rs.sendError(pkt, syscall.EBADF) }
|
||||
|
||||
// send packet to request handler and wait for response
|
||||
request.pktChan <- pkt
|
||||
resp := <-request.rspChan
|
||||
if resp.err != nil { rs.sendError(resp.pkt, err) }
|
||||
resp := request.handleRequest(rs.Handlers, pkt)
|
||||
if resp.err != nil { rs.sendError(resp.pkt, resp.err) }
|
||||
rs.sendPacket(resp.pkt)
|
||||
}
|
||||
return nil
|
||||
|
|
105
request.go
105
request.go
|
@ -20,85 +20,87 @@ type Request struct {
|
|||
Pflags uint32
|
||||
Attrs []byte // convert to sub-struct
|
||||
Target string // for renames and sym-links
|
||||
data []byte
|
||||
length uint32
|
||||
pktChan chan packet
|
||||
rspChan chan response
|
||||
handlers Handlers
|
||||
// packet data
|
||||
pkt_id uint32
|
||||
data []byte
|
||||
length uint32
|
||||
// reader/writer from handlers
|
||||
put_writer io.Writer
|
||||
get_reader io.Reader
|
||||
}
|
||||
|
||||
func newRequest(path string, handlers Handlers) *Request {
|
||||
request := &Request{Filepath: path, handlers: handlers}
|
||||
go request.requestWorker()
|
||||
func newRequest(path string) *Request {
|
||||
request := &Request{Filepath: path}
|
||||
return request
|
||||
}
|
||||
|
||||
func (r *Request) close() {
|
||||
close(r.pktChan)
|
||||
close(r.rspChan)
|
||||
}
|
||||
|
||||
func (r *Request) requestWorker() {
|
||||
for pkt := range r.pktChan {
|
||||
r.populate(pkt)
|
||||
handlers := r.handlers
|
||||
var err error
|
||||
var rpkt resp_packet
|
||||
switch r.Method {
|
||||
case "Get":
|
||||
rpkt, err = fileget(handlers.FileGet, r, pkt.id())
|
||||
case "Put":
|
||||
rpkt, err = fileput(handlers.FilePut, r, pkt.id())
|
||||
case "SetStat", "Rename", "Rmdir", "Mkdir", "Symlink":
|
||||
rpkt, err = filecmd(handlers.FileCmd, r, pkt.id())
|
||||
case "List", "Stat", "Readlink":
|
||||
rpkt, err = fileinfo(handlers.FileInfo, r, pkt.id())
|
||||
}
|
||||
if err != nil { r.rspChan <- response{nil, err} }
|
||||
r.rspChan <- response{rpkt, nil}
|
||||
func (r *Request) handleRequest(handlers Handlers, pkt packet) response {
|
||||
r.populate(pkt)
|
||||
var err error
|
||||
var rpkt resp_packet
|
||||
switch r.Method {
|
||||
case "Get":
|
||||
rpkt, err = fileget(handlers.FileGet, r)
|
||||
case "Put":
|
||||
rpkt, err = fileput(handlers.FilePut, r)
|
||||
case "SetStat", "Rename", "Rmdir", "Mkdir", "Symlink":
|
||||
rpkt, err = filecmd(handlers.FileCmd, r)
|
||||
case "List", "Stat", "Readlink":
|
||||
rpkt, err = fileinfo(handlers.FileInfo, r)
|
||||
}
|
||||
if err != nil { return response{nil, err} }
|
||||
return response{rpkt, nil}
|
||||
}
|
||||
|
||||
func fileget(h FileReader, r *Request, pkt_id uint32) (resp_packet, error) {
|
||||
reader, err := h.Fileread(r)
|
||||
if err != nil { return nil, syscall.EBADF }
|
||||
func fileget(h FileReader, r *Request) (resp_packet, error) {
|
||||
if r.get_reader == nil {
|
||||
reader, err := h.Fileread(r)
|
||||
if err != nil { return nil, syscall.EBADF }
|
||||
r.get_reader = reader
|
||||
}
|
||||
reader := r.get_reader
|
||||
data := make([]byte, clamp(r.length, maxTxPacket))
|
||||
n, err := reader.Read(data)
|
||||
if err != nil && (err != io.EOF || n == 0) { return nil, err }
|
||||
return &sshFxpDataPacket{
|
||||
ID: pkt_id,
|
||||
ID: r.pkt_id,
|
||||
Length: uint32(n),
|
||||
Data: r.data[:n],
|
||||
}, nil
|
||||
}
|
||||
func fileput(h FileWriter, r *Request, pkt_id uint32) (resp_packet, error) {
|
||||
writer, err := h.Filewrite(r)
|
||||
if err != nil { return nil, syscall.EBADF }
|
||||
_, err = writer.Write(r.data)
|
||||
func fileput(h FileWriter, r *Request) (resp_packet, error) {
|
||||
if r.put_writer == nil {
|
||||
writer, err := h.Filewrite(r)
|
||||
if err != nil { return nil, syscall.EBADF }
|
||||
r.put_writer = writer
|
||||
}
|
||||
writer := r.put_writer
|
||||
|
||||
_, err := writer.Write(r.data)
|
||||
if err != nil { return nil, err }
|
||||
return &sshFxpStatusPacket{
|
||||
ID: pkt_id,
|
||||
ID: r.pkt_id,
|
||||
StatusError: StatusError{
|
||||
Code: ssh_FX_OK,
|
||||
}}, nil
|
||||
}
|
||||
func filecmd(h FileCmder, r *Request, pkt_id uint32) (resp_packet, error) {
|
||||
func filecmd(h FileCmder, r *Request) (resp_packet, error) {
|
||||
err := h.Filecmd(r)
|
||||
if err != nil { return nil, err }
|
||||
return sshFxpStatusPacket{
|
||||
ID: pkt_id,
|
||||
ID: r.pkt_id,
|
||||
StatusError: StatusError{
|
||||
Code: ssh_FX_OK,
|
||||
}}, nil
|
||||
}
|
||||
func fileinfo(h FileInfoer, r *Request, pkt_id uint32) (resp_packet, error) {
|
||||
func fileinfo(h FileInfoer, r *Request) (resp_packet, error) {
|
||||
finfo, err := h.Fileinfo(r)
|
||||
if err != nil { return nil, err }
|
||||
|
||||
switch r.Method {
|
||||
case "List":
|
||||
dirname := path.Base(r.Filepath)
|
||||
ret := sshFxpNamePacket{ID: pkt_id}
|
||||
ret := sshFxpNamePacket{ID: r.pkt_id}
|
||||
for _, fi := range finfo {
|
||||
ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{
|
||||
Name: fi.Name(),
|
||||
|
@ -112,7 +114,7 @@ func fileinfo(h FileInfoer, r *Request, pkt_id uint32) (resp_packet, error) {
|
|||
return nil, err
|
||||
}
|
||||
return &sshFxpStatResponse{
|
||||
ID: pkt_id,
|
||||
ID: r.pkt_id,
|
||||
info: finfo[0],
|
||||
}, nil
|
||||
case "Readlink":
|
||||
|
@ -121,7 +123,7 @@ func fileinfo(h FileInfoer, r *Request, pkt_id uint32) (resp_packet, error) {
|
|||
return nil, err
|
||||
}
|
||||
return sshFxpNamePacket{
|
||||
ID: pkt_id,
|
||||
ID: r.pkt_id,
|
||||
NameAttrs: []sshFxpNameAttr{{
|
||||
Name: finfo[0].Name(),
|
||||
LongName: finfo[0].Name(),
|
||||
|
@ -139,36 +141,47 @@ func (r *Request) populate(p interface{}) {
|
|||
r.Method = "Setstat"
|
||||
r.Pflags = p.Flags
|
||||
r.Attrs = p.Attrs.([]byte)
|
||||
r.pkt_id = p.id()
|
||||
case *sshFxpFsetstatPacket:
|
||||
r.Method = "Setstat"
|
||||
r.Pflags = p.Flags
|
||||
r.Attrs = p.Attrs.([]byte)
|
||||
r.pkt_id = p.id()
|
||||
case *sshFxpRenamePacket:
|
||||
r.Method = "Rename"
|
||||
r.Target = p.Newpath
|
||||
r.pkt_id = p.id()
|
||||
case *sshFxpSymlinkPacket:
|
||||
r.Method = "Symlink"
|
||||
r.Target = p.Linkpath
|
||||
r.pkt_id = p.id()
|
||||
case *sshFxpReadPacket:
|
||||
r.Method = "Get"
|
||||
r.length = p.Len
|
||||
r.pkt_id = p.id()
|
||||
case *sshFxpWritePacket:
|
||||
r.Method = "Put"
|
||||
r.data = p.Data
|
||||
r.length = p.Length
|
||||
r.pkt_id = p.id()
|
||||
// below here method and path are all the data
|
||||
case *sshFxpReaddirPacket:
|
||||
r.Method = "List"
|
||||
r.pkt_id = p.id()
|
||||
case *sshFxpStatPacket, *sshFxpLstatPacket, *sshFxpFstatPacket,
|
||||
*sshFxpRealpathPacket, *sshFxpRemovePacket:
|
||||
r.Method = "Stat"
|
||||
r.pkt_id = p.(packet).id()
|
||||
case *sshFxpRmdirPacket:
|
||||
r.Method = "Rmdir"
|
||||
r.pkt_id = p.id()
|
||||
case *sshFxpReadlinkPacket:
|
||||
r.Method = "Readlink"
|
||||
r.pkt_id = p.id()
|
||||
// special cases
|
||||
case *sshFxpMkdirPacket:
|
||||
r.Method = "Mkdir"
|
||||
r.pkt_id = p.id()
|
||||
//r.Attrs are ignored in ./packet.go
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue