factor out response struct

This commit is contained in:
John Eikenberry 2016-07-12 16:50:59 -07:00
parent 8b3c376b5a
commit e307459f45
3 changed files with 34 additions and 37 deletions

View File

@ -1,6 +1,7 @@
package sftp
import (
"encoding"
"io"
"io/ioutil"
"sync"
@ -121,6 +122,8 @@ func (rs *RequestServer) packetWorker() error {
for pkt := range rs.pktChan {
// handle packet specific pre-processing
var handle string
var rpkt encoding.BinaryMarshaler
var err error
switch pkt := pkt.(type) {
case *sshFxInitPacket:
err := rs.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil})
@ -149,12 +152,13 @@ func (rs *RequestServer) packetWorker() error {
}
request, ok := rs.getRequest(handle)
if !ok { return rs.sendError(pkt, syscall.EBADF) }
if !ok { rpkt = statusFromError(pkt, syscall.EBADF) }
request.populate(pkt)
resp := request.handleRequest(rs.Handlers)
if resp.err != nil { rs.sendError(resp.pkt, resp.err) }
rs.sendPacket(resp.pkt)
rpkt, err = request.handleRequest(rs.Handlers)
if err != nil { rpkt = statusFromError(pkt, err) }
err = rs.sendPacket(rpkt)
if err != nil { return err }
}
return nil
}

View File

@ -7,12 +7,6 @@ import (
"syscall"
)
// response passed back to packet handling code
type response struct {
pkt resp_packet
err error
}
type Request struct {
// Get, Put, SetStat, Rename, Rmdir, Mkdir, Symlink, List, Stat, Readlink
Method string
@ -36,7 +30,7 @@ func newRequest(path string) *Request {
}
// called from worker to handle packet/request
func (r *Request) handleRequest(handlers Handlers) response {
func (r *Request) handleRequest(handlers Handlers) (resp_packet, error) {
var err error
var rpkt resp_packet
switch r.Method {
@ -49,8 +43,7 @@ func (r *Request) handleRequest(handlers Handlers) response {
case "List", "Stat", "Readlink":
rpkt, err = fileinfo(handlers.FileInfo, r)
}
if err != nil { return response{nil, err} }
return response{rpkt, nil}
return rpkt, err
}
// wrap FileReader handler

View File

@ -92,34 +92,34 @@ func TestGetMethod(t *testing.T) {
request := testRequest("Get")
// req.length is 4, so we test reads in 4 byte chunks
for _, txt := range []string{"file-", "data."} {
resp := request.handleRequest(handlers)
assert.Nil(t, resp.err)
pkt := resp.pkt.(*sshFxpDataPacket)
assert.Equal(t, pkt.id(), uint32(1))
assert.Equal(t, string(pkt.Data), txt)
pkt, err := request.handleRequest(handlers)
assert.Nil(t, err)
dpkt := pkt.(*sshFxpDataPacket)
assert.Equal(t, dpkt.id(), uint32(1))
assert.Equal(t, string(dpkt.Data), txt)
}
}
func TestPutMethod(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Put")
resp := request.handleRequest(handlers)
assert.Nil(t, resp.err)
pkt, err := request.handleRequest(handlers)
assert.Nil(t, err)
assert.Equal(t, handlers.getOut().String(), "file-data.")
statusOk(t, resp.pkt)
statusOk(t, pkt)
}
func TestCmdrMethod(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Mkdir")
resp := request.handleRequest(handlers)
assert.Nil(t, resp.err)
statusOk(t, resp.pkt)
pkt, err := request.handleRequest(handlers)
assert.Nil(t, err)
statusOk(t, pkt)
handlers.returnError()
resp = request.handleRequest(handlers)
assert.Nil(t, resp.pkt)
assert.Equal(t, resp.err, testError)
pkt, err = request.handleRequest(handlers)
assert.Nil(t, pkt)
assert.Equal(t, err, testError)
}
func TestInfoListMethod(t *testing.T) { testInfoMethod(t, "List") }
@ -127,19 +127,19 @@ func TestInfoReadlinkMethod(t *testing.T) { testInfoMethod(t, "Readlink") }
func TestInfoStatMethod(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Stat")
resp := request.handleRequest(handlers)
assert.Nil(t, resp.err)
pkt := resp.pkt.(*sshFxpStatResponse)
assert.Equal(t, pkt.info.Name(), "request_test.go")
pkt, err := request.handleRequest(handlers)
assert.Nil(t, err)
spkt := pkt.(*sshFxpStatResponse)
assert.Equal(t, spkt.info.Name(), "request_test.go")
}
func testInfoMethod(t *testing.T, method string) {
handlers := newTestHandlers()
request := testRequest(method)
resp := request.handleRequest(handlers)
assert.Nil(t, resp.err)
pkt, ok := resp.pkt.(*sshFxpNamePacket)
pkt, err := request.handleRequest(handlers)
assert.Nil(t, err)
npkt, ok := pkt.(*sshFxpNamePacket)
assert.True(t, ok)
assert.IsType(t, sshFxpNameAttr{}, pkt.NameAttrs[0])
assert.Equal(t, pkt.NameAttrs[0].Name, "request_test.go")
assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0])
assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go")
}