sftp/request-server.go

210 lines
4.9 KiB
Go
Raw Normal View History

package sftp
import (
2017-03-15 09:02:17 +08:00
"encoding"
"io"
"os"
"path/filepath"
"strconv"
"sync"
2016-07-09 03:38:35 +08:00
"syscall"
"github.com/pkg/errors"
)
2016-07-09 03:38:35 +08:00
var maxTxPacket uint32 = 1 << 15
2016-07-09 03:38:35 +08:00
type handleHandler func(string) string
2016-07-26 02:42:18 +08:00
// Handlers contains the 4 SFTP server request handlers.
2016-07-09 03:38:35 +08:00
type Handlers struct {
FileGet FileReader
FilePut FileWriter
FileCmd FileCmder
FileInfo FileInfoer
2016-07-09 03:38:35 +08:00
}
2016-07-26 02:42:18 +08:00
// RequestServer abstracts the sftp protocol with an http request-like protocol
type RequestServer struct {
serverConn
2016-07-12 11:19:49 +08:00
Handlers Handlers
2017-03-14 09:24:32 +08:00
pktChan chan requestPacket
2017-03-15 09:02:17 +08:00
pktMgr packetManager
openRequests map[string]Request
2016-07-09 03:38:35 +08:00
openRequestLock sync.RWMutex
handleCount int
}
2016-07-26 02:42:18 +08:00
// NewRequestServer creates/allocates/returns new RequestServer.
// Normally there there will be one server per user-session.
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers) *RequestServer {
2017-03-15 09:02:17 +08:00
svrConn := serverConn{
conn: conn{
Reader: rwc,
WriteCloser: rwc,
},
2017-03-15 09:02:17 +08:00
}
return &RequestServer{
serverConn: svrConn,
Handlers: h,
2017-03-14 09:24:32 +08:00
pktChan: make(chan requestPacket, sftpServerWorkerCount),
2017-03-15 09:02:17 +08:00
pktMgr: newPktMgr(&svrConn),
openRequests: make(map[string]Request),
}
}
func (rs *RequestServer) nextRequest(r Request) string {
2016-07-09 03:38:35 +08:00
rs.openRequestLock.Lock()
defer rs.openRequestLock.Unlock()
rs.handleCount++
handle := strconv.Itoa(rs.handleCount)
rs.openRequests[handle] = r
return handle
2016-07-09 03:38:35 +08:00
}
func (rs *RequestServer) getRequest(handle string) (Request, bool) {
2016-08-02 11:22:06 +08:00
rs.openRequestLock.RLock()
defer rs.openRequestLock.RUnlock()
2016-07-09 03:38:35 +08:00
r, ok := rs.openRequests[handle]
return r, ok
}
func (rs *RequestServer) closeRequest(handle string) {
rs.openRequestLock.Lock()
defer rs.openRequestLock.Unlock()
if r, ok := rs.openRequests[handle]; ok {
r.close()
2016-07-09 03:38:35 +08:00
delete(rs.openRequests, handle)
}
}
2016-07-26 02:42:18 +08:00
// Close the read/write/closer to trigger exiting the main server loop
2016-07-23 07:20:00 +08:00
func (rs *RequestServer) Close() error { return rs.conn.Close() }
2016-07-21 07:48:31 +08:00
2016-07-26 02:42:18 +08:00
// Serve requests for user session
2016-07-09 03:38:35 +08:00
func (rs *RequestServer) Serve() error {
var wg sync.WaitGroup
wg.Add(sftpServerWorkerCount)
for i := 0; i < sftpServerWorkerCount; i++ {
go func() {
defer wg.Done()
2016-07-09 03:38:35 +08:00
if err := rs.packetWorker(); err != nil {
rs.conn.Close() // shuts down recvPacket
}
}()
}
var err error
2017-03-15 09:02:17 +08:00
var pkt requestPacket
var pktType uint8
var pktBytes []byte
for {
2016-07-09 03:38:35 +08:00
pktType, pktBytes, err = rs.recvPacket()
2016-07-19 02:50:45 +08:00
if err != nil {
break
}
2017-03-15 09:02:17 +08:00
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
2016-07-19 02:50:45 +08:00
if err != nil {
2017-03-15 09:02:17 +08:00
debug("makePacket err: %v", err)
rs.conn.Close() // shuts down recvPacket
2016-07-19 02:50:45 +08:00
break
}
2017-03-15 09:02:17 +08:00
rs.pktMgr.incomingPacket(pkt)
2016-07-12 11:19:49 +08:00
rs.pktChan <- pkt
}
2016-07-09 03:38:35 +08:00
close(rs.pktChan) // shuts down sftpServerWorkers
wg.Wait() // wait for all workers to exit
2017-03-15 09:02:17 +08:00
rs.pktMgr.close() // shuts down packetManager
return err
}
2016-07-09 03:38:35 +08:00
func (rs *RequestServer) packetWorker() error {
2016-07-12 11:19:49 +08:00
for pkt := range rs.pktChan {
2016-07-26 02:52:07 +08:00
var rpkt responsePacket
switch pkt := pkt.(type) {
case *sshFxInitPacket:
2016-07-13 08:36:12 +08:00
rpkt = sshFxVersionPacket{sftpProtocolVersion, nil}
2016-07-12 11:19:49 +08:00
case *sshFxpClosePacket:
handle := pkt.getHandle()
2016-07-12 11:19:49 +08:00
rs.closeRequest(handle)
2016-07-13 08:36:12 +08:00
rpkt = statusFromError(pkt, nil)
case *sshFxpRealpathPacket:
rpkt = cleanPath(pkt)
case isOpener:
handle := rs.nextRequest(requestFromPacket(pkt))
rpkt = sshFxpHandlePacket{pkt.id(), handle}
2016-07-13 08:36:12 +08:00
case hasHandle:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
request.update(pkt)
if !ok {
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
rpkt = rs.handle(request, pkt)
}
case hasPath:
request := requestFromPacket(pkt)
rpkt = rs.handle(request, pkt)
default:
return errors.Errorf("unexpected packet type %T", pkt)
}
2016-07-12 11:19:49 +08:00
err := rs.sendPacket(rpkt)
2016-07-19 02:50:45 +08:00
if err != nil {
return err
}
}
return nil
}
2016-07-13 08:36:12 +08:00
2016-07-26 02:52:07 +08:00
func cleanPath(pkt *sshFxpRealpathPacket) responsePacket {
2016-07-19 03:40:41 +08:00
path := pkt.getPath()
if !filepath.IsAbs(path) {
2016-07-23 07:20:00 +08:00
path = "/" + path
} // all paths are absolute
2016-07-19 03:40:41 +08:00
cleaned_path := filepath.Clean(path)
return &sshFxpNamePacket{
ID: pkt.id(),
NameAttrs: []sshFxpNameAttr{{
Name: cleaned_path,
LongName: cleaned_path,
Attrs: emptyFileStat,
}},
}
}
2017-03-14 09:24:32 +08:00
func (rs *RequestServer) handle(request Request, pkt requestPacket) responsePacket {
// fmt.Println("Request Method: ", request.Method)
rpkt, err := request.handle(rs.Handlers)
if err != nil {
err = errorAdapter(err)
rpkt = statusFromError(pkt, err)
2016-07-19 02:50:45 +08:00
}
2016-07-13 08:36:12 +08:00
return rpkt
}
2017-03-15 09:02:17 +08:00
// Wrap underlying connection methods to use packetManager
func (rs *RequestServer) sendPacket(m encoding.BinaryMarshaler) error {
2017-03-15 09:02:17 +08:00
if pkt, ok := m.(responsePacket); ok {
rs.pktMgr.readyPacket(pkt)
} else {
return errors.Errorf("unexpected packet type %T", m)
}
return nil
}
func (rs *RequestServer) sendError(p ider, err error) error {
2017-03-15 09:02:17 +08:00
return rs.sendPacket(statusFromError(p, err))
}
// os.ErrNotExist should convert to ssh_FX_NO_SUCH_FILE, but is not recognized
// by statusFromError. So we convert to syscall.ENOENT which it does.
func errorAdapter(err error) error {
if err == os.ErrNotExist {
return syscall.ENOENT
}
return err
}