mirror of https://github.com/pkg/sftp.git
fix issue with leaking requests in handle cache
This commit is contained in:
parent
b6f2e2d29e
commit
f4432147d1
|
@ -7,6 +7,8 @@ import (
|
|||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var maxTxPacket uint32 = 1 << 15
|
||||
|
@ -109,32 +111,35 @@ func (rs *RequestServer) Serve() error {
|
|||
|
||||
func (rs *RequestServer) packetWorker() error {
|
||||
for pkt := range rs.pktChan {
|
||||
// fmt.Println("Incoming Packet: ", pkt, reflect.TypeOf(pkt))
|
||||
var handle string
|
||||
var rpkt responsePacket
|
||||
var err error
|
||||
switch pkt := pkt.(type) {
|
||||
case *sshFxInitPacket:
|
||||
rpkt = sshFxVersionPacket{sftpProtocolVersion, nil}
|
||||
case *sshFxpClosePacket:
|
||||
handle = pkt.getHandle()
|
||||
handle := pkt.getHandle()
|
||||
rs.closeRequest(handle)
|
||||
rpkt = statusFromError(pkt, nil)
|
||||
case *sshFxpRealpathPacket:
|
||||
rpkt = cleanPath(pkt)
|
||||
case isOpener:
|
||||
handle = rs.nextRequest(newRequest(pkt.getPath()))
|
||||
handle := rs.nextRequest(newRequest(pkt.getPath()))
|
||||
rpkt = sshFxpHandlePacket{pkt.id(), handle}
|
||||
case hasPath:
|
||||
handle = rs.nextRequest(newRequest(pkt.getPath()))
|
||||
rpkt = rs.request(handle, pkt)
|
||||
case hasHandle:
|
||||
handle = pkt.getHandle()
|
||||
rpkt = rs.request(handle, pkt)
|
||||
handle := pkt.getHandle()
|
||||
request, ok := rs.getRequest(handle)
|
||||
if !ok {
|
||||
rpkt = statusFromError(pkt, syscall.EBADF)
|
||||
} else {
|
||||
rpkt = rs.handle(request, pkt)
|
||||
}
|
||||
case hasPath:
|
||||
request := newRequest(pkt.getPath())
|
||||
rpkt = rs.handle(request, pkt)
|
||||
default:
|
||||
return errors.Errorf("unexpected packet type %T", pkt)
|
||||
}
|
||||
|
||||
// fmt.Println("Reply Packet: ", rpkt, reflect.TypeOf(rpkt))
|
||||
err = rs.sendPacket(rpkt)
|
||||
err := rs.sendPacket(rpkt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -159,20 +164,14 @@ func cleanPath(pkt *sshFxpRealpathPacket) responsePacket {
|
|||
}
|
||||
}
|
||||
|
||||
func (rs *RequestServer) request(handle string, pkt packet) responsePacket {
|
||||
var rpkt responsePacket
|
||||
var err error
|
||||
if request, ok := rs.getRequest(handle); ok {
|
||||
// called here to keep packet handling out of request for testing
|
||||
request.populate(pkt)
|
||||
// fmt.Println("Request Method: ", request.Method)
|
||||
rpkt, err = request.handle(rs.Handlers)
|
||||
if err != nil {
|
||||
err = errorAdapter(err)
|
||||
rpkt = statusFromError(pkt, err)
|
||||
}
|
||||
} else {
|
||||
rpkt = statusFromError(pkt, syscall.EBADF)
|
||||
func (rs *RequestServer) handle(request *Request, pkt packet) responsePacket {
|
||||
// called here to keep packet handling out of request for testing
|
||||
request.populate(pkt)
|
||||
// fmt.Println("Request Method: ", request.Method)
|
||||
rpkt, err := request.handle(rs.Handlers)
|
||||
if err != nil {
|
||||
err = errorAdapter(err)
|
||||
rpkt = statusFromError(pkt, err)
|
||||
}
|
||||
return rpkt
|
||||
}
|
||||
|
|
|
@ -61,6 +61,19 @@ func TestRequestCache(t *testing.T) {
|
|||
assert.Len(t, p.svr.openRequests, 0)
|
||||
}
|
||||
|
||||
func TestRequestCacheState(t *testing.T) {
|
||||
// test operation that uses open/close
|
||||
p := clientRequestServerPair(t)
|
||||
defer p.Close()
|
||||
_, err := putTestFile(p.cli, "/foo", "hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, p.svr.openRequests, 0)
|
||||
// test operation that doesn't open/close
|
||||
err = p.cli.Remove("/foo")
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, p.svr.openRequests, 0)
|
||||
}
|
||||
|
||||
func putTestFile(cli *Client, path, content string) (int, error) {
|
||||
w, err := cli.Create(path)
|
||||
if err == nil {
|
||||
|
|
Loading…
Reference in New Issue