mirror of https://github.com/pkg/sftp.git
collect all marshal/unmarshal functions into packet.go
This commit is contained in:
parent
792ae58b7e
commit
ba854bee45
76
attrs.go
76
attrs.go
|
@ -95,79 +95,3 @@ func fileStatFromInfo(fi os.FileInfo) (uint32, FileStat) {
|
|||
|
||||
return flags, fileStat
|
||||
}
|
||||
|
||||
func unmarshalAttrs(b []byte) (*FileStat, []byte) {
|
||||
flags, b := unmarshalUint32(b)
|
||||
return getFileStat(flags, b)
|
||||
}
|
||||
|
||||
func getFileStat(flags uint32, b []byte) (*FileStat, []byte) {
|
||||
var fs FileStat
|
||||
if flags&sshFileXferAttrSize == sshFileXferAttrSize {
|
||||
fs.Size, b, _ = unmarshalUint64Safe(b)
|
||||
}
|
||||
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
|
||||
fs.UID, b, _ = unmarshalUint32Safe(b)
|
||||
}
|
||||
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
|
||||
fs.GID, b, _ = unmarshalUint32Safe(b)
|
||||
}
|
||||
if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions {
|
||||
fs.Mode, b, _ = unmarshalUint32Safe(b)
|
||||
}
|
||||
if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime {
|
||||
fs.Atime, b, _ = unmarshalUint32Safe(b)
|
||||
fs.Mtime, b, _ = unmarshalUint32Safe(b)
|
||||
}
|
||||
if flags&sshFileXferAttrExtended == sshFileXferAttrExtended {
|
||||
var count uint32
|
||||
count, b, _ = unmarshalUint32Safe(b)
|
||||
ext := make([]StatExtended, count)
|
||||
for i := uint32(0); i < count; i++ {
|
||||
var typ string
|
||||
var data string
|
||||
typ, b, _ = unmarshalStringSafe(b)
|
||||
data, b, _ = unmarshalStringSafe(b)
|
||||
ext[i] = StatExtended{typ, data}
|
||||
}
|
||||
fs.Extended = ext
|
||||
}
|
||||
return &fs, b
|
||||
}
|
||||
|
||||
func marshalFileInfo(b []byte, fi os.FileInfo) []byte {
|
||||
// attributes variable struct, and also variable per protocol version
|
||||
// spec version 3 attributes:
|
||||
// uint32 flags
|
||||
// uint64 size present only if flag SSH_FILEXFER_ATTR_SIZE
|
||||
// uint32 uid present only if flag SSH_FILEXFER_ATTR_UIDGID
|
||||
// uint32 gid present only if flag SSH_FILEXFER_ATTR_UIDGID
|
||||
// uint32 permissions present only if flag SSH_FILEXFER_ATTR_PERMISSIONS
|
||||
// uint32 atime present only if flag SSH_FILEXFER_ACMODTIME
|
||||
// uint32 mtime present only if flag SSH_FILEXFER_ACMODTIME
|
||||
// uint32 extended_count present only if flag SSH_FILEXFER_ATTR_EXTENDED
|
||||
// string extended_type
|
||||
// string extended_data
|
||||
// ... more extended data (extended_type - extended_data pairs),
|
||||
// so that number of pairs equals extended_count
|
||||
|
||||
flags, fileStat := fileStatFromInfo(fi)
|
||||
|
||||
b = marshalUint32(b, flags)
|
||||
if flags&sshFileXferAttrSize != 0 {
|
||||
b = marshalUint64(b, fileStat.Size)
|
||||
}
|
||||
if flags&sshFileXferAttrUIDGID != 0 {
|
||||
b = marshalUint32(b, fileStat.UID)
|
||||
b = marshalUint32(b, fileStat.GID)
|
||||
}
|
||||
if flags&sshFileXferAttrPermissions != 0 {
|
||||
b = marshalUint32(b, fileStat.Mode)
|
||||
}
|
||||
if flags&sshFileXferAttrACmodTime != 0 {
|
||||
b = marshalUint32(b, fileStat.Atime)
|
||||
b = marshalUint32(b, fileStat.Mtime)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
|
|
@ -1,45 +1,8 @@
|
|||
package sftp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ensure that attrs implemenst os.FileInfo
|
||||
var _ os.FileInfo = new(fileInfo)
|
||||
|
||||
var unmarshalAttrsTests = []struct {
|
||||
b []byte
|
||||
want *fileInfo
|
||||
rest []byte
|
||||
}{
|
||||
{marshal(nil, struct{ Flags uint32 }{}), &fileInfo{mtime: time.Unix(int64(0), 0)}, nil},
|
||||
{marshal(nil, struct {
|
||||
Flags uint32
|
||||
Size uint64
|
||||
}{sshFileXferAttrSize, 20}), &fileInfo{size: 20, mtime: time.Unix(int64(0), 0)}, nil},
|
||||
{marshal(nil, struct {
|
||||
Flags uint32
|
||||
Size uint64
|
||||
Permissions uint32
|
||||
}{sshFileXferAttrSize | sshFileXferAttrPermissions, 20, 0644}), &fileInfo{size: 20, mode: os.FileMode(0644), mtime: time.Unix(int64(0), 0)}, nil},
|
||||
{marshal(nil, struct {
|
||||
Flags uint32
|
||||
Size uint64
|
||||
UID, GID, Permissions uint32
|
||||
}{sshFileXferAttrSize | sshFileXferAttrUIDGID | sshFileXferAttrUIDGID | sshFileXferAttrPermissions, 20, 1000, 1000, 0644}), &fileInfo{size: 20, mode: os.FileMode(0644), mtime: time.Unix(int64(0), 0)}, nil},
|
||||
}
|
||||
|
||||
func TestUnmarshalAttrs(t *testing.T) {
|
||||
for _, tt := range unmarshalAttrsTests {
|
||||
stat, rest := unmarshalAttrs(tt.b)
|
||||
got := fileInfoFromStat(stat, "")
|
||||
tt.want.sys = got.Sys()
|
||||
if !reflect.DeepEqual(got, tt.want) || !bytes.Equal(tt.rest, rest) {
|
||||
t.Errorf("unmarshalAttrs(%#v): want %#v, %#v, got: %#v, %#v", tt.b, tt.want, tt.rest, got, rest)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
22
client.go
22
client.go
|
@ -1892,28 +1892,6 @@ func normaliseError(err error) error {
|
|||
}
|
||||
}
|
||||
|
||||
func unmarshalStatus(id uint32, data []byte) error {
|
||||
sid, data := unmarshalUint32(data)
|
||||
if sid != id {
|
||||
return &unexpectedIDErr{id, sid}
|
||||
}
|
||||
code, data := unmarshalUint32(data)
|
||||
msg, data, _ := unmarshalStringSafe(data)
|
||||
lang, _, _ := unmarshalStringSafe(data)
|
||||
return &StatusError{
|
||||
Code: code,
|
||||
msg: msg,
|
||||
lang: lang,
|
||||
}
|
||||
}
|
||||
|
||||
func marshalStatus(b []byte, err StatusError) []byte {
|
||||
b = marshalUint32(b, err.Code)
|
||||
b = marshalString(b, err.msg)
|
||||
b = marshalString(b, err.lang)
|
||||
return b
|
||||
}
|
||||
|
||||
// flags converts the flags passed to OpenFile into ssh flags.
|
||||
// Unsupported flags are ignored.
|
||||
func flags(f int) uint32 {
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/kr/fs"
|
||||
|
@ -89,64 +88,6 @@ func TestFlags(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalStatus(t *testing.T) {
|
||||
requestID := uint32(1)
|
||||
|
||||
id := marshalUint32([]byte{}, requestID)
|
||||
idCode := marshalUint32(id, sshFxFailure)
|
||||
idCodeMsg := marshalString(idCode, "err msg")
|
||||
idCodeMsgLang := marshalString(idCodeMsg, "lang tag")
|
||||
|
||||
var tests = []struct {
|
||||
desc string
|
||||
reqID uint32
|
||||
status []byte
|
||||
want error
|
||||
}{
|
||||
{
|
||||
desc: "well-formed status",
|
||||
reqID: 1,
|
||||
status: idCodeMsgLang,
|
||||
want: &StatusError{
|
||||
Code: sshFxFailure,
|
||||
msg: "err msg",
|
||||
lang: "lang tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "missing error message and language tag",
|
||||
reqID: 1,
|
||||
status: idCode,
|
||||
want: &StatusError{
|
||||
Code: sshFxFailure,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "missing language tag",
|
||||
reqID: 1,
|
||||
status: idCodeMsg,
|
||||
want: &StatusError{
|
||||
Code: sshFxFailure,
|
||||
msg: "err msg",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "request identifier mismatch",
|
||||
reqID: 2,
|
||||
status: idCodeMsgLang,
|
||||
want: &unexpectedIDErr{2, requestID},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := unmarshalStatus(tt.reqID, tt.status)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("unmarshalStatus(%v, %v), test %q\n- want: %#v\n- got: %#v",
|
||||
requestID, tt.status, tt.desc, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type packetSizeTest struct {
|
||||
size int
|
||||
valid bool
|
||||
|
|
103
packet.go
103
packet.go
|
@ -37,6 +37,50 @@ func marshalString(b []byte, v string) []byte {
|
|||
return append(marshalUint32(b, uint32(len(v))), v...)
|
||||
}
|
||||
|
||||
func marshalFileInfo(b []byte, fi os.FileInfo) []byte {
|
||||
// attributes variable struct, and also variable per protocol version
|
||||
// spec version 3 attributes:
|
||||
// uint32 flags
|
||||
// uint64 size present only if flag SSH_FILEXFER_ATTR_SIZE
|
||||
// uint32 uid present only if flag SSH_FILEXFER_ATTR_UIDGID
|
||||
// uint32 gid present only if flag SSH_FILEXFER_ATTR_UIDGID
|
||||
// uint32 permissions present only if flag SSH_FILEXFER_ATTR_PERMISSIONS
|
||||
// uint32 atime present only if flag SSH_FILEXFER_ACMODTIME
|
||||
// uint32 mtime present only if flag SSH_FILEXFER_ACMODTIME
|
||||
// uint32 extended_count present only if flag SSH_FILEXFER_ATTR_EXTENDED
|
||||
// string extended_type
|
||||
// string extended_data
|
||||
// ... more extended data (extended_type - extended_data pairs),
|
||||
// so that number of pairs equals extended_count
|
||||
|
||||
flags, fileStat := fileStatFromInfo(fi)
|
||||
|
||||
b = marshalUint32(b, flags)
|
||||
if flags&sshFileXferAttrSize != 0 {
|
||||
b = marshalUint64(b, fileStat.Size)
|
||||
}
|
||||
if flags&sshFileXferAttrUIDGID != 0 {
|
||||
b = marshalUint32(b, fileStat.UID)
|
||||
b = marshalUint32(b, fileStat.GID)
|
||||
}
|
||||
if flags&sshFileXferAttrPermissions != 0 {
|
||||
b = marshalUint32(b, fileStat.Mode)
|
||||
}
|
||||
if flags&sshFileXferAttrACmodTime != 0 {
|
||||
b = marshalUint32(b, fileStat.Atime)
|
||||
b = marshalUint32(b, fileStat.Mtime)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func marshalStatus(b []byte, err StatusError) []byte {
|
||||
b = marshalUint32(b, err.Code)
|
||||
b = marshalString(b, err.msg)
|
||||
b = marshalString(b, err.lang)
|
||||
return b
|
||||
}
|
||||
|
||||
func marshal(b []byte, v interface{}) []byte {
|
||||
if v == nil {
|
||||
return b
|
||||
|
@ -115,6 +159,63 @@ func unmarshalStringSafe(b []byte) (string, []byte, error) {
|
|||
return string(b[:n]), b[n:], nil
|
||||
}
|
||||
|
||||
func unmarshalAttrs(b []byte) (*FileStat, []byte) {
|
||||
flags, b := unmarshalUint32(b)
|
||||
return unmarshalFileStat(flags, b)
|
||||
}
|
||||
|
||||
func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) {
|
||||
var fs FileStat
|
||||
if flags&sshFileXferAttrSize == sshFileXferAttrSize {
|
||||
fs.Size, b, _ = unmarshalUint64Safe(b)
|
||||
}
|
||||
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
|
||||
fs.UID, b, _ = unmarshalUint32Safe(b)
|
||||
}
|
||||
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
|
||||
fs.GID, b, _ = unmarshalUint32Safe(b)
|
||||
}
|
||||
if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions {
|
||||
fs.Mode, b, _ = unmarshalUint32Safe(b)
|
||||
}
|
||||
if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime {
|
||||
fs.Atime, b, _ = unmarshalUint32Safe(b)
|
||||
fs.Mtime, b, _ = unmarshalUint32Safe(b)
|
||||
}
|
||||
if flags&sshFileXferAttrExtended == sshFileXferAttrExtended {
|
||||
var count uint32
|
||||
count, b, _ = unmarshalUint32Safe(b)
|
||||
ext := make([]StatExtended, count)
|
||||
for i := uint32(0); i < count; i++ {
|
||||
var typ string
|
||||
var data string
|
||||
typ, b, _ = unmarshalStringSafe(b)
|
||||
data, b, _ = unmarshalStringSafe(b)
|
||||
ext[i] = StatExtended{
|
||||
ExtType: typ,
|
||||
ExtData: data,
|
||||
}
|
||||
}
|
||||
fs.Extended = ext
|
||||
}
|
||||
return &fs, b
|
||||
}
|
||||
|
||||
func unmarshalStatus(id uint32, data []byte) error {
|
||||
sid, data := unmarshalUint32(data)
|
||||
if sid != id {
|
||||
return &unexpectedIDErr{id, sid}
|
||||
}
|
||||
code, data := unmarshalUint32(data)
|
||||
msg, data, _ := unmarshalStringSafe(data)
|
||||
lang, _, _ := unmarshalStringSafe(data)
|
||||
return &StatusError{
|
||||
Code: code,
|
||||
msg: msg,
|
||||
lang: lang,
|
||||
}
|
||||
}
|
||||
|
||||
type packetMarshaler interface {
|
||||
marshalPacket() (header, payload []byte, err error)
|
||||
}
|
||||
|
@ -639,11 +740,13 @@ const dataHeaderLen = 4 + 1 + 4 + 4
|
|||
|
||||
func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte {
|
||||
dataLen := clamp(p.Len, maxTxPacket)
|
||||
|
||||
if alloc != nil {
|
||||
// GetPage returns a slice with capacity = maxMsgLength this is enough to avoid new allocations in
|
||||
// sshFxpDataPacket.MarshalBinary
|
||||
return alloc.GetPage(orderID)[:dataLen]
|
||||
}
|
||||
|
||||
// allocate with extra space for the header
|
||||
return make([]byte, dataLen, dataLen+dataHeaderLen)
|
||||
}
|
||||
|
|
133
packet_test.go
133
packet_test.go
|
@ -6,6 +6,7 @@ import (
|
|||
"errors"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
@ -217,6 +218,138 @@ func TestUnmarshalString(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalAttrs(t *testing.T) {
|
||||
var tests = []struct {
|
||||
b []byte
|
||||
want *FileStat
|
||||
}{
|
||||
{
|
||||
b: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
want: &FileStat{},
|
||||
},
|
||||
{
|
||||
b: []byte{
|
||||
0x00, 0x00, 0x00, byte(sshFileXferAttrSize),
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20,
|
||||
},
|
||||
want: &FileStat{
|
||||
Size: 20,
|
||||
},
|
||||
},
|
||||
{
|
||||
b: []byte{
|
||||
0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions),
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20,
|
||||
0x00, 0x00, 0x01, 0xA4,
|
||||
},
|
||||
want: &FileStat{
|
||||
Size: 20,
|
||||
Mode: 0644,
|
||||
},
|
||||
},
|
||||
{
|
||||
b: []byte{
|
||||
0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions | sshFileXferAttrUIDGID),
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20,
|
||||
0x00, 0x00, 0x03, 0xE8,
|
||||
0x00, 0x00, 0x03, 0xE9,
|
||||
0x00, 0x00, 0x01, 0xA4,
|
||||
},
|
||||
want: &FileStat{
|
||||
Size: 20,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1001,
|
||||
},
|
||||
},
|
||||
{
|
||||
b: []byte{
|
||||
0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions | sshFileXferAttrUIDGID | sshFileXferAttrACmodTime),
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20,
|
||||
0x00, 0x00, 0x03, 0xE8,
|
||||
0x00, 0x00, 0x03, 0xE9,
|
||||
0x00, 0x00, 0x01, 0xA4,
|
||||
0x00, 0x00, 0x00, 42,
|
||||
0x00, 0x00, 0x00, 13,
|
||||
},
|
||||
want: &FileStat{
|
||||
Size: 20,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1001,
|
||||
Atime: 42,
|
||||
Mtime: 13,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got, _ := unmarshalAttrs(tt.b)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("unmarshalAttrs(% X):\n- got: %#v\n- want: %#v", tt.b, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalStatus(t *testing.T) {
|
||||
var requestID uint32 = 1
|
||||
|
||||
id := marshalUint32(nil, requestID)
|
||||
idCode := marshalUint32(id, sshFxFailure)
|
||||
idCodeMsg := marshalString(idCode, "err msg")
|
||||
idCodeMsgLang := marshalString(idCodeMsg, "lang tag")
|
||||
|
||||
var tests = []struct {
|
||||
desc string
|
||||
reqID uint32
|
||||
status []byte
|
||||
want error
|
||||
}{
|
||||
{
|
||||
desc: "well-formed status",
|
||||
status: idCodeMsgLang,
|
||||
want: &StatusError{
|
||||
Code: sshFxFailure,
|
||||
msg: "err msg",
|
||||
lang: "lang tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "missing language tag",
|
||||
status: idCodeMsg,
|
||||
want: &StatusError{
|
||||
Code: sshFxFailure,
|
||||
msg: "err msg",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "missing error message and language tag",
|
||||
status: idCode,
|
||||
want: &StatusError{
|
||||
Code: sshFxFailure,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
got := unmarshalStatus(1, tt.status)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("unmarshalStatus(1, % X):\n- got: %#v\n- want: %#v", tt.status, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
got := unmarshalStatus(2, idCodeMsgLang)
|
||||
want := &unexpectedIDErr{
|
||||
want: 2,
|
||||
got: 1,
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("unmarshalStatus(2, % X):\n- got: %#v\n- want: %#v", idCodeMsgLang, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendPacket(t *testing.T) {
|
||||
var tests = []struct {
|
||||
packet encoding.BinaryMarshaler
|
||||
|
|
|
@ -58,6 +58,6 @@ func (a FileStat) FileMode() os.FileMode {
|
|||
// Attributes parses file attributes byte blob and return them in a
|
||||
// FileStat object.
|
||||
func (r *Request) Attributes() *FileStat {
|
||||
fs, _ := getFileStat(r.Flags, r.Attrs)
|
||||
fs, _ := unmarshalFileStat(r.Flags, r.Attrs)
|
||||
return fs
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ func TestRequestAttributes(t *testing.T) {
|
|||
at := []byte{}
|
||||
at = marshalUint32(at, 1)
|
||||
at = marshalUint32(at, 2)
|
||||
testFs, _ := getFileStat(fl, at)
|
||||
testFs, _ := unmarshalFileStat(fl, at)
|
||||
assert.Equal(t, fa, *testFs)
|
||||
// Size and Mode
|
||||
fa = FileStat{Mode: 700, Size: 99}
|
||||
|
@ -41,7 +41,7 @@ func TestRequestAttributes(t *testing.T) {
|
|||
at = []byte{}
|
||||
at = marshalUint64(at, 99)
|
||||
at = marshalUint32(at, 700)
|
||||
testFs, _ = getFileStat(fl, at)
|
||||
testFs, _ = unmarshalFileStat(fl, at)
|
||||
assert.Equal(t, fa, *testFs)
|
||||
// FileMode
|
||||
assert.True(t, testFs.FileMode().IsRegular())
|
||||
|
@ -50,7 +50,7 @@ func TestRequestAttributes(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRequestAttributesEmpty(t *testing.T) {
|
||||
fs, b := getFileStat(sshFileXferAttrAll, nil)
|
||||
fs, b := unmarshalFileStat(sshFileXferAttrAll, nil)
|
||||
assert.Equal(t, &FileStat{
|
||||
Extended: []StatExtended{},
|
||||
}, fs)
|
||||
|
|
Loading…
Reference in New Issue