diff --git a/request-server.go b/request-server.go index 469d993..a38dad9 100644 --- a/request-server.go +++ b/request-server.go @@ -40,10 +40,10 @@ type Handlers struct { // Server that abstracts the sftp protocol for a http request-like protocol type RequestServer struct { - Handlers *Handlers serverConn + Handlers Handlers debugStream io.Writer - pktChan chan rxPacket + pktChan chan packet openRequests map[string]*Request openRequestLock sync.RWMutex } @@ -59,7 +59,7 @@ func NewRequestServer(rwc io.ReadWriteCloser) (*RequestServer, error) { }, }, debugStream: ioutil.Discard, - pktChan: make(chan rxPacket, sftpServerWorkerCount), + pktChan: make(chan packet, sftpServerWorkerCount), openRequests: make(map[string]*Request), } @@ -85,7 +85,6 @@ func (rs *RequestServer) closeRequest(handle string) { defer rs.openRequestLock.Unlock() if _, ok := rs.openRequests[handle]; ok { delete(rs.openRequests, handle) - // Do Requests need cleanup? } } @@ -108,7 +107,9 @@ func (rs *RequestServer) Serve() error { for { pktType, pktBytes, err = rs.recvPacket() if err != nil { break } - rs.pktChan <- rxPacket{fxp(pktType), pktBytes} + pkt, err := makePacket(rxPacket{fxp(pktType), pktBytes}) + if err != nil { break } + rs.pktChan <- pkt } close(rs.pktChan) // shuts down sftpServerWorkers @@ -116,16 +117,8 @@ func (rs *RequestServer) Serve() error { return err } -// make packet -// handle special cases -// convert to request -// call RequestHandler -// send feedback func (rs *RequestServer) packetWorker() error { - for p := range rs.pktChan { - pkt, err := makePacket(p) - if err != nil { return err } - + for pkt := range rs.pktChan { // handle packet specific pre-processing var handle string switch pkt := pkt.(type) { @@ -133,34 +126,33 @@ func (rs *RequestServer) packetWorker() error { err := rs.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil}) if err != nil { return err } continue + case *sshFxpOpenPacket: + handle = rs.nextRequest(newRequest(pkt.getPath())) + err := rs.sendPacket(sshFxpHandlePacket{pkt.id(), handle}) + if err != nil { return err } + continue + case *sshFxpOpendirPacket: + handle = rs.nextRequest(newRequest(pkt.getPath())) + err := rs.sendPacket(sshFxpHandlePacket{pkt.id(), handle}) + if err != nil { return err } + continue case *sshFxpClosePacket: handle = pkt.getHandle() rs.closeRequest(handle) err := rs.sendError(pkt, nil) if err != nil { return err } continue - case *sshFxpOpenPacket: - handle = rs.nextRequest(newRequest(pkt.getPath(), *rs.Handlers)) - err := rs.sendPacket(sshFxpHandlePacket{pkt.id(), handle}) - if err != nil { return err } - continue - case *sshFxpOpendirPacket: - handle = rs.nextRequest(newRequest(pkt.getPath(), *rs.Handlers)) - err := rs.sendPacket(sshFxpHandlePacket{pkt.id(), handle}) - if err != nil { return err } - continue case hasHandle: handle = pkt.getHandle() case hasPath: - handle = rs.nextRequest(newRequest(pkt.getPath(), *rs.Handlers)) + handle = rs.nextRequest(newRequest(pkt.getPath())) } + request, ok := rs.getRequest(handle) if !ok { return rs.sendError(pkt, syscall.EBADF) } - // send packet to request handler and wait for response - request.pktChan <- pkt - resp := <-request.rspChan - if resp.err != nil { rs.sendError(resp.pkt, err) } + resp := request.handleRequest(rs.Handlers, pkt) + if resp.err != nil { rs.sendError(resp.pkt, resp.err) } rs.sendPacket(resp.pkt) } return nil diff --git a/request.go b/request.go index aaeedfa..42e62d2 100644 --- a/request.go +++ b/request.go @@ -20,85 +20,87 @@ type Request struct { Pflags uint32 Attrs []byte // convert to sub-struct Target string // for renames and sym-links - data []byte - length uint32 - pktChan chan packet - rspChan chan response - handlers Handlers + // packet data + pkt_id uint32 + data []byte + length uint32 + // reader/writer from handlers + put_writer io.Writer + get_reader io.Reader } -func newRequest(path string, handlers Handlers) *Request { - request := &Request{Filepath: path, handlers: handlers} - go request.requestWorker() +func newRequest(path string) *Request { + request := &Request{Filepath: path} return request } -func (r *Request) close() { - close(r.pktChan) - close(r.rspChan) -} - -func (r *Request) requestWorker() { - for pkt := range r.pktChan { - r.populate(pkt) - handlers := r.handlers - var err error - var rpkt resp_packet - switch r.Method { - case "Get": - rpkt, err = fileget(handlers.FileGet, r, pkt.id()) - case "Put": - rpkt, err = fileput(handlers.FilePut, r, pkt.id()) - case "SetStat", "Rename", "Rmdir", "Mkdir", "Symlink": - rpkt, err = filecmd(handlers.FileCmd, r, pkt.id()) - case "List", "Stat", "Readlink": - rpkt, err = fileinfo(handlers.FileInfo, r, pkt.id()) - } - if err != nil { r.rspChan <- response{nil, err} } - r.rspChan <- response{rpkt, nil} +func (r *Request) handleRequest(handlers Handlers, pkt packet) response { + r.populate(pkt) + var err error + var rpkt resp_packet + switch r.Method { + case "Get": + rpkt, err = fileget(handlers.FileGet, r) + case "Put": + rpkt, err = fileput(handlers.FilePut, r) + case "SetStat", "Rename", "Rmdir", "Mkdir", "Symlink": + rpkt, err = filecmd(handlers.FileCmd, r) + case "List", "Stat", "Readlink": + rpkt, err = fileinfo(handlers.FileInfo, r) } + if err != nil { return response{nil, err} } + return response{rpkt, nil} } -func fileget(h FileReader, r *Request, pkt_id uint32) (resp_packet, error) { - reader, err := h.Fileread(r) - if err != nil { return nil, syscall.EBADF } +func fileget(h FileReader, r *Request) (resp_packet, error) { + if r.get_reader == nil { + reader, err := h.Fileread(r) + if err != nil { return nil, syscall.EBADF } + r.get_reader = reader + } + reader := r.get_reader data := make([]byte, clamp(r.length, maxTxPacket)) n, err := reader.Read(data) if err != nil && (err != io.EOF || n == 0) { return nil, err } return &sshFxpDataPacket{ - ID: pkt_id, + ID: r.pkt_id, Length: uint32(n), Data: r.data[:n], }, nil } -func fileput(h FileWriter, r *Request, pkt_id uint32) (resp_packet, error) { - writer, err := h.Filewrite(r) - if err != nil { return nil, syscall.EBADF } - _, err = writer.Write(r.data) +func fileput(h FileWriter, r *Request) (resp_packet, error) { + if r.put_writer == nil { + writer, err := h.Filewrite(r) + if err != nil { return nil, syscall.EBADF } + r.put_writer = writer + } + writer := r.put_writer + + _, err := writer.Write(r.data) if err != nil { return nil, err } return &sshFxpStatusPacket{ - ID: pkt_id, + ID: r.pkt_id, StatusError: StatusError{ Code: ssh_FX_OK, }}, nil } -func filecmd(h FileCmder, r *Request, pkt_id uint32) (resp_packet, error) { +func filecmd(h FileCmder, r *Request) (resp_packet, error) { err := h.Filecmd(r) if err != nil { return nil, err } return sshFxpStatusPacket{ - ID: pkt_id, + ID: r.pkt_id, StatusError: StatusError{ Code: ssh_FX_OK, }}, nil } -func fileinfo(h FileInfoer, r *Request, pkt_id uint32) (resp_packet, error) { +func fileinfo(h FileInfoer, r *Request) (resp_packet, error) { finfo, err := h.Fileinfo(r) if err != nil { return nil, err } switch r.Method { case "List": dirname := path.Base(r.Filepath) - ret := sshFxpNamePacket{ID: pkt_id} + ret := sshFxpNamePacket{ID: r.pkt_id} for _, fi := range finfo { ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{ Name: fi.Name(), @@ -112,7 +114,7 @@ func fileinfo(h FileInfoer, r *Request, pkt_id uint32) (resp_packet, error) { return nil, err } return &sshFxpStatResponse{ - ID: pkt_id, + ID: r.pkt_id, info: finfo[0], }, nil case "Readlink": @@ -121,7 +123,7 @@ func fileinfo(h FileInfoer, r *Request, pkt_id uint32) (resp_packet, error) { return nil, err } return sshFxpNamePacket{ - ID: pkt_id, + ID: r.pkt_id, NameAttrs: []sshFxpNameAttr{{ Name: finfo[0].Name(), LongName: finfo[0].Name(), @@ -139,36 +141,47 @@ func (r *Request) populate(p interface{}) { r.Method = "Setstat" r.Pflags = p.Flags r.Attrs = p.Attrs.([]byte) + r.pkt_id = p.id() case *sshFxpFsetstatPacket: r.Method = "Setstat" r.Pflags = p.Flags r.Attrs = p.Attrs.([]byte) + r.pkt_id = p.id() case *sshFxpRenamePacket: r.Method = "Rename" r.Target = p.Newpath + r.pkt_id = p.id() case *sshFxpSymlinkPacket: r.Method = "Symlink" r.Target = p.Linkpath + r.pkt_id = p.id() case *sshFxpReadPacket: r.Method = "Get" r.length = p.Len + r.pkt_id = p.id() case *sshFxpWritePacket: r.Method = "Put" r.data = p.Data r.length = p.Length + r.pkt_id = p.id() // below here method and path are all the data case *sshFxpReaddirPacket: r.Method = "List" + r.pkt_id = p.id() case *sshFxpStatPacket, *sshFxpLstatPacket, *sshFxpFstatPacket, *sshFxpRealpathPacket, *sshFxpRemovePacket: r.Method = "Stat" + r.pkt_id = p.(packet).id() case *sshFxpRmdirPacket: r.Method = "Rmdir" + r.pkt_id = p.id() case *sshFxpReadlinkPacket: r.Method = "Readlink" + r.pkt_id = p.id() // special cases case *sshFxpMkdirPacket: r.Method = "Mkdir" + r.pkt_id = p.id() //r.Attrs are ignored in ./packet.go } }