mirror of https://github.com/pkg/sftp.git
request server: add WithStartDirectory option
This commit is contained in:
parent
aa9a37d639
commit
98b35dcfc3
|
@ -27,6 +27,8 @@ type RequestServer struct {
|
||||||
*serverConn
|
*serverConn
|
||||||
pktMgr *packetManager
|
pktMgr *packetManager
|
||||||
|
|
||||||
|
startDirectory string
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
handleCount int
|
handleCount int
|
||||||
openRequests map[string]*Request
|
openRequests map[string]*Request
|
||||||
|
@ -47,6 +49,14 @@ func WithRSAllocator() RequestServerOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithStartDirectory sets a start directory to use as base for relative paths.
|
||||||
|
// If unset the default is "/"
|
||||||
|
func WithStartDirectory(startDirectory string) RequestServerOption {
|
||||||
|
return func(rs *RequestServer) {
|
||||||
|
rs.startDirectory = cleanPath(startDirectory)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewRequestServer creates/allocates/returns new RequestServer.
|
// NewRequestServer creates/allocates/returns new RequestServer.
|
||||||
// Normally there will be one server per user-session.
|
// Normally there will be one server per user-session.
|
||||||
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer {
|
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer {
|
||||||
|
@ -62,6 +72,8 @@ func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServ
|
||||||
serverConn: svrConn,
|
serverConn: svrConn,
|
||||||
pktMgr: newPktMgr(svrConn),
|
pktMgr: newPktMgr(svrConn),
|
||||||
|
|
||||||
|
startDirectory: "/",
|
||||||
|
|
||||||
openRequests: make(map[string]*Request),
|
openRequests: make(map[string]*Request),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,11 +222,11 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
|
||||||
if realPather, ok := rs.Handlers.FileList.(RealPathFileLister); ok {
|
if realPather, ok := rs.Handlers.FileList.(RealPathFileLister); ok {
|
||||||
realPath = realPather.RealPath(pkt.getPath())
|
realPath = realPather.RealPath(pkt.getPath())
|
||||||
} else {
|
} else {
|
||||||
realPath = cleanPath(pkt.getPath())
|
realPath = cleanPathWithBase(rs.startDirectory, pkt.getPath())
|
||||||
}
|
}
|
||||||
rpkt = cleanPacketPath(pkt, realPath)
|
rpkt = cleanPacketPath(pkt, realPath)
|
||||||
case *sshFxpOpendirPacket:
|
case *sshFxpOpendirPacket:
|
||||||
request := requestFromPacket(ctx, pkt)
|
request := requestFromPacket(ctx, pkt, rs.startDirectory)
|
||||||
handle := rs.nextRequest(request)
|
handle := rs.nextRequest(request)
|
||||||
rpkt = request.opendir(rs.Handlers, pkt)
|
rpkt = request.opendir(rs.Handlers, pkt)
|
||||||
if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
|
if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
|
||||||
|
@ -222,7 +234,7 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
|
||||||
rs.closeRequest(handle)
|
rs.closeRequest(handle)
|
||||||
}
|
}
|
||||||
case *sshFxpOpenPacket:
|
case *sshFxpOpenPacket:
|
||||||
request := requestFromPacket(ctx, pkt)
|
request := requestFromPacket(ctx, pkt, rs.startDirectory)
|
||||||
handle := rs.nextRequest(request)
|
handle := rs.nextRequest(request)
|
||||||
rpkt = request.open(rs.Handlers, pkt)
|
rpkt = request.open(rs.Handlers, pkt)
|
||||||
if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
|
if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
|
||||||
|
@ -235,7 +247,10 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
|
||||||
if !ok {
|
if !ok {
|
||||||
rpkt = statusFromError(pkt.ID, EBADF)
|
rpkt = statusFromError(pkt.ID, EBADF)
|
||||||
} else {
|
} else {
|
||||||
request = NewRequest("Stat", request.Filepath)
|
request = &Request{
|
||||||
|
Method: "Stat",
|
||||||
|
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
|
||||||
|
}
|
||||||
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
|
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
|
||||||
}
|
}
|
||||||
case *sshFxpFsetstatPacket:
|
case *sshFxpFsetstatPacket:
|
||||||
|
@ -244,15 +259,24 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
|
||||||
if !ok {
|
if !ok {
|
||||||
rpkt = statusFromError(pkt.ID, EBADF)
|
rpkt = statusFromError(pkt.ID, EBADF)
|
||||||
} else {
|
} else {
|
||||||
request = NewRequest("Setstat", request.Filepath)
|
request = &Request{
|
||||||
|
Method: "Setstat",
|
||||||
|
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
|
||||||
|
}
|
||||||
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
|
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
|
||||||
}
|
}
|
||||||
case *sshFxpExtendedPacketPosixRename:
|
case *sshFxpExtendedPacketPosixRename:
|
||||||
request := NewRequest("PosixRename", pkt.Oldpath)
|
request := &Request{
|
||||||
request.Target = pkt.Newpath
|
Method: "PosixRename",
|
||||||
|
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Oldpath),
|
||||||
|
Target: cleanPathWithBase(rs.startDirectory, pkt.Newpath),
|
||||||
|
}
|
||||||
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
|
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
|
||||||
case *sshFxpExtendedPacketStatVFS:
|
case *sshFxpExtendedPacketStatVFS:
|
||||||
request := NewRequest("StatVFS", pkt.Path)
|
request := &Request{
|
||||||
|
Method: "StatVFS",
|
||||||
|
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Path),
|
||||||
|
}
|
||||||
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
|
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
|
||||||
case hasHandle:
|
case hasHandle:
|
||||||
handle := pkt.getHandle()
|
handle := pkt.getHandle()
|
||||||
|
@ -263,7 +287,7 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
|
||||||
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
|
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
|
||||||
}
|
}
|
||||||
case hasPath:
|
case hasPath:
|
||||||
request := requestFromPacket(ctx, pkt)
|
request := requestFromPacket(ctx, pkt, rs.startDirectory)
|
||||||
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
|
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
|
||||||
request.close()
|
request.close()
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -37,7 +37,7 @@ func (cs csPair) testHandler() *root {
|
||||||
|
|
||||||
const sock = "/tmp/rstest.sock"
|
const sock = "/tmp/rstest.sock"
|
||||||
|
|
||||||
func clientRequestServerPair(t *testing.T) *csPair {
|
func clientRequestServerPair(t *testing.T, options ...RequestServerOption) *csPair {
|
||||||
skipIfWindows(t)
|
skipIfWindows(t)
|
||||||
skipIfPlan9(t)
|
skipIfPlan9(t)
|
||||||
|
|
||||||
|
@ -62,7 +62,6 @@ func clientRequestServerPair(t *testing.T) *csPair {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
handlers := InMemHandler()
|
handlers := InMemHandler()
|
||||||
var options []RequestServerOption
|
|
||||||
if *testAllocator {
|
if *testAllocator {
|
||||||
options = append(options, WithRSAllocator())
|
options = append(options, WithRSAllocator())
|
||||||
}
|
}
|
||||||
|
@ -781,6 +780,37 @@ func TestRequestStatVFSError(t *testing.T) {
|
||||||
checkRequestServerAllocator(t, p)
|
checkRequestServerAllocator(t, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRequestStartDirOption(t *testing.T) {
|
||||||
|
startDir := "/start/dir"
|
||||||
|
p := clientRequestServerPair(t, WithStartDirectory(startDir))
|
||||||
|
defer p.Close()
|
||||||
|
|
||||||
|
// create the start directory
|
||||||
|
err := p.cli.MkdirAll(startDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// the working directory must be the defined start directory
|
||||||
|
wd, err := p.cli.Getwd()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, startDir, wd)
|
||||||
|
// upload a file using a relative path, it must be uploaded to the start directory
|
||||||
|
fileName := "file.txt"
|
||||||
|
_, err = putTestFile(p.cli, fileName, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
// we must be able to stat the file using both a relative and an absolute path
|
||||||
|
for _, filePath := range []string{fileName, path.Join(startDir, fileName)} {
|
||||||
|
fi, err := p.cli.Stat(filePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, fileName, fi.Name())
|
||||||
|
}
|
||||||
|
// list dir contents using a relative path
|
||||||
|
entries, err := p.cli.ReadDir(".")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, entries, 1)
|
||||||
|
// delete the file using a relative path
|
||||||
|
err = p.cli.Remove(fileName)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func TestCleanDisconnect(t *testing.T) {
|
func TestCleanDisconnect(t *testing.T) {
|
||||||
p := clientRequestServerPair(t)
|
p := clientRequestServerPair(t)
|
||||||
defer p.Close()
|
defer p.Close()
|
||||||
|
@ -831,6 +861,7 @@ func TestRealPath(t *testing.T) {
|
||||||
func TestCleanPath(t *testing.T) {
|
func TestCleanPath(t *testing.T) {
|
||||||
assert.Equal(t, "/", cleanPath("/"))
|
assert.Equal(t, "/", cleanPath("/"))
|
||||||
assert.Equal(t, "/", cleanPath("."))
|
assert.Equal(t, "/", cleanPath("."))
|
||||||
|
assert.Equal(t, "/", cleanPath(""))
|
||||||
assert.Equal(t, "/", cleanPath("/."))
|
assert.Equal(t, "/", cleanPath("/."))
|
||||||
assert.Equal(t, "/", cleanPath("/a/.."))
|
assert.Equal(t, "/", cleanPath("/a/.."))
|
||||||
assert.Equal(t, "/a/c", cleanPath("/a/b/../c"))
|
assert.Equal(t, "/a/c", cleanPath("/a/b/../c"))
|
||||||
|
|
14
request.go
14
request.go
|
@ -168,9 +168,11 @@ func (r *Request) copy() *Request {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New Request initialized based on packet data
|
// New Request initialized based on packet data
|
||||||
func requestFromPacket(ctx context.Context, pkt hasPath) *Request {
|
func requestFromPacket(ctx context.Context, pkt hasPath, baseDir string) *Request {
|
||||||
method := requestMethod(pkt)
|
request := &Request{
|
||||||
request := NewRequest(method, pkt.getPath())
|
Method: requestMethod(pkt),
|
||||||
|
Filepath: cleanPathWithBase(baseDir, pkt.getPath()),
|
||||||
|
}
|
||||||
request.ctx, request.cancelCtx = context.WithCancel(ctx)
|
request.ctx, request.cancelCtx = context.WithCancel(ctx)
|
||||||
|
|
||||||
switch p := pkt.(type) {
|
switch p := pkt.(type) {
|
||||||
|
@ -180,13 +182,13 @@ func requestFromPacket(ctx context.Context, pkt hasPath) *Request {
|
||||||
request.Flags = p.Flags
|
request.Flags = p.Flags
|
||||||
request.Attrs = p.Attrs.([]byte)
|
request.Attrs = p.Attrs.([]byte)
|
||||||
case *sshFxpRenamePacket:
|
case *sshFxpRenamePacket:
|
||||||
request.Target = cleanPath(p.Newpath)
|
request.Target = cleanPathWithBase(baseDir, p.Newpath)
|
||||||
case *sshFxpSymlinkPacket:
|
case *sshFxpSymlinkPacket:
|
||||||
// NOTE: given a POSIX compliant signature: symlink(target, linkpath string)
|
// NOTE: given a POSIX compliant signature: symlink(target, linkpath string)
|
||||||
// this makes Request.Target the linkpath, and Request.Filepath the target.
|
// this makes Request.Target the linkpath, and Request.Filepath the target.
|
||||||
request.Target = cleanPath(p.Linkpath)
|
request.Target = cleanPathWithBase(baseDir, p.Linkpath)
|
||||||
case *sshFxpExtendedPacketHardlink:
|
case *sshFxpExtendedPacketHardlink:
|
||||||
request.Target = cleanPath(p.Newpath)
|
request.Target = cleanPathWithBase(baseDir, p.Newpath)
|
||||||
}
|
}
|
||||||
return request
|
return request
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue