rework client to prevent after-close usage, and support perm at open

This commit is contained in:
Cassondra Foesch 2024-01-19 00:20:23 +00:00
parent 22452ea54d
commit d1903fbd46
9 changed files with 381 additions and 110 deletions

View File

@ -32,10 +32,10 @@ func (fi *fileInfo) Name() string { return fi.name }
func (fi *fileInfo) Size() int64 { return int64(fi.stat.Size) }
// Mode returns file mode bits.
func (fi *fileInfo) Mode() os.FileMode { return toFileMode(fi.stat.Mode) }
func (fi *fileInfo) Mode() os.FileMode { return fi.stat.FileMode() }
// ModTime returns the last modification time of the file.
func (fi *fileInfo) ModTime() time.Time { return time.Unix(int64(fi.stat.Mtime), 0) }
func (fi *fileInfo) ModTime() time.Time { return fi.stat.ModTime() }
// IsDir returns true if the file is a directory.
func (fi *fileInfo) IsDir() bool { return fi.Mode().IsDir() }
@ -56,6 +56,21 @@ type FileStat struct {
Extended []StatExtended
}
// ModTime returns the Mtime SFTP file attribute converted to a time.Time
func (fs *FileStat) ModTime() time.Time {
return time.Unix(int64(fs.Mtime), 0)
}
// AccessTime returns the Atime SFTP file attribute converted to a time.Time
func (fs *FileStat) AccessTime() time.Time {
return time.Unix(int64(fs.Atime), 0)
}
// FileMode returns the Mode SFTP file attribute converted to an os.FileMode
func (fs *FileStat) FileMode() os.FileMode {
return toFileMode(fs.Mode)
}
// StatExtended contains additional, extended information for a FileStat.
type StatExtended struct {
ExtType string

156
client.go
View File

@ -257,7 +257,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
// read/write at the same time. For those services you will need to use
// `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`.
func (c *Client) Create(path string) (*File, error) {
return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
return c.open(path, toPflags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
}
const sftpProtocolVersion = 3 // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt
@ -510,7 +510,7 @@ func (c *Client) Symlink(oldname, newname string) error {
}
}
func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error {
func (c *Client) fsetstat(handle string, flags uint32, attrs interface{}) error {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{
ID: id,
@ -590,14 +590,14 @@ func (c *Client) Truncate(path string, size int64) error {
// returned file can be used for reading; the associated file descriptor
// has mode O_RDONLY.
func (c *Client) Open(path string) (*File, error) {
return c.open(path, flags(os.O_RDONLY))
return c.open(path, toPflags(os.O_RDONLY))
}
// OpenFile is the generalized open call; most users will use Open or
// Create instead. It opens the named file with specified flag (O_RDONLY
// etc.). If successful, methods on the returned File can be used for I/O.
func (c *Client) OpenFile(path string, f int) (*File, error) {
return c.open(path, flags(f))
return c.open(path, toPflags(f))
}
func (c *Client) open(path string, pflags uint32) (*File, error) {
@ -976,16 +976,26 @@ func (c *Client) RemoveAll(path string) error {
type File struct {
c *Client
path string
handle string
mu sync.Mutex
mu sync.RWMutex
handle string
offset int64 // current offset within remote file
}
// Close closes the File, rendering it unusable for I/O. It returns an
// error, if any.
func (f *File) Close() error {
return f.c.close(f.handle)
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return os.ErrClosed
}
handle := f.handle
f.handle = ""
return f.c.close(handle)
}
// Name returns the name of the file as presented to Open or Create.
@ -1006,7 +1016,11 @@ func (f *File) Read(b []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()
n, err := f.ReadAt(b, f.offset)
if f.handle == "" {
return 0, os.ErrClosed
}
n, err := f.readAt(b, f.offset)
f.offset += int64(n)
return n, err
}
@ -1071,6 +1085,17 @@ func (f *File) readAtSequential(b []byte, off int64) (read int, err error) {
// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics,
// so the file offset is not altered during the read.
func (f *File) ReadAt(b []byte, off int64) (int, error) {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return 0, os.ErrClosed
}
return f.readAt(b, off)
}
func (f *File) readAt(b []byte, off int64) (int, error) {
if len(b) <= f.c.maxPacket {
// This should be able to be serviced with 1/2 requests.
// So, just do it directly.
@ -1267,6 +1292,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return 0, os.ErrClosed
}
if f.c.disableConcurrentReads {
return f.writeToSequential(w)
}
@ -1456,9 +1485,20 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
}
}
func (f *File) Stat() (os.FileInfo, error) {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return nil, os.ErrClosed
}
return f.stat()
}
// Stat returns the FileInfo structure describing file. If there is an
// error.
func (f *File) Stat() (os.FileInfo, error) {
func (f *File) stat() (os.FileInfo, error) {
fs, err := f.c.fstat(f.handle)
if err != nil {
return nil, err
@ -1478,7 +1518,11 @@ func (f *File) Write(b []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()
n, err := f.WriteAt(b, f.offset)
if f.handle == "" {
return 0, os.ErrClosed
}
n, err := f.writeAt(b, f.offset)
f.offset += int64(n)
return n, err
}
@ -1636,6 +1680,17 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) {
// the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics,
// so the file offset is not altered during the write.
func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return 0, os.ErrClosed
}
return f.writeAt(b, off)
}
func (f *File) writeAt(b []byte, off int64) (written int, err error) {
if len(b) <= f.c.maxPacket {
// We can do this in one write.
return f.writeChunkAt(nil, b, off)
@ -1675,6 +1730,17 @@ func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
//
// Otherwise, the given concurrency will be capped by the Client's max concurrency.
func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return 0, os.ErrClosed
}
return f.readFromWithConcurrency(r, concurrency)
}
func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
// Split the write into multiple maxPacket sized concurrent writes.
// This allows writes with a suitably large reader
// to transfer data at a much faster rate due to overlapping round trip times.
@ -1824,6 +1890,10 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return 0, os.ErrClosed
}
if f.c.useConcurrentWrites {
var remain int64
switch r := r.(type) {
@ -1845,7 +1915,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
if remain < 0 {
// We can strongly assert that we want default max concurrency here.
return f.ReadFromWithConcurrency(r, f.c.maxConcurrentRequests)
return f.readFromWithConcurrency(r, f.c.maxConcurrentRequests)
}
if remain > int64(f.c.maxPacket) {
@ -1860,7 +1930,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
concurrency64 = int64(f.c.maxConcurrentRequests)
}
return f.ReadFromWithConcurrency(r, int(concurrency64))
return f.readFromWithConcurrency(r, int(concurrency64))
}
}
@ -1903,12 +1973,16 @@ func (f *File) Seek(offset int64, whence int) (int64, error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return 0, os.ErrClosed
}
switch whence {
case io.SeekStart:
case io.SeekCurrent:
offset += f.offset
case io.SeekEnd:
fi, err := f.Stat()
fi, err := f.stat()
if err != nil {
return f.offset, err
}
@ -1927,20 +2001,61 @@ func (f *File) Seek(offset int64, whence int) (int64, error) {
// Chown changes the uid/gid of the current file.
func (f *File) Chown(uid, gid int) error {
return f.c.Chown(f.path, uid, gid)
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return os.ErrClosed
}
return f.c.fsetstat(f.handle, sshFileXferAttrUIDGID, &FileStat{
UID: uint32(uid),
GID: uint32(gid),
})
}
// Chmod changes the permissions of the current file.
//
// See Client.Chmod for details.
func (f *File) Chmod(mode os.FileMode) error {
return f.c.setfstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode))
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return os.ErrClosed
}
return f.c.fsetstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode))
}
// Truncate sets the size of the current file. Although it may be safely assumed
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
// We send a SSH_FXP_FSETSTAT here since we have a file handle
func (f *File) Truncate(size int64) error {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return os.ErrClosed
}
return f.c.fsetstat(f.handle, sshFileXferAttrSize, uint64(size))
}
// Sync requests a flush of the contents of a File to stable storage.
//
// Sync requires the server to support the fsync@openssh.com extension.
func (f *File) Sync() error {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return os.ErrClosed
}
id := f.c.nextID()
typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{
ID: id,
@ -1957,15 +2072,6 @@ func (f *File) Sync() error {
}
}
// Truncate sets the size of the current file. Although it may be safely assumed
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
// We send a SSH_FXP_FSETSTAT here since we have a file handle
func (f *File) Truncate(size int64) error {
return f.c.setfstat(f.handle, sshFileXferAttrSize, uint64(size))
}
// normaliseError normalises an error into a more standard form that can be
// checked against stdlib errors like io.EOF or os.ErrNotExist.
func normaliseError(err error) error {
@ -1990,7 +2096,7 @@ func normaliseError(err error) error {
// flags converts the flags passed to OpenFile into ssh flags.
// Unsupported flags are ignored.
func flags(f int) uint32 {
func toPflags(f int) uint32 {
var out uint32
switch f & os.O_WRONLY {
case os.O_WRONLY:

View File

@ -81,7 +81,7 @@ var flagsTests = []struct {
func TestFlags(t *testing.T) {
for i, tt := range flagsTests {
got := flags(tt.flags)
got := toPflags(tt.flags)
if got != tt.want {
t.Errorf("test %v: flags(%x): want: %x, got: %x", i, tt.flags, tt.want, got)
}

View File

@ -56,6 +56,11 @@ func marshalFileInfo(b []byte, fi os.FileInfo) []byte {
flags, fileStat := fileStatFromInfo(fi)
b = marshalUint32(b, flags)
return marshalFileStat(b, flags, fileStat)
}
func marshalFileStat(b []byte, flags uint32, fileStat *FileStat) []byte {
if flags&sshFileXferAttrSize != 0 {
b = marshalUint64(b, fileStat.Size)
}
@ -91,10 +96,9 @@ func marshalStatus(b []byte, err StatusError) []byte {
}
func marshal(b []byte, v interface{}) []byte {
if v == nil {
return b
}
switch v := v.(type) {
case nil:
return b
case uint8:
return append(b, v)
case uint32:
@ -103,6 +107,8 @@ func marshal(b []byte, v interface{}) []byte {
return marshalUint64(b, v)
case string:
return marshalString(b, v)
case []byte:
return append(b, v...)
case os.FileInfo:
return marshalFileInfo(b, v)
default:
@ -180,8 +186,6 @@ func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) {
}
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
fs.UID, b, _ = unmarshalUint32Safe(b)
}
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
fs.GID, b, _ = unmarshalUint32Safe(b)
}
if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions {
@ -681,12 +685,13 @@ type sshFxpOpenPacket struct {
ID uint32
Path string
Pflags uint32
Flags uint32 // ignored
Flags uint32
Attrs interface{}
}
func (p *sshFxpOpenPacket) id() uint32 { return p.ID }
func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
func (p *sshFxpOpenPacket) marshalPacket() ([]byte, []byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(p.Path) +
4 + 4
@ -698,7 +703,20 @@ func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
b = marshalUint32(b, p.Pflags)
b = marshalUint32(b, p.Flags)
return b, nil
switch attrs := p.Attrs.(type) {
case os.FileInfo:
_, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet.
return b, marshalFileStat(nil, p.Flags, fs), nil
case *FileStat:
return b, marshalFileStat(nil, p.Flags, attrs), nil
}
return b, marshal(nil, p.Attrs), nil
}
func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
header, payload, err := p.marshalPacket()
return append(header, payload...), err
}
func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
@ -709,12 +727,25 @@ func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
return err
} else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil {
} else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil {
return err
}
p.Attrs = b
return nil
}
func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) *FileStat {
switch attrs := p.Attrs.(type) {
case *FileStat:
return attrs
case []byte:
fs, _ := unmarshalFileStat(flags, attrs)
return fs
default:
panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs))
}
}
type sshFxpReadPacket struct {
ID uint32
Len uint32
@ -943,9 +974,15 @@ func (p *sshFxpSetstatPacket) marshalPacket() ([]byte, []byte, error) {
b = marshalString(b, p.Path)
b = marshalUint32(b, p.Flags)
payload := marshal(nil, p.Attrs)
switch attrs := p.Attrs.(type) {
case os.FileInfo:
_, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet.
return b, marshalFileStat(nil, p.Flags, fs), nil
case *FileStat:
return b, marshalFileStat(nil, p.Flags, attrs), nil
}
return b, payload, nil
return b, marshal(nil, p.Attrs), nil
}
func (p *sshFxpSetstatPacket) MarshalBinary() ([]byte, error) {
@ -964,9 +1001,15 @@ func (p *sshFxpFsetstatPacket) marshalPacket() ([]byte, []byte, error) {
b = marshalString(b, p.Handle)
b = marshalUint32(b, p.Flags)
payload := marshal(nil, p.Attrs)
switch attrs := p.Attrs.(type) {
case os.FileInfo:
_, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet.
return b, marshalFileStat(nil, p.Flags, fs), nil
case *FileStat:
return b, marshalFileStat(nil, p.Flags, attrs), nil
}
return b, payload, nil
return b, marshal(nil, p.Attrs), nil
}
func (p *sshFxpFsetstatPacket) MarshalBinary() ([]byte, error) {
@ -987,6 +1030,18 @@ func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error {
return nil
}
func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) *FileStat {
switch attrs := p.Attrs.(type) {
case *FileStat:
return attrs
case []byte:
fs, _ := unmarshalFileStat(flags, attrs)
return fs
default:
panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs))
}
}
func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error {
var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
@ -1000,6 +1055,18 @@ func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error {
return nil
}
func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) *FileStat {
switch attrs := p.Attrs.(type) {
case *FileStat:
return attrs
case []byte:
fs, _ := unmarshalFileStat(flags, attrs)
return fs
default:
panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs))
}
}
type sshFxpHandlePacket struct {
ID uint32
Handle string

View File

@ -376,7 +376,7 @@ func TestSendPacket(t *testing.T) {
packet: &sshFxpOpenPacket{
ID: 1,
Path: "/foo",
Pflags: flags(os.O_RDONLY),
Pflags: toPflags(os.O_RDONLY),
},
want: []byte{
0x0, 0x0, 0x0, 0x15,
@ -387,6 +387,26 @@ func TestSendPacket(t *testing.T) {
0x0, 0x0, 0x0, 0x0,
},
},
{
packet: &sshFxpOpenPacket{
ID: 3,
Path: "/foo",
Pflags: toPflags(os.O_WRONLY | os.O_CREATE | os.O_TRUNC),
Flags: sshFileXferAttrPermissions,
Attrs: &FileStat{
Mode: 0o755,
},
},
want: []byte{
0x0, 0x0, 0x0, 0x19,
0x3,
0x0, 0x0, 0x0, 0x3,
0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o',
0x0, 0x0, 0x0, 0x1a,
0x0, 0x0, 0x0, 0x4,
0x0, 0x0, 0x1, 0xed,
},
},
{
packet: &sshFxpWritePacket{
ID: 124,
@ -409,10 +429,7 @@ func TestSendPacket(t *testing.T) {
ID: 31,
Path: "/bar",
Flags: sshFileXferAttrUIDGID,
Attrs: struct {
UID uint32
GID uint32
}{
Attrs: &FileStat{
UID: 1000,
GID: 100,
},
@ -611,7 +628,7 @@ func BenchmarkMarshalOpen(b *testing.B) {
benchMarshal(b, &sshFxpOpenPacket{
ID: 1,
Path: "/home/test/some/random/path",
Pflags: flags(os.O_RDONLY),
Pflags: toPflags(os.O_RDONLY),
})
}

View File

@ -3,7 +3,6 @@ package sftp
// Methods on the Request object to make working with the Flags bitmasks and
// Attr(ibutes) byte blob easier. Use Pflags() when working with an Open/Write
// request and AttrFlags() and Attributes() when working with SetStat requests.
import "os"
// FileOpenFlags defines Open and Write Flags. Correlate directly with with os.OpenFile flags
// (https://golang.org/pkg/os/#pkg-constants).
@ -50,11 +49,6 @@ func (r *Request) AttrFlags() FileAttrFlags {
return newFileAttrFlags(r.Flags)
}
// FileMode returns the Mode SFTP file attributes wrapped as os.FileMode
func (a FileStat) FileMode() os.FileMode {
return os.FileMode(a.Mode)
}
// Attributes parses file attributes byte blob and return them in a
// FileStat object.
func (r *Request) Attributes() *FileStat {

View File

@ -13,7 +13,6 @@ import (
"strconv"
"sync"
"syscall"
"time"
)
const (
@ -462,7 +461,15 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket {
osFlags |= os.O_EXCL
}
f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, 0o644)
mode := os.FileMode(0o644)
// Like OpenSSH, we only handle permissions here, if the file is being created.
// Otherwise, the permissions are ignored.
if p.Flags & sshFileXferAttrPermissions != 0 {
fs := p.unmarshalFileStat(p.Flags)
mode = fs.FileMode() & os.ModePerm
}
f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, mode)
if err != nil {
return statusFromError(p.ID, err)
}
@ -496,43 +503,32 @@ func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket {
}
func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket {
// additional unmarshalling is required for each possibility here
b := p.Attrs.([]byte)
path := svr.toLocalPath(p.Path)
debug("setstat name %q", path)
fs := p.unmarshalFileStat(p.Flags)
var err error
p.Path = svr.toLocalPath(p.Path)
debug("setstat name \"%s\"", p.Path)
if (p.Flags & sshFileXferAttrSize) != 0 {
var size uint64
if size, b, err = unmarshalUint64Safe(b); err == nil {
err = os.Truncate(p.Path, int64(size))
if err == nil {
err = os.Truncate(path, int64(fs.Size))
}
}
if (p.Flags & sshFileXferAttrPermissions) != 0 {
var mode uint32
if mode, b, err = unmarshalUint32Safe(b); err == nil {
err = os.Chmod(p.Path, os.FileMode(mode))
if err == nil {
err = os.Chmod(path, fs.FileMode())
}
}
if (p.Flags & sshFileXferAttrACmodTime) != 0 {
var atime uint32
var mtime uint32
if atime, b, err = unmarshalUint32Safe(b); err != nil {
} else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
} else {
atimeT := time.Unix(int64(atime), 0)
mtimeT := time.Unix(int64(mtime), 0)
err = os.Chtimes(p.Path, atimeT, mtimeT)
if err == nil {
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
}
}
if (p.Flags & sshFileXferAttrUIDGID) != 0 {
var uid uint32
var gid uint32
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, _, err = unmarshalUint32Safe(b); err != nil {
} else {
err = os.Chown(p.Path, int(uid), int(gid))
if err == nil {
err = os.Chown(path, int(fs.UID), int(fs.GID))
}
}
@ -545,41 +541,32 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket {
return statusFromError(p.ID, EBADF)
}
// additional unmarshalling is required for each possibility here
b := p.Attrs.([]byte)
path := f.Name()
debug("fsetstat name %q", path)
fs := p.unmarshalFileStat(p.Flags)
var err error
debug("fsetstat name \"%s\"", f.Name())
if (p.Flags & sshFileXferAttrSize) != 0 {
var size uint64
if size, b, err = unmarshalUint64Safe(b); err == nil {
err = f.Truncate(int64(size))
if err == nil {
err = f.Truncate(int64(fs.Size))
}
}
if (p.Flags & sshFileXferAttrPermissions) != 0 {
var mode uint32
if mode, b, err = unmarshalUint32Safe(b); err == nil {
err = f.Chmod(os.FileMode(mode))
if err == nil {
err = f.Chmod(fs.FileMode())
}
}
if (p.Flags & sshFileXferAttrACmodTime) != 0 {
var atime uint32
var mtime uint32
if atime, b, err = unmarshalUint32Safe(b); err != nil {
} else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
} else {
atimeT := time.Unix(int64(atime), 0)
mtimeT := time.Unix(int64(mtime), 0)
err = os.Chtimes(f.Name(), atimeT, mtimeT)
if err == nil {
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
}
}
if (p.Flags & sshFileXferAttrUIDGID) != 0 {
var uid uint32
var gid uint32
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, _, err = unmarshalUint32Safe(b); err != nil {
} else {
err = f.Chown(int(uid), int(gid))
if err == nil {
err = f.Chown(int(fs.UID), int(fs.GID))
}
}

View File

@ -591,18 +591,25 @@ ls -l /usr/bin/
goWords := spaceRegex.Split(goLine, -1)
opWords := spaceRegex.Split(opLine, -1)
// some fields are allowed to be different..
// words[2] and [3] as these are users & groups
// words[1] as the link count for directories like proc is unstable
// during testing as processes are created/destroyed.
// words[7] as timestamp on dirs can very for things like /tmp
for j, goWord := range goWords {
if j >= len(opWords) {
bad = true
break
}
opWord := opWords[j]
if goWord != opWord && j != 1 && j != 2 && j != 3 && j != 7 {
bad = true
if goWord != opWord {
switch j {
case 1, 2, 3, 7:
// words[1] as the link count for directories like proc is unstable
// words[2] and [3] as these are users & groups
// words[7] as timestamps on dirs can vary for things like /tmp
case 8:
// words[8] can either have full path or just the filename
bad = !strings.HasSuffix(opWord, "/" + goWord)
default:
bad = true
}
}
}
}

View File

@ -178,21 +178,22 @@ func TestOpenStatRace(t *testing.T) {
// openpacket finishes to fast to trigger race in tests
// need to add a small sleep on server to openpackets somehow
tmppath := path.Join(os.TempDir(), "stat_race")
pflags := flags(os.O_RDWR | os.O_CREATE | os.O_TRUNC)
pflags := toPflags(os.O_RDWR | os.O_CREATE | os.O_TRUNC)
ch := make(chan result, 3)
id1 := client.nextID()
id2 := client.nextID()
client.dispatchRequest(ch, &sshFxpOpenPacket{
ID: id1,
Path: tmppath,
Pflags: pflags,
})
id2 := client.nextID()
client.dispatchRequest(ch, &sshFxpLstatPacket{
ID: id2,
Path: tmppath,
})
testreply := func(id uint32) {
r := <-ch
require.NoError(t, r.err)
switch r.typ {
case sshFxpAttrs, sshFxpHandle: // ignore
case sshFxpStatus:
@ -208,6 +209,83 @@ func TestOpenStatRace(t *testing.T) {
checkServerAllocator(t, server)
}
func TestOpenWithPermissions(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
skipIfWindows(t)
client, server := clientServerPair(t)
defer client.Close()
defer server.Close()
tmppath := path.Join(os.TempDir(), "open_permissions")
defer os.Remove(tmppath)
pflags := toPflags(os.O_RDWR | os.O_CREATE | os.O_TRUNC)
id1 := client.nextID()
id2 := client.nextID()
typ, data, err := client.sendPacket(ctx, nil, &sshFxpOpenPacket{
ID: id1,
Path: tmppath,
Pflags: pflags,
Flags: sshFileXferAttrPermissions,
Attrs: &FileStat{
Mode: 0o745,
},
})
if err != nil {
t.Fatal("unexpected error:", err)
}
switch typ {
case sshFxpHandle:
// do nothing, we can just leave the handle open.
case sshFxpStatus:
t.Fatal("unexpected status:", normaliseError(unmarshalStatus(id1, data)))
default:
t.Fatal("unpexpected packet type:", unimplementedPacketErr(typ))
}
stat, err := os.Stat(tmppath)
if err != nil {
t.Fatal("unexpected error:", err)
}
if stat.Mode()&os.ModePerm != 0o745 {
t.Errorf("stat.Mode() = %v was expecting 0o745", stat.Mode())
}
// Existing files should not have their permissions changed.
typ, data, err = client.sendPacket(ctx, nil, &sshFxpOpenPacket{
ID: id2,
Path: tmppath,
Pflags: pflags,
Flags: sshFileXferAttrPermissions,
Attrs: &FileStat{
Mode: 0o755,
},
})
if err != nil {
t.Fatal("unexpected error:", err)
}
switch typ {
case sshFxpHandle:
// do nothing, we can just leave the handle open.
case sshFxpStatus:
t.Fatal("unexpected status:", normaliseError(unmarshalStatus(id2, data)))
default:
t.Fatal("unpexpected packet type:", unimplementedPacketErr(typ))
}
if stat.Mode()&os.ModePerm != 0o745 {
t.Errorf("stat.Mode() = %v, was expecting unchanged 0o745", stat.Mode())
}
checkServerAllocator(t, server)
}
// Ensure that proper error codes are returned for non existent files, such
// that they are mapped back to a 'not exists' error on the client side.
func TestStatNonExistent(t *testing.T) {