mirror of https://github.com/pkg/sftp.git
factor out response struct
This commit is contained in:
parent
8b3c376b5a
commit
e307459f45
|
@ -1,6 +1,7 @@
|
|||
package sftp
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"sync"
|
||||
|
@ -121,6 +122,8 @@ func (rs *RequestServer) packetWorker() error {
|
|||
for pkt := range rs.pktChan {
|
||||
// handle packet specific pre-processing
|
||||
var handle string
|
||||
var rpkt encoding.BinaryMarshaler
|
||||
var err error
|
||||
switch pkt := pkt.(type) {
|
||||
case *sshFxInitPacket:
|
||||
err := rs.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil})
|
||||
|
@ -149,12 +152,13 @@ func (rs *RequestServer) packetWorker() error {
|
|||
}
|
||||
|
||||
request, ok := rs.getRequest(handle)
|
||||
if !ok { return rs.sendError(pkt, syscall.EBADF) }
|
||||
|
||||
if !ok { rpkt = statusFromError(pkt, syscall.EBADF) }
|
||||
request.populate(pkt)
|
||||
resp := request.handleRequest(rs.Handlers)
|
||||
if resp.err != nil { rs.sendError(resp.pkt, resp.err) }
|
||||
rs.sendPacket(resp.pkt)
|
||||
rpkt, err = request.handleRequest(rs.Handlers)
|
||||
if err != nil { rpkt = statusFromError(pkt, err) }
|
||||
|
||||
err = rs.sendPacket(rpkt)
|
||||
if err != nil { return err }
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
11
request.go
11
request.go
|
@ -7,12 +7,6 @@ import (
|
|||
"syscall"
|
||||
)
|
||||
|
||||
// response passed back to packet handling code
|
||||
type response struct {
|
||||
pkt resp_packet
|
||||
err error
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
// Get, Put, SetStat, Rename, Rmdir, Mkdir, Symlink, List, Stat, Readlink
|
||||
Method string
|
||||
|
@ -36,7 +30,7 @@ func newRequest(path string) *Request {
|
|||
}
|
||||
|
||||
// called from worker to handle packet/request
|
||||
func (r *Request) handleRequest(handlers Handlers) response {
|
||||
func (r *Request) handleRequest(handlers Handlers) (resp_packet, error) {
|
||||
var err error
|
||||
var rpkt resp_packet
|
||||
switch r.Method {
|
||||
|
@ -49,8 +43,7 @@ func (r *Request) handleRequest(handlers Handlers) response {
|
|||
case "List", "Stat", "Readlink":
|
||||
rpkt, err = fileinfo(handlers.FileInfo, r)
|
||||
}
|
||||
if err != nil { return response{nil, err} }
|
||||
return response{rpkt, nil}
|
||||
return rpkt, err
|
||||
}
|
||||
|
||||
// wrap FileReader handler
|
||||
|
|
|
@ -92,34 +92,34 @@ func TestGetMethod(t *testing.T) {
|
|||
request := testRequest("Get")
|
||||
// req.length is 4, so we test reads in 4 byte chunks
|
||||
for _, txt := range []string{"file-", "data."} {
|
||||
resp := request.handleRequest(handlers)
|
||||
assert.Nil(t, resp.err)
|
||||
pkt := resp.pkt.(*sshFxpDataPacket)
|
||||
assert.Equal(t, pkt.id(), uint32(1))
|
||||
assert.Equal(t, string(pkt.Data), txt)
|
||||
pkt, err := request.handleRequest(handlers)
|
||||
assert.Nil(t, err)
|
||||
dpkt := pkt.(*sshFxpDataPacket)
|
||||
assert.Equal(t, dpkt.id(), uint32(1))
|
||||
assert.Equal(t, string(dpkt.Data), txt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutMethod(t *testing.T) {
|
||||
handlers := newTestHandlers()
|
||||
request := testRequest("Put")
|
||||
resp := request.handleRequest(handlers)
|
||||
assert.Nil(t, resp.err)
|
||||
pkt, err := request.handleRequest(handlers)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, handlers.getOut().String(), "file-data.")
|
||||
statusOk(t, resp.pkt)
|
||||
statusOk(t, pkt)
|
||||
}
|
||||
|
||||
func TestCmdrMethod(t *testing.T) {
|
||||
handlers := newTestHandlers()
|
||||
request := testRequest("Mkdir")
|
||||
resp := request.handleRequest(handlers)
|
||||
assert.Nil(t, resp.err)
|
||||
statusOk(t, resp.pkt)
|
||||
pkt, err := request.handleRequest(handlers)
|
||||
assert.Nil(t, err)
|
||||
statusOk(t, pkt)
|
||||
|
||||
handlers.returnError()
|
||||
resp = request.handleRequest(handlers)
|
||||
assert.Nil(t, resp.pkt)
|
||||
assert.Equal(t, resp.err, testError)
|
||||
pkt, err = request.handleRequest(handlers)
|
||||
assert.Nil(t, pkt)
|
||||
assert.Equal(t, err, testError)
|
||||
}
|
||||
|
||||
func TestInfoListMethod(t *testing.T) { testInfoMethod(t, "List") }
|
||||
|
@ -127,19 +127,19 @@ func TestInfoReadlinkMethod(t *testing.T) { testInfoMethod(t, "Readlink") }
|
|||
func TestInfoStatMethod(t *testing.T) {
|
||||
handlers := newTestHandlers()
|
||||
request := testRequest("Stat")
|
||||
resp := request.handleRequest(handlers)
|
||||
assert.Nil(t, resp.err)
|
||||
pkt := resp.pkt.(*sshFxpStatResponse)
|
||||
assert.Equal(t, pkt.info.Name(), "request_test.go")
|
||||
pkt, err := request.handleRequest(handlers)
|
||||
assert.Nil(t, err)
|
||||
spkt := pkt.(*sshFxpStatResponse)
|
||||
assert.Equal(t, spkt.info.Name(), "request_test.go")
|
||||
}
|
||||
|
||||
func testInfoMethod(t *testing.T, method string) {
|
||||
handlers := newTestHandlers()
|
||||
request := testRequest(method)
|
||||
resp := request.handleRequest(handlers)
|
||||
assert.Nil(t, resp.err)
|
||||
pkt, ok := resp.pkt.(*sshFxpNamePacket)
|
||||
pkt, err := request.handleRequest(handlers)
|
||||
assert.Nil(t, err)
|
||||
npkt, ok := pkt.(*sshFxpNamePacket)
|
||||
assert.True(t, ok)
|
||||
assert.IsType(t, sshFxpNameAttr{}, pkt.NameAttrs[0])
|
||||
assert.Equal(t, pkt.NameAttrs[0].Name, "request_test.go")
|
||||
assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0])
|
||||
assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue