diff --git a/server.go b/server.go index 901977f..2a3df5e 100644 --- a/server.go +++ b/server.go @@ -122,7 +122,10 @@ type rxPacket struct { // Up to N parallel servers func (svr *Server) sftpServerWorker() error { for p := range svr.pktChan { - var pkt serverRespondablePacket + var pkt interface { + encoding.BinaryUnmarshaler + id() uint32 + } var readonly = true switch p.pktType { case ssh_FXP_INIT: @@ -181,7 +184,7 @@ func (svr *Server) sftpServerWorker() error { return err } - // handle SFP_OPENDIR specially + // handle FXP_OPENDIR specially switch pkt := pkt.(type) { case *sshFxpOpenPacket: readonly = pkt.readonly() @@ -198,14 +201,113 @@ func (svr *Server) sftpServerWorker() error { continue } - if err := pkt.respond(svr); err != nil { - return errors.Wrap(err, "pkt.respond failed") + if err := handlePacket(svr, pkt); err != nil { + return err } - } return nil } +func handlePacket(s *Server, p interface{}) error { + switch p := p.(type) { + case *sshFxInitPacket: + return s.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil}) + case *sshFxpStatPacket: + // stat the requested file + info, err := os.Stat(p.Path) + if err != nil { + return s.sendError(p, err) + } + return s.sendPacket(sshFxpStatResponse{ + ID: p.ID, + info: info, + }) + case *sshFxpLstatPacket: + // stat the requested file + info, err := os.Lstat(p.Path) + if err != nil { + return s.sendError(p, err) + } + return s.sendPacket(sshFxpStatResponse{ + ID: p.ID, + info: info, + }) + case *sshFxpFstatPacket: + f, ok := s.getHandle(p.Handle) + if !ok { + return s.sendError(p, syscall.EBADF) + } + + info, err := f.Stat() + if err != nil { + return s.sendError(p, err) + } + + return s.sendPacket(sshFxpStatResponse{ + ID: p.ID, + info: info, + }) + case *sshFxpMkdirPacket: + // TODO FIXME: ignore flags field + err := os.Mkdir(p.Path, 0755) + return s.sendError(p, err) + case *sshFxpRmdirPacket: + err := os.Remove(p.Path) + return s.sendError(p, err) + case *sshFxpRemovePacket: + err := os.Remove(p.Filename) + return s.sendError(p, err) + case *sshFxpRenamePacket: + err := os.Rename(p.Oldpath, p.Newpath) + return s.sendError(p, err) + case *sshFxpSymlinkPacket: + err := os.Symlink(p.Targetpath, p.Linkpath) + return s.sendError(p, err) + case *sshFxpClosePacket: + return s.sendError(p, s.closeHandle(p.Handle)) + case *sshFxpReadlinkPacket: + f, err := os.Readlink(p.Path) + if err != nil { + return s.sendError(p, err) + } + + return s.sendPacket(sshFxpNamePacket{ + ID: p.ID, + NameAttrs: []sshFxpNameAttr{{ + Name: f, + LongName: f, + Attrs: emptyFileStat, + }}, + }) + + case *sshFxpRealpathPacket: + f, err := filepath.Abs(p.Path) + if err != nil { + return s.sendError(p, err) + } + f = filepath.Clean(f) + return s.sendPacket(sshFxpNamePacket{ + ID: p.ID, + NameAttrs: []sshFxpNameAttr{{ + Name: f, + LongName: f, + Attrs: emptyFileStat, + }}, + }) + case *sshFxpOpendirPacket: + return sshFxpOpenPacket{ + ID: p.ID, + Path: p.Path, + Pflags: ssh_FXF_READ, + }.respond(s) + case serverRespondablePacket: + err := p.respond(s) + return errors.Wrap(err, "pkt.respond failed") + default: + return errors.Errorf("unexpected packet type %T", p) + } +} + // Serve serves SFTP connections until the streams stop or the SFTP subsystem // is stopped. func (svr *Server) Serve() error { @@ -250,10 +352,6 @@ func (s *Server) sendError(p id, err error) error { return s.sendPacket(statusFromError(p, err)) } -func (p sshFxInitPacket) respond(svr *Server) error { - return svr.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil}) -} - // The init packet has no ID, so we just return a zero-value ID func (p sshFxInitPacket) id() uint32 { return 0 } @@ -269,119 +367,8 @@ func (p sshFxpStatResponse) MarshalBinary() ([]byte, error) { return b, nil } -func (p sshFxpLstatPacket) respond(svr *Server) error { - // stat the requested file - info, err := os.Lstat(p.Path) - if err != nil { - return svr.sendError(p, err) - } - - return svr.sendPacket(sshFxpStatResponse{ - ID: p.ID, - info: info, - }) -} - -func (p sshFxpStatPacket) respond(svr *Server) error { - // stat the requested file - info, err := os.Stat(p.Path) - if err != nil { - return svr.sendError(p, err) - } - - return svr.sendPacket(sshFxpStatResponse{ - ID: p.ID, - info: info, - }) -} - -func (p sshFxpFstatPacket) respond(svr *Server) error { - f, ok := svr.getHandle(p.Handle) - if !ok { - return svr.sendError(p, syscall.EBADF) - } - - info, err := f.Stat() - if err != nil { - return svr.sendError(p, err) - } - - return svr.sendPacket(sshFxpStatResponse{ - ID: p.ID, - info: info, - }) -} - -func (p sshFxpMkdirPacket) respond(svr *Server) error { - // TODO FIXME: ignore flags field - err := os.Mkdir(p.Path, 0755) - return svr.sendError(p, err) -} - -func (p sshFxpRmdirPacket) respond(svr *Server) error { - err := os.Remove(p.Path) - return svr.sendError(p, err) -} - -func (p sshFxpRemovePacket) respond(svr *Server) error { - err := os.Remove(p.Filename) - return svr.sendError(p, err) -} - -func (p sshFxpRenamePacket) respond(svr *Server) error { - err := os.Rename(p.Oldpath, p.Newpath) - return svr.sendError(p, err) -} - -func (p sshFxpSymlinkPacket) respond(svr *Server) error { - err := os.Symlink(p.Targetpath, p.Linkpath) - return svr.sendError(p, err) -} - var emptyFileStat = []interface{}{uint32(0)} -func (p sshFxpReadlinkPacket) respond(svr *Server) error { - f, err := os.Readlink(p.Path) - if err != nil { - return svr.sendError(p, err) - } - - return svr.sendPacket(sshFxpNamePacket{ - ID: p.ID, - NameAttrs: []sshFxpNameAttr{{ - Name: f, - LongName: f, - Attrs: emptyFileStat, - }}, - }) -} - -func (p sshFxpRealpathPacket) respond(svr *Server) error { - f, err := filepath.Abs(p.Path) - if err != nil { - return svr.sendError(p, err) - } - - f = filepath.Clean(f) - - return svr.sendPacket(sshFxpNamePacket{ - ID: p.ID, - NameAttrs: []sshFxpNameAttr{{ - Name: f, - LongName: f, - Attrs: emptyFileStat, - }}, - }) -} - -func (p sshFxpOpendirPacket) respond(svr *Server) error { - return sshFxpOpenPacket{ - ID: p.ID, - Path: p.Path, - Pflags: ssh_FXF_READ, - }.respond(svr) -} - func (p sshFxpOpenPacket) readonly() bool { return !p.hasPflags(ssh_FXF_WRITE) } @@ -392,7 +379,6 @@ func (p sshFxpOpenPacket) hasPflags(flags ...uint32) bool { return false } } - return true } @@ -431,10 +417,6 @@ func (p sshFxpOpenPacket) respond(svr *Server) error { return svr.sendPacket(sshFxpHandlePacket{p.ID, handle}) } -func (p sshFxpClosePacket) respond(svr *Server) error { - return svr.sendError(p, svr.closeHandle(p.Handle)) -} - func (p sshFxpReadPacket) respond(svr *Server) error { f, ok := svr.getHandle(p.Handle) if !ok {