collect all marshal/unmarshal functions into packet.go

This commit is contained in:
Cassondra Foesch 2021-07-21 14:23:51 +00:00
parent 792ae58b7e
commit ba854bee45
8 changed files with 240 additions and 198 deletions

View File

@ -95,79 +95,3 @@ func fileStatFromInfo(fi os.FileInfo) (uint32, FileStat) {
return flags, fileStat return flags, fileStat
} }
func unmarshalAttrs(b []byte) (*FileStat, []byte) {
flags, b := unmarshalUint32(b)
return getFileStat(flags, b)
}
func getFileStat(flags uint32, b []byte) (*FileStat, []byte) {
var fs FileStat
if flags&sshFileXferAttrSize == sshFileXferAttrSize {
fs.Size, b, _ = unmarshalUint64Safe(b)
}
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
fs.UID, b, _ = unmarshalUint32Safe(b)
}
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
fs.GID, b, _ = unmarshalUint32Safe(b)
}
if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions {
fs.Mode, b, _ = unmarshalUint32Safe(b)
}
if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime {
fs.Atime, b, _ = unmarshalUint32Safe(b)
fs.Mtime, b, _ = unmarshalUint32Safe(b)
}
if flags&sshFileXferAttrExtended == sshFileXferAttrExtended {
var count uint32
count, b, _ = unmarshalUint32Safe(b)
ext := make([]StatExtended, count)
for i := uint32(0); i < count; i++ {
var typ string
var data string
typ, b, _ = unmarshalStringSafe(b)
data, b, _ = unmarshalStringSafe(b)
ext[i] = StatExtended{typ, data}
}
fs.Extended = ext
}
return &fs, b
}
func marshalFileInfo(b []byte, fi os.FileInfo) []byte {
// attributes variable struct, and also variable per protocol version
// spec version 3 attributes:
// uint32 flags
// uint64 size present only if flag SSH_FILEXFER_ATTR_SIZE
// uint32 uid present only if flag SSH_FILEXFER_ATTR_UIDGID
// uint32 gid present only if flag SSH_FILEXFER_ATTR_UIDGID
// uint32 permissions present only if flag SSH_FILEXFER_ATTR_PERMISSIONS
// uint32 atime present only if flag SSH_FILEXFER_ACMODTIME
// uint32 mtime present only if flag SSH_FILEXFER_ACMODTIME
// uint32 extended_count present only if flag SSH_FILEXFER_ATTR_EXTENDED
// string extended_type
// string extended_data
// ... more extended data (extended_type - extended_data pairs),
// so that number of pairs equals extended_count
flags, fileStat := fileStatFromInfo(fi)
b = marshalUint32(b, flags)
if flags&sshFileXferAttrSize != 0 {
b = marshalUint64(b, fileStat.Size)
}
if flags&sshFileXferAttrUIDGID != 0 {
b = marshalUint32(b, fileStat.UID)
b = marshalUint32(b, fileStat.GID)
}
if flags&sshFileXferAttrPermissions != 0 {
b = marshalUint32(b, fileStat.Mode)
}
if flags&sshFileXferAttrACmodTime != 0 {
b = marshalUint32(b, fileStat.Atime)
b = marshalUint32(b, fileStat.Mtime)
}
return b
}

View File

@ -1,45 +1,8 @@
package sftp package sftp
import ( import (
"bytes"
"os" "os"
"reflect"
"testing"
"time"
) )
// ensure that attrs implemenst os.FileInfo // ensure that attrs implemenst os.FileInfo
var _ os.FileInfo = new(fileInfo) var _ os.FileInfo = new(fileInfo)
var unmarshalAttrsTests = []struct {
b []byte
want *fileInfo
rest []byte
}{
{marshal(nil, struct{ Flags uint32 }{}), &fileInfo{mtime: time.Unix(int64(0), 0)}, nil},
{marshal(nil, struct {
Flags uint32
Size uint64
}{sshFileXferAttrSize, 20}), &fileInfo{size: 20, mtime: time.Unix(int64(0), 0)}, nil},
{marshal(nil, struct {
Flags uint32
Size uint64
Permissions uint32
}{sshFileXferAttrSize | sshFileXferAttrPermissions, 20, 0644}), &fileInfo{size: 20, mode: os.FileMode(0644), mtime: time.Unix(int64(0), 0)}, nil},
{marshal(nil, struct {
Flags uint32
Size uint64
UID, GID, Permissions uint32
}{sshFileXferAttrSize | sshFileXferAttrUIDGID | sshFileXferAttrUIDGID | sshFileXferAttrPermissions, 20, 1000, 1000, 0644}), &fileInfo{size: 20, mode: os.FileMode(0644), mtime: time.Unix(int64(0), 0)}, nil},
}
func TestUnmarshalAttrs(t *testing.T) {
for _, tt := range unmarshalAttrsTests {
stat, rest := unmarshalAttrs(tt.b)
got := fileInfoFromStat(stat, "")
tt.want.sys = got.Sys()
if !reflect.DeepEqual(got, tt.want) || !bytes.Equal(tt.rest, rest) {
t.Errorf("unmarshalAttrs(%#v): want %#v, %#v, got: %#v, %#v", tt.b, tt.want, tt.rest, got, rest)
}
}
}

