sftp/request-server.go

329 lines
8.4 KiB
Go
Raw Normal View History

package sftp
import (
"context"
"errors"
"io"
"path"
"path/filepath"
"strconv"
"sync"
)
2016-07-09 03:38:35 +08:00
var maxTxPacket uint32 = 1 << 15
2016-07-26 02:42:18 +08:00
// Handlers contains the 4 SFTP server request handlers.
2016-07-09 03:38:35 +08:00
type Handlers struct {
FileGet FileReader
FilePut FileWriter
FileCmd FileCmder
FileList FileLister
2016-07-09 03:38:35 +08:00
}
2016-07-26 02:42:18 +08:00
// RequestServer abstracts the sftp protocol with an http request-like protocol
type RequestServer struct {
2021-07-21 22:31:14 +08:00
Handlers Handlers
*serverConn
2021-07-21 22:31:14 +08:00
pktMgr *packetManager
startDirectory string
2021-07-21 22:31:14 +08:00
mu sync.RWMutex
handleCount int
openRequests map[string]*Request
}
// A RequestServerOption is a function which applies configuration to a RequestServer.
type RequestServerOption func(*RequestServer)
// WithRSAllocator enable the allocator.
// After processing a packet we keep in memory the allocated slices
// and we reuse them for new packets.
// The allocator is experimental
func WithRSAllocator() RequestServerOption {
return func(rs *RequestServer) {
alloc := newAllocator()
rs.pktMgr.alloc = alloc
rs.conn.alloc = alloc
}
}
// WithStartDirectory sets a start directory to use as base for relative paths.
// If unset the default is "/"
func WithStartDirectory(startDirectory string) RequestServerOption {
return func(rs *RequestServer) {
rs.startDirectory = cleanPath(startDirectory)
}
}
// NewRequestServer creates/allocates/returns new RequestServer.
// Normally there will be one server per user-session.
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer {
svrConn := &serverConn{
2017-03-15 09:02:17 +08:00
conn: conn{
Reader: rwc,
WriteCloser: rwc,
},
2017-03-15 09:02:17 +08:00
}
rs := &RequestServer{
2021-07-21 22:31:14 +08:00
Handlers: h,
serverConn: svrConn,
pktMgr: newPktMgr(svrConn),
startDirectory: "/",
openRequests: make(map[string]*Request),
}
for _, o := range options {
o(rs)
}
return rs
}
// New Open packet/Request
func (rs *RequestServer) nextRequest(r *Request) string {
2021-07-21 22:31:14 +08:00
rs.mu.Lock()
defer rs.mu.Unlock()
rs.handleCount++
2021-07-21 22:31:14 +08:00
r.handle = strconv.Itoa(rs.handleCount)
rs.openRequests[r.handle] = r
return r.handle
2016-07-09 03:38:35 +08:00
}
// Returns Request from openRequests, bool is false if it is missing.
//
// The Requests in openRequests work essentially as open file descriptors that
// you can do different things with. What you are doing with it are denoted by
// the first packet of that type (read/write/etc).
func (rs *RequestServer) getRequest(handle string) (*Request, bool) {
2021-07-21 22:31:14 +08:00
rs.mu.RLock()
defer rs.mu.RUnlock()
2016-07-09 03:38:35 +08:00
r, ok := rs.openRequests[handle]
return r, ok
2016-07-09 03:38:35 +08:00
}
// Close the Request and clear from openRequests map
func (rs *RequestServer) closeRequest(handle string) error {
2021-07-21 22:31:14 +08:00
rs.mu.Lock()
defer rs.mu.Unlock()
if r, ok := rs.openRequests[handle]; ok {
2016-07-09 03:38:35 +08:00
delete(rs.openRequests, handle)
return r.close()
2016-07-09 03:38:35 +08:00
}
2021-07-21 22:31:14 +08:00
return EBADF
2016-07-09 03:38:35 +08:00
}
2016-07-26 02:42:18 +08:00
// Close the read/write/closer to trigger exiting the main server loop
2016-07-23 07:20:00 +08:00
func (rs *RequestServer) Close() error { return rs.conn.Close() }
2016-07-21 07:48:31 +08:00
func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
defer close(pktChan) // shuts down sftpServerWorkers
var err error
2017-03-15 09:02:17 +08:00
var pkt requestPacket
var pktType uint8
var pktBytes []byte
for {
pktType, pktBytes, err = rs.serverConn.recvPacket(rs.pktMgr.getNextOrderID())
2016-07-19 02:50:45 +08:00
if err != nil {
// we don't care about releasing allocated pages here, the server will quit and the allocator freed
return err
2016-07-19 02:50:45 +08:00
}
2017-03-15 09:02:17 +08:00
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
2016-07-19 02:50:45 +08:00
if err != nil {
switch {
case errors.Is(err, errUnknownExtendedPacket):
// do nothing
default:
debug("makePacket err: %v", err)
rs.conn.Close() // shuts down recvPacket
return err
}
2016-07-19 02:50:45 +08:00
}
pktChan <- rs.pktMgr.newOrderedRequest(pkt)
}
}
// Serve requests for user session
func (rs *RequestServer) Serve() error {
defer func() {
if rs.pktMgr.alloc != nil {
rs.pktMgr.alloc.Free()
}
}()
2021-07-21 22:31:14 +08:00
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
2021-07-21 22:31:14 +08:00
var wg sync.WaitGroup
runWorker := func(ch chan orderedRequest) {
wg.Add(1)
go func() {
defer wg.Done()
if err := rs.packetWorker(ctx, ch); err != nil {
rs.conn.Close() // shuts down recvPacket
}
}()
}
pktChan := rs.pktMgr.workerChan(runWorker)
err := rs.serveLoop(pktChan)
wg.Wait() // wait for all workers to exit
2021-07-21 22:31:14 +08:00
rs.mu.Lock()
defer rs.mu.Unlock()
// make sure all open requests are properly closed
// (eg. possible on dropped connections, client crashes, etc.)
for handle, req := range rs.openRequests {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
req.transferError(err)
delete(rs.openRequests, handle)
req.close()
}
return err
}
2021-07-21 22:31:14 +08:00
func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedRequest) error {
for pkt := range pktChan {
orderID := pkt.orderID()
if epkt, ok := pkt.requestPacket.(*sshFxpExtendedPacket); ok {
if epkt.SpecificPacket != nil {
pkt.requestPacket = epkt.SpecificPacket
}
}
2016-07-26 02:52:07 +08:00
var rpkt responsePacket
switch pkt := pkt.requestPacket.(type) {
case *sshFxInitPacket:
rpkt = &sshFxVersionPacket{Version: sftpProtocolVersion, Extensions: sftpExtensions}
2016-07-12 11:19:49 +08:00
case *sshFxpClosePacket:
handle := pkt.getHandle()
rpkt = statusFromError(pkt.ID, rs.closeRequest(handle))
case *sshFxpRealpathPacket:
var realPath string
2021-04-28 00:06:39 +08:00
if realPather, ok := rs.Handlers.FileList.(RealPathFileLister); ok {
realPath = realPather.RealPath(pkt.getPath())
} else {
realPath = cleanPathWithBase(rs.startDirectory, pkt.getPath())
}
rpkt = cleanPacketPath(pkt, realPath)
case *sshFxpOpendirPacket:
request := requestFromPacket(ctx, pkt, rs.startDirectory)
handle := rs.nextRequest(request)
rpkt = request.opendir(rs.Handlers, pkt)
if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
// if we return an error we have to remove the handle from the active ones
rs.closeRequest(handle)
}
case *sshFxpOpenPacket:
request := requestFromPacket(ctx, pkt, rs.startDirectory)
handle := rs.nextRequest(request)
rpkt = request.open(rs.Handlers, pkt)
if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
// if we return an error we have to remove the handle from the active ones
rs.closeRequest(handle)
}
case *sshFxpFstatPacket:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt.ID, EBADF)
} else {
request = &Request{
Method: "Stat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
}
case *sshFxpFsetstatPacket:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt.ID, EBADF)
} else {
request = &Request{
Method: "Setstat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
}
2019-10-27 13:59:47 +08:00
case *sshFxpExtendedPacketPosixRename:
request := &Request{
Method: "PosixRename",
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Oldpath),
Target: cleanPathWithBase(rs.startDirectory, pkt.Newpath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
2021-02-11 02:13:19 +08:00
case *sshFxpExtendedPacketStatVFS:
request := &Request{
Method: "StatVFS",
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Path),
}
2021-02-11 02:13:19 +08:00
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
2016-07-13 08:36:12 +08:00
case hasHandle:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt.id(), EBADF)
} else {
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
}
case hasPath:
request := requestFromPacket(ctx, pkt, rs.startDirectory)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
request.close()
default:
rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
}
2016-07-12 11:19:49 +08:00
rs.pktMgr.readyPacket(
rs.pktMgr.newOrderedResponse(rpkt, orderID))
}
return nil
}
2016-07-13 08:36:12 +08:00
// clean and return name packet for file
func cleanPacketPath(pkt *sshFxpRealpathPacket, realPath string) responsePacket {
return &sshFxpNamePacket{
ID: pkt.id(),
NameAttrs: []*sshFxpNameAttr{
2021-02-23 05:29:35 +08:00
{
Name: realPath,
LongName: realPath,
Attrs: emptyFileStat,
},
},
}
}
// Makes sure we have a clean POSIX (/) absolute path to work with
func cleanPath(p string) string {
2021-04-28 00:06:39 +08:00
return cleanPathWithBase("/", p)
}
2021-04-28 00:06:39 +08:00
func cleanPathWithBase(base, p string) string {
p = filepath.ToSlash(filepath.Clean(p))
if !path.IsAbs(p) {
return path.Join(base, p)
}
return p
}