mirror of https://github.com/pkg/sftp.git
rework client to prevent after-close usage, and support perm at open
This commit is contained in:
parent
22452ea54d
commit
d1903fbd46
19
attrs.go
19
attrs.go
|
@ -32,10 +32,10 @@ func (fi *fileInfo) Name() string { return fi.name }
|
|||
func (fi *fileInfo) Size() int64 { return int64(fi.stat.Size) }
|
||||
|
||||
// Mode returns file mode bits.
|
||||
func (fi *fileInfo) Mode() os.FileMode { return toFileMode(fi.stat.Mode) }
|
||||
func (fi *fileInfo) Mode() os.FileMode { return fi.stat.FileMode() }
|
||||
|
||||
// ModTime returns the last modification time of the file.
|
||||
func (fi *fileInfo) ModTime() time.Time { return time.Unix(int64(fi.stat.Mtime), 0) }
|
||||
func (fi *fileInfo) ModTime() time.Time { return fi.stat.ModTime() }
|
||||
|
||||
// IsDir returns true if the file is a directory.
|
||||
func (fi *fileInfo) IsDir() bool { return fi.Mode().IsDir() }
|
||||
|
@ -56,6 +56,21 @@ type FileStat struct {
|
|||
Extended []StatExtended
|
||||
}
|
||||
|
||||
// ModTime returns the Mtime SFTP file attribute converted to a time.Time
|
||||
func (fs *FileStat) ModTime() time.Time {
|
||||
return time.Unix(int64(fs.Mtime), 0)
|
||||
}
|
||||
|
||||
// AccessTime returns the Atime SFTP file attribute converted to a time.Time
|
||||
func (fs *FileStat) AccessTime() time.Time {
|
||||
return time.Unix(int64(fs.Atime), 0)
|
||||
}
|
||||
|
||||
// FileMode returns the Mode SFTP file attribute converted to an os.FileMode
|
||||
func (fs *FileStat) FileMode() os.FileMode {
|
||||
return toFileMode(fs.Mode)
|
||||
}
|
||||
|
||||
// StatExtended contains additional, extended information for a FileStat.
|
||||
type StatExtended struct {
|
||||
ExtType string
|
||||
|
|
156
client.go
156
client.go
|
@ -257,7 +257,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
|
|||
// read/write at the same time. For those services you will need to use
|
||||
// `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`.
|
||||
func (c *Client) Create(path string) (*File, error) {
|
||||
return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
|
||||
return c.open(path, toPflags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
|
||||
}
|
||||
|
||||
const sftpProtocolVersion = 3 // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt
|
||||
|
@ -510,7 +510,7 @@ func (c *Client) Symlink(oldname, newname string) error {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error {
|
||||
func (c *Client) fsetstat(handle string, flags uint32, attrs interface{}) error {
|
||||
id := c.nextID()
|
||||
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{
|
||||
ID: id,
|
||||
|
@ -590,14 +590,14 @@ func (c *Client) Truncate(path string, size int64) error {
|
|||
// returned file can be used for reading; the associated file descriptor
|
||||
// has mode O_RDONLY.
|
||||
func (c *Client) Open(path string) (*File, error) {
|
||||
return c.open(path, flags(os.O_RDONLY))
|
||||
return c.open(path, toPflags(os.O_RDONLY))
|
||||
}
|
||||
|
||||
// OpenFile is the generalized open call; most users will use Open or
|
||||
// Create instead. It opens the named file with specified flag (O_RDONLY
|
||||
// etc.). If successful, methods on the returned File can be used for I/O.
|
||||
func (c *Client) OpenFile(path string, f int) (*File, error) {
|
||||
return c.open(path, flags(f))
|
||||
return c.open(path, toPflags(f))
|
||||
}
|
||||
|
||||
func (c *Client) open(path string, pflags uint32) (*File, error) {
|
||||
|
@ -976,16 +976,26 @@ func (c *Client) RemoveAll(path string) error {
|
|||
type File struct {
|
||||
c *Client
|
||||
path string
|
||||
handle string
|
||||
|
||||
mu sync.Mutex
|
||||
mu sync.RWMutex
|
||||
handle string
|
||||
offset int64 // current offset within remote file
|
||||
}
|
||||
|
||||
// Close closes the File, rendering it unusable for I/O. It returns an
|
||||
// error, if any.
|
||||
func (f *File) Close() error {
|
||||
return f.c.close(f.handle)
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if f.handle == "" {
|
||||
return os.ErrClosed
|
||||
}
|
||||
|
||||
handle := f.handle
|
||||
f.handle = ""
|
||||
|
||||
return f.c.close(handle)
|
||||
}
|
||||
|
||||
// Name returns the name of the file as presented to Open or Create.
|
||||
|
@ -1006,7 +1016,11 @@ func (f *File) Read(b []byte) (int, error) {
|
|||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
n, err := f.ReadAt(b, f.offset)
|
||||
if f.handle == "" {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
n, err := f.readAt(b, f.offset)
|
||||
f.offset += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
@ -1071,6 +1085,17 @@ func (f *File) readAtSequential(b []byte, off int64) (read int, err error) {
|
|||
// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics,
|
||||
// so the file offset is not altered during the read.
|
||||
func (f *File) ReadAt(b []byte, off int64) (int, error) {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
if f.handle == "" {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
return f.readAt(b, off)
|
||||
}
|
||||
|
||||
func (f *File) readAt(b []byte, off int64) (int, error) {
|
||||
if len(b) <= f.c.maxPacket {
|
||||
// This should be able to be serviced with 1/2 requests.
|
||||
// So, just do it directly.
|
||||
|
@ -1267,6 +1292,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
|
|||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if f.handle == "" {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
if f.c.disableConcurrentReads {
|
||||
return f.writeToSequential(w)
|
||||
}
|
||||
|
@ -1456,9 +1485,20 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (f *File) Stat() (os.FileInfo, error) {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
if f.handle == "" {
|
||||
return nil, os.ErrClosed
|
||||
}
|
||||
|
||||
return f.stat()
|
||||
}
|
||||
|
||||
// Stat returns the FileInfo structure describing file. If there is an
|
||||
// error.
|
||||
func (f *File) Stat() (os.FileInfo, error) {
|
||||
func (f *File) stat() (os.FileInfo, error) {
|
||||
fs, err := f.c.fstat(f.handle)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -1478,7 +1518,11 @@ func (f *File) Write(b []byte) (int, error) {
|
|||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
n, err := f.WriteAt(b, f.offset)
|
||||
if f.handle == "" {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
n, err := f.writeAt(b, f.offset)
|
||||
f.offset += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
@ -1636,6 +1680,17 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) {
|
|||
// the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics,
|
||||
// so the file offset is not altered during the write.
|
||||
func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
if f.handle == "" {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
return f.writeAt(b, off)
|
||||
}
|
||||
|
||||
func (f *File) writeAt(b []byte, off int64) (written int, err error) {
|
||||
if len(b) <= f.c.maxPacket {
|
||||
// We can do this in one write.
|
||||
return f.writeChunkAt(nil, b, off)
|
||||
|
@ -1675,6 +1730,17 @@ func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
|
|||
//
|
||||
// Otherwise, the given concurrency will be capped by the Client's max concurrency.
|
||||
func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if f.handle == "" {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
return f.readFromWithConcurrency(r, concurrency)
|
||||
}
|
||||
|
||||
func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
|
||||
// Split the write into multiple maxPacket sized concurrent writes.
|
||||
// This allows writes with a suitably large reader
|
||||
// to transfer data at a much faster rate due to overlapping round trip times.
|
||||
|
@ -1824,6 +1890,10 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
|
|||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if f.handle == "" {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
if f.c.useConcurrentWrites {
|
||||
var remain int64
|
||||
switch r := r.(type) {
|
||||
|
@ -1845,7 +1915,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
|
|||
|
||||
if remain < 0 {
|
||||
// We can strongly assert that we want default max concurrency here.
|
||||
return f.ReadFromWithConcurrency(r, f.c.maxConcurrentRequests)
|
||||
return f.readFromWithConcurrency(r, f.c.maxConcurrentRequests)
|
||||
}
|
||||
|
||||
if remain > int64(f.c.maxPacket) {
|
||||
|
@ -1860,7 +1930,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
|
|||
concurrency64 = int64(f.c.maxConcurrentRequests)
|
||||
}
|
||||
|
||||
return f.ReadFromWithConcurrency(r, int(concurrency64))
|
||||
return f.readFromWithConcurrency(r, int(concurrency64))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1903,12 +1973,16 @@ func (f *File) Seek(offset int64, whence int) (int64, error) {
|
|||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if f.handle == "" {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
switch whence {
|
||||
case io.SeekStart:
|
||||
case io.SeekCurrent:
|
||||
offset += f.offset
|
||||
case io.SeekEnd:
|
||||
fi, err := f.Stat()
|
||||
fi, err := f.stat()
|
||||
if err != nil {
|
||||
return f.offset, err
|
||||
}
|
||||
|
@ -1927,20 +2001,61 @@ func (f *File) Seek(offset int64, whence int) (int64, error) {
|
|||
|
||||
// Chown changes the uid/gid of the current file.
|
||||
func (f *File) Chown(uid, gid int) error {
|
||||
return f.c.Chown(f.path, uid, gid)
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
if f.handle == "" {
|
||||
return os.ErrClosed
|
||||
}
|
||||
|
||||
return f.c.fsetstat(f.handle, sshFileXferAttrUIDGID, &FileStat{
|
||||
UID: uint32(uid),
|
||||
GID: uint32(gid),
|
||||
})
|
||||
}
|
||||
|
||||
// Chmod changes the permissions of the current file.
|
||||
//
|
||||
// See Client.Chmod for details.
|
||||
func (f *File) Chmod(mode os.FileMode) error {
|
||||
return f.c.setfstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode))
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
if f.handle == "" {
|
||||
return os.ErrClosed
|
||||
}
|
||||
|
||||
return f.c.fsetstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode))
|
||||
}
|
||||
|
||||
// Truncate sets the size of the current file. Although it may be safely assumed
|
||||
// that if the size is less than its current size it will be truncated to fit,
|
||||
// the SFTP protocol does not specify what behavior the server should do when setting
|
||||
// size greater than the current size.
|
||||
// We send a SSH_FXP_FSETSTAT here since we have a file handle
|
||||
func (f *File) Truncate(size int64) error {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
if f.handle == "" {
|
||||
return os.ErrClosed
|
||||
}
|
||||
|
||||
return f.c.fsetstat(f.handle, sshFileXferAttrSize, uint64(size))
|
||||
}
|
||||
|
||||
// Sync requests a flush of the contents of a File to stable storage.
|
||||
//
|
||||
// Sync requires the server to support the fsync@openssh.com extension.
|
||||
func (f *File) Sync() error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if f.handle == "" {
|
||||
return os.ErrClosed
|
||||
}
|
||||
|
||||
|
||||
id := f.c.nextID()
|
||||
typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{
|
||||
ID: id,
|
||||
|
@ -1957,15 +2072,6 @@ func (f *File) Sync() error {
|
|||
}
|
||||
}
|
||||
|
||||
// Truncate sets the size of the current file. Although it may be safely assumed
|
||||
// that if the size is less than its current size it will be truncated to fit,
|
||||
// the SFTP protocol does not specify what behavior the server should do when setting
|
||||
// size greater than the current size.
|
||||
// We send a SSH_FXP_FSETSTAT here since we have a file handle
|
||||
func (f *File) Truncate(size int64) error {
|
||||
return f.c.setfstat(f.handle, sshFileXferAttrSize, uint64(size))
|
||||
}
|
||||
|
||||
// normaliseError normalises an error into a more standard form that can be
|
||||
// checked against stdlib errors like io.EOF or os.ErrNotExist.
|
||||
func normaliseError(err error) error {
|
||||
|
@ -1990,7 +2096,7 @@ func normaliseError(err error) error {
|
|||
|
||||
// flags converts the flags passed to OpenFile into ssh flags.
|
||||
// Unsupported flags are ignored.
|
||||
func flags(f int) uint32 {
|
||||
func toPflags(f int) uint32 {
|
||||
var out uint32
|
||||
switch f & os.O_WRONLY {
|
||||
case os.O_WRONLY:
|
||||
|
|
|
@ -81,7 +81,7 @@ var flagsTests = []struct {
|
|||
|
||||
func TestFlags(t *testing.T) {
|
||||
for i, tt := range flagsTests {
|
||||
got := flags(tt.flags)
|
||||
got := toPflags(tt.flags)
|
||||
if got != tt.want {
|
||||
t.Errorf("test %v: flags(%x): want: %x, got: %x", i, tt.flags, tt.want, got)
|
||||
}
|
||||
|
|
93
packet.go
93
packet.go
|
@ -56,6 +56,11 @@ func marshalFileInfo(b []byte, fi os.FileInfo) []byte {
|
|||
flags, fileStat := fileStatFromInfo(fi)
|
||||
|
||||
b = marshalUint32(b, flags)
|
||||
|
||||
return marshalFileStat(b, flags, fileStat)
|
||||
}
|
||||
|
||||
func marshalFileStat(b []byte, flags uint32, fileStat *FileStat) []byte {
|
||||
if flags&sshFileXferAttrSize != 0 {
|
||||
b = marshalUint64(b, fileStat.Size)
|
||||
}
|
||||
|
@ -91,10 +96,9 @@ func marshalStatus(b []byte, err StatusError) []byte {
|
|||
}
|
||||
|
||||
func marshal(b []byte, v interface{}) []byte {
|
||||
if v == nil {
|
||||
return b
|
||||
}
|
||||
switch v := v.(type) {
|
||||
case nil:
|
||||
return b
|
||||
case uint8:
|
||||
return append(b, v)
|
||||
case uint32:
|
||||
|
@ -103,6 +107,8 @@ func marshal(b []byte, v interface{}) []byte {
|
|||
return marshalUint64(b, v)
|
||||
case string:
|
||||
return marshalString(b, v)
|
||||
case []byte:
|
||||
return append(b, v...)
|
||||
case os.FileInfo:
|
||||
return marshalFileInfo(b, v)
|
||||
default:
|
||||
|
@ -180,8 +186,6 @@ func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) {
|
|||
}
|
||||
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
|
||||
fs.UID, b, _ = unmarshalUint32Safe(b)
|
||||
}
|
||||
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
|
||||
fs.GID, b, _ = unmarshalUint32Safe(b)
|
||||
}
|
||||
if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions {
|
||||
|
@ -681,12 +685,13 @@ type sshFxpOpenPacket struct {
|
|||
ID uint32
|
||||
Path string
|
||||
Pflags uint32
|
||||
Flags uint32 // ignored
|
||||
Flags uint32
|
||||
Attrs interface{}
|
||||
}
|
||||
|
||||
func (p *sshFxpOpenPacket) id() uint32 { return p.ID }
|
||||
|
||||
func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
|
||||
func (p *sshFxpOpenPacket) marshalPacket() ([]byte, []byte, error) {
|
||||
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
|
||||
4 + len(p.Path) +
|
||||
4 + 4
|
||||
|
@ -698,7 +703,20 @@ func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
|
|||
b = marshalUint32(b, p.Pflags)
|
||||
b = marshalUint32(b, p.Flags)
|
||||
|
||||
return b, nil
|
||||
switch attrs := p.Attrs.(type) {
|
||||
case os.FileInfo:
|
||||
_, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet.
|
||||
return b, marshalFileStat(nil, p.Flags, fs), nil
|
||||
case *FileStat:
|
||||
return b, marshalFileStat(nil, p.Flags, attrs), nil
|
||||
}
|
||||
|
||||
return b, marshal(nil, p.Attrs), nil
|
||||
}
|
||||
|
||||
func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
|
||||
header, payload, err := p.marshalPacket()
|
||||
return append(header, payload...), err
|
||||
}
|
||||
|
||||
func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
|
||||
|
@ -709,12 +727,25 @@ func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
|
|||
return err
|
||||
} else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil {
|
||||
return err
|
||||
} else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil {
|
||||
} else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil {
|
||||
return err
|
||||
}
|
||||
p.Attrs = b
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) *FileStat {
|
||||
switch attrs := p.Attrs.(type) {
|
||||
case *FileStat:
|
||||
return attrs
|
||||
case []byte:
|
||||
fs, _ := unmarshalFileStat(flags, attrs)
|
||||
return fs
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs))
|
||||
}
|
||||
}
|
||||
|
||||
type sshFxpReadPacket struct {
|
||||
ID uint32
|
||||
Len uint32
|
||||
|
@ -943,9 +974,15 @@ func (p *sshFxpSetstatPacket) marshalPacket() ([]byte, []byte, error) {
|
|||
b = marshalString(b, p.Path)
|
||||
b = marshalUint32(b, p.Flags)
|
||||
|
||||
payload := marshal(nil, p.Attrs)
|
||||
switch attrs := p.Attrs.(type) {
|
||||
case os.FileInfo:
|
||||
_, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet.
|
||||
return b, marshalFileStat(nil, p.Flags, fs), nil
|
||||
case *FileStat:
|
||||
return b, marshalFileStat(nil, p.Flags, attrs), nil
|
||||
}
|
||||
|
||||
return b, payload, nil
|
||||
return b, marshal(nil, p.Attrs), nil
|
||||
}
|
||||
|
||||
func (p *sshFxpSetstatPacket) MarshalBinary() ([]byte, error) {
|
||||
|
@ -964,9 +1001,15 @@ func (p *sshFxpFsetstatPacket) marshalPacket() ([]byte, []byte, error) {
|
|||
b = marshalString(b, p.Handle)
|
||||
b = marshalUint32(b, p.Flags)
|
||||
|
||||
payload := marshal(nil, p.Attrs)
|
||||
switch attrs := p.Attrs.(type) {
|
||||
case os.FileInfo:
|
||||
_, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet.
|
||||
return b, marshalFileStat(nil, p.Flags, fs), nil
|
||||
case *FileStat:
|
||||
return b, marshalFileStat(nil, p.Flags, attrs), nil
|
||||
}
|
||||
|
||||
return b, payload, nil
|
||||
return b, marshal(nil, p.Attrs), nil
|
||||
}
|
||||
|
||||
func (p *sshFxpFsetstatPacket) MarshalBinary() ([]byte, error) {
|
||||
|
@ -987,6 +1030,18 @@ func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) *FileStat {
|
||||
switch attrs := p.Attrs.(type) {
|
||||
case *FileStat:
|
||||
return attrs
|
||||
case []byte:
|
||||
fs, _ := unmarshalFileStat(flags, attrs)
|
||||
return fs
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error {
|
||||
var err error
|
||||
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
|
||||
|
@ -1000,6 +1055,18 @@ func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) *FileStat {
|
||||
switch attrs := p.Attrs.(type) {
|
||||
case *FileStat:
|
||||
return attrs
|
||||
case []byte:
|
||||
fs, _ := unmarshalFileStat(flags, attrs)
|
||||
return fs
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs))
|
||||
}
|
||||
}
|
||||
|
||||
type sshFxpHandlePacket struct {
|
||||
ID uint32
|
||||
Handle string
|
||||
|
|
|
@ -376,7 +376,7 @@ func TestSendPacket(t *testing.T) {
|
|||
packet: &sshFxpOpenPacket{
|
||||
ID: 1,
|
||||
Path: "/foo",
|
||||
Pflags: flags(os.O_RDONLY),
|
||||
Pflags: toPflags(os.O_RDONLY),
|
||||
},
|
||||
want: []byte{
|
||||
0x0, 0x0, 0x0, 0x15,
|
||||
|
@ -387,6 +387,26 @@ func TestSendPacket(t *testing.T) {
|
|||
0x0, 0x0, 0x0, 0x0,
|
||||
},
|
||||
},
|
||||
{
|
||||
packet: &sshFxpOpenPacket{
|
||||
ID: 3,
|
||||
Path: "/foo",
|
||||
Pflags: toPflags(os.O_WRONLY | os.O_CREATE | os.O_TRUNC),
|
||||
Flags: sshFileXferAttrPermissions,
|
||||
Attrs: &FileStat{
|
||||
Mode: 0o755,
|
||||
},
|
||||
},
|
||||
want: []byte{
|
||||
0x0, 0x0, 0x0, 0x19,
|
||||
0x3,
|
||||
0x0, 0x0, 0x0, 0x3,
|
||||
0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o',
|
||||
0x0, 0x0, 0x0, 0x1a,
|
||||
0x0, 0x0, 0x0, 0x4,
|
||||
0x0, 0x0, 0x1, 0xed,
|
||||
},
|
||||
},
|
||||
{
|
||||
packet: &sshFxpWritePacket{
|
||||
ID: 124,
|
||||
|
@ -409,10 +429,7 @@ func TestSendPacket(t *testing.T) {
|
|||
ID: 31,
|
||||
Path: "/bar",
|
||||
Flags: sshFileXferAttrUIDGID,
|
||||
Attrs: struct {
|
||||
UID uint32
|
||||
GID uint32
|
||||
}{
|
||||
Attrs: &FileStat{
|
||||
UID: 1000,
|
||||
GID: 100,
|
||||
},
|
||||
|
@ -611,7 +628,7 @@ func BenchmarkMarshalOpen(b *testing.B) {
|
|||
benchMarshal(b, &sshFxpOpenPacket{
|
||||
ID: 1,
|
||||
Path: "/home/test/some/random/path",
|
||||
Pflags: flags(os.O_RDONLY),
|
||||
Pflags: toPflags(os.O_RDONLY),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ package sftp
|
|||
// Methods on the Request object to make working with the Flags bitmasks and
|
||||
// Attr(ibutes) byte blob easier. Use Pflags() when working with an Open/Write
|
||||
// request and AttrFlags() and Attributes() when working with SetStat requests.
|
||||
import "os"
|
||||
|
||||
// FileOpenFlags defines Open and Write Flags. Correlate directly with with os.OpenFile flags
|
||||
// (https://golang.org/pkg/os/#pkg-constants).
|
||||
|
@ -50,11 +49,6 @@ func (r *Request) AttrFlags() FileAttrFlags {
|
|||
return newFileAttrFlags(r.Flags)
|
||||
}
|
||||
|
||||
// FileMode returns the Mode SFTP file attributes wrapped as os.FileMode
|
||||
func (a FileStat) FileMode() os.FileMode {
|
||||
return os.FileMode(a.Mode)
|
||||
}
|
||||
|
||||
// Attributes parses file attributes byte blob and return them in a
|
||||
// FileStat object.
|
||||
func (r *Request) Attributes() *FileStat {
|
||||
|
|
87
server.go
87
server.go
|
@ -13,7 +13,6 @@ import (
|
|||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -462,7 +461,15 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket {
|
|||
osFlags |= os.O_EXCL
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, 0o644)
|
||||
mode := os.FileMode(0o644)
|
||||
// Like OpenSSH, we only handle permissions here, if the file is being created.
|
||||
// Otherwise, the permissions are ignored.
|
||||
if p.Flags & sshFileXferAttrPermissions != 0 {
|
||||
fs := p.unmarshalFileStat(p.Flags)
|
||||
mode = fs.FileMode() & os.ModePerm
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, mode)
|
||||
if err != nil {
|
||||
return statusFromError(p.ID, err)
|
||||
}
|
||||
|
@ -496,43 +503,32 @@ func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket {
|
|||
}
|
||||
|
||||
func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket {
|
||||
// additional unmarshalling is required for each possibility here
|
||||
b := p.Attrs.([]byte)
|
||||
path := svr.toLocalPath(p.Path)
|
||||
|
||||
debug("setstat name %q", path)
|
||||
|
||||
fs := p.unmarshalFileStat(p.Flags)
|
||||
|
||||
var err error
|
||||
|
||||
p.Path = svr.toLocalPath(p.Path)
|
||||
|
||||
debug("setstat name \"%s\"", p.Path)
|
||||
if (p.Flags & sshFileXferAttrSize) != 0 {
|
||||
var size uint64
|
||||
if size, b, err = unmarshalUint64Safe(b); err == nil {
|
||||
err = os.Truncate(p.Path, int64(size))
|
||||
if err == nil {
|
||||
err = os.Truncate(path, int64(fs.Size))
|
||||
}
|
||||
}
|
||||
if (p.Flags & sshFileXferAttrPermissions) != 0 {
|
||||
var mode uint32
|
||||
if mode, b, err = unmarshalUint32Safe(b); err == nil {
|
||||
err = os.Chmod(p.Path, os.FileMode(mode))
|
||||
if err == nil {
|
||||
err = os.Chmod(path, fs.FileMode())
|
||||
}
|
||||
}
|
||||
if (p.Flags & sshFileXferAttrACmodTime) != 0 {
|
||||
var atime uint32
|
||||
var mtime uint32
|
||||
if atime, b, err = unmarshalUint32Safe(b); err != nil {
|
||||
} else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
|
||||
} else {
|
||||
atimeT := time.Unix(int64(atime), 0)
|
||||
mtimeT := time.Unix(int64(mtime), 0)
|
||||
err = os.Chtimes(p.Path, atimeT, mtimeT)
|
||||
if err == nil {
|
||||
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
|
||||
}
|
||||
}
|
||||
if (p.Flags & sshFileXferAttrUIDGID) != 0 {
|
||||
var uid uint32
|
||||
var gid uint32
|
||||
if uid, b, err = unmarshalUint32Safe(b); err != nil {
|
||||
} else if gid, _, err = unmarshalUint32Safe(b); err != nil {
|
||||
} else {
|
||||
err = os.Chown(p.Path, int(uid), int(gid))
|
||||
if err == nil {
|
||||
err = os.Chown(path, int(fs.UID), int(fs.GID))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -545,41 +541,32 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket {
|
|||
return statusFromError(p.ID, EBADF)
|
||||
}
|
||||
|
||||
// additional unmarshalling is required for each possibility here
|
||||
b := p.Attrs.([]byte)
|
||||
path := f.Name()
|
||||
|
||||
debug("fsetstat name %q", path)
|
||||
|
||||
fs := p.unmarshalFileStat(p.Flags)
|
||||
|
||||
var err error
|
||||
|
||||
debug("fsetstat name \"%s\"", f.Name())
|
||||
if (p.Flags & sshFileXferAttrSize) != 0 {
|
||||
var size uint64
|
||||
if size, b, err = unmarshalUint64Safe(b); err == nil {
|
||||
err = f.Truncate(int64(size))
|
||||
if err == nil {
|
||||
err = f.Truncate(int64(fs.Size))
|
||||
}
|
||||
}
|
||||
if (p.Flags & sshFileXferAttrPermissions) != 0 {
|
||||
var mode uint32
|
||||
if mode, b, err = unmarshalUint32Safe(b); err == nil {
|
||||
err = f.Chmod(os.FileMode(mode))
|
||||
if err == nil {
|
||||
err = f.Chmod(fs.FileMode())
|
||||
}
|
||||
}
|
||||
if (p.Flags & sshFileXferAttrACmodTime) != 0 {
|
||||
var atime uint32
|
||||
var mtime uint32
|
||||
if atime, b, err = unmarshalUint32Safe(b); err != nil {
|
||||
} else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
|
||||
} else {
|
||||
atimeT := time.Unix(int64(atime), 0)
|
||||
mtimeT := time.Unix(int64(mtime), 0)
|
||||
err = os.Chtimes(f.Name(), atimeT, mtimeT)
|
||||
if err == nil {
|
||||
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
|
||||
}
|
||||
}
|
||||
if (p.Flags & sshFileXferAttrUIDGID) != 0 {
|
||||
var uid uint32
|
||||
var gid uint32
|
||||
if uid, b, err = unmarshalUint32Safe(b); err != nil {
|
||||
} else if gid, _, err = unmarshalUint32Safe(b); err != nil {
|
||||
} else {
|
||||
err = f.Chown(int(uid), int(gid))
|
||||
if err == nil {
|
||||
err = f.Chown(int(fs.UID), int(fs.GID))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -591,21 +591,28 @@ ls -l /usr/bin/
|
|||
goWords := spaceRegex.Split(goLine, -1)
|
||||
opWords := spaceRegex.Split(opLine, -1)
|
||||
// some fields are allowed to be different..
|
||||
// words[2] and [3] as these are users & groups
|
||||
// words[1] as the link count for directories like proc is unstable
|
||||
// during testing as processes are created/destroyed.
|
||||
// words[7] as timestamp on dirs can very for things like /tmp
|
||||
for j, goWord := range goWords {
|
||||
if j >= len(opWords) {
|
||||
bad = true
|
||||
break
|
||||
}
|
||||
opWord := opWords[j]
|
||||
if goWord != opWord && j != 1 && j != 2 && j != 3 && j != 7 {
|
||||
if goWord != opWord {
|
||||
switch j {
|
||||
case 1, 2, 3, 7:
|
||||
// words[1] as the link count for directories like proc is unstable
|
||||
// words[2] and [3] as these are users & groups
|
||||
// words[7] as timestamps on dirs can vary for things like /tmp
|
||||
case 8:
|
||||
// words[8] can either have full path or just the filename
|
||||
bad = !strings.HasSuffix(opWord, "/" + goWord)
|
||||
default:
|
||||
bad = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bad {
|
||||
t.Errorf("outputs differ\n go: %q\nopenssh: %q\n", goLine, opLine)
|
||||
|
|
|
@ -178,21 +178,22 @@ func TestOpenStatRace(t *testing.T) {
|
|||
// openpacket finishes to fast to trigger race in tests
|
||||
// need to add a small sleep on server to openpackets somehow
|
||||
tmppath := path.Join(os.TempDir(), "stat_race")
|
||||
pflags := flags(os.O_RDWR | os.O_CREATE | os.O_TRUNC)
|
||||
pflags := toPflags(os.O_RDWR | os.O_CREATE | os.O_TRUNC)
|
||||
ch := make(chan result, 3)
|
||||
id1 := client.nextID()
|
||||
id2 := client.nextID()
|
||||
client.dispatchRequest(ch, &sshFxpOpenPacket{
|
||||
ID: id1,
|
||||
Path: tmppath,
|
||||
Pflags: pflags,
|
||||
})
|
||||
id2 := client.nextID()
|
||||
client.dispatchRequest(ch, &sshFxpLstatPacket{
|
||||
ID: id2,
|
||||
Path: tmppath,
|
||||
})
|
||||
testreply := func(id uint32) {
|
||||
r := <-ch
|
||||
require.NoError(t, r.err)
|
||||
switch r.typ {
|
||||
case sshFxpAttrs, sshFxpHandle: // ignore
|
||||
case sshFxpStatus:
|
||||
|
@ -208,6 +209,83 @@ func TestOpenStatRace(t *testing.T) {
|
|||
checkServerAllocator(t, server)
|
||||
}
|
||||
|
||||
func TestOpenWithPermissions(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
skipIfWindows(t)
|
||||
|
||||
client, server := clientServerPair(t)
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
tmppath := path.Join(os.TempDir(), "open_permissions")
|
||||
defer os.Remove(tmppath)
|
||||
|
||||
pflags := toPflags(os.O_RDWR | os.O_CREATE | os.O_TRUNC)
|
||||
|
||||
id1 := client.nextID()
|
||||
id2 := client.nextID()
|
||||
|
||||
typ, data, err := client.sendPacket(ctx, nil, &sshFxpOpenPacket{
|
||||
ID: id1,
|
||||
Path: tmppath,
|
||||
Pflags: pflags,
|
||||
Flags: sshFileXferAttrPermissions,
|
||||
Attrs: &FileStat{
|
||||
Mode: 0o745,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error:", err)
|
||||
}
|
||||
switch typ {
|
||||
case sshFxpHandle:
|
||||
// do nothing, we can just leave the handle open.
|
||||
case sshFxpStatus:
|
||||
t.Fatal("unexpected status:", normaliseError(unmarshalStatus(id1, data)))
|
||||
default:
|
||||
t.Fatal("unpexpected packet type:", unimplementedPacketErr(typ))
|
||||
}
|
||||
|
||||
stat, err := os.Stat(tmppath)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error:", err)
|
||||
}
|
||||
|
||||
if stat.Mode()&os.ModePerm != 0o745 {
|
||||
t.Errorf("stat.Mode() = %v was expecting 0o745", stat.Mode())
|
||||
}
|
||||
|
||||
// Existing files should not have their permissions changed.
|
||||
typ, data, err = client.sendPacket(ctx, nil, &sshFxpOpenPacket{
|
||||
ID: id2,
|
||||
Path: tmppath,
|
||||
Pflags: pflags,
|
||||
Flags: sshFileXferAttrPermissions,
|
||||
Attrs: &FileStat{
|
||||
Mode: 0o755,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error:", err)
|
||||
}
|
||||
switch typ {
|
||||
case sshFxpHandle:
|
||||
// do nothing, we can just leave the handle open.
|
||||
case sshFxpStatus:
|
||||
t.Fatal("unexpected status:", normaliseError(unmarshalStatus(id2, data)))
|
||||
default:
|
||||
t.Fatal("unpexpected packet type:", unimplementedPacketErr(typ))
|
||||
}
|
||||
|
||||
if stat.Mode()&os.ModePerm != 0o745 {
|
||||
t.Errorf("stat.Mode() = %v, was expecting unchanged 0o745", stat.Mode())
|
||||
}
|
||||
|
||||
checkServerAllocator(t, server)
|
||||
}
|
||||
|
||||
// Ensure that proper error codes are returned for non existent files, such
|
||||
// that they are mapped back to a 'not exists' error on the client side.
|
||||
func TestStatNonExistent(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue