mirror of https://github.com/pkg/sftp.git
factor out response struct
This commit is contained in:
parent
8b3c376b5a
commit
e307459f45
|
@ -1,6 +1,7 @@
|
||||||
package sftp
|
package sftp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -121,6 +122,8 @@ func (rs *RequestServer) packetWorker() error {
|
||||||
for pkt := range rs.pktChan {
|
for pkt := range rs.pktChan {
|
||||||
// handle packet specific pre-processing
|
// handle packet specific pre-processing
|
||||||
var handle string
|
var handle string
|
||||||
|
var rpkt encoding.BinaryMarshaler
|
||||||
|
var err error
|
||||||
switch pkt := pkt.(type) {
|
switch pkt := pkt.(type) {
|
||||||
case *sshFxInitPacket:
|
case *sshFxInitPacket:
|
||||||
err := rs.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil})
|
err := rs.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil})
|
||||||
|
@ -149,12 +152,13 @@ func (rs *RequestServer) packetWorker() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
request, ok := rs.getRequest(handle)
|
request, ok := rs.getRequest(handle)
|
||||||
if !ok { return rs.sendError(pkt, syscall.EBADF) }
|
if !ok { rpkt = statusFromError(pkt, syscall.EBADF) }
|
||||||
|
|
||||||
request.populate(pkt)
|
request.populate(pkt)
|
||||||
resp := request.handleRequest(rs.Handlers)
|
rpkt, err = request.handleRequest(rs.Handlers)
|
||||||
if resp.err != nil { rs.sendError(resp.pkt, resp.err) }
|
if err != nil { rpkt = statusFromError(pkt, err) }
|
||||||
rs.sendPacket(resp.pkt)
|
|
||||||
|
err = rs.sendPacket(rpkt)
|
||||||
|
if err != nil { return err }
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
11
request.go
11
request.go
|
@ -7,12 +7,6 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
// response passed back to packet handling code
|
|
||||||
type response struct {
|
|
||||||
pkt resp_packet
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
// Get, Put, SetStat, Rename, Rmdir, Mkdir, Symlink, List, Stat, Readlink
|
// Get, Put, SetStat, Rename, Rmdir, Mkdir, Symlink, List, Stat, Readlink
|
||||||
Method string
|
Method string
|
||||||
|
@ -36,7 +30,7 @@ func newRequest(path string) *Request {
|
||||||
}
|
}
|
||||||
|
|
||||||
// called from worker to handle packet/request
|
// called from worker to handle packet/request
|
||||||
func (r *Request) handleRequest(handlers Handlers) response {
|
func (r *Request) handleRequest(handlers Handlers) (resp_packet, error) {
|
||||||
var err error
|
var err error
|
||||||
var rpkt resp_packet
|
var rpkt resp_packet
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
|
@ -49,8 +43,7 @@ func (r *Request) handleRequest(handlers Handlers) response {
|
||||||
case "List", "Stat", "Readlink":
|
case "List", "Stat", "Readlink":
|
||||||
rpkt, err = fileinfo(handlers.FileInfo, r)
|
rpkt, err = fileinfo(handlers.FileInfo, r)
|
||||||
}
|
}
|
||||||
if err != nil { return response{nil, err} }
|
return rpkt, err
|
||||||
return response{rpkt, nil}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// wrap FileReader handler
|
// wrap FileReader handler
|
||||||
|
|
|
@ -92,34 +92,34 @@ func TestGetMethod(t *testing.T) {
|
||||||
request := testRequest("Get")
|
request := testRequest("Get")
|
||||||
// req.length is 4, so we test reads in 4 byte chunks
|
// req.length is 4, so we test reads in 4 byte chunks
|
||||||
for _, txt := range []string{"file-", "data."} {
|
for _, txt := range []string{"file-", "data."} {
|
||||||
resp := request.handleRequest(handlers)
|
pkt, err := request.handleRequest(handlers)
|
||||||
assert.Nil(t, resp.err)
|
assert.Nil(t, err)
|
||||||
pkt := resp.pkt.(*sshFxpDataPacket)
|
dpkt := pkt.(*sshFxpDataPacket)
|
||||||
assert.Equal(t, pkt.id(), uint32(1))
|
assert.Equal(t, dpkt.id(), uint32(1))
|
||||||
assert.Equal(t, string(pkt.Data), txt)
|
assert.Equal(t, string(dpkt.Data), txt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPutMethod(t *testing.T) {
|
func TestPutMethod(t *testing.T) {
|
||||||
handlers := newTestHandlers()
|
handlers := newTestHandlers()
|
||||||
request := testRequest("Put")
|
request := testRequest("Put")
|
||||||
resp := request.handleRequest(handlers)
|
pkt, err := request.handleRequest(handlers)
|
||||||
assert.Nil(t, resp.err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, handlers.getOut().String(), "file-data.")
|
assert.Equal(t, handlers.getOut().String(), "file-data.")
|
||||||
statusOk(t, resp.pkt)
|
statusOk(t, pkt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCmdrMethod(t *testing.T) {
|
func TestCmdrMethod(t *testing.T) {
|
||||||
handlers := newTestHandlers()
|
handlers := newTestHandlers()
|
||||||
request := testRequest("Mkdir")
|
request := testRequest("Mkdir")
|
||||||
resp := request.handleRequest(handlers)
|
pkt, err := request.handleRequest(handlers)
|
||||||
assert.Nil(t, resp.err)
|
assert.Nil(t, err)
|
||||||
statusOk(t, resp.pkt)
|
statusOk(t, pkt)
|
||||||
|
|
||||||
handlers.returnError()
|
handlers.returnError()
|
||||||
resp = request.handleRequest(handlers)
|
pkt, err = request.handleRequest(handlers)
|
||||||
assert.Nil(t, resp.pkt)
|
assert.Nil(t, pkt)
|
||||||
assert.Equal(t, resp.err, testError)
|
assert.Equal(t, err, testError)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInfoListMethod(t *testing.T) { testInfoMethod(t, "List") }
|
func TestInfoListMethod(t *testing.T) { testInfoMethod(t, "List") }
|
||||||
|
@ -127,19 +127,19 @@ func TestInfoReadlinkMethod(t *testing.T) { testInfoMethod(t, "Readlink") }
|
||||||
func TestInfoStatMethod(t *testing.T) {
|
func TestInfoStatMethod(t *testing.T) {
|
||||||
handlers := newTestHandlers()
|
handlers := newTestHandlers()
|
||||||
request := testRequest("Stat")
|
request := testRequest("Stat")
|
||||||
resp := request.handleRequest(handlers)
|
pkt, err := request.handleRequest(handlers)
|
||||||
assert.Nil(t, resp.err)
|
assert.Nil(t, err)
|
||||||
pkt := resp.pkt.(*sshFxpStatResponse)
|
spkt := pkt.(*sshFxpStatResponse)
|
||||||
assert.Equal(t, pkt.info.Name(), "request_test.go")
|
assert.Equal(t, spkt.info.Name(), "request_test.go")
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInfoMethod(t *testing.T, method string) {
|
func testInfoMethod(t *testing.T, method string) {
|
||||||
handlers := newTestHandlers()
|
handlers := newTestHandlers()
|
||||||
request := testRequest(method)
|
request := testRequest(method)
|
||||||
resp := request.handleRequest(handlers)
|
pkt, err := request.handleRequest(handlers)
|
||||||
assert.Nil(t, resp.err)
|
assert.Nil(t, err)
|
||||||
pkt, ok := resp.pkt.(*sshFxpNamePacket)
|
npkt, ok := pkt.(*sshFxpNamePacket)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.IsType(t, sshFxpNameAttr{}, pkt.NameAttrs[0])
|
assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0])
|
||||||
assert.Equal(t, pkt.NameAttrs[0].Name, "request_test.go")
|
assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue