Merge pull request #67 from mdlayher/master

server: simplify code and improve readability
This commit is contained in:
Dave Cheney 2015-12-30 22:20:32 +01:00
commit e85cb2364d
1 changed files with 210 additions and 162 deletions

372
server.go
View File

@ -53,9 +53,9 @@ func (svr *Server) closeHandle(handle string) error {
if f, ok := svr.openFiles[handle]; ok {
delete(svr.openFiles, handle)
return f.Close()
} else {
return syscall.EBADF
}
return syscall.EBADF
}
func (svr *Server) getHandle(handle string) (*os.File, bool) {
@ -75,11 +75,12 @@ type serverRespondablePacket interface {
// A subsequent call to Serve() is required.
func NewServer(in io.Reader, out io.WriteCloser, debugStream io.Writer, debugLevel int, readOnly bool, rootDir string) (*Server, error) {
if rootDir == "" {
if wd, err := os.Getwd(); err != nil {
wd, err := os.Getwd()
if err != nil {
return nil, err
} else {
rootDir = wd
}
rootDir = wd
}
return &Server{
in: in,
@ -108,10 +109,13 @@ func (svr *Server) rxPackets() error {
for {
pktType, pktBytes, err := recvPacket(svr.in)
if err == io.EOF {
switch err {
case nil:
break
case io.EOF:
fmt.Fprintf(svr.debugStream, "rxPackets loop done\n")
return nil
} else if err != nil {
default:
fmt.Fprintf(svr.debugStream, "recvPacket error: %v\n", err)
return err
}
@ -123,14 +127,15 @@ func (svr *Server) rxPackets() error {
// Up to N parallel servers
func (svr *Server) sftpServerWorker(doneChan chan error) {
for pkt := range svr.pktChan {
if pkt, err := svr.decodePacket(pkt.pktType, pkt.pktBytes); err != nil {
dPkt, err := svr.decodePacket(pkt.pktType, pkt.pktBytes)
if err != nil {
fmt.Fprintf(svr.debugStream, "decodePacket error: %v\n", err)
doneChan <- err
return
} else {
//fmt.Fprintf(svr.debugStream, "pkt: %T %v\n", pkt, pkt)
pkt.respond(svr)
}
//fmt.Fprintf(svr.debugStream, "pkt: %T %v\n", pkt, pkt)
dPkt.respond(svr)
}
doneChan <- nil
}
@ -229,30 +234,45 @@ func (p sshFxpStatResponse) MarshalBinary() ([]byte, error) {
func (p sshFxpLstatPacket) respond(svr *Server) error {
// stat the requested file
if info, err := os.Lstat(p.Path); err != nil {
info, err := os.Lstat(p.Path)
if err != nil {
return svr.sendPacket(statusFromError(p.Id, err))
} else {
return svr.sendPacket(sshFxpStatResponse{p.Id, info})
}
return svr.sendPacket(sshFxpStatResponse{
Id: p.Id,
info: info,
})
}
func (p sshFxpStatPacket) respond(svr *Server) error {
// stat the requested file
if info, err := os.Stat(p.Path); err != nil {
info, err := os.Stat(p.Path)
if err != nil {
return svr.sendPacket(statusFromError(p.Id, err))
} else {
return svr.sendPacket(sshFxpStatResponse{p.Id, info})
}
return svr.sendPacket(sshFxpStatResponse{
Id: p.Id,
info: info,
})
}
func (p sshFxpFstatPacket) respond(svr *Server) error {
if f, ok := svr.getHandle(p.Handle); !ok {
f, ok := svr.getHandle(p.Handle)
if !ok {
return svr.sendPacket(statusFromError(p.Id, syscall.EBADF))
} else if info, err := f.Stat(); err != nil {
return svr.sendPacket(statusFromError(p.Id, err))
} else {
return svr.sendPacket(sshFxpStatResponse{p.Id, info})
}
info, err := f.Stat()
if err != nil {
return svr.sendPacket(statusFromError(p.Id, err))
}
return svr.sendPacket(sshFxpStatResponse{
Id: p.Id,
info: info,
})
}
func (p sshFxpMkdirPacket) respond(svr *Server) error {
@ -299,28 +319,49 @@ func (p sshFxpSymlinkPacket) respond(svr *Server) error {
var emptyFileStat = []interface{}{uint32(0)}
func (p sshFxpReadlinkPacket) respond(svr *Server) error {
if f, err := os.Readlink(p.Path); err != nil {
f, err := os.Readlink(p.Path)
if err != nil {
return svr.sendPacket(statusFromError(p.Id, err))
} else {
return svr.sendPacket(sshFxpNamePacket{p.Id, []sshFxpNameAttr{sshFxpNameAttr{f, f, emptyFileStat}}})
}
return svr.sendPacket(sshFxpNamePacket{
Id: p.Id,
NameAttrs: []sshFxpNameAttr{{
Name: f,
LongName: f,
Attrs: emptyFileStat,
}},
})
}
func (p sshFxpRealpathPacket) respond(svr *Server) error {
if f, err := filepath.Abs(p.Path); err != nil {
f, err := filepath.Abs(p.Path)
if err != nil {
return svr.sendPacket(statusFromError(p.Id, err))
} else {
f = filepath.Clean(f)
return svr.sendPacket(sshFxpNamePacket{p.Id, []sshFxpNameAttr{sshFxpNameAttr{f, f, emptyFileStat}}})
}
f = filepath.Clean(f)
return svr.sendPacket(sshFxpNamePacket{
Id: p.Id,
NameAttrs: []sshFxpNameAttr{{
Name: f,
LongName: f,
Attrs: emptyFileStat,
}},
})
}
func (p sshFxpOpendirPacket) respond(svr *Server) error {
return sshFxpOpenPacket{p.Id, p.Path, ssh_FXF_READ, 0}.respond(svr)
return sshFxpOpenPacket{
Id: p.Id,
Path: p.Path,
Pflags: ssh_FXF_READ,
}.respond(svr)
}
func (p sshFxpOpenPacket) respond(svr *Server) error {
osFlags := 0
var osFlags int
if p.Pflags&ssh_FXF_READ != 0 && p.Pflags&ssh_FXF_WRITE != 0 {
if svr.readOnly {
return svr.sendPacket(statusFromError(p.Id, syscall.EPERM))
@ -352,12 +393,13 @@ func (p sshFxpOpenPacket) respond(svr *Server) error {
osFlags |= os.O_EXCL
}
if f, err := os.OpenFile(p.Path, osFlags, 0644); err != nil {
f, err := os.OpenFile(p.Path, osFlags, 0644)
if err != nil {
return svr.sendPacket(statusFromError(p.Id, err))
} else {
handle := svr.nextHandle(f)
return svr.sendPacket(sshFxpHandlePacket{p.Id, handle})
}
handle := svr.nextHandle(f)
return svr.sendPacket(sshFxpHandlePacket{p.Id, handle})
}
func (p sshFxpClosePacket) respond(svr *Server) error {
@ -365,20 +407,27 @@ func (p sshFxpClosePacket) respond(svr *Server) error {
}
func (p sshFxpReadPacket) respond(svr *Server) error {
if f, ok := svr.getHandle(p.Handle); !ok {
f, ok := svr.getHandle(p.Handle)
if !ok {
return svr.sendPacket(statusFromError(p.Id, syscall.EBADF))
} else {
if p.Len > svr.maxTxPacket {
p.Len = svr.maxTxPacket
}
ret := sshFxpDataPacket{Id: p.Id, Length: p.Len, Data: make([]byte, p.Len)}
if n, err := f.ReadAt(ret.Data, int64(p.Offset)); err != nil && (err != io.EOF || n == 0) {
return svr.sendPacket(statusFromError(p.Id, err))
} else {
ret.Length = uint32(n)
return svr.sendPacket(ret)
}
}
if p.Len > svr.maxTxPacket {
p.Len = svr.maxTxPacket
}
ret := sshFxpDataPacket{
Id: p.Id,
Length: p.Len,
Data: make([]byte, p.Len),
}
n, err := f.ReadAt(ret.Data, int64(p.Offset))
if err != nil && (err != io.EOF || n == 0) {
return svr.sendPacket(statusFromError(p.Id, err))
}
ret.Length = uint32(n)
return svr.sendPacket(ret)
}
func (p sshFxpWritePacket) respond(svr *Server) error {
@ -386,146 +435,147 @@ func (p sshFxpWritePacket) respond(svr *Server) error {
// shouldn't really get here, the open should have failed
return svr.sendPacket(statusFromError(p.Id, syscall.EPERM))
}
if f, ok := svr.getHandle(p.Handle); !ok {
f, ok := svr.getHandle(p.Handle)
if !ok {
return svr.sendPacket(statusFromError(p.Id, syscall.EBADF))
} else {
_, err := f.WriteAt(p.Data, int64(p.Offset))
return svr.sendPacket(statusFromError(p.Id, err))
}
_, err := f.WriteAt(p.Data, int64(p.Offset))
return svr.sendPacket(statusFromError(p.Id, err))
}
func (p sshFxpReaddirPacket) respond(svr *Server) error {
if f, ok := svr.getHandle(p.Handle); !ok {
f, ok := svr.getHandle(p.Handle)
if !ok {
return svr.sendPacket(statusFromError(p.Id, syscall.EBADF))
} else {
dirname := ""
dirents := []os.FileInfo{}
var err error = nil
dirname = f.Name()
dirents, err = f.Readdir(128)
if err != nil {
return svr.sendPacket(statusFromError(p.Id, err))
}
ret := sshFxpNamePacket{p.Id, nil}
for _, dirent := range dirents {
ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{
dirent.Name(),
runLs(dirname, dirent),
[]interface{}{dirent},
})
}
return svr.sendPacket(ret)
}
dirname := f.Name()
dirents, err := f.Readdir(128)
if err != nil {
return svr.sendPacket(statusFromError(p.Id, err))
}
ret := sshFxpNamePacket{Id: p.Id}
for _, dirent := range dirents {
ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{
Name: dirent.Name(),
LongName: runLs(dirname, dirent),
Attrs: []interface{}{dirent},
})
}
return svr.sendPacket(ret)
}
func (p sshFxpSetstatPacket) respond(svr *Server) error {
if svr.readOnly {
return svr.sendPacket(statusFromError(p.Id, syscall.EPERM))
} else {
// additional unmarshalling is required for each possibility here
b := p.Attrs.([]byte)
var err error = nil
debug("setstat name \"%s\"", p.Path)
if (p.Flags & ssh_FILEXFER_ATTR_SIZE) != 0 {
var size uint64 = 0
if size, b, err = unmarshalUint64Safe(b); err == nil {
err = os.Truncate(p.Path, int64(size))
}
}
if (p.Flags & ssh_FILEXFER_ATTR_PERMISSIONS) != 0 {
var mode uint32 = 0
if mode, b, err = unmarshalUint32Safe(b); err == nil {
err = os.Chmod(p.Path, os.FileMode(mode))
}
}
if (p.Flags & ssh_FILEXFER_ATTR_ACMODTIME) != 0 {
var atime uint32 = 0
var mtime uint32 = 0
if atime, b, err = unmarshalUint32Safe(b); err != nil {
} else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
} else {
atimeT := time.Unix(int64(atime), 0)
mtimeT := time.Unix(int64(mtime), 0)
err = os.Chtimes(p.Path, atimeT, mtimeT)
}
}
if (p.Flags & ssh_FILEXFER_ATTR_UIDGID) != 0 {
var uid uint32 = 0
var gid uint32 = 0
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, b, err = unmarshalUint32Safe(b); err != nil {
} else {
err = os.Chown(p.Path, int(uid), int(gid))
}
}
return svr.sendPacket(statusFromError(p.Id, err))
}
// additional unmarshalling is required for each possibility here
b := p.Attrs.([]byte)
var err error
debug("setstat name \"%s\"", p.Path)
if (p.Flags & ssh_FILEXFER_ATTR_SIZE) != 0 {
var size uint64
if size, b, err = unmarshalUint64Safe(b); err == nil {
err = os.Truncate(p.Path, int64(size))
}
}
if (p.Flags & ssh_FILEXFER_ATTR_PERMISSIONS) != 0 {
var mode uint32
if mode, b, err = unmarshalUint32Safe(b); err == nil {
err = os.Chmod(p.Path, os.FileMode(mode))
}
}
if (p.Flags & ssh_FILEXFER_ATTR_ACMODTIME) != 0 {
var atime uint32
var mtime uint32
if atime, b, err = unmarshalUint32Safe(b); err != nil {
} else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
} else {
atimeT := time.Unix(int64(atime), 0)
mtimeT := time.Unix(int64(mtime), 0)
err = os.Chtimes(p.Path, atimeT, mtimeT)
}
}
if (p.Flags & ssh_FILEXFER_ATTR_UIDGID) != 0 {
var uid uint32
var gid uint32
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, b, err = unmarshalUint32Safe(b); err != nil {
} else {
err = os.Chown(p.Path, int(uid), int(gid))
}
}
return svr.sendPacket(statusFromError(p.Id, err))
}
func (p sshFxpFsetstatPacket) respond(svr *Server) error {
if svr.readOnly {
return svr.sendPacket(statusFromError(p.Id, syscall.EPERM))
} else if f, ok := svr.getHandle(p.Handle); !ok {
return svr.sendPacket(statusFromError(p.Id, syscall.EBADF))
} else {
// additional unmarshalling is required for each possibility here
b := p.Attrs.([]byte)
var err error = nil
debug("fsetstat name \"%s\"", f.Name())
if (p.Flags & ssh_FILEXFER_ATTR_SIZE) != 0 {
var size uint64 = 0
if size, b, err = unmarshalUint64Safe(b); err == nil {
err = f.Truncate(int64(size))
}
}
if (p.Flags & ssh_FILEXFER_ATTR_PERMISSIONS) != 0 {
var mode uint32 = 0
if mode, b, err = unmarshalUint32Safe(b); err == nil {
err = f.Chmod(os.FileMode(mode))
}
}
if (p.Flags & ssh_FILEXFER_ATTR_ACMODTIME) != 0 {
var atime uint32 = 0
var mtime uint32 = 0
if atime, b, err = unmarshalUint32Safe(b); err != nil {
} else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
} else {
atimeT := time.Unix(int64(atime), 0)
mtimeT := time.Unix(int64(mtime), 0)
err = os.Chtimes(f.Name(), atimeT, mtimeT)
}
}
if (p.Flags & ssh_FILEXFER_ATTR_UIDGID) != 0 {
var uid uint32 = 0
var gid uint32 = 0
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, b, err = unmarshalUint32Safe(b); err != nil {
} else {
err = f.Chown(int(uid), int(gid))
}
}
return svr.sendPacket(statusFromError(p.Id, err))
}
f, ok := svr.getHandle(p.Handle)
if !ok {
return svr.sendPacket(statusFromError(p.Id, syscall.EBADF))
}
// additional unmarshalling is required for each possibility here
b := p.Attrs.([]byte)
var err error
debug("fsetstat name \"%s\"", f.Name())
if (p.Flags & ssh_FILEXFER_ATTR_SIZE) != 0 {
var size uint64
if size, b, err = unmarshalUint64Safe(b); err == nil {
err = f.Truncate(int64(size))
}
}
if (p.Flags & ssh_FILEXFER_ATTR_PERMISSIONS) != 0 {
var mode uint32
if mode, b, err = unmarshalUint32Safe(b); err == nil {
err = f.Chmod(os.FileMode(mode))
}
}
if (p.Flags & ssh_FILEXFER_ATTR_ACMODTIME) != 0 {
var atime uint32
var mtime uint32
if atime, b, err = unmarshalUint32Safe(b); err != nil {
} else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
} else {
atimeT := time.Unix(int64(atime), 0)
mtimeT := time.Unix(int64(mtime), 0)
err = os.Chtimes(f.Name(), atimeT, mtimeT)
}
}
if (p.Flags & ssh_FILEXFER_ATTR_UIDGID) != 0 {
var uid uint32
var gid uint32
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, b, err = unmarshalUint32Safe(b); err != nil {
} else {
err = f.Chown(int(uid), int(gid))
}
}
return svr.sendPacket(statusFromError(p.Id, err))
}
func errnoToSshErr(errno syscall.Errno) uint32 {
if errno == 0 {
switch errno {
case 0:
return ssh_FX_OK
} else if errno == syscall.ENOENT {
case syscall.ENOENT:
return ssh_FX_NO_SUCH_FILE
} else if errno == syscall.EPERM {
case syscall.EPERM:
return ssh_FX_PERMISSION_DENIED
} else {
return ssh_FX_FAILURE
}
return uint32(errno)
return ssh_FX_FAILURE
}
func statusFromError(id uint32, err error) sshFxpStatusPacket {
@ -542,8 +592,6 @@ func statusFromError(id uint32, err error) sshFxpStatusPacket {
// ssh_FX_CONNECTION_LOST = 7
// ssh_FX_OP_UNSUPPORTED = 8
Code: ssh_FX_OK,
msg: "",
lang: "",
},
}
if err != nil {