diff --git a/request-server.go b/request-server.go index b0405cb..fc64cac 100644 --- a/request-server.go +++ b/request-server.go @@ -1,6 +1,7 @@ package sftp import ( + "encoding" "io" "io/ioutil" "sync" @@ -121,6 +122,8 @@ func (rs *RequestServer) packetWorker() error { for pkt := range rs.pktChan { // handle packet specific pre-processing var handle string + var rpkt encoding.BinaryMarshaler + var err error switch pkt := pkt.(type) { case *sshFxInitPacket: err := rs.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil}) @@ -149,12 +152,13 @@ func (rs *RequestServer) packetWorker() error { } request, ok := rs.getRequest(handle) - if !ok { return rs.sendError(pkt, syscall.EBADF) } - + if !ok { rpkt = statusFromError(pkt, syscall.EBADF) } request.populate(pkt) - resp := request.handleRequest(rs.Handlers) - if resp.err != nil { rs.sendError(resp.pkt, resp.err) } - rs.sendPacket(resp.pkt) + rpkt, err = request.handleRequest(rs.Handlers) + if err != nil { rpkt = statusFromError(pkt, err) } + + err = rs.sendPacket(rpkt) + if err != nil { return err } } return nil } diff --git a/request.go b/request.go index 64a7f10..621644f 100644 --- a/request.go +++ b/request.go @@ -7,12 +7,6 @@ import ( "syscall" ) -// response passed back to packet handling code -type response struct { - pkt resp_packet - err error -} - type Request struct { // Get, Put, SetStat, Rename, Rmdir, Mkdir, Symlink, List, Stat, Readlink Method string @@ -36,7 +30,7 @@ func newRequest(path string) *Request { } // called from worker to handle packet/request -func (r *Request) handleRequest(handlers Handlers) response { +func (r *Request) handleRequest(handlers Handlers) (resp_packet, error) { var err error var rpkt resp_packet switch r.Method { @@ -49,8 +43,7 @@ func (r *Request) handleRequest(handlers Handlers) response { case "List", "Stat", "Readlink": rpkt, err = fileinfo(handlers.FileInfo, r) } - if err != nil { return response{nil, err} } - return response{rpkt, nil} + return rpkt, err } // wrap FileReader handler diff --git a/request_test.go b/request_test.go index d621ba3..0f78dd6 100644 --- a/request_test.go +++ b/request_test.go @@ -92,34 +92,34 @@ func TestGetMethod(t *testing.T) { request := testRequest("Get") // req.length is 4, so we test reads in 4 byte chunks for _, txt := range []string{"file-", "data."} { - resp := request.handleRequest(handlers) - assert.Nil(t, resp.err) - pkt := resp.pkt.(*sshFxpDataPacket) - assert.Equal(t, pkt.id(), uint32(1)) - assert.Equal(t, string(pkt.Data), txt) + pkt, err := request.handleRequest(handlers) + assert.Nil(t, err) + dpkt := pkt.(*sshFxpDataPacket) + assert.Equal(t, dpkt.id(), uint32(1)) + assert.Equal(t, string(dpkt.Data), txt) } } func TestPutMethod(t *testing.T) { handlers := newTestHandlers() request := testRequest("Put") - resp := request.handleRequest(handlers) - assert.Nil(t, resp.err) + pkt, err := request.handleRequest(handlers) + assert.Nil(t, err) assert.Equal(t, handlers.getOut().String(), "file-data.") - statusOk(t, resp.pkt) + statusOk(t, pkt) } func TestCmdrMethod(t *testing.T) { handlers := newTestHandlers() request := testRequest("Mkdir") - resp := request.handleRequest(handlers) - assert.Nil(t, resp.err) - statusOk(t, resp.pkt) + pkt, err := request.handleRequest(handlers) + assert.Nil(t, err) + statusOk(t, pkt) handlers.returnError() - resp = request.handleRequest(handlers) - assert.Nil(t, resp.pkt) - assert.Equal(t, resp.err, testError) + pkt, err = request.handleRequest(handlers) + assert.Nil(t, pkt) + assert.Equal(t, err, testError) } func TestInfoListMethod(t *testing.T) { testInfoMethod(t, "List") } @@ -127,19 +127,19 @@ func TestInfoReadlinkMethod(t *testing.T) { testInfoMethod(t, "Readlink") } func TestInfoStatMethod(t *testing.T) { handlers := newTestHandlers() request := testRequest("Stat") - resp := request.handleRequest(handlers) - assert.Nil(t, resp.err) - pkt := resp.pkt.(*sshFxpStatResponse) - assert.Equal(t, pkt.info.Name(), "request_test.go") + pkt, err := request.handleRequest(handlers) + assert.Nil(t, err) + spkt := pkt.(*sshFxpStatResponse) + assert.Equal(t, spkt.info.Name(), "request_test.go") } func testInfoMethod(t *testing.T, method string) { handlers := newTestHandlers() request := testRequest(method) - resp := request.handleRequest(handlers) - assert.Nil(t, resp.err) - pkt, ok := resp.pkt.(*sshFxpNamePacket) + pkt, err := request.handleRequest(handlers) + assert.Nil(t, err) + npkt, ok := pkt.(*sshFxpNamePacket) assert.True(t, ok) - assert.IsType(t, sshFxpNameAttr{}, pkt.NameAttrs[0]) - assert.Equal(t, pkt.NameAttrs[0].Name, "request_test.go") + assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0]) + assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go") }