mirror of https://github.com/pkg/sftp.git
address code review
This commit is contained in:
parent
d1903fbd46
commit
f3501dc6ba
30
client.go
30
client.go
|
@ -363,7 +363,10 @@ func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, e
|
|||
filename, data = unmarshalString(data)
|
||||
_, data = unmarshalString(data) // discard longname
|
||||
var attr *FileStat
|
||||
attr, data = unmarshalAttrs(data)
|
||||
attr, data, err = unmarshalAttrs(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if filename == "." || filename == ".." {
|
||||
continue
|
||||
}
|
||||
|
@ -434,8 +437,8 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) {
|
|||
if sid != id {
|
||||
return nil, &unexpectedIDErr{id, sid}
|
||||
}
|
||||
attr, _ := unmarshalAttrs(data)
|
||||
return fileInfoFromStat(attr, path.Base(p)), nil
|
||||
attr, _, err := unmarshalAttrs(data)
|
||||
return fileInfoFromStat(attr, path.Base(p)), err
|
||||
case sshFxpStatus:
|
||||
return nil, normaliseError(unmarshalStatus(id, data))
|
||||
default:
|
||||
|
@ -660,8 +663,8 @@ func (c *Client) stat(path string) (*FileStat, error) {
|
|||
if sid != id {
|
||||
return nil, &unexpectedIDErr{id, sid}
|
||||
}
|
||||
attr, _ := unmarshalAttrs(data)
|
||||
return attr, nil
|
||||
attr, _, err := unmarshalAttrs(data)
|
||||
return attr, err
|
||||
case sshFxpStatus:
|
||||
return nil, normaliseError(unmarshalStatus(id, data))
|
||||
default:
|
||||
|
@ -684,8 +687,8 @@ func (c *Client) fstat(handle string) (*FileStat, error) {
|
|||
if sid != id {
|
||||
return nil, &unexpectedIDErr{id, sid}
|
||||
}
|
||||
attr, _ := unmarshalAttrs(data)
|
||||
return attr, nil
|
||||
attr, _, err := unmarshalAttrs(data)
|
||||
return attr, err
|
||||
case sshFxpStatus:
|
||||
return nil, normaliseError(unmarshalStatus(id, data))
|
||||
default:
|
||||
|
@ -974,8 +977,8 @@ func (c *Client) RemoveAll(path string) error {
|
|||
|
||||
// File represents a remote file.
|
||||
type File struct {
|
||||
c *Client
|
||||
path string
|
||||
c *Client
|
||||
path string
|
||||
|
||||
mu sync.RWMutex
|
||||
handle string
|
||||
|
@ -992,6 +995,10 @@ func (f *File) Close() error {
|
|||
return os.ErrClosed
|
||||
}
|
||||
|
||||
// When `openssh-portable/sftp-server.c` is doing `handle_close`,
|
||||
// it will unconditionally mark the handle as unused,
|
||||
// so we need to also unconditionally mark this handle as invalid.
|
||||
|
||||
handle := f.handle
|
||||
f.handle = ""
|
||||
|
||||
|
@ -1485,6 +1492,8 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Stat returns the FileInfo structure describing file. If there is an
|
||||
// error.
|
||||
func (f *File) Stat() (os.FileInfo, error) {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
@ -1496,8 +1505,6 @@ func (f *File) Stat() (os.FileInfo, error) {
|
|||
return f.stat()
|
||||
}
|
||||
|
||||
// Stat returns the FileInfo structure describing file. If there is an
|
||||
// error.
|
||||
func (f *File) stat() (os.FileInfo, error) {
|
||||
fs, err := f.c.fstat(f.handle)
|
||||
if err != nil {
|
||||
|
@ -2055,7 +2062,6 @@ func (f *File) Sync() error {
|
|||
return os.ErrClosed
|
||||
}
|
||||
|
||||
|
||||
id := f.c.nextID()
|
||||
typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{
|
||||
ID: id,
|
||||
|
|
89
packet.go
89
packet.go
|
@ -174,36 +174,69 @@ func unmarshalStringSafe(b []byte) (string, []byte, error) {
|
|||
return string(b[:n]), b[n:], nil
|
||||
}
|
||||
|
||||
func unmarshalAttrs(b []byte) (*FileStat, []byte) {
|
||||
flags, b := unmarshalUint32(b)
|
||||
func unmarshalAttrs(b []byte) (*FileStat, []byte, error) {
|
||||
flags, b, err := unmarshalUint32Safe(b)
|
||||
if err != nil {
|
||||
return nil, b, err
|
||||
}
|
||||
return unmarshalFileStat(flags, b)
|
||||
}
|
||||
|
||||
func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) {
|
||||
func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte, error) {
|
||||
var fs FileStat
|
||||
var err error
|
||||
|
||||
if flags&sshFileXferAttrSize == sshFileXferAttrSize {
|
||||
fs.Size, b, _ = unmarshalUint64Safe(b)
|
||||
fs.Size, b, err = unmarshalUint64Safe(b)
|
||||
if err != nil {
|
||||
return nil, b, err
|
||||
}
|
||||
}
|
||||
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
|
||||
fs.UID, b, _ = unmarshalUint32Safe(b)
|
||||
fs.GID, b, _ = unmarshalUint32Safe(b)
|
||||
fs.UID, b, err = unmarshalUint32Safe(b)
|
||||
if err != nil {
|
||||
return nil, b, err
|
||||
}
|
||||
fs.GID, b, err = unmarshalUint32Safe(b)
|
||||
if err != nil {
|
||||
return nil, b, err
|
||||
}
|
||||
}
|
||||
if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions {
|
||||
fs.Mode, b, _ = unmarshalUint32Safe(b)
|
||||
fs.Mode, b, err = unmarshalUint32Safe(b)
|
||||
if err != nil {
|
||||
return nil, b, err
|
||||
}
|
||||
}
|
||||
if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime {
|
||||
fs.Atime, b, _ = unmarshalUint32Safe(b)
|
||||
fs.Mtime, b, _ = unmarshalUint32Safe(b)
|
||||
fs.Atime, b, err = unmarshalUint32Safe(b)
|
||||
if err != nil {
|
||||
return nil, b, err
|
||||
}
|
||||
fs.Mtime, b, err = unmarshalUint32Safe(b)
|
||||
if err != nil {
|
||||
return nil, b, err
|
||||
}
|
||||
}
|
||||
if flags&sshFileXferAttrExtended == sshFileXferAttrExtended {
|
||||
var count uint32
|
||||
count, b, _ = unmarshalUint32Safe(b)
|
||||
count, b, err = unmarshalUint32Safe(b)
|
||||
if err != nil {
|
||||
return nil, b, err
|
||||
}
|
||||
|
||||
ext := make([]StatExtended, count)
|
||||
for i := uint32(0); i < count; i++ {
|
||||
var typ string
|
||||
var data string
|
||||
typ, b, _ = unmarshalStringSafe(b)
|
||||
data, b, _ = unmarshalStringSafe(b)
|
||||
typ, b, err = unmarshalStringSafe(b)
|
||||
if err != nil {
|
||||
return nil, b, err
|
||||
}
|
||||
data, b, err = unmarshalStringSafe(b)
|
||||
if err != nil {
|
||||
return nil, b, err
|
||||
}
|
||||
ext[i] = StatExtended{
|
||||
ExtType: typ,
|
||||
ExtData: data,
|
||||
|
@ -211,7 +244,7 @@ func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) {
|
|||
}
|
||||
fs.Extended = ext
|
||||
}
|
||||
return &fs, b
|
||||
return &fs, b, nil
|
||||
}
|
||||
|
||||
func unmarshalStatus(id uint32, data []byte) error {
|
||||
|
@ -734,15 +767,15 @@ func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) *FileStat {
|
||||
func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) (*FileStat, error) {
|
||||
switch attrs := p.Attrs.(type) {
|
||||
case *FileStat:
|
||||
return attrs
|
||||
return attrs, nil
|
||||
case []byte:
|
||||
fs, _ := unmarshalFileStat(flags, attrs)
|
||||
return fs
|
||||
fs, _, err := unmarshalFileStat(flags, attrs)
|
||||
return fs, err
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs))
|
||||
return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1030,15 +1063,15 @@ func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) *FileStat {
|
||||
func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) {
|
||||
switch attrs := p.Attrs.(type) {
|
||||
case *FileStat:
|
||||
return attrs
|
||||
return attrs, nil
|
||||
case []byte:
|
||||
fs, _ := unmarshalFileStat(flags, attrs)
|
||||
return fs
|
||||
fs, _, err := unmarshalFileStat(flags, attrs)
|
||||
return fs, err
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs))
|
||||
return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1055,15 +1088,15 @@ func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) *FileStat {
|
||||
func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) {
|
||||
switch attrs := p.Attrs.(type) {
|
||||
case *FileStat:
|
||||
return attrs
|
||||
return attrs, nil
|
||||
case []byte:
|
||||
fs, _ := unmarshalFileStat(flags, attrs)
|
||||
return fs
|
||||
fs, _, err := unmarshalFileStat(flags, attrs)
|
||||
return fs, err
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs))
|
||||
return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -284,7 +284,10 @@ func TestUnmarshalAttrs(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got, _ := unmarshalAttrs(tt.b)
|
||||
got, _, err := unmarshalAttrs(tt.b)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error:", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("unmarshalAttrs(% X):\n- got: %#v\n- want: %#v", tt.b, got, tt.want)
|
||||
}
|
||||
|
@ -389,11 +392,11 @@ func TestSendPacket(t *testing.T) {
|
|||
},
|
||||
{
|
||||
packet: &sshFxpOpenPacket{
|
||||
ID: 3,
|
||||
Path: "/foo",
|
||||
ID: 3,
|
||||
Path: "/foo",
|
||||
Pflags: toPflags(os.O_WRONLY | os.O_CREATE | os.O_TRUNC),
|
||||
Flags: sshFileXferAttrPermissions,
|
||||
Attrs: &FileStat{
|
||||
Attrs: &FileStat{
|
||||
Mode: 0o755,
|
||||
},
|
||||
},
|
||||
|
|
|
@ -52,6 +52,6 @@ func (r *Request) AttrFlags() FileAttrFlags {
|
|||
// Attributes parses file attributes byte blob and return them in a
|
||||
// FileStat object.
|
||||
func (r *Request) Attributes() *FileStat {
|
||||
fs, _ := unmarshalFileStat(r.Flags, r.Attrs)
|
||||
fs, _, _ := unmarshalFileStat(r.Flags, r.Attrs)
|
||||
return fs
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRequestPflags(t *testing.T) {
|
||||
|
@ -33,7 +34,8 @@ func TestRequestAttributes(t *testing.T) {
|
|||
at := []byte{}
|
||||
at = marshalUint32(at, 1)
|
||||
at = marshalUint32(at, 2)
|
||||
testFs, _ := unmarshalFileStat(fl, at)
|
||||
testFs, _, err := unmarshalFileStat(fl, at)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, fa, *testFs)
|
||||
// Size and Mode
|
||||
fa = FileStat{Mode: 0700, Size: 99}
|
||||
|
@ -41,7 +43,8 @@ func TestRequestAttributes(t *testing.T) {
|
|||
at = []byte{}
|
||||
at = marshalUint64(at, 99)
|
||||
at = marshalUint32(at, 0700)
|
||||
testFs, _ = unmarshalFileStat(fl, at)
|
||||
testFs, _, err = unmarshalFileStat(fl, at)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, fa, *testFs)
|
||||
// FileMode
|
||||
assert.True(t, testFs.FileMode().IsRegular())
|
||||
|
@ -50,7 +53,16 @@ func TestRequestAttributes(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRequestAttributesEmpty(t *testing.T) {
|
||||
fs, b := unmarshalFileStat(sshFileXferAttrAll, nil)
|
||||
fs, b, err := unmarshalFileStat(sshFileXferAttrAll, []byte{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // size
|
||||
0x00, 0x00, 0x00, 0x00, // mode
|
||||
0x00, 0x00, 0x00, 0x00, // mtime
|
||||
0x00, 0x00, 0x00, 0x00, // atime
|
||||
0x00, 0x00, 0x00, 0x00, // uid
|
||||
0x00, 0x00, 0x00, 0x00, // gid
|
||||
0x00, 0x00, 0x00, 0x00, // extended_count
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &FileStat{
|
||||
Extended: []StatExtended{},
|
||||
}, fs)
|
||||
|
|
28
server.go
28
server.go
|
@ -13,6 +13,7 @@ import (
|
|||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -462,10 +463,13 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket {
|
|||
}
|
||||
|
||||
mode := os.FileMode(0o644)
|
||||
// Like OpenSSH, we only handle permissions here, if the file is being created.
|
||||
// Like OpenSSH, we only handle permissions here, and only when the file is being created.
|
||||
// Otherwise, the permissions are ignored.
|
||||
if p.Flags & sshFileXferAttrPermissions != 0 {
|
||||
fs := p.unmarshalFileStat(p.Flags)
|
||||
if p.Flags&sshFileXferAttrPermissions != 0 {
|
||||
fs, err := p.unmarshalFileStat(p.Flags)
|
||||
if err != nil {
|
||||
return statusFromError(p.ID, err)
|
||||
}
|
||||
mode = fs.FileMode() & os.ModePerm
|
||||
}
|
||||
|
||||
|
@ -507,9 +511,7 @@ func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket {
|
|||
|
||||
debug("setstat name %q", path)
|
||||
|
||||
fs := p.unmarshalFileStat(p.Flags)
|
||||
|
||||
var err error
|
||||
fs, err := p.unmarshalFileStat(p.Flags)
|
||||
|
||||
if (p.Flags & sshFileXferAttrSize) != 0 {
|
||||
if err == nil {
|
||||
|
@ -545,9 +547,7 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket {
|
|||
|
||||
debug("fsetstat name %q", path)
|
||||
|
||||
fs := p.unmarshalFileStat(p.Flags)
|
||||
|
||||
var err error
|
||||
fs, err := p.unmarshalFileStat(p.Flags)
|
||||
|
||||
if (p.Flags & sshFileXferAttrSize) != 0 {
|
||||
if err == nil {
|
||||
|
@ -561,7 +561,15 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket {
|
|||
}
|
||||
if (p.Flags & sshFileXferAttrACmodTime) != 0 {
|
||||
if err == nil {
|
||||
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
|
||||
switch f := interface{}(f).(type) {
|
||||
case interface {
|
||||
Chtimes(atime, mtime time.Time) error
|
||||
}:
|
||||
// future-compatible, if any when *os.File supports Chtimes.
|
||||
err = f.Chtimes(fs.AccessTime(), fs.ModTime())
|
||||
default:
|
||||
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
|
||||
}
|
||||
}
|
||||
}
|
||||
if (p.Flags & sshFileXferAttrUIDGID) != 0 {
|
||||
|
|
|
@ -606,7 +606,7 @@ ls -l /usr/bin/
|
|||
// words[7] as timestamps on dirs can vary for things like /tmp
|
||||
case 8:
|
||||
// words[8] can either have full path or just the filename
|
||||
bad = !strings.HasSuffix(opWord, "/" + goWord)
|
||||
bad = !strings.HasSuffix(opWord, "/"+goWord)
|
||||
default:
|
||||
bad = true
|
||||
}
|
||||
|
|
|
@ -228,11 +228,11 @@ func TestOpenWithPermissions(t *testing.T) {
|
|||
id2 := client.nextID()
|
||||
|
||||
typ, data, err := client.sendPacket(ctx, nil, &sshFxpOpenPacket{
|
||||
ID: id1,
|
||||
Path: tmppath,
|
||||
ID: id1,
|
||||
Path: tmppath,
|
||||
Pflags: pflags,
|
||||
Flags: sshFileXferAttrPermissions,
|
||||
Attrs: &FileStat{
|
||||
Attrs: &FileStat{
|
||||
Mode: 0o745,
|
||||
},
|
||||
})
|
||||
|
@ -259,11 +259,11 @@ func TestOpenWithPermissions(t *testing.T) {
|
|||
|
||||
// Existing files should not have their permissions changed.
|
||||
typ, data, err = client.sendPacket(ctx, nil, &sshFxpOpenPacket{
|
||||
ID: id2,
|
||||
Path: tmppath,
|
||||
ID: id2,
|
||||
Path: tmppath,
|
||||
Pflags: pflags,
|
||||
Flags: sshFileXferAttrPermissions,
|
||||
Attrs: &FileStat{
|
||||
Attrs: &FileStat{
|
||||
Mode: 0o755,
|
||||
},
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue