From 1afc1d9a7850373c208051c5e4a3c04f13f4fd02 Mon Sep 17 00:00:00 2001 From: John Eikenberry Date: Wed, 25 Jul 2018 14:54:02 -0700 Subject: [PATCH] refactor server response to allow for extending Instead of sendPacket/sendError being sprayed all over the place, this change has all those places instead return a responsePacket (eventually) back to the main handling function which then calls sendPacket in one place. Behaviour of the code should remain exactly the same. This makes it much easier to work with the response packets (eg. for the packet ordering issue I'm working on). --- packet.go | 8 +- server.go | 169 ++++++++++++++++++++-------------------- server_statvfs_impl.go | 8 +- server_statvfs_stubs.go | 4 +- 4 files changed, 95 insertions(+), 94 deletions(-) diff --git a/packet.go b/packet.go index 5183d94..4eec06d 100644 --- a/packet.go +++ b/packet.go @@ -882,9 +882,9 @@ func (p sshFxpExtendedPacket) readonly() bool { return p.SpecificPacket.readonly() } -func (p sshFxpExtendedPacket) respond(svr *Server) error { +func (p sshFxpExtendedPacket) respond(svr *Server) responsePacket { if p.SpecificPacket == nil { - return nil + return statusFromError(p, nil) } return p.SpecificPacket.respond(svr) } @@ -954,7 +954,7 @@ func (p *sshFxpExtendedPacketPosixRename) UnmarshalBinary(b []byte) error { return nil } -func (p sshFxpExtendedPacketPosixRename) respond(s *Server) error { +func (p sshFxpExtendedPacketPosixRename) respond(s *Server) responsePacket { err := os.Rename(p.Oldpath, p.Newpath) - return s.sendError(&p, err) + return statusFromError(p, err) } diff --git a/server.go b/server.go index 9a97ca9..a5dd2cd 100644 --- a/server.go +++ b/server.go @@ -66,7 +66,7 @@ func (svr *Server) getHandle(handle string) (*os.File, bool) { type serverRespondablePacket interface { encoding.BinaryUnmarshaler id() uint32 - respond(svr *Server) error + respond(svr *Server) responsePacket } // NewServer creates a new Server instance around the provided streams, serving @@ -140,7 +140,7 @@ func (svr *Server) sftpServerWorker(pktChan chan requestPacket) error { // If server is operating read-only and a write operation is requested, // return permission denied if !readonly && svr.readOnly { - svr.sendError(pkt, syscall.EPERM) + svr.sendPacket(statusFromError(pkt, syscall.EPERM)) continue } @@ -152,133 +152,138 @@ func (svr *Server) sftpServerWorker(pktChan chan requestPacket) error { } func handlePacket(s *Server, p requestPacket) error { + var rpkt responsePacket switch p := p.(type) { case *sshFxInitPacket: - return s.sendPacket(sshFxVersionPacket{Version: sftpProtocolVersion}) + rpkt = sshFxVersionPacket{Version: sftpProtocolVersion} case *sshFxpStatPacket: // stat the requested file info, err := os.Stat(p.Path) - if err != nil { - return s.sendError(p, err) - } - return s.sendPacket(sshFxpStatResponse{ + rpkt = sshFxpStatResponse{ ID: p.ID, info: info, - }) + } + if err != nil { + rpkt = statusFromError(p, err) + } case *sshFxpLstatPacket: // stat the requested file info, err := os.Lstat(p.Path) - if err != nil { - return s.sendError(p, err) - } - return s.sendPacket(sshFxpStatResponse{ + rpkt = sshFxpStatResponse{ ID: p.ID, info: info, - }) + } + if err != nil { + rpkt = statusFromError(p, err) + } case *sshFxpFstatPacket: + fmt.Println("fstat") f, ok := s.getHandle(p.Handle) - if !ok { - return s.sendError(p, syscall.EBADF) + var err error = syscall.EBADF + var info os.FileInfo + if ok { + info, err = f.Stat() + rpkt = sshFxpStatResponse{ + ID: p.ID, + info: info, + } } - - info, err := f.Stat() if err != nil { - return s.sendError(p, err) + rpkt = statusFromError(p, err) } - - return s.sendPacket(sshFxpStatResponse{ - ID: p.ID, - info: info, - }) case *sshFxpMkdirPacket: // TODO FIXME: ignore flags field err := os.Mkdir(p.Path, 0755) - return s.sendError(p, err) + rpkt = statusFromError(p, err) case *sshFxpRmdirPacket: err := os.Remove(p.Path) - return s.sendError(p, err) + rpkt = statusFromError(p, err) case *sshFxpRemovePacket: err := os.Remove(p.Filename) - return s.sendError(p, err) + rpkt = statusFromError(p, err) case *sshFxpRenamePacket: err := os.Rename(p.Oldpath, p.Newpath) - return s.sendError(p, err) + rpkt = statusFromError(p, err) case *sshFxpSymlinkPacket: err := os.Symlink(p.Targetpath, p.Linkpath) - return s.sendError(p, err) + rpkt = statusFromError(p, err) case *sshFxpClosePacket: - return s.sendError(p, s.closeHandle(p.Handle)) + rpkt = statusFromError(p, s.closeHandle(p.Handle)) case *sshFxpReadlinkPacket: f, err := os.Readlink(p.Path) - if err != nil { - return s.sendError(p, err) - } - - return s.sendPacket(sshFxpNamePacket{ + rpkt = sshFxpNamePacket{ ID: p.ID, NameAttrs: []sshFxpNameAttr{{ Name: f, LongName: f, Attrs: emptyFileStat, }}, - }) - + } + if err != nil { + rpkt = statusFromError(p, err) + } case *sshFxpRealpathPacket: f, err := filepath.Abs(p.Path) - if err != nil { - return s.sendError(p, err) - } f = cleanPath(f) - return s.sendPacket(sshFxpNamePacket{ + rpkt = sshFxpNamePacket{ ID: p.ID, NameAttrs: []sshFxpNameAttr{{ Name: f, LongName: f, Attrs: emptyFileStat, }}, - }) + } + if err != nil { + rpkt = statusFromError(p, err) + } case *sshFxpOpendirPacket: if stat, err := os.Stat(p.Path); err != nil { - return s.sendError(p, err) + rpkt = statusFromError(p, err) } else if !stat.IsDir() { - return s.sendError(p, &os.PathError{ + rpkt = statusFromError(p, &os.PathError{ Path: p.Path, Err: syscall.ENOTDIR}) + } else { + rpkt = sshFxpOpenPacket{ + ID: p.ID, + Path: p.Path, + Pflags: ssh_FXF_READ, + }.respond(s) } - return sshFxpOpenPacket{ - ID: p.ID, - Path: p.Path, - Pflags: ssh_FXF_READ, - }.respond(s) case *sshFxpReadPacket: + var err error = syscall.EBADF f, ok := s.getHandle(p.Handle) - if !ok { - return s.sendError(p, syscall.EBADF) + if ok { + err = nil + data := make([]byte, clamp(p.Len, s.maxTxPacket)) + n, _err := f.ReadAt(data, int64(p.Offset)) + if _err != nil && (_err != io.EOF || n == 0) { + err = _err + } + rpkt = sshFxpDataPacket{ + ID: p.ID, + Length: uint32(n), + Data: data[:n], + } + } + if err != nil { + rpkt = statusFromError(p, err) } - data := make([]byte, clamp(p.Len, s.maxTxPacket)) - n, err := f.ReadAt(data, int64(p.Offset)) - if err != nil && (err != io.EOF || n == 0) { - return s.sendError(p, err) - } - return s.sendPacket(sshFxpDataPacket{ - ID: p.ID, - Length: uint32(n), - Data: data[:n], - }) case *sshFxpWritePacket: f, ok := s.getHandle(p.Handle) - if !ok { - return s.sendError(p, syscall.EBADF) + var err error = syscall.EBADF + if ok { + _, err = f.WriteAt(p.Data, int64(p.Offset)) } - - _, err := f.WriteAt(p.Data, int64(p.Offset)) - return s.sendError(p, err) + rpkt = statusFromError(p, err) case serverRespondablePacket: - err := p.respond(s) - return errors.Wrap(err, "pkt.respond failed") + rpkt = p.respond(s) default: return errors.Errorf("unexpected packet type %T", p) } + + s.sendPacket(rpkt) + return nil } // Serve serves SFTP connections until the streams stop or the SFTP subsystem @@ -342,10 +347,6 @@ func (svr *Server) sendPacket(pkt responsePacket) error { return nil } -func (svr *Server) sendError(p requestPacket, err error) error { - return svr.sendPacket(statusFromError(p, err)) -} - type ider interface { id() uint32 } @@ -380,7 +381,7 @@ func (p sshFxpOpenPacket) hasPflags(flags ...uint32) bool { return true } -func (p sshFxpOpenPacket) respond(svr *Server) error { +func (p sshFxpOpenPacket) respond(svr *Server) responsePacket { var osFlags int if p.hasPflags(ssh_FXF_READ, ssh_FXF_WRITE) { osFlags |= os.O_RDWR @@ -390,7 +391,7 @@ func (p sshFxpOpenPacket) respond(svr *Server) error { osFlags |= os.O_RDONLY } else { // how are they opening? - return svr.sendError(&p, syscall.EINVAL) + return statusFromError(p, syscall.EINVAL) } if p.hasPflags(ssh_FXF_APPEND) { @@ -408,23 +409,23 @@ func (p sshFxpOpenPacket) respond(svr *Server) error { f, err := os.OpenFile(p.Path, osFlags, 0644) if err != nil { - return svr.sendError(&p, err) + return statusFromError(p, err) } handle := svr.nextHandle(f) - return svr.sendPacket(sshFxpHandlePacket{ID: p.id(), Handle: handle}) + return sshFxpHandlePacket{ID: p.id(), Handle: handle} } -func (p sshFxpReaddirPacket) respond(svr *Server) error { +func (p sshFxpReaddirPacket) respond(svr *Server) responsePacket { f, ok := svr.getHandle(p.Handle) if !ok { - return svr.sendError(&p, syscall.EBADF) + return statusFromError(p, syscall.EBADF) } dirname := f.Name() dirents, err := f.Readdir(128) if err != nil { - return svr.sendError(&p, err) + return statusFromError(p, err) } ret := sshFxpNamePacket{ID: p.ID} @@ -435,10 +436,10 @@ func (p sshFxpReaddirPacket) respond(svr *Server) error { Attrs: []interface{}{dirent}, }) } - return svr.sendPacket(ret) + return ret } -func (p sshFxpSetstatPacket) respond(svr *Server) error { +func (p sshFxpSetstatPacket) respond(svr *Server) responsePacket { // additional unmarshalling is required for each possibility here b := p.Attrs.([]byte) var err error @@ -477,13 +478,13 @@ func (p sshFxpSetstatPacket) respond(svr *Server) error { } } - return svr.sendError(&p, err) + return statusFromError(p, err) } -func (p sshFxpFsetstatPacket) respond(svr *Server) error { +func (p sshFxpFsetstatPacket) respond(svr *Server) responsePacket { f, ok := svr.getHandle(p.Handle) if !ok { - return svr.sendError(&p, syscall.EBADF) + return statusFromError(p, syscall.EBADF) } // additional unmarshalling is required for each possibility here @@ -524,7 +525,7 @@ func (p sshFxpFsetstatPacket) respond(svr *Server) error { } } - return svr.sendError(&p, err) + return statusFromError(p, err) } // translateErrno translates a syscall error number to a SFTP error code. diff --git a/server_statvfs_impl.go b/server_statvfs_impl.go index c37a34a..4cf91dc 100644 --- a/server_statvfs_impl.go +++ b/server_statvfs_impl.go @@ -9,17 +9,17 @@ import ( "syscall" ) -func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) error { +func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket { stat := &syscall.Statfs_t{} if err := syscall.Statfs(p.Path, stat); err != nil { - return svr.sendPacket(statusFromError(p, err)) + return statusFromError(p, err) } retPkt, err := statvfsFromStatfst(stat) if err != nil { - return svr.sendPacket(statusFromError(p, err)) + return statusFromError(p, err) } retPkt.ID = p.ID - return svr.sendPacket(retPkt) + return retPkt } diff --git a/server_statvfs_stubs.go b/server_statvfs_stubs.go index 3fe4078..c6f6164 100644 --- a/server_statvfs_stubs.go +++ b/server_statvfs_stubs.go @@ -6,6 +6,6 @@ import ( "syscall" ) -func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) error { - return syscall.ENOTSUP +func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket { + return statusFromError(p, syscall.ENOTSUP) }