diff --git a/attrs.go b/attrs.go index 7020d3a..daa18bc 100644 --- a/attrs.go +++ b/attrs.go @@ -95,79 +95,3 @@ func fileStatFromInfo(fi os.FileInfo) (uint32, FileStat) { return flags, fileStat } - -func unmarshalAttrs(b []byte) (*FileStat, []byte) { - flags, b := unmarshalUint32(b) - return getFileStat(flags, b) -} - -func getFileStat(flags uint32, b []byte) (*FileStat, []byte) { - var fs FileStat - if flags&sshFileXferAttrSize == sshFileXferAttrSize { - fs.Size, b, _ = unmarshalUint64Safe(b) - } - if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID { - fs.UID, b, _ = unmarshalUint32Safe(b) - } - if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID { - fs.GID, b, _ = unmarshalUint32Safe(b) - } - if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions { - fs.Mode, b, _ = unmarshalUint32Safe(b) - } - if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime { - fs.Atime, b, _ = unmarshalUint32Safe(b) - fs.Mtime, b, _ = unmarshalUint32Safe(b) - } - if flags&sshFileXferAttrExtended == sshFileXferAttrExtended { - var count uint32 - count, b, _ = unmarshalUint32Safe(b) - ext := make([]StatExtended, count) - for i := uint32(0); i < count; i++ { - var typ string - var data string - typ, b, _ = unmarshalStringSafe(b) - data, b, _ = unmarshalStringSafe(b) - ext[i] = StatExtended{typ, data} - } - fs.Extended = ext - } - return &fs, b -} - -func marshalFileInfo(b []byte, fi os.FileInfo) []byte { - // attributes variable struct, and also variable per protocol version - // spec version 3 attributes: - // uint32 flags - // uint64 size present only if flag SSH_FILEXFER_ATTR_SIZE - // uint32 uid present only if flag SSH_FILEXFER_ATTR_UIDGID - // uint32 gid present only if flag SSH_FILEXFER_ATTR_UIDGID - // uint32 permissions present only if flag SSH_FILEXFER_ATTR_PERMISSIONS - // uint32 atime present only if flag SSH_FILEXFER_ACMODTIME - // uint32 mtime present only if flag SSH_FILEXFER_ACMODTIME - // uint32 extended_count present only if flag SSH_FILEXFER_ATTR_EXTENDED - // string extended_type - // string extended_data - // ... more extended data (extended_type - extended_data pairs), - // so that number of pairs equals extended_count - - flags, fileStat := fileStatFromInfo(fi) - - b = marshalUint32(b, flags) - if flags&sshFileXferAttrSize != 0 { - b = marshalUint64(b, fileStat.Size) - } - if flags&sshFileXferAttrUIDGID != 0 { - b = marshalUint32(b, fileStat.UID) - b = marshalUint32(b, fileStat.GID) - } - if flags&sshFileXferAttrPermissions != 0 { - b = marshalUint32(b, fileStat.Mode) - } - if flags&sshFileXferAttrACmodTime != 0 { - b = marshalUint32(b, fileStat.Atime) - b = marshalUint32(b, fileStat.Mtime) - } - - return b -} diff --git a/attrs_test.go b/attrs_test.go index 18d4f5c..a755df6 100644 --- a/attrs_test.go +++ b/attrs_test.go @@ -1,45 +1,8 @@ package sftp import ( - "bytes" "os" - "reflect" - "testing" - "time" ) // ensure that attrs implemenst os.FileInfo var _ os.FileInfo = new(fileInfo) - -var unmarshalAttrsTests = []struct { - b []byte - want *fileInfo - rest []byte -}{ - {marshal(nil, struct{ Flags uint32 }{}), &fileInfo{mtime: time.Unix(int64(0), 0)}, nil}, - {marshal(nil, struct { - Flags uint32 - Size uint64 - }{sshFileXferAttrSize, 20}), &fileInfo{size: 20, mtime: time.Unix(int64(0), 0)}, nil}, - {marshal(nil, struct { - Flags uint32 - Size uint64 - Permissions uint32 - }{sshFileXferAttrSize | sshFileXferAttrPermissions, 20, 0644}), &fileInfo{size: 20, mode: os.FileMode(0644), mtime: time.Unix(int64(0), 0)}, nil}, - {marshal(nil, struct { - Flags uint32 - Size uint64 - UID, GID, Permissions uint32 - }{sshFileXferAttrSize | sshFileXferAttrUIDGID | sshFileXferAttrUIDGID | sshFileXferAttrPermissions, 20, 1000, 1000, 0644}), &fileInfo{size: 20, mode: os.FileMode(0644), mtime: time.Unix(int64(0), 0)}, nil}, -} - -func TestUnmarshalAttrs(t *testing.T) { - for _, tt := range unmarshalAttrsTests { - stat, rest := unmarshalAttrs(tt.b) - got := fileInfoFromStat(stat, "") - tt.want.sys = got.Sys() - if !reflect.DeepEqual(got, tt.want) || !bytes.Equal(tt.rest, rest) { - t.Errorf("unmarshalAttrs(%#v): want %#v, %#v, got: %#v, %#v", tt.b, tt.want, tt.rest, got, rest) - } - } -} diff --git a/client.go b/client.go index 0569094..bf21fc3 100644 --- a/client.go +++ b/client.go @@ -1892,28 +1892,6 @@ func normaliseError(err error) error { } } -func unmarshalStatus(id uint32, data []byte) error { - sid, data := unmarshalUint32(data) - if sid != id { - return &unexpectedIDErr{id, sid} - } - code, data := unmarshalUint32(data) - msg, data, _ := unmarshalStringSafe(data) - lang, _, _ := unmarshalStringSafe(data) - return &StatusError{ - Code: code, - msg: msg, - lang: lang, - } -} - -func marshalStatus(b []byte, err StatusError) []byte { - b = marshalUint32(b, err.Code) - b = marshalString(b, err.msg) - b = marshalString(b, err.lang) - return b -} - // flags converts the flags passed to OpenFile into ssh flags. // Unsupported flags are ignored. func flags(f int) uint32 { diff --git a/client_test.go b/client_test.go index 8a8c51b..4577ca2 100644 --- a/client_test.go +++ b/client_test.go @@ -5,7 +5,6 @@ import ( "errors" "io" "os" - "reflect" "testing" "github.com/kr/fs" @@ -89,64 +88,6 @@ func TestFlags(t *testing.T) { } } -func TestUnmarshalStatus(t *testing.T) { - requestID := uint32(1) - - id := marshalUint32([]byte{}, requestID) - idCode := marshalUint32(id, sshFxFailure) - idCodeMsg := marshalString(idCode, "err msg") - idCodeMsgLang := marshalString(idCodeMsg, "lang tag") - - var tests = []struct { - desc string - reqID uint32 - status []byte - want error - }{ - { - desc: "well-formed status", - reqID: 1, - status: idCodeMsgLang, - want: &StatusError{ - Code: sshFxFailure, - msg: "err msg", - lang: "lang tag", - }, - }, - { - desc: "missing error message and language tag", - reqID: 1, - status: idCode, - want: &StatusError{ - Code: sshFxFailure, - }, - }, - { - desc: "missing language tag", - reqID: 1, - status: idCodeMsg, - want: &StatusError{ - Code: sshFxFailure, - msg: "err msg", - }, - }, - { - desc: "request identifier mismatch", - reqID: 2, - status: idCodeMsgLang, - want: &unexpectedIDErr{2, requestID}, - }, - } - - for _, tt := range tests { - got := unmarshalStatus(tt.reqID, tt.status) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("unmarshalStatus(%v, %v), test %q\n- want: %#v\n- got: %#v", - requestID, tt.status, tt.desc, tt.want, got) - } - } -} - type packetSizeTest struct { size int valid bool diff --git a/packet.go b/packet.go index 2b2e592..96a8ede 100644 --- a/packet.go +++ b/packet.go @@ -37,6 +37,50 @@ func marshalString(b []byte, v string) []byte { return append(marshalUint32(b, uint32(len(v))), v...) } +func marshalFileInfo(b []byte, fi os.FileInfo) []byte { + // attributes variable struct, and also variable per protocol version + // spec version 3 attributes: + // uint32 flags + // uint64 size present only if flag SSH_FILEXFER_ATTR_SIZE + // uint32 uid present only if flag SSH_FILEXFER_ATTR_UIDGID + // uint32 gid present only if flag SSH_FILEXFER_ATTR_UIDGID + // uint32 permissions present only if flag SSH_FILEXFER_ATTR_PERMISSIONS + // uint32 atime present only if flag SSH_FILEXFER_ACMODTIME + // uint32 mtime present only if flag SSH_FILEXFER_ACMODTIME + // uint32 extended_count present only if flag SSH_FILEXFER_ATTR_EXTENDED + // string extended_type + // string extended_data + // ... more extended data (extended_type - extended_data pairs), + // so that number of pairs equals extended_count + + flags, fileStat := fileStatFromInfo(fi) + + b = marshalUint32(b, flags) + if flags&sshFileXferAttrSize != 0 { + b = marshalUint64(b, fileStat.Size) + } + if flags&sshFileXferAttrUIDGID != 0 { + b = marshalUint32(b, fileStat.UID) + b = marshalUint32(b, fileStat.GID) + } + if flags&sshFileXferAttrPermissions != 0 { + b = marshalUint32(b, fileStat.Mode) + } + if flags&sshFileXferAttrACmodTime != 0 { + b = marshalUint32(b, fileStat.Atime) + b = marshalUint32(b, fileStat.Mtime) + } + + return b +} + +func marshalStatus(b []byte, err StatusError) []byte { + b = marshalUint32(b, err.Code) + b = marshalString(b, err.msg) + b = marshalString(b, err.lang) + return b +} + func marshal(b []byte, v interface{}) []byte { if v == nil { return b @@ -115,6 +159,63 @@ func unmarshalStringSafe(b []byte) (string, []byte, error) { return string(b[:n]), b[n:], nil } +func unmarshalAttrs(b []byte) (*FileStat, []byte) { + flags, b := unmarshalUint32(b) + return unmarshalFileStat(flags, b) +} + +func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) { + var fs FileStat + if flags&sshFileXferAttrSize == sshFileXferAttrSize { + fs.Size, b, _ = unmarshalUint64Safe(b) + } + if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID { + fs.UID, b, _ = unmarshalUint32Safe(b) + } + if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID { + fs.GID, b, _ = unmarshalUint32Safe(b) + } + if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions { + fs.Mode, b, _ = unmarshalUint32Safe(b) + } + if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime { + fs.Atime, b, _ = unmarshalUint32Safe(b) + fs.Mtime, b, _ = unmarshalUint32Safe(b) + } + if flags&sshFileXferAttrExtended == sshFileXferAttrExtended { + var count uint32 + count, b, _ = unmarshalUint32Safe(b) + ext := make([]StatExtended, count) + for i := uint32(0); i < count; i++ { + var typ string + var data string + typ, b, _ = unmarshalStringSafe(b) + data, b, _ = unmarshalStringSafe(b) + ext[i] = StatExtended{ + ExtType: typ, + ExtData: data, + } + } + fs.Extended = ext + } + return &fs, b +} + +func unmarshalStatus(id uint32, data []byte) error { + sid, data := unmarshalUint32(data) + if sid != id { + return &unexpectedIDErr{id, sid} + } + code, data := unmarshalUint32(data) + msg, data, _ := unmarshalStringSafe(data) + lang, _, _ := unmarshalStringSafe(data) + return &StatusError{ + Code: code, + msg: msg, + lang: lang, + } +} + type packetMarshaler interface { marshalPacket() (header, payload []byte, err error) } @@ -639,11 +740,13 @@ const dataHeaderLen = 4 + 1 + 4 + 4 func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte { dataLen := clamp(p.Len, maxTxPacket) + if alloc != nil { // GetPage returns a slice with capacity = maxMsgLength this is enough to avoid new allocations in // sshFxpDataPacket.MarshalBinary return alloc.GetPage(orderID)[:dataLen] } + // allocate with extra space for the header return make([]byte, dataLen, dataLen+dataHeaderLen) } diff --git a/packet_test.go b/packet_test.go index c7deb5a..cbee5e4 100644 --- a/packet_test.go +++ b/packet_test.go @@ -6,6 +6,7 @@ import ( "errors" "io/ioutil" "os" + "reflect" "testing" ) @@ -217,6 +218,138 @@ func TestUnmarshalString(t *testing.T) { } } +func TestUnmarshalAttrs(t *testing.T) { + var tests = []struct { + b []byte + want *FileStat + }{ + { + b: []byte{0x00, 0x00, 0x00, 0x00}, + want: &FileStat{}, + }, + { + b: []byte{ + 0x00, 0x00, 0x00, byte(sshFileXferAttrSize), + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20, + }, + want: &FileStat{ + Size: 20, + }, + }, + { + b: []byte{ + 0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions), + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20, + 0x00, 0x00, 0x01, 0xA4, + }, + want: &FileStat{ + Size: 20, + Mode: 0644, + }, + }, + { + b: []byte{ + 0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions | sshFileXferAttrUIDGID), + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20, + 0x00, 0x00, 0x03, 0xE8, + 0x00, 0x00, 0x03, 0xE9, + 0x00, 0x00, 0x01, 0xA4, + }, + want: &FileStat{ + Size: 20, + Mode: 0644, + UID: 1000, + GID: 1001, + }, + }, + { + b: []byte{ + 0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions | sshFileXferAttrUIDGID | sshFileXferAttrACmodTime), + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20, + 0x00, 0x00, 0x03, 0xE8, + 0x00, 0x00, 0x03, 0xE9, + 0x00, 0x00, 0x01, 0xA4, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 13, + }, + want: &FileStat{ + Size: 20, + Mode: 0644, + UID: 1000, + GID: 1001, + Atime: 42, + Mtime: 13, + }, + }, + } + + for _, tt := range tests { + got, _ := unmarshalAttrs(tt.b) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("unmarshalAttrs(% X):\n- got: %#v\n- want: %#v", tt.b, got, tt.want) + } + } +} + +func TestUnmarshalStatus(t *testing.T) { + var requestID uint32 = 1 + + id := marshalUint32(nil, requestID) + idCode := marshalUint32(id, sshFxFailure) + idCodeMsg := marshalString(idCode, "err msg") + idCodeMsgLang := marshalString(idCodeMsg, "lang tag") + + var tests = []struct { + desc string + reqID uint32 + status []byte + want error + }{ + { + desc: "well-formed status", + status: idCodeMsgLang, + want: &StatusError{ + Code: sshFxFailure, + msg: "err msg", + lang: "lang tag", + }, + }, + { + desc: "missing language tag", + status: idCodeMsg, + want: &StatusError{ + Code: sshFxFailure, + msg: "err msg", + }, + }, + { + desc: "missing error message and language tag", + status: idCode, + want: &StatusError{ + Code: sshFxFailure, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + got := unmarshalStatus(1, tt.status) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("unmarshalStatus(1, % X):\n- got: %#v\n- want: %#v", tt.status, got, tt.want) + } + }) + } + + got := unmarshalStatus(2, idCodeMsgLang) + want := &unexpectedIDErr{ + want: 2, + got: 1, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("unmarshalStatus(2, % X):\n- got: %#v\n- want: %#v", idCodeMsgLang, got, want) + } +} + func TestSendPacket(t *testing.T) { var tests = []struct { packet encoding.BinaryMarshaler diff --git a/request-attrs.go b/request-attrs.go index 7c2e5c1..b5c95b4 100644 --- a/request-attrs.go +++ b/request-attrs.go @@ -58,6 +58,6 @@ func (a FileStat) FileMode() os.FileMode { // Attributes parses file attributes byte blob and return them in a // FileStat object. func (r *Request) Attributes() *FileStat { - fs, _ := getFileStat(r.Flags, r.Attrs) + fs, _ := unmarshalFileStat(r.Flags, r.Attrs) return fs } diff --git a/request-attrs_test.go b/request-attrs_test.go index 423a17c..3e1b096 100644 --- a/request-attrs_test.go +++ b/request-attrs_test.go @@ -33,7 +33,7 @@ func TestRequestAttributes(t *testing.T) { at := []byte{} at = marshalUint32(at, 1) at = marshalUint32(at, 2) - testFs, _ := getFileStat(fl, at) + testFs, _ := unmarshalFileStat(fl, at) assert.Equal(t, fa, *testFs) // Size and Mode fa = FileStat{Mode: 700, Size: 99} @@ -41,7 +41,7 @@ func TestRequestAttributes(t *testing.T) { at = []byte{} at = marshalUint64(at, 99) at = marshalUint32(at, 700) - testFs, _ = getFileStat(fl, at) + testFs, _ = unmarshalFileStat(fl, at) assert.Equal(t, fa, *testFs) // FileMode assert.True(t, testFs.FileMode().IsRegular()) @@ -50,7 +50,7 @@ func TestRequestAttributes(t *testing.T) { } func TestRequestAttributesEmpty(t *testing.T) { - fs, b := getFileStat(sshFileXferAttrAll, nil) + fs, b := unmarshalFileStat(sshFileXferAttrAll, nil) assert.Equal(t, &FileStat{ Extended: []StatExtended{}, }, fs)