refactor server response to allow for extending

Instead of sendPacket/sendError being sprayed all over the place, this
change has all those places instead return a responsePacket (eventually)
back to the main handling function which then calls sendPacket in one
place.

Behaviour of the code should remain exactly the same.

This makes it much easier to work with the response packets (eg. for the
packet ordering issue I'm working on).
This commit is contained in:
John Eikenberry 2018-07-25 14:54:02 -07:00
parent b50b1f9eaf
commit 1afc1d9a78
4 changed files with 95 additions and 94 deletions

View File

@ -882,9 +882,9 @@ func (p sshFxpExtendedPacket) readonly() bool {
return p.SpecificPacket.readonly()
}
func (p sshFxpExtendedPacket) respond(svr *Server) error {
func (p sshFxpExtendedPacket) respond(svr *Server) responsePacket {
if p.SpecificPacket == nil {
return nil
return statusFromError(p, nil)
}
return p.SpecificPacket.respond(svr)
}
@ -954,7 +954,7 @@ func (p *sshFxpExtendedPacketPosixRename) UnmarshalBinary(b []byte) error {
return nil
}
func (p sshFxpExtendedPacketPosixRename) respond(s *Server) error {
func (p sshFxpExtendedPacketPosixRename) respond(s *Server) responsePacket {
err := os.Rename(p.Oldpath, p.Newpath)
return s.sendError(&p, err)
return statusFromError(p, err)
}

169
server.go
View File

@ -66,7 +66,7 @@ func (svr *Server) getHandle(handle string) (*os.File, bool) {
type serverRespondablePacket interface {
encoding.BinaryUnmarshaler
id() uint32
respond(svr *Server) error
respond(svr *Server) responsePacket
}
// NewServer creates a new Server instance around the provided streams, serving
@ -140,7 +140,7 @@ func (svr *Server) sftpServerWorker(pktChan chan requestPacket) error {
// If server is operating read-only and a write operation is requested,
// return permission denied
if !readonly && svr.readOnly {
svr.sendError(pkt, syscall.EPERM)
svr.sendPacket(statusFromError(pkt, syscall.EPERM))
continue
}
@ -152,133 +152,138 @@ func (svr *Server) sftpServerWorker(pktChan chan requestPacket) error {
}
func handlePacket(s *Server, p requestPacket) error {
var rpkt responsePacket
switch p := p.(type) {
case *sshFxInitPacket:
return s.sendPacket(sshFxVersionPacket{Version: sftpProtocolVersion})
rpkt = sshFxVersionPacket{Version: sftpProtocolVersion}
case *sshFxpStatPacket:
// stat the requested file
info, err := os.Stat(p.Path)
if err != nil {
return s.sendError(p, err)
}
return s.sendPacket(sshFxpStatResponse{
rpkt = sshFxpStatResponse{
ID: p.ID,
info: info,
})
}
if err != nil {
rpkt = statusFromError(p, err)
}
case *sshFxpLstatPacket:
// stat the requested file
info, err := os.Lstat(p.Path)
if err != nil {
return s.sendError(p, err)
}
return s.sendPacket(sshFxpStatResponse{
rpkt = sshFxpStatResponse{
ID: p.ID,
info: info,
})
}
if err != nil {
rpkt = statusFromError(p, err)
}
case *sshFxpFstatPacket:
fmt.Println("fstat")
f, ok := s.getHandle(p.Handle)
if !ok {
return s.sendError(p, syscall.EBADF)
var err error = syscall.EBADF
var info os.FileInfo
if ok {
info, err = f.Stat()
rpkt = sshFxpStatResponse{
ID: p.ID,
info: info,
}
}
info, err := f.Stat()
if err != nil {
return s.sendError(p, err)
rpkt = statusFromError(p, err)
}
return s.sendPacket(sshFxpStatResponse{
ID: p.ID,
info: info,
})
case *sshFxpMkdirPacket:
// TODO FIXME: ignore flags field
err := os.Mkdir(p.Path, 0755)
return s.sendError(p, err)
rpkt = statusFromError(p, err)
case *sshFxpRmdirPacket:
err := os.Remove(p.Path)
return s.sendError(p, err)
rpkt = statusFromError(p, err)
case *sshFxpRemovePacket:
err := os.Remove(p.Filename)
return s.sendError(p, err)
rpkt = statusFromError(p, err)
case *sshFxpRenamePacket:
err := os.Rename(p.Oldpath, p.Newpath)
return s.sendError(p, err)
rpkt = statusFromError(p, err)
case *sshFxpSymlinkPacket:
err := os.Symlink(p.Targetpath, p.Linkpath)
return s.sendError(p, err)
rpkt = statusFromError(p, err)
case *sshFxpClosePacket:
return s.sendError(p, s.closeHandle(p.Handle))
rpkt = statusFromError(p, s.closeHandle(p.Handle))
case *sshFxpReadlinkPacket:
f, err := os.Readlink(p.Path)
if err != nil {
return s.sendError(p, err)
}
return s.sendPacket(sshFxpNamePacket{
rpkt = sshFxpNamePacket{
ID: p.ID,
NameAttrs: []sshFxpNameAttr{{
Name: f,
LongName: f,
Attrs: emptyFileStat,
}},
})
}
if err != nil {
rpkt = statusFromError(p, err)
}
case *sshFxpRealpathPacket:
f, err := filepath.Abs(p.Path)
if err != nil {
return s.sendError(p, err)
}
f = cleanPath(f)
return s.sendPacket(sshFxpNamePacket{
rpkt = sshFxpNamePacket{
ID: p.ID,
NameAttrs: []sshFxpNameAttr{{
Name: f,
LongName: f,
Attrs: emptyFileStat,
}},
})
}
if err != nil {
rpkt = statusFromError(p, err)
}
case *sshFxpOpendirPacket:
if stat, err := os.Stat(p.Path); err != nil {
return s.sendError(p, err)
rpkt = statusFromError(p, err)
} else if !stat.IsDir() {
return s.sendError(p, &os.PathError{
rpkt = statusFromError(p, &os.PathError{
Path: p.Path, Err: syscall.ENOTDIR})
} else {
rpkt = sshFxpOpenPacket{
ID: p.ID,
Path: p.Path,
Pflags: ssh_FXF_READ,
}.respond(s)
}
return sshFxpOpenPacket{
ID: p.ID,
Path: p.Path,
Pflags: ssh_FXF_READ,
}.respond(s)
case *sshFxpReadPacket:
var err error = syscall.EBADF
f, ok := s.getHandle(p.Handle)
if !ok {
return s.sendError(p, syscall.EBADF)
if ok {
err = nil
data := make([]byte, clamp(p.Len, s.maxTxPacket))
n, _err := f.ReadAt(data, int64(p.Offset))
if _err != nil && (_err != io.EOF || n == 0) {
err = _err
}
rpkt = sshFxpDataPacket{
ID: p.ID,
Length: uint32(n),
Data: data[:n],
}
}
if err != nil {
rpkt = statusFromError(p, err)
}
data := make([]byte, clamp(p.Len, s.maxTxPacket))
n, err := f.ReadAt(data, int64(p.Offset))
if err != nil && (err != io.EOF || n == 0) {
return s.sendError(p, err)
}
return s.sendPacket(sshFxpDataPacket{
ID: p.ID,
Length: uint32(n),
Data: data[:n],
})
case *sshFxpWritePacket:
f, ok := s.getHandle(p.Handle)
if !ok {
return s.sendError(p, syscall.EBADF)
var err error = syscall.EBADF
if ok {
_, err = f.WriteAt(p.Data, int64(p.Offset))
}
_, err := f.WriteAt(p.Data, int64(p.Offset))
return s.sendError(p, err)
rpkt = statusFromError(p, err)
case serverRespondablePacket:
err := p.respond(s)
return errors.Wrap(err, "pkt.respond failed")
rpkt = p.respond(s)
default:
return errors.Errorf("unexpected packet type %T", p)
}
s.sendPacket(rpkt)
return nil
}
// Serve serves SFTP connections until the streams stop or the SFTP subsystem
@ -342,10 +347,6 @@ func (svr *Server) sendPacket(pkt responsePacket) error {
return nil
}
func (svr *Server) sendError(p requestPacket, err error) error {
return svr.sendPacket(statusFromError(p, err))
}
type ider interface {
id() uint32
}
@ -380,7 +381,7 @@ func (p sshFxpOpenPacket) hasPflags(flags ...uint32) bool {
return true
}
func (p sshFxpOpenPacket) respond(svr *Server) error {
func (p sshFxpOpenPacket) respond(svr *Server) responsePacket {
var osFlags int
if p.hasPflags(ssh_FXF_READ, ssh_FXF_WRITE) {
osFlags |= os.O_RDWR
@ -390,7 +391,7 @@ func (p sshFxpOpenPacket) respond(svr *Server) error {
osFlags |= os.O_RDONLY
} else {
// how are they opening?
return svr.sendError(&p, syscall.EINVAL)
return statusFromError(p, syscall.EINVAL)
}
if p.hasPflags(ssh_FXF_APPEND) {
@ -408,23 +409,23 @@ func (p sshFxpOpenPacket) respond(svr *Server) error {
f, err := os.OpenFile(p.Path, osFlags, 0644)
if err != nil {
return svr.sendError(&p, err)
return statusFromError(p, err)
}
handle := svr.nextHandle(f)
return svr.sendPacket(sshFxpHandlePacket{ID: p.id(), Handle: handle})
return sshFxpHandlePacket{ID: p.id(), Handle: handle}
}
func (p sshFxpReaddirPacket) respond(svr *Server) error {
func (p sshFxpReaddirPacket) respond(svr *Server) responsePacket {
f, ok := svr.getHandle(p.Handle)
if !ok {
return svr.sendError(&p, syscall.EBADF)
return statusFromError(p, syscall.EBADF)
}
dirname := f.Name()
dirents, err := f.Readdir(128)
if err != nil {
return svr.sendError(&p, err)
return statusFromError(p, err)
}
ret := sshFxpNamePacket{ID: p.ID}
@ -435,10 +436,10 @@ func (p sshFxpReaddirPacket) respond(svr *Server) error {
Attrs: []interface{}{dirent},
})
}
return svr.sendPacket(ret)
return ret
}
func (p sshFxpSetstatPacket) respond(svr *Server) error {
func (p sshFxpSetstatPacket) respond(svr *Server) responsePacket {
// additional unmarshalling is required for each possibility here
b := p.Attrs.([]byte)
var err error
@ -477,13 +478,13 @@ func (p sshFxpSetstatPacket) respond(svr *Server) error {
}
}
return svr.sendError(&p, err)
return statusFromError(p, err)
}
func (p sshFxpFsetstatPacket) respond(svr *Server) error {
func (p sshFxpFsetstatPacket) respond(svr *Server) responsePacket {
f, ok := svr.getHandle(p.Handle)
if !ok {
return svr.sendError(&p, syscall.EBADF)
return statusFromError(p, syscall.EBADF)
}
// additional unmarshalling is required for each possibility here
@ -524,7 +525,7 @@ func (p sshFxpFsetstatPacket) respond(svr *Server) error {
}
}
return svr.sendError(&p, err)
return statusFromError(p, err)
}
// translateErrno translates a syscall error number to a SFTP error code.

View File

@ -9,17 +9,17 @@ import (
"syscall"
)
func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) error {
func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket {
stat := &syscall.Statfs_t{}
if err := syscall.Statfs(p.Path, stat); err != nil {
return svr.sendPacket(statusFromError(p, err))
return statusFromError(p, err)
}
retPkt, err := statvfsFromStatfst(stat)
if err != nil {
return svr.sendPacket(statusFromError(p, err))
return statusFromError(p, err)
}
retPkt.ID = p.ID
return svr.sendPacket(retPkt)
return retPkt
}

View File

@ -6,6 +6,6 @@ import (
"syscall"
)
func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) error {
return syscall.ENOTSUP
func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket {
return statusFromError(p, syscall.ENOTSUP)
}