View File

@ -1892,28 +1892,6 @@ func normaliseError(err error) error {
} }
} }
func unmarshalStatus(id uint32, data []byte) error {
sid, data := unmarshalUint32(data)
if sid != id {
return &unexpectedIDErr{id, sid}
}
code, data := unmarshalUint32(data)
msg, data, _ := unmarshalStringSafe(data)
lang, _, _ := unmarshalStringSafe(data)
return &StatusError{
Code: code,
msg: msg,
lang: lang,
}
}
func marshalStatus(b []byte, err StatusError) []byte {
b = marshalUint32(b, err.Code)
b = marshalString(b, err.msg)
b = marshalString(b, err.lang)
return b
}
// flags converts the flags passed to OpenFile into ssh flags. // flags converts the flags passed to OpenFile into ssh flags.
// Unsupported flags are ignored. // Unsupported flags are ignored.
func flags(f int) uint32 { func flags(f int) uint32 {

View File

@ -5,7 +5,6 @@ import (
"errors" "errors"
"io" "io"
"os" "os"
"reflect"
"testing" "testing"
"github.com/kr/fs" "github.com/kr/fs"
@ -89,64 +88,6 @@ func TestFlags(t *testing.T) {
} }
} }
func TestUnmarshalStatus(t *testing.T) {
requestID := uint32(1)
id := marshalUint32([]byte{}, requestID)
idCode := marshalUint32(id, sshFxFailure)
idCodeMsg := marshalString(idCode, "err msg")
idCodeMsgLang := marshalString(idCodeMsg, "lang tag")
var tests = []struct {
desc string
reqID uint32
status []byte
want error
}{
{
desc: "well-formed status",
reqID: 1,
status: idCodeMsgLang,
want: &StatusError{
Code: sshFxFailure,
msg: "err msg",
lang: "lang tag",
},
},
{
desc: "missing error message and language tag",
reqID: 1,
status: idCode,
want: &StatusError{
Code: sshFxFailure,
},
},
{
desc: "missing language tag",
reqID: 1,
status: idCodeMsg,
want: &StatusError{
Code: sshFxFailure,
msg: "err msg",
},
},
{
desc: "request identifier mismatch",
reqID: 2,
status: idCodeMsgLang,
want: &unexpectedIDErr{2, requestID},
},
}
for _, tt := range tests {
got := unmarshalStatus(tt.reqID, tt.status)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("unmarshalStatus(%v, %v), test %q\n- want: %#v\n- got: %#v",
requestID, tt.status, tt.desc, tt.want, got)
}
}
}
type packetSizeTest struct { type packetSizeTest struct {
size int size int
valid bool valid bool

103
packet.go
View File

@ -37,6 +37,50 @@ func marshalString(b []byte, v string) []byte {
return append(marshalUint32(b, uint32(len(v))), v...) return append(marshalUint32(b, uint32(len(v))), v...)
} }
func marshalFileInfo(b []byte, fi os.FileInfo) []byte {
// attributes variable struct, and also variable per protocol version
// spec version 3 attributes:
// uint32 flags
// uint64 size present only if flag SSH_FILEXFER_ATTR_SIZE
// uint32 uid present only if flag SSH_FILEXFER_ATTR_UIDGID
// uint32 gid present only if flag SSH_FILEXFER_ATTR_UIDGID
// uint32 permissions present only if flag SSH_FILEXFER_ATTR_PERMISSIONS
// uint32 atime present only if flag SSH_FILEXFER_ACMODTIME
// uint32 mtime present only if flag SSH_FILEXFER_ACMODTIME
// uint32 extended_count present only if flag SSH_FILEXFER_ATTR_EXTENDED
// string extended_type
// string extended_data
// ... more extended data (extended_type - extended_data pairs),
// so that number of pairs equals extended_count
flags, fileStat := fileStatFromInfo(fi)
b = marshalUint32(b, flags)
if flags&sshFileXferAttrSize != 0 {
b = marshalUint64(b, fileStat.Size)
}
if flags&sshFileXferAttrUIDGID != 0 {
b = marshalUint32(b, fileStat.UID)
b = marshalUint32(b, fileStat.GID)
}
if flags&sshFileXferAttrPermissions != 0 {
b = marshalUint32(b, fileStat.Mode)
}
if flags&sshFileXferAttrACmodTime != 0 {
b = marshalUint32(b, fileStat.Atime)
b = marshalUint32(b, fileStat.Mtime)
}
return b
}
func marshalStatus(b []byte, err StatusError) []byte {
b = marshalUint32(b, err.Code)
b = marshalString(b, err.msg)
b = marshalString(b, err.lang)
return b
}
func marshal(b []byte, v interface{}) []byte { func marshal(b []byte, v interface{}) []byte {
if v == nil { if v == nil {
return b return b
@ -115,6 +159,63 @@ func unmarshalStringSafe(b []byte) (string, []byte, error) {
return string(b[:n]), b[n:], nil return string(b[:n]), b[n:], nil
} }
func unmarshalAttrs(b []byte) (*FileStat, []byte) {
flags, b := unmarshalUint32(b)
return unmarshalFileStat(flags, b)
}
func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) {
var fs FileStat
if flags&sshFileXferAttrSize == sshFileXferAttrSize {
fs.Size, b, _ = unmarshalUint64Safe(b)
}
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
fs.UID, b, _ = unmarshalUint32Safe(b)
}
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
fs.GID, b, _ = unmarshalUint32Safe(b)
}
if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions {
fs.Mode, b, _ = unmarshalUint32Safe(b)
}
if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime {
fs.Atime, b, _ = unmarshalUint32Safe(b)
fs.Mtime, b, _ = unmarshalUint32Safe(b)
}
if flags&sshFileXferAttrExtended == sshFileXferAttrExtended {
var count uint32
count, b, _ = unmarshalUint32Safe(b)
ext := make([]StatExtended, count)
for i := uint32(0); i < count; i++ {
var typ string
var data string
typ, b, _ = unmarshalStringSafe(b)
data, b, _ = unmarshalStringSafe(b)
ext[i] = StatExtended{
ExtType: typ,
ExtData: data,
}
}
fs.Extended = ext
}
return &fs, b
}
func unmarshalStatus(id uint32, data []byte) error {
sid, data := unmarshalUint32(data)
if sid != id {
return &unexpectedIDErr{id, sid}
}
code, data := unmarshalUint32(data)
msg, data, _ := unmarshalStringSafe(data)
lang, _, _ := unmarshalStringSafe(data)
return &StatusError{
Code: code,
msg: msg,
lang: lang,
}
}
type packetMarshaler interface { type packetMarshaler interface {
marshalPacket() (header, payload []byte, err error) marshalPacket() (header, payload []byte, err error)
} }
@ -639,11 +740,13 @@ const dataHeaderLen = 4 + 1 + 4 + 4
func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte { func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte {
dataLen := clamp(p.Len, maxTxPacket) dataLen := clamp(p.Len, maxTxPacket)
if alloc != nil { if alloc != nil {
// GetPage returns a slice with capacity = maxMsgLength this is enough to avoid new allocations in // GetPage returns a slice with capacity = maxMsgLength this is enough to avoid new allocations in
// sshFxpDataPacket.MarshalBinary // sshFxpDataPacket.MarshalBinary
return alloc.GetPage(orderID)[:dataLen] return alloc.GetPage(orderID)[:dataLen]
} }
// allocate with extra space for the header // allocate with extra space for the header
return make([]byte, dataLen, dataLen+dataHeaderLen) return make([]byte, dataLen, dataLen+dataHeaderLen)
} }

View File

@ -6,6 +6,7 @@ import (
"errors" "errors"
"io/ioutil" "io/ioutil"
"os" "os"
"reflect"
"testing" "testing"
) )
@ -217,6 +218,138 @@ func TestUnmarshalString(t *testing.T) {
} }
} }
func TestUnmarshalAttrs(t *testing.T) {
var tests = []struct {
b []byte
want *FileStat
}{
{
b: []byte{0x00, 0x00, 0x00, 0x00},
want: &FileStat{},
},
{
b: []byte{
0x00, 0x00, 0x00, byte(sshFileXferAttrSize),
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20,
},
want: &FileStat{
Size: 20,
},
},
{
b: []byte{
0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions),
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20,
0x00, 0x00, 0x01, 0xA4,
},
want: &FileStat{
Size: 20,
Mode: 0644,
},
},
{
b: []byte{
0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions | sshFileXferAttrUIDGID),
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20,
0x00, 0x00, 0x03, 0xE8,
0x00, 0x00, 0x03, 0xE9,
0x00, 0x00, 0x01, 0xA4,
},
want: &FileStat{
Size: 20,
Mode: 0644,
UID: 1000,
GID: 1001,
},
},
{
b: []byte{
0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions | sshFileXferAttrUIDGID | sshFileXferAttrACmodTime),
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20,
0x00, 0x00, 0x03, 0xE8,
0x00, 0x00, 0x03, 0xE9,
0x00, 0x00, 0x01, 0xA4,
0x00, 0x00, 0x00, 42,
0x00, 0x00, 0x00, 13,
},
want: &FileStat{
Size: 20,
Mode: 0644,
UID: 1000,
GID: 1001,
Atime: 42,
Mtime: 13,
},
},
}
for _, tt := range tests {
got, _ := unmarshalAttrs(tt.b)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("unmarshalAttrs(% X):\n- got: %#v\n- want: %#v", tt.b, got, tt.want)
}
}
}
func TestUnmarshalStatus(t *testing.T) {
var requestID uint32 = 1
id := marshalUint32(nil, requestID)
idCode := marshalUint32(id, sshFxFailure)
idCodeMsg := marshalString(idCode, "err msg")
idCodeMsgLang := marshalString(idCodeMsg, "lang tag")
var tests = []struct {
desc string
reqID uint32
status []byte
want error
}{
{
desc: "well-formed status",
status: idCodeMsgLang,
want: &StatusError{
Code: sshFxFailure,
msg: "err msg",
lang: "lang tag",
},
},
{
desc: "missing language tag",
status: idCodeMsg,
want: &StatusError{
Code: sshFxFailure,
msg: "err msg",
},
},
{
desc: "missing error message and language tag",
status: idCode,
want: &StatusError{
Code: sshFxFailure,
},
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
got := unmarshalStatus(1, tt.status)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("unmarshalStatus(1, % X):\n- got: %#v\n- want: %#v", tt.status, got, tt.want)
}
})
}
got := unmarshalStatus(2, idCodeMsgLang)
want := &unexpectedIDErr{
want: 2,
got: 1,
}
if !reflect.DeepEqual(got, want) {
t.Errorf("unmarshalStatus(2, % X):\n- got: %#v\n- want: %#v", idCodeMsgLang, got, want)
}
}
func TestSendPacket(t *testing.T) { func TestSendPacket(t *testing.T) {
var tests = []struct { var tests = []struct {
packet encoding.BinaryMarshaler packet encoding.BinaryMarshaler

View File

@ -58,6 +58,6 @@ func (a FileStat) FileMode() os.FileMode {
// Attributes parses file attributes byte blob and return them in a // Attributes parses file attributes byte blob and return them in a
// FileStat object. // FileStat object.
func (r *Request) Attributes() *FileStat { func (r *Request) Attributes() *FileStat {
fs, _ := getFileStat(r.Flags, r.Attrs) fs, _ := unmarshalFileStat(r.Flags, r.Attrs)
return fs return fs
} }

View File

@ -33,7 +33,7 @@ func TestRequestAttributes(t *testing.T) {
at := []byte{} at := []byte{}
at = marshalUint32(at, 1) at = marshalUint32(at, 1)
at = marshalUint32(at, 2) at = marshalUint32(at, 2)
testFs, _ := getFileStat(fl, at) testFs, _ := unmarshalFileStat(fl, at)
assert.Equal(t, fa, *testFs) assert.Equal(t, fa, *testFs)
// Size and Mode // Size and Mode
fa = FileStat{Mode: 700, Size: 99} fa = FileStat{Mode: 700, Size: 99}
@ -41,7 +41,7 @@ func TestRequestAttributes(t *testing.T) {
at = []byte{} at = []byte{}
at = marshalUint64(at, 99) at = marshalUint64(at, 99)
at = marshalUint32(at, 700) at = marshalUint32(at, 700)
testFs, _ = getFileStat(fl, at) testFs, _ = unmarshalFileStat(fl, at)
assert.Equal(t, fa, *testFs) assert.Equal(t, fa, *testFs)
// FileMode // FileMode
assert.True(t, testFs.FileMode().IsRegular()) assert.True(t, testFs.FileMode().IsRegular())
@ -50,7 +50,7 @@ func TestRequestAttributes(t *testing.T) {
} }
func TestRequestAttributesEmpty(t *testing.T) { func TestRequestAttributesEmpty(t *testing.T) {
fs, b := getFileStat(sshFileXferAttrAll, nil) fs, b := unmarshalFileStat(sshFileXferAttrAll, nil)
assert.Equal(t, &FileStat{ assert.Equal(t, &FileStat{
Extended: []StatExtended{}, Extended: []StatExtended{},
}, fs) }, fs)