refactor server response to allow for extending

Instead of sendPacket/sendError being sprayed all over the place, this
change has all those places instead return a responsePacket (eventually)
back to the main handling function which then calls sendPacket in one
place.

Behaviour of the code should remain exactly the same.

This makes it much easier to work with the response packets (eg. for the
packet ordering issue I'm working on).
This commit is contained in:
John Eikenberry 2018-07-25 14:54:02 -07:00
parent b50b1f9eaf
commit 1afc1d9a78
4 changed files with 95 additions and 94 deletions

View File

@ -882,9 +882,9 @@ func (p sshFxpExtendedPacket) readonly() bool {
return p.SpecificPacket.readonly() return p.SpecificPacket.readonly()
} }
func (p sshFxpExtendedPacket) respond(svr *Server) error { func (p sshFxpExtendedPacket) respond(svr *Server) responsePacket {
if p.SpecificPacket == nil { if p.SpecificPacket == nil {
return nil return statusFromError(p, nil)
} }
return p.SpecificPacket.respond(svr) return p.SpecificPacket.respond(svr)
} }
@ -954,7 +954,7 @@ func (p *sshFxpExtendedPacketPosixRename) UnmarshalBinary(b []byte) error {
return nil return nil
} }
func (p sshFxpExtendedPacketPosixRename) respond(s *Server) error { func (p sshFxpExtendedPacketPosixRename) respond(s *Server) responsePacket {
err := os.Rename(p.Oldpath, p.Newpath) err := os.Rename(p.Oldpath, p.Newpath)
return s.sendError(&p, err) return statusFromError(p, err)
} }

169
server.go
View File

@ -66,7 +66,7 @@ func (svr *Server) getHandle(handle string) (*os.File, bool) {
type serverRespondablePacket interface { type serverRespondablePacket interface {
encoding.BinaryUnmarshaler encoding.BinaryUnmarshaler
id() uint32 id() uint32
respond(svr *Server) error respond(svr *Server) responsePacket
} }
// NewServer creates a new Server instance around the provided streams, serving // NewServer creates a new Server instance around the provided streams, serving
@ -140,7 +140,7 @@ func (svr *Server) sftpServerWorker(pktChan chan requestPacket) error {
// If server is operating read-only and a write operation is requested, // If server is operating read-only and a write operation is requested,
// return permission denied // return permission denied
if !readonly && svr.readOnly { if !readonly && svr.readOnly {
svr.sendError(pkt, syscall.EPERM) svr.sendPacket(statusFromError(pkt, syscall.EPERM))
continue continue
} }
@ -152,133 +152,138 @@ func (svr *Server) sftpServerWorker(pktChan chan requestPacket) error {
} }
func handlePacket(s *Server, p requestPacket) error { func handlePacket(s *Server, p requestPacket) error {
var rpkt responsePacket
switch p := p.(type) { switch p := p.(type) {
case *sshFxInitPacket: case *sshFxInitPacket:
return s.sendPacket(sshFxVersionPacket{Version: sftpProtocolVersion}) rpkt = sshFxVersionPacket{Version: sftpProtocolVersion}
case *sshFxpStatPacket: case *sshFxpStatPacket:
// stat the requested file // stat the requested file
info, err := os.Stat(p.Path) info, err := os.Stat(p.Path)
if err != nil { rpkt = sshFxpStatResponse{
return s.sendError(p, err)
}
return s.sendPacket(sshFxpStatResponse{
ID: p.ID, ID: p.ID,
info: info, info: info,
}) }
if err != nil {
rpkt = statusFromError(p, err)
}
case *sshFxpLstatPacket: case *sshFxpLstatPacket:
// stat the requested file // stat the requested file
info, err := os.Lstat(p.Path) info, err := os.Lstat(p.Path)
if err != nil { rpkt = sshFxpStatResponse{
return s.sendError(p, err)
}
return s.sendPacket(sshFxpStatResponse{
ID: p.ID, ID: p.ID,
info: info, info: info,
}) }
if err != nil {
rpkt = statusFromError(p, err)
}
case *sshFxpFstatPacket: case *sshFxpFstatPacket:
fmt.Println("fstat")
f, ok := s.getHandle(p.Handle) f, ok := s.getHandle(p.Handle)
if !ok { var err error = syscall.EBADF
return s.sendError(p, syscall.EBADF) var info os.FileInfo
if ok {
info, err = f.Stat()
rpkt = sshFxpStatResponse{
ID: p.ID,
info: info,
}
} }
info, err := f.Stat()
if err != nil { if err != nil {
return s.sendError(p, err) rpkt = statusFromError(p, err)
} }
return s.sendPacket(sshFxpStatResponse{
ID: p.ID,
info: info,
})
case *sshFxpMkdirPacket: case *sshFxpMkdirPacket:
// TODO FIXME: ignore flags field // TODO FIXME: ignore flags field
err := os.Mkdir(p.Path, 0755) err := os.Mkdir(p.Path, 0755)
return s.sendError(p, err) rpkt = statusFromError(p, err)
case *sshFxpRmdirPacket: case *sshFxpRmdirPacket:
err := os.Remove(p.Path) err := os.Remove(p.Path)
return s.sendError(p, err) rpkt = statusFromError(p, err)
case *sshFxpRemovePacket: case *sshFxpRemovePacket:
err := os.Remove(p.Filename) err := os.Remove(p.Filename)
return s.sendError(p, err) rpkt = statusFromError(p, err)
case *sshFxpRenamePacket: case *sshFxpRenamePacket:
err := os.Rename(p.Oldpath, p.Newpath) err := os.Rename(p.Oldpath, p.Newpath)
return s.sendError(p, err) rpkt = statusFromError(p, err)
case *sshFxpSymlinkPacket: case *sshFxpSymlinkPacket:
err := os.Symlink(p.Targetpath, p.Linkpath) err := os.Symlink(p.Targetpath, p.Linkpath)
return s.sendError(p, err) rpkt = statusFromError(p, err)
case *sshFxpClosePacket: case *sshFxpClosePacket:
return s.sendError(p, s.closeHandle(p.Handle)) rpkt = statusFromError(p, s.closeHandle(p.Handle))
case *sshFxpReadlinkPacket: case *sshFxpReadlinkPacket:
f, err := os.Readlink(p.Path) f, err := os.Readlink(p.Path)
if err != nil { rpkt = sshFxpNamePacket{
return s.sendError(p, err)
}
return s.sendPacket(sshFxpNamePacket{
ID: p.ID, ID: p.ID,
NameAttrs: []sshFxpNameAttr{{ NameAttrs: []sshFxpNameAttr{{
Name: f, Name: f,
LongName: f, LongName: f,
Attrs: emptyFileStat, Attrs: emptyFileStat,
}}, }},
}) }
if err != nil {
rpkt = statusFromError(p, err)
}
case *sshFxpRealpathPacket: case *sshFxpRealpathPacket:
f, err := filepath.Abs(p.Path) f, err := filepath.Abs(p.Path)
if err != nil {
return s.sendError(p, err)
}
f = cleanPath(f) f = cleanPath(f)
return s.sendPacket(sshFxpNamePacket{ rpkt = sshFxpNamePacket{
ID: p.ID, ID: p.ID,
NameAttrs: []sshFxpNameAttr{{ NameAttrs: []sshFxpNameAttr{{
Name: f, Name: f,
LongName: f, LongName: f,
Attrs: emptyFileStat, Attrs: emptyFileStat,
}}, }},
}) }
if err != nil {
rpkt = statusFromError(p, err)
}
case *sshFxpOpendirPacket: case *sshFxpOpendirPacket:
if stat, err := os.Stat(p.Path); err != nil { if stat, err := os.Stat(p.Path); err != nil {
return s.sendError(p, err) rpkt = statusFromError(p, err)
} else if !stat.IsDir() { } else if !stat.IsDir() {
return s.sendError(p, &os.PathError{ rpkt = statusFromError(p, &os.PathError{
Path: p.Path, Err: syscall.ENOTDIR}) Path: p.Path, Err: syscall.ENOTDIR})
} else {
rpkt = sshFxpOpenPacket{
ID: p.ID,
Path: p.Path,
Pflags: ssh_FXF_READ,
}.respond(s)
} }
return sshFxpOpenPacket{
ID: p.ID,
Path: p.Path,
Pflags: ssh_FXF_READ,
}.respond(s)
case *sshFxpReadPacket: case *sshFxpReadPacket:
var err error = syscall.EBADF
f, ok := s.getHandle(p.Handle) f, ok := s.getHandle(p.Handle)
if !ok { if ok {
return s.sendError(p, syscall.EBADF) err = nil
data := make([]byte, clamp(p.Len, s.maxTxPacket))
n, _err := f.ReadAt(data, int64(p.Offset))
if _err != nil && (_err != io.EOF || n == 0) {
err = _err
}
rpkt = sshFxpDataPacket{
ID: p.ID,
Length: uint32(n),
Data: data[:n],
}
}
if err != nil {
rpkt = statusFromError(p, err)
} }
data := make([]byte, clamp(p.Len, s.maxTxPacket))
n, err := f.ReadAt(data, int64(p.Offset))
if err != nil && (err != io.EOF || n == 0) {
return s.sendError(p, err)
}
return s.sendPacket(sshFxpDataPacket{
ID: p.ID,
Length: uint32(n),
Data: data[:n],
})
case *sshFxpWritePacket: case *sshFxpWritePacket:
f, ok := s.getHandle(p.Handle) f, ok := s.getHandle(p.Handle)
if !ok { var err error = syscall.EBADF
return s.sendError(p, syscall.EBADF) if ok {
_, err = f.WriteAt(p.Data, int64(p.Offset))
} }
rpkt = statusFromError(p, err)
_, err := f.WriteAt(p.Data, int64(p.Offset))
return s.sendError(p, err)
case serverRespondablePacket: case serverRespondablePacket:
err := p.respond(s) rpkt = p.respond(s)
return errors.Wrap(err, "pkt.respond failed")
default: default:
return errors.Errorf("unexpected packet type %T", p) return errors.Errorf("unexpected packet type %T", p)
} }
s.sendPacket(rpkt)
return nil
} }
// Serve serves SFTP connections until the streams stop or the SFTP subsystem // Serve serves SFTP connections until the streams stop or the SFTP subsystem
@ -342,10 +347,6 @@ func (svr *Server) sendPacket(pkt responsePacket) error {
return nil return nil
} }
func (svr *Server) sendError(p requestPacket, err error) error {
return svr.sendPacket(statusFromError(p, err))
}
type ider interface { type ider interface {
id() uint32 id() uint32
} }
@ -380,7 +381,7 @@ func (p sshFxpOpenPacket) hasPflags(flags ...uint32) bool {
return true return true
} }
func (p sshFxpOpenPacket) respond(svr *Server) error { func (p sshFxpOpenPacket) respond(svr *Server) responsePacket {
var osFlags int var osFlags int
if p.hasPflags(ssh_FXF_READ, ssh_FXF_WRITE) { if p.hasPflags(ssh_FXF_READ, ssh_FXF_WRITE) {
osFlags |= os.O_RDWR osFlags |= os.O_RDWR
@ -390,7 +391,7 @@ func (p sshFxpOpenPacket) respond(svr *Server) error {
osFlags |= os.O_RDONLY osFlags |= os.O_RDONLY
} else { } else {
// how are they opening? // how are they opening?
return svr.sendError(&p, syscall.EINVAL) return statusFromError(p, syscall.EINVAL)
} }
if p.hasPflags(ssh_FXF_APPEND) { if p.hasPflags(ssh_FXF_APPEND) {
@ -408,23 +409,23 @@ func (p sshFxpOpenPacket) respond(svr *Server) error {
f, err := os.OpenFile(p.Path, osFlags, 0644) f, err := os.OpenFile(p.Path, osFlags, 0644)
if err != nil { if err != nil {
return svr.sendError(&p, err) return statusFromError(p, err)
} }
handle := svr.nextHandle(f) handle := svr.nextHandle(f)
return svr.sendPacket(sshFxpHandlePacket{ID: p.id(), Handle: handle}) return sshFxpHandlePacket{ID: p.id(), Handle: handle}
} }
func (p sshFxpReaddirPacket) respond(svr *Server) error { func (p sshFxpReaddirPacket) respond(svr *Server) responsePacket {
f, ok := svr.getHandle(p.Handle) f, ok := svr.getHandle(p.Handle)
if !ok { if !ok {
return svr.sendError(&p, syscall.EBADF) return statusFromError(p, syscall.EBADF)
} }
dirname := f.Name() dirname := f.Name()
dirents, err := f.Readdir(128) dirents, err := f.Readdir(128)
if err != nil { if err != nil {
return svr.sendError(&p, err) return statusFromError(p, err)
} }
ret := sshFxpNamePacket{ID: p.ID} ret := sshFxpNamePacket{ID: p.ID}
@ -435,10 +436,10 @@ func (p sshFxpReaddirPacket) respond(svr *Server) error {
Attrs: []interface{}{dirent}, Attrs: []interface{}{dirent},
}) })
} }
return svr.sendPacket(ret) return ret
} }
func (p sshFxpSetstatPacket) respond(svr *Server) error { func (p sshFxpSetstatPacket) respond(svr *Server) responsePacket {
// additional unmarshalling is required for each possibility here // additional unmarshalling is required for each possibility here
b := p.Attrs.([]byte) b := p.Attrs.([]byte)
var err error var err error
@ -477,13 +478,13 @@ func (p sshFxpSetstatPacket) respond(svr *Server) error {
} }
} }
return svr.sendError(&p, err) return statusFromError(p, err)
} }
func (p sshFxpFsetstatPacket) respond(svr *Server) error { func (p sshFxpFsetstatPacket) respond(svr *Server) responsePacket {
f, ok := svr.getHandle(p.Handle) f, ok := svr.getHandle(p.Handle)
if !ok { if !ok {
return svr.sendError(&p, syscall.EBADF) return statusFromError(p, syscall.EBADF)
} }
// additional unmarshalling is required for each possibility here // additional unmarshalling is required for each possibility here
@ -524,7 +525,7 @@ func (p sshFxpFsetstatPacket) respond(svr *Server) error {
} }
} }
return svr.sendError(&p, err) return statusFromError(p, err)
} }
// translateErrno translates a syscall error number to a SFTP error code. // translateErrno translates a syscall error number to a SFTP error code.

View File

@ -9,17 +9,17 @@ import (
"syscall" "syscall"
) )
func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) error { func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket {
stat := &syscall.Statfs_t{} stat := &syscall.Statfs_t{}
if err := syscall.Statfs(p.Path, stat); err != nil { if err := syscall.Statfs(p.Path, stat); err != nil {
return svr.sendPacket(statusFromError(p, err)) return statusFromError(p, err)
} }
retPkt, err := statvfsFromStatfst(stat) retPkt, err := statvfsFromStatfst(stat)
if err != nil { if err != nil {
return svr.sendPacket(statusFromError(p, err)) return statusFromError(p, err)
} }
retPkt.ID = p.ID retPkt.ID = p.ID
return svr.sendPacket(retPkt) return retPkt
} }

View File

@ -6,6 +6,6 @@ import (
"syscall" "syscall"
) )
func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) error { func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket {
return syscall.ENOTSUP return statusFromError(p, syscall.ENOTSUP)
} }