diff --git a/packet.go b/packet.go index 25ea2cf..fd0be41 100644 --- a/packet.go +++ b/packet.go @@ -1,7 +1,9 @@ package sftp import ( + "bytes" "encoding" + "encoding/binary" "fmt" "io" "os" @@ -11,7 +13,8 @@ import ( ) var ( - errShortPacket = errors.New("packet too short") + errShortPacket = errors.New("packet too short") + errUnknownExtendedPacket = errors.New("unknown extended packet") ) const ( @@ -832,3 +835,69 @@ func (p *StatVFS) TotalSpace() uint64 { func (p *StatVFS) FreeSpace() uint64 { return p.Frsize * p.Bfree } + +// Convert to ssh_FXP_EXTENDED_REPLY packet binary format +func (p *StatVFS) MarshalBinary() ([]byte, error) { + buf := &bytes.Buffer{} + buf.Write([]byte{ssh_FXP_EXTENDED_REPLY}) + if err := binary.Write(buf, binary.BigEndian, p); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +type sshFxpExtendedPacket struct { + ID uint32 + ExtendedRequest string + SpecificPacket interface { + serverRespondablePacket + readonly() bool + } +} + +func (p sshFxpExtendedPacket) id() uint32 { return p.ID } +func (p sshFxpExtendedPacket) readonly() bool { return p.SpecificPacket.readonly() } + +func (p sshFxpExtendedPacket) respond(svr *Server) error { + return p.SpecificPacket.respond(svr) +} + +func (p *sshFxpExtendedPacket) UnmarshalBinary(b []byte) error { + var err error + bOrig := b + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil { + return err + } + + // specific unmarshalling + switch p.ExtendedRequest { + case "statvfs@openssh.com": + p.SpecificPacket = &sshFxpExtendedPacketStatVFS{} + default: + return errUnknownExtendedPacket + } + + return p.SpecificPacket.UnmarshalBinary(bOrig) +} + +type sshFxpExtendedPacketStatVFS struct { + ID uint32 + ExtendedRequest string + Path string +} + +func (p sshFxpExtendedPacketStatVFS) id() uint32 { return p.ID } +func (p sshFxpExtendedPacketStatVFS) readonly() bool { return true } +func (p *sshFxpExtendedPacketStatVFS) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Path, b, err = unmarshalStringSafe(b); err != nil { + return err + } + return nil +} diff --git a/server.go b/server.go index 07c4ef5..f554535 100644 --- a/server.go +++ b/server.go @@ -171,6 +171,8 @@ func (svr *Server) sftpServerWorker() error { case ssh_FXP_SYMLINK: pkt = &sshFxpSymlinkPacket{} readonly = false + case ssh_FXP_EXTENDED: + pkt = &sshFxpExtendedPacket{} default: return errors.Errorf("unhandled packet type: %s", p.pktType) } @@ -182,6 +184,8 @@ func (svr *Server) sftpServerWorker() error { switch pkt := pkt.(type) { case *sshFxpOpenPacket: readonly = pkt.readonly() + case *sshFxpExtendedPacket: + readonly = pkt.SpecificPacket.readonly() } // If server is operating read-only and a write operation is requested, diff --git a/server_standalone/main.go b/server_standalone/main.go index f08e152..646e99d 100644 --- a/server_standalone/main.go +++ b/server_standalone/main.go @@ -17,10 +17,12 @@ func main() { var ( readOnly bool debugStderr bool + debugLevel string ) flag.BoolVar(&readOnly, "R", false, "read-only server") flag.BoolVar(&debugStderr, "e", false, "debug to stderr") + flag.StringVar(&debugLevel, "l", "none", "debug level (ignored)") flag.Parse() debugStream := ioutil.Discard diff --git a/server_statvfs_darwin.go b/server_statvfs_darwin.go new file mode 100644 index 0000000..8c01dac --- /dev/null +++ b/server_statvfs_darwin.go @@ -0,0 +1,21 @@ +package sftp + +import ( + "syscall" +) + +func statvfsFromStatfst(stat *syscall.Statfs_t) (*StatVFS, error) { + return &StatVFS{ + Bsize: uint64(stat.Bsize), + Frsize: uint64(stat.Bsize), // fragment size is a linux thing; use block size here + Blocks: stat.Blocks, + Bfree: stat.Bfree, + Bavail: stat.Bavail, + Files: stat.Files, + Ffree: stat.Ffree, + Favail: stat.Ffree, // not sure how to calculate Favail + Fsid: uint64(uint64(stat.Fsid.Val[1])<<32 | uint64(stat.Fsid.Val[0])), // endianness? + Flag: uint64(stat.Flags), // assuming POSIX? + Namemax: 1024, // man 2 statfs shows: #define MAXPATHLEN 1024 + }, nil +} diff --git a/server_statvfs_impl.go b/server_statvfs_impl.go new file mode 100644 index 0000000..1685c33 --- /dev/null +++ b/server_statvfs_impl.go @@ -0,0 +1,25 @@ +// +build darwin linux + +// fill in statvfs structure with OS specific values +// Statfs_t is different per-kernel, and only exists on some unixes (not Solaris for instance) + +package sftp + +import ( + "syscall" +) + +func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) error { + stat := &syscall.Statfs_t{} + if err := syscall.Statfs(p.Path, stat); err != nil { + return svr.sendPacket(statusFromError(p.ID, err)) + } + + retPkt, err := statvfsFromStatfst(stat) + if err != nil { + return svr.sendPacket(statusFromError(p.ID, err)) + } + retPkt.ID = p.ID + + return svr.sendPacket(retPkt) +} diff --git a/server_statvfs_linux.go b/server_statvfs_linux.go new file mode 100644 index 0000000..ab1dccf --- /dev/null +++ b/server_statvfs_linux.go @@ -0,0 +1,21 @@ +package sftp + +import ( + "syscall" +) + +func statvfsFromStatfst(stat *syscall.Statfs_t) (*StatVFS, error) { + return &StatVFS{ + Bsize: uint64(stat.Bsize), + Frsize: uint64(stat.Frsize), + Blocks: stat.Blocks, + Bfree: stat.Bfree, + Bavail: stat.Bavail, + Files: stat.Files, + Ffree: stat.Ffree, + Favail: stat.Ffree, // not sure how to calculate Favail + Fsid: uint64(uint64(stat.Fsid.X__val[1])<<32 | uint64(stat.Fsid.X__val[0])), // endianness? + Flag: uint64(stat.Flags), // assuming POSIX? + Namemax: uint64(stat.Namelen), + }, nil +} diff --git a/server_statvfs_stubs.go b/server_statvfs_stubs.go new file mode 100644 index 0000000..3fe4078 --- /dev/null +++ b/server_statvfs_stubs.go @@ -0,0 +1,11 @@ +// +build !darwin,!linux + +package sftp + +import ( + "syscall" +) + +func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) error { + return syscall.ENOTSUP +} diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..6d36057 --- /dev/null +++ b/server_test.go @@ -0,0 +1,60 @@ +package sftp + +import ( + "errors" + "testing" +) + +var errClientRecvFinished = errors.New("client recv finished") + +func clientServerPair(t *testing.T) (*Client, *Server) { + c, s := netPipe(t) + server, err := NewServer(s) + if err != nil { + t.Fatal(err) + } + go server.Serve() + client, err := NewClientPipe(c, c) + if err != nil { + t.Fatal(err) + } + return client, server +} + +type sshFxpTestBadExtendedPacket struct { + ID uint32 + Extension string + Data string +} + +func (p sshFxpTestBadExtendedPacket) id() uint32 { return p.ID } + +func (p sshFxpTestBadExtendedPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + 4 + // type(byte) + uint32 + uint32 + len(p.Extension) + + len(p.Data) + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_EXTENDED) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Extension) + b = marshalString(b, p.Data) + return b, nil +} + +// test that errors are sent back when we request an invalid extended packet operation +func TestInvalidExtendedPacket(t *testing.T) { + client, _ := clientServerPair(t) + defer client.Close() + badPacket := sshFxpTestBadExtendedPacket{client.nextID(), "thisDoesn'tExist", "foobar"} + _, _, err := client.sendRequest(badPacket) + if err != nil { + t.Log(err) + } else { + t.Fatal("expected error from bad packet") + } + + // try to stat a file; the client should have shut down. + filePath := "/etc/passwd" + _, err = client.Stat(filePath) +}