factor out response struct

This commit is contained in:
John Eikenberry 2016-07-12 16:50:59 -07:00
parent 8b3c376b5a
commit e307459f45
3 changed files with 34 additions and 37 deletions

View File

@ -1,6 +1,7 @@
package sftp package sftp
import ( import (
"encoding"
"io" "io"
"io/ioutil" "io/ioutil"
"sync" "sync"
@ -121,6 +122,8 @@ func (rs *RequestServer) packetWorker() error {
for pkt := range rs.pktChan { for pkt := range rs.pktChan {
// handle packet specific pre-processing // handle packet specific pre-processing
var handle string var handle string
var rpkt encoding.BinaryMarshaler
var err error
switch pkt := pkt.(type) { switch pkt := pkt.(type) {
case *sshFxInitPacket: case *sshFxInitPacket:
err := rs.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil}) err := rs.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil})
@ -149,12 +152,13 @@ func (rs *RequestServer) packetWorker() error {
} }
request, ok := rs.getRequest(handle) request, ok := rs.getRequest(handle)
if !ok { return rs.sendError(pkt, syscall.EBADF) } if !ok { rpkt = statusFromError(pkt, syscall.EBADF) }
request.populate(pkt) request.populate(pkt)
resp := request.handleRequest(rs.Handlers) rpkt, err = request.handleRequest(rs.Handlers)
if resp.err != nil { rs.sendError(resp.pkt, resp.err) } if err != nil { rpkt = statusFromError(pkt, err) }
rs.sendPacket(resp.pkt)
err = rs.sendPacket(rpkt)
if err != nil { return err }
} }
return nil return nil
} }

View File

@ -7,12 +7,6 @@ import (
"syscall" "syscall"
) )
// response passed back to packet handling code
type response struct {
pkt resp_packet
err error
}
type Request struct { type Request struct {
// Get, Put, SetStat, Rename, Rmdir, Mkdir, Symlink, List, Stat, Readlink // Get, Put, SetStat, Rename, Rmdir, Mkdir, Symlink, List, Stat, Readlink
Method string Method string
@ -36,7 +30,7 @@ func newRequest(path string) *Request {
} }
// called from worker to handle packet/request // called from worker to handle packet/request
func (r *Request) handleRequest(handlers Handlers) response { func (r *Request) handleRequest(handlers Handlers) (resp_packet, error) {
var err error var err error
var rpkt resp_packet var rpkt resp_packet
switch r.Method { switch r.Method {
@ -49,8 +43,7 @@ func (r *Request) handleRequest(handlers Handlers) response {
case "List", "Stat", "Readlink": case "List", "Stat", "Readlink":
rpkt, err = fileinfo(handlers.FileInfo, r) rpkt, err = fileinfo(handlers.FileInfo, r)
} }
if err != nil { return response{nil, err} } return rpkt, err
return response{rpkt, nil}
} }
// wrap FileReader handler // wrap FileReader handler

View File

@ -92,34 +92,34 @@ func TestGetMethod(t *testing.T) {
request := testRequest("Get") request := testRequest("Get")
// req.length is 4, so we test reads in 4 byte chunks // req.length is 4, so we test reads in 4 byte chunks
for _, txt := range []string{"file-", "data."} { for _, txt := range []string{"file-", "data."} {
resp := request.handleRequest(handlers) pkt, err := request.handleRequest(handlers)
assert.Nil(t, resp.err) assert.Nil(t, err)
pkt := resp.pkt.(*sshFxpDataPacket) dpkt := pkt.(*sshFxpDataPacket)
assert.Equal(t, pkt.id(), uint32(1)) assert.Equal(t, dpkt.id(), uint32(1))
assert.Equal(t, string(pkt.Data), txt) assert.Equal(t, string(dpkt.Data), txt)
} }
} }
func TestPutMethod(t *testing.T) { func TestPutMethod(t *testing.T) {
handlers := newTestHandlers() handlers := newTestHandlers()
request := testRequest("Put") request := testRequest("Put")
resp := request.handleRequest(handlers) pkt, err := request.handleRequest(handlers)
assert.Nil(t, resp.err) assert.Nil(t, err)
assert.Equal(t, handlers.getOut().String(), "file-data.") assert.Equal(t, handlers.getOut().String(), "file-data.")
statusOk(t, resp.pkt) statusOk(t, pkt)
} }
func TestCmdrMethod(t *testing.T) { func TestCmdrMethod(t *testing.T) {
handlers := newTestHandlers() handlers := newTestHandlers()
request := testRequest("Mkdir") request := testRequest("Mkdir")
resp := request.handleRequest(handlers) pkt, err := request.handleRequest(handlers)
assert.Nil(t, resp.err) assert.Nil(t, err)
statusOk(t, resp.pkt) statusOk(t, pkt)
handlers.returnError() handlers.returnError()
resp = request.handleRequest(handlers) pkt, err = request.handleRequest(handlers)
assert.Nil(t, resp.pkt) assert.Nil(t, pkt)
assert.Equal(t, resp.err, testError) assert.Equal(t, err, testError)
} }
func TestInfoListMethod(t *testing.T) { testInfoMethod(t, "List") } func TestInfoListMethod(t *testing.T) { testInfoMethod(t, "List") }
@ -127,19 +127,19 @@ func TestInfoReadlinkMethod(t *testing.T) { testInfoMethod(t, "Readlink") }
func TestInfoStatMethod(t *testing.T) { func TestInfoStatMethod(t *testing.T) {
handlers := newTestHandlers() handlers := newTestHandlers()
request := testRequest("Stat") request := testRequest("Stat")
resp := request.handleRequest(handlers) pkt, err := request.handleRequest(handlers)
assert.Nil(t, resp.err) assert.Nil(t, err)
pkt := resp.pkt.(*sshFxpStatResponse) spkt := pkt.(*sshFxpStatResponse)
assert.Equal(t, pkt.info.Name(), "request_test.go") assert.Equal(t, spkt.info.Name(), "request_test.go")
} }
func testInfoMethod(t *testing.T, method string) { func testInfoMethod(t *testing.T, method string) {
handlers := newTestHandlers() handlers := newTestHandlers()
request := testRequest(method) request := testRequest(method)
resp := request.handleRequest(handlers) pkt, err := request.handleRequest(handlers)
assert.Nil(t, resp.err) assert.Nil(t, err)
pkt, ok := resp.pkt.(*sshFxpNamePacket) npkt, ok := pkt.(*sshFxpNamePacket)
assert.True(t, ok) assert.True(t, ok)
assert.IsType(t, sshFxpNameAttr{}, pkt.NameAttrs[0]) assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0])
assert.Equal(t, pkt.NameAttrs[0].Name, "request_test.go") assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go")
} }