request server: add WithStartDirectory option

This commit is contained in:
Nicola Murino 2022-03-02 18:33:41 +01:00
parent aa9a37d639
commit 98b35dcfc3
3 changed files with 74 additions and 17 deletions

View File

@ -27,6 +27,8 @@ type RequestServer struct {
*serverConn *serverConn
pktMgr *packetManager pktMgr *packetManager
startDirectory string
mu sync.RWMutex mu sync.RWMutex
handleCount int handleCount int
openRequests map[string]*Request openRequests map[string]*Request
@ -47,6 +49,14 @@ func WithRSAllocator() RequestServerOption {
} }
} }
// WithStartDirectory sets a start directory to use as base for relative paths.
// If unset the default is "/"
func WithStartDirectory(startDirectory string) RequestServerOption {
return func(rs *RequestServer) {
rs.startDirectory = cleanPath(startDirectory)
}
}
// NewRequestServer creates/allocates/returns new RequestServer. // NewRequestServer creates/allocates/returns new RequestServer.
// Normally there will be one server per user-session. // Normally there will be one server per user-session.
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer { func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer {
@ -62,6 +72,8 @@ func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServ
serverConn: svrConn, serverConn: svrConn,
pktMgr: newPktMgr(svrConn), pktMgr: newPktMgr(svrConn),
startDirectory: "/",
openRequests: make(map[string]*Request), openRequests: make(map[string]*Request),
} }
@ -210,11 +222,11 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
if realPather, ok := rs.Handlers.FileList.(RealPathFileLister); ok { if realPather, ok := rs.Handlers.FileList.(RealPathFileLister); ok {
realPath = realPather.RealPath(pkt.getPath()) realPath = realPather.RealPath(pkt.getPath())
} else { } else {
realPath = cleanPath(pkt.getPath()) realPath = cleanPathWithBase(rs.startDirectory, pkt.getPath())
} }
rpkt = cleanPacketPath(pkt, realPath) rpkt = cleanPacketPath(pkt, realPath)
case *sshFxpOpendirPacket: case *sshFxpOpendirPacket:
request := requestFromPacket(ctx, pkt) request := requestFromPacket(ctx, pkt, rs.startDirectory)
handle := rs.nextRequest(request) handle := rs.nextRequest(request)
rpkt = request.opendir(rs.Handlers, pkt) rpkt = request.opendir(rs.Handlers, pkt)
if _, ok := rpkt.(*sshFxpHandlePacket); !ok { if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
@ -222,7 +234,7 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
rs.closeRequest(handle) rs.closeRequest(handle)
} }
case *sshFxpOpenPacket: case *sshFxpOpenPacket:
request := requestFromPacket(ctx, pkt) request := requestFromPacket(ctx, pkt, rs.startDirectory)
handle := rs.nextRequest(request) handle := rs.nextRequest(request)
rpkt = request.open(rs.Handlers, pkt) rpkt = request.open(rs.Handlers, pkt)
if _, ok := rpkt.(*sshFxpHandlePacket); !ok { if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
@ -235,7 +247,10 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
if !ok { if !ok {
rpkt = statusFromError(pkt.ID, EBADF) rpkt = statusFromError(pkt.ID, EBADF)
} else { } else {
request = NewRequest("Stat", request.Filepath) request = &Request{
Method: "Stat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
} }
case *sshFxpFsetstatPacket: case *sshFxpFsetstatPacket:
@ -244,15 +259,24 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
if !ok { if !ok {
rpkt = statusFromError(pkt.ID, EBADF) rpkt = statusFromError(pkt.ID, EBADF)
} else { } else {
request = NewRequest("Setstat", request.Filepath) request = &Request{
Method: "Setstat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
} }
case *sshFxpExtendedPacketPosixRename: case *sshFxpExtendedPacketPosixRename:
request := NewRequest("PosixRename", pkt.Oldpath) request := &Request{
request.Target = pkt.Newpath Method: "PosixRename",
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Oldpath),
Target: cleanPathWithBase(rs.startDirectory, pkt.Newpath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
case *sshFxpExtendedPacketStatVFS: case *sshFxpExtendedPacketStatVFS:
request := NewRequest("StatVFS", pkt.Path) request := &Request{
Method: "StatVFS",
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Path),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
case hasHandle: case hasHandle:
handle := pkt.getHandle() handle := pkt.getHandle()
@ -263,7 +287,7 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
} }
case hasPath: case hasPath:
request := requestFromPacket(ctx, pkt) request := requestFromPacket(ctx, pkt, rs.startDirectory)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
request.close() request.close()
default: default:

View File

@ -37,7 +37,7 @@ func (cs csPair) testHandler() *root {
const sock = "/tmp/rstest.sock" const sock = "/tmp/rstest.sock"
func clientRequestServerPair(t *testing.T) *csPair { func clientRequestServerPair(t *testing.T, options ...RequestServerOption) *csPair {
skipIfWindows(t) skipIfWindows(t)
skipIfPlan9(t) skipIfPlan9(t)
@ -62,7 +62,6 @@ func clientRequestServerPair(t *testing.T) *csPair {
require.NoError(t, err) require.NoError(t, err)
handlers := InMemHandler() handlers := InMemHandler()
var options []RequestServerOption
if *testAllocator { if *testAllocator {
options = append(options, WithRSAllocator()) options = append(options, WithRSAllocator())
} }
@ -781,6 +780,37 @@ func TestRequestStatVFSError(t *testing.T) {
checkRequestServerAllocator(t, p) checkRequestServerAllocator(t, p)
} }
func TestRequestStartDirOption(t *testing.T) {
startDir := "/start/dir"
p := clientRequestServerPair(t, WithStartDirectory(startDir))
defer p.Close()
// create the start directory
err := p.cli.MkdirAll(startDir)
require.NoError(t, err)
// the working directory must be the defined start directory
wd, err := p.cli.Getwd()
require.NoError(t, err)
require.Equal(t, startDir, wd)
// upload a file using a relative path, it must be uploaded to the start directory
fileName := "file.txt"
_, err = putTestFile(p.cli, fileName, "")
require.NoError(t, err)
// we must be able to stat the file using both a relative and an absolute path
for _, filePath := range []string{fileName, path.Join(startDir, fileName)} {
fi, err := p.cli.Stat(filePath)
require.NoError(t, err)
assert.Equal(t, fileName, fi.Name())
}
// list dir contents using a relative path
entries, err := p.cli.ReadDir(".")
assert.NoError(t, err)
assert.Len(t, entries, 1)
// delete the file using a relative path
err = p.cli.Remove(fileName)
assert.NoError(t, err)
}
func TestCleanDisconnect(t *testing.T) { func TestCleanDisconnect(t *testing.T) {
p := clientRequestServerPair(t) p := clientRequestServerPair(t)
defer p.Close() defer p.Close()
@ -831,6 +861,7 @@ func TestRealPath(t *testing.T) {
func TestCleanPath(t *testing.T) { func TestCleanPath(t *testing.T) {
assert.Equal(t, "/", cleanPath("/")) assert.Equal(t, "/", cleanPath("/"))
assert.Equal(t, "/", cleanPath(".")) assert.Equal(t, "/", cleanPath("."))
assert.Equal(t, "/", cleanPath(""))
assert.Equal(t, "/", cleanPath("/.")) assert.Equal(t, "/", cleanPath("/."))
assert.Equal(t, "/", cleanPath("/a/..")) assert.Equal(t, "/", cleanPath("/a/.."))
assert.Equal(t, "/a/c", cleanPath("/a/b/../c")) assert.Equal(t, "/a/c", cleanPath("/a/b/../c"))

View File

@ -168,9 +168,11 @@ func (r *Request) copy() *Request {
} }
// New Request initialized based on packet data // New Request initialized based on packet data
func requestFromPacket(ctx context.Context, pkt hasPath) *Request { func requestFromPacket(ctx context.Context, pkt hasPath, baseDir string) *Request {
method := requestMethod(pkt) request := &Request{
request := NewRequest(method, pkt.getPath()) Method: requestMethod(pkt),
Filepath: cleanPathWithBase(baseDir, pkt.getPath()),
}
request.ctx, request.cancelCtx = context.WithCancel(ctx) request.ctx, request.cancelCtx = context.WithCancel(ctx)
switch p := pkt.(type) { switch p := pkt.(type) {
@ -180,13 +182,13 @@ func requestFromPacket(ctx context.Context, pkt hasPath) *Request {
request.Flags = p.Flags request.Flags = p.Flags
request.Attrs = p.Attrs.([]byte) request.Attrs = p.Attrs.([]byte)
case *sshFxpRenamePacket: case *sshFxpRenamePacket:
request.Target = cleanPath(p.Newpath) request.Target = cleanPathWithBase(baseDir, p.Newpath)
case *sshFxpSymlinkPacket: case *sshFxpSymlinkPacket:
// NOTE: given a POSIX compliant signature: symlink(target, linkpath string) // NOTE: given a POSIX compliant signature: symlink(target, linkpath string)
// this makes Request.Target the linkpath, and Request.Filepath the target. // this makes Request.Target the linkpath, and Request.Filepath the target.
request.Target = cleanPath(p.Linkpath) request.Target = cleanPathWithBase(baseDir, p.Linkpath)
case *sshFxpExtendedPacketHardlink: case *sshFxpExtendedPacketHardlink:
request.Target = cleanPath(p.Newpath) request.Target = cleanPathWithBase(baseDir, p.Newpath)
} }
return request return request
} }