diff --git a/encoding/ssh/filexfer/buffer.go b/encoding/ssh/filexfer/buffer.go index 4091e13..d8e9d97 100644 --- a/encoding/ssh/filexfer/buffer.go +++ b/encoding/ssh/filexfer/buffer.go @@ -29,16 +29,19 @@ func NewBuffer(buf []byte) *Buffer { } } -// NewMarshalBuffer creates an initializes a new Buffer ready to start marshaling a Packet into. -// It prepopulates 4 bytes for length, the 1-byte packetType, and the 4-byte requestID. -// It preallocates enough space for an additional size bytes of data above the prepopulated values. -func NewMarshalBuffer(packetType PacketType, requestID uint32, size int) *Buffer { - buf := NewBuffer(make([]byte, 4, 4+1+4+size)) +// NewMarshalBuffer creates a new Buffer ready to start marshaling a Packet into. +// It preallocates enough space for uint32(length), uint8(type), uint32(request-id) and size more bytes. +func NewMarshalBuffer(size int) *Buffer { + return NewBuffer(make([]byte, 4+1+4+size)) +} - buf.AppendUint8(uint8(packetType)) - buf.AppendUint32(requestID) +// StartPacket resets and initializes the Buffer to be ready to start marshaling a Packet body into. +// It truncates the buffer, reserves space for uint32(length), then appends the packetType and requestID. +func (b *Buffer) StartPacket(packetType PacketType, requestID uint32) { + b.b = append(b.b[:0], make([]byte, 4)...) - return buf + b.AppendUint8(uint8(packetType)) + b.AppendUint32(requestID) } // Bytes returns a slice of length b.Len() holding the unconsumed bytes in the Buffer. @@ -65,9 +68,11 @@ func (b *Buffer) Packet(payload []byte) (header, payloadPassThru []byte, err err } // Len returns the number of unconsumed bytes in the Buffer. -func (b *Buffer) Len() int { - return len(b.b) - b.off -} +func (b *Buffer) Len() int { return len(b.b) - b.off } + +// Cap returns the capacity of the Buffer’s underlying byte slice, +// that is, the total space allocated for the buffer’s data. +func (b *Buffer) Cap() int { return cap(b.b) } // ConsumeUint8 consumes a single byte from the Buffer. // If Buffer does not have enough data, it will return ErrShortPacket. diff --git a/encoding/ssh/filexfer/extended_packets.go b/encoding/ssh/filexfer/extended_packets.go index 94c6310..de04c39 100644 --- a/encoding/ssh/filexfer/extended_packets.go +++ b/encoding/ssh/filexfer/extended_packets.go @@ -50,12 +50,15 @@ type ExtendedPacket struct { // MarshalPacket returns p as a two-part binary encoding of p. // // The Data is marshaled into binary, and returned as the payload. -func (p *ExtendedPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 + len(p.ExtendedRequest) // string(extended-request) +func (p *ExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.ExtendedRequest) // string(extended-request) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeExtended, reqid, size) - - b.AppendString(p.ExtendedRequest) + buf.StartPacket(PacketTypeExtended, reqid) + buf.AppendString(p.ExtendedRequest) if p.Data != nil { payload, err = p.Data.MarshalBinary() @@ -64,7 +67,7 @@ func (p *ExtendedPacket) MarshalPacket(reqid uint32) (header, payload []byte, er } } - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -93,8 +96,13 @@ type ExtendedReplyPacket struct { // MarshalPacket returns p as a two-part binary encoding of p. // // The Data is marshaled into binary, and returned as the payload. -func (p *ExtendedReplyPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - b := NewMarshalBuffer(PacketTypeExtendedReply, reqid, 0) +func (p *ExtendedReplyPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + buf = NewMarshalBuffer(0) + } + + buf.StartPacket(PacketTypeExtendedReply, reqid) if p.Data != nil { payload, err = p.Data.MarshalBinary() @@ -103,7 +111,7 @@ func (p *ExtendedReplyPacket) MarshalPacket(reqid uint32) (header, payload []byt } } - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. diff --git a/encoding/ssh/filexfer/extended_packets_test.go b/encoding/ssh/filexfer/extended_packets_test.go index a362774..668ef57 100644 --- a/encoding/ssh/filexfer/extended_packets_test.go +++ b/encoding/ssh/filexfer/extended_packets_test.go @@ -42,7 +42,7 @@ func TestExtendedPacketNoData(t *testing.T) { ExtendedRequest: extendedRequest, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -86,7 +86,7 @@ func TestExtendedPacketTestData(t *testing.T) { }, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -153,7 +153,7 @@ func TestExtendedReplyNoData(t *testing.T) { p := &ExtendedReplyPacket{} - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -190,7 +190,7 @@ func TestExtendedReplyPacketTestData(t *testing.T) { }, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } diff --git a/encoding/ssh/filexfer/filexfer.go b/encoding/ssh/filexfer/filexfer.go index d45fd35..0b26169 100644 --- a/encoding/ssh/filexfer/filexfer.go +++ b/encoding/ssh/filexfer/filexfer.go @@ -1,8 +1,23 @@ +// Package filexfer implements the wire encoding for secsh-filexfer as described in https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 package filexfer // Packet defines the behavior of an SFTP packet. type Packet interface { - MarshalPacket(reqid uint32) (header, payload []byte, err error) + // MarshalPacket is the primary intended way to encode a packet. + // The request-id for the packet is set from reqid. + // + // An optional buffer may be given in b. + // If the buffer has a minimum capacity, it shall be truncated and used to marshal the header into. + // The minimum capacity for the packet must be a constant expression, and should be at least 9. + // + // It shall return the main body of the encoded packet in header, + // and may optionally return an additional payload to be written immediately after the header. + // + // It shall encode in the first 4-bytes of the header the proper length of the rest of the header+payload. + MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) + + // UnmarshalPacketBody decodes a packet body from the given Buffer. + // It is assumed that the common header values of the length, type and request-id have already been consumed. UnmarshalPacketBody(buf *Buffer) error } diff --git a/encoding/ssh/filexfer/handle_packets.go b/encoding/ssh/filexfer/handle_packets.go index 4841bce..33670dc 100644 --- a/encoding/ssh/filexfer/handle_packets.go +++ b/encoding/ssh/filexfer/handle_packets.go @@ -6,14 +6,17 @@ type ClosePacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *ClosePacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 + len(p.Handle) // string(handle) +func (p *ClosePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Handle) // string(handle) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeClose, reqid, size) + buf.StartPacket(PacketTypeClose, reqid) + buf.AppendString(p.Handle) - b.AppendString(p.Handle) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -34,17 +37,20 @@ type ReadPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *ReadPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - // string(handle) + uint64(offset) + uint32(len) - size := 4 + len(p.Handle) + 8 + 4 +func (p *ReadPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + // string(handle) + uint64(offset) + uint32(len) + size := 4 + len(p.Handle) + 8 + 4 + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeRead, reqid, size) + buf.StartPacket(PacketTypeRead, reqid) + buf.AppendString(p.Handle) + buf.AppendUint64(p.Offset) + buf.AppendUint32(p.Len) - b.AppendString(p.Handle) - b.AppendUint64(p.Offset) - b.AppendUint32(p.Len) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -73,17 +79,20 @@ type WritePacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *WritePacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - // string(handle) + uint64(offset) + uint32(len(data)); data content in payload - size := 4 + len(p.Handle) + 8 + 4 +func (p *WritePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + // string(handle) + uint64(offset) + uint32(len(data)); data content in payload + size := 4 + len(p.Handle) + 8 + 4 + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeWrite, reqid, size) + buf.StartPacket(PacketTypeWrite, reqid) + buf.AppendString(p.Handle) + buf.AppendUint64(p.Offset) + buf.AppendUint32(uint32(len(p.Data))) - b.AppendString(p.Handle) - b.AppendUint64(p.Offset) - b.AppendUint32(uint32(len(p.Data))) - - return b.Packet(p.Data) + return buf.Packet(p.Data) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -110,14 +119,17 @@ type FStatPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *FStatPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 + len(p.Handle) // string(handle) +func (p *FStatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Handle) // string(handle) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeFStat, reqid, size) + buf.StartPacket(PacketTypeFStat, reqid) + buf.AppendString(p.Handle) - b.AppendString(p.Handle) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -137,16 +149,19 @@ type FSetstatPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *FSetstatPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 + len(p.Handle) + p.Attrs.Len() // string(handle) + ATTRS(attrs) +func (p *FSetstatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Handle) + p.Attrs.Len() // string(handle) + ATTRS(attrs) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeFSetstat, reqid, size) + buf.StartPacket(PacketTypeFSetstat, reqid) + buf.AppendString(p.Handle) - b.AppendString(p.Handle) + p.Attrs.MarshalInto(buf) - p.Attrs.MarshalInto(b) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -165,14 +180,17 @@ type ReadDirPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *ReadDirPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 + len(p.Handle) // string(handle) +func (p *ReadDirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Handle) // string(handle) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeReadDir, reqid, size) + buf.StartPacket(PacketTypeReadDir, reqid) + buf.AppendString(p.Handle) - b.AppendString(p.Handle) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. diff --git a/encoding/ssh/filexfer/handle_packets_test.go b/encoding/ssh/filexfer/handle_packets_test.go index b8599ad..10fdc53 100644 --- a/encoding/ssh/filexfer/handle_packets_test.go +++ b/encoding/ssh/filexfer/handle_packets_test.go @@ -17,7 +17,7 @@ func TestClosePacket(t *testing.T) { Handle: "somehandle", } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -61,7 +61,7 @@ func TestReadPacket(t *testing.T) { Len: length, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -116,7 +116,7 @@ func TestWritePacket(t *testing.T) { Data: payload, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -166,7 +166,7 @@ func TestFStatPacket(t *testing.T) { Handle: "somehandle", } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -211,7 +211,7 @@ func TestFSetstatPacket(t *testing.T) { }, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -253,7 +253,7 @@ func TestReadDirPacket(t *testing.T) { Handle: "somehandle", } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } diff --git a/encoding/ssh/filexfer/open_packets.go b/encoding/ssh/filexfer/open_packets.go index 4284d0f..b0e25c2 100644 --- a/encoding/ssh/filexfer/open_packets.go +++ b/encoding/ssh/filexfer/open_packets.go @@ -18,18 +18,21 @@ type OpenPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *OpenPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - // string(filename) + uint32(pflags) + ATTRS(attrs) - size := 4 + len(p.Filename) + 4 + p.Attrs.Len() +func (p *OpenPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + // string(filename) + uint32(pflags) + ATTRS(attrs) + size := 4 + len(p.Filename) + 4 + p.Attrs.Len() + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeOpen, reqid, size) + buf.StartPacket(PacketTypeOpen, reqid) + buf.AppendString(p.Filename) + buf.AppendUint32(p.PFlags) - b.AppendString(p.Filename) - b.AppendUint32(p.PFlags) + p.Attrs.MarshalInto(buf) - p.Attrs.MarshalInto(b) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -52,14 +55,17 @@ type OpenDirPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *OpenDirPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 + len(p.Path) // string(path) +func (p *OpenDirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) // string(path) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeOpenDir, reqid, size) + buf.StartPacket(PacketTypeOpenDir, reqid) + buf.AppendString(p.Path) - b.AppendString(p.Path) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. diff --git a/encoding/ssh/filexfer/open_packets_test.go b/encoding/ssh/filexfer/open_packets_test.go index f4b0a44..8637cd0 100644 --- a/encoding/ssh/filexfer/open_packets_test.go +++ b/encoding/ssh/filexfer/open_packets_test.go @@ -23,7 +23,7 @@ func TestOpenPacket(t *testing.T) { }, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -78,7 +78,7 @@ func TestOpenDirPacket(t *testing.T) { Path: path, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } diff --git a/encoding/ssh/filexfer/openssh/hardlink.go b/encoding/ssh/filexfer/openssh/hardlink.go index f279f52..838ff2d 100644 --- a/encoding/ssh/filexfer/openssh/hardlink.go +++ b/encoding/ssh/filexfer/openssh/hardlink.go @@ -27,13 +27,13 @@ type HardlinkExtendedPacket struct { } // MarshalPacket returns ep as a two-part binary encoding of the full extended packet. -func (ep *HardlinkExtendedPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { +func (ep *HardlinkExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { p := &sshfx.ExtendedPacket{ ExtendedRequest: extensionHardlink, Data: ep, } - return p.MarshalPacket(reqid) + return p.MarshalPacket(reqid, b) } // MarshalInto encodes ep into the binary encoding of the hardlink@openssh.com extended packet-specific data. diff --git a/encoding/ssh/filexfer/openssh/openssh.go b/encoding/ssh/filexfer/openssh/openssh.go new file mode 100644 index 0000000..f93ff17 --- /dev/null +++ b/encoding/ssh/filexfer/openssh/openssh.go @@ -0,0 +1,2 @@ +// Package openssh implements the openssh secsh-filexfer extensions as described in https://github.com/openssh/openssh-portable/blob/master/PROTOCOL +package openssh diff --git a/encoding/ssh/filexfer/openssh/posix-rename.go b/encoding/ssh/filexfer/openssh/posix-rename.go index b4c45c4..8c8313a 100644 --- a/encoding/ssh/filexfer/openssh/posix-rename.go +++ b/encoding/ssh/filexfer/openssh/posix-rename.go @@ -27,13 +27,13 @@ type PosixRenameExtendedPacket struct { } // MarshalPacket returns ep as a two-part binary encoding of the full extended packet. -func (ep *PosixRenameExtendedPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { +func (ep *PosixRenameExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { p := &sshfx.ExtendedPacket{ ExtendedRequest: extensionPosixRename, Data: ep, } - return p.MarshalPacket(reqid) + return p.MarshalPacket(reqid, b) } // MarshalInto encodes ep into the binary encoding of the hardlink@openssh.com extended packet-specific data. diff --git a/encoding/ssh/filexfer/openssh/statvfs.go b/encoding/ssh/filexfer/openssh/statvfs.go index cf41cc3..5a723d4 100644 --- a/encoding/ssh/filexfer/openssh/statvfs.go +++ b/encoding/ssh/filexfer/openssh/statvfs.go @@ -26,13 +26,13 @@ type StatVFSExtendedPacket struct { } // MarshalPacket returns ep as a two-part binary encoding of the full extended packet. -func (ep *StatVFSExtendedPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { +func (ep *StatVFSExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { p := &sshfx.ExtendedPacket{ ExtendedRequest: extensionStatVFS, Data: ep, } - return p.MarshalPacket(reqid) + return p.MarshalPacket(reqid, b) } // MarshalInto encodes ep into the binary encoding of the statvfs@openssh.com extended packet-specific data. @@ -89,13 +89,13 @@ type FStatVFSExtendedPacket struct { } // MarshalPacket returns ep as a two-part binary encoding of the full extended packet. -func (ep *FStatVFSExtendedPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { +func (ep *FStatVFSExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { p := &sshfx.ExtendedPacket{ ExtendedRequest: extensionFStatVFS, Data: ep, } - return p.MarshalPacket(reqid) + return p.MarshalPacket(reqid, b) } // MarshalInto encodes ep into the binary encoding of the statvfs@openssh.com extended packet-specific data. @@ -153,11 +153,11 @@ type StatVFSExtendedReplyPacket struct { } // MarshalPacket returns ep as a two-part binary encoding of the full extended reply packet. -func (ep *StatVFSExtendedReplyPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { +func (ep *StatVFSExtendedReplyPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { p := &sshfx.ExtendedReplyPacket{ Data: ep, } - return p.MarshalPacket(reqid) + return p.MarshalPacket(reqid, b) } // MarshalInto encodes ep into the binary encoding of the (f)statvfs@openssh.com extended reply packet-specific data. diff --git a/encoding/ssh/filexfer/packets.go b/encoding/ssh/filexfer/packets.go index e32f1c4..baa2af4 100644 --- a/encoding/ssh/filexfer/packets.go +++ b/encoding/ssh/filexfer/packets.go @@ -64,15 +64,20 @@ type RawPacket struct { // MarshalPacket returns p as a two-part binary encoding of p. // // The internal p.RequestID is overridden by the reqid argument. -func (p *RawPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - b := NewMarshalBuffer(p.Type, reqid, 0) +func (p *RawPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + buf = NewMarshalBuffer(0) + } - return b.Packet(p.Data.Bytes()) + buf.StartPacket(p.Type, reqid) + + return buf.Packet(p.Data.Bytes()) } // MarshalBinary returns p as the binary encoding of p. func (p *RawPacket) MarshalBinary() ([]byte, error) { - return ComposePacket(p.MarshalPacket(p.RequestID)) + return ComposePacket(p.MarshalPacket(p.RequestID, nil)) } // UnmarshalFrom decodes a RawPacket from the given Buffer into p. @@ -139,7 +144,8 @@ func (p *RawPacket) ReadFrom(r io.Reader, b []byte) error { return p.UnmarshalBinary(b) } -// RequestPacket decodes a full request packet from the internal Data based on the Type. +// RequestPacket decodes a full RequestPacket from the internal Data based on the Type. +// Prefer using a RequestPacket directly, rather than going indirectly through RawPacket. func (p *RawPacket) RequestPacket() (*RequestPacket, error) { packet, err := newPacketFromType(p.Type) if err != nil { @@ -173,17 +179,17 @@ func (p *RequestPacket) Reset() { // MarshalPacket returns p as a two-part binary encoding of p. // // The internal p.RequestID is overridden by the reqid argument. -func (p *RequestPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { +func (p *RequestPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { if p.Request == nil { return nil, nil, errors.New("empty request packet") } - return p.Request.MarshalPacket(reqid) + return p.Request.MarshalPacket(reqid, b) } // MarshalBinary returns p as the binary encoding of p. func (p *RequestPacket) MarshalBinary() ([]byte, error) { - return ComposePacket(p.MarshalPacket(p.RequestID)) + return ComposePacket(p.MarshalPacket(p.RequestID, nil)) } // UnmarshalFrom decodes a RequestPacket from the given Buffer into p. diff --git a/encoding/ssh/filexfer/path_packets.go b/encoding/ssh/filexfer/path_packets.go index 45ec5a5..f7085f3 100644 --- a/encoding/ssh/filexfer/path_packets.go +++ b/encoding/ssh/filexfer/path_packets.go @@ -6,14 +6,17 @@ type LStatPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *LStatPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 + len(p.Path) // string(path) +func (p *LStatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) // string(path) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeLStat, reqid, size) + buf.StartPacket(PacketTypeLStat, reqid) + buf.AppendString(p.Path) - b.AppendString(p.Path) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -33,16 +36,19 @@ type SetstatPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *SetstatPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 + len(p.Path) + p.Attrs.Len() // string(path) + ATTRS(attrs) +func (p *SetstatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) + p.Attrs.Len() // string(path) + ATTRS(attrs) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeSetstat, reqid, size) + buf.StartPacket(PacketTypeSetstat, reqid) + buf.AppendString(p.Path) - b.AppendString(p.Path) + p.Attrs.MarshalInto(buf) - p.Attrs.MarshalInto(b) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -61,14 +67,17 @@ type RemovePacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *RemovePacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 + len(p.Path) // string(path) +func (p *RemovePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) // string(path) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeRemove, reqid, size) + buf.StartPacket(PacketTypeRemove, reqid) + buf.AppendString(p.Path) - b.AppendString(p.Path) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -88,16 +97,19 @@ type MkdirPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *MkdirPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 + len(p.Path) + p.Attrs.Len() // string(path) + ATTRS(attrs) +func (p *MkdirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) + p.Attrs.Len() // string(path) + ATTRS(attrs) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeMkdir, reqid, size) + buf.StartPacket(PacketTypeMkdir, reqid) + buf.AppendString(p.Path) - b.AppendString(p.Path) + p.Attrs.MarshalInto(buf) - p.Attrs.MarshalInto(b) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -116,14 +128,17 @@ type RmdirPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *RmdirPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 + len(p.Path) // string(path) +func (p *RmdirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) // string(path) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeRmdir, reqid, size) + buf.StartPacket(PacketTypeRmdir, reqid) + buf.AppendString(p.Path) - b.AppendString(p.Path) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -142,14 +157,17 @@ type RealPathPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *RealPathPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 + len(p.Path) // string(path) +func (p *RealPathPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) // string(path) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeRealPath, reqid, size) + buf.StartPacket(PacketTypeRealPath, reqid) + buf.AppendString(p.Path) - b.AppendString(p.Path) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -168,14 +186,17 @@ type StatPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *StatPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 + len(p.Path) // string(path) +func (p *StatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) // string(path) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeStat, reqid, size) + buf.StartPacket(PacketTypeStat, reqid) + buf.AppendString(p.Path) - b.AppendString(p.Path) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -195,16 +216,19 @@ type RenamePacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *RenamePacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - // string(oldpath) + string(newpath) - size := 4 + len(p.OldPath) + 4 + len(p.NewPath) +func (p *RenamePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + // string(oldpath) + string(newpath) + size := 4 + len(p.OldPath) + 4 + len(p.NewPath) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeRename, reqid, size) + buf.StartPacket(PacketTypeRename, reqid) + buf.AppendString(p.OldPath) + buf.AppendString(p.NewPath) - b.AppendString(p.OldPath) - b.AppendString(p.NewPath) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -227,14 +251,17 @@ type ReadLinkPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *ReadLinkPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 + len(p.Path) // string(path) +func (p *ReadLinkPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) // string(path) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeReadLink, reqid, size) + buf.StartPacket(PacketTypeReadLink, reqid) + buf.AppendString(p.Path) - b.AppendString(p.Path) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -258,17 +285,21 @@ type SymlinkPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *SymlinkPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - // string(targetpath) + string(linkpath) - size := 4 + len(p.TargetPath) + 4 + len(p.LinkPath) +func (p *SymlinkPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + // string(targetpath) + string(linkpath) + size := 4 + len(p.TargetPath) + 4 + len(p.LinkPath) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeSymlink, reqid, size) + buf.StartPacket(PacketTypeSymlink, reqid) // Arguments were inadvertently reversed. - b.AppendString(p.TargetPath) - b.AppendString(p.LinkPath) + buf.AppendString(p.TargetPath) + buf.AppendString(p.LinkPath) - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. diff --git a/encoding/ssh/filexfer/path_packets_test.go b/encoding/ssh/filexfer/path_packets_test.go index aba3200..852145e 100644 --- a/encoding/ssh/filexfer/path_packets_test.go +++ b/encoding/ssh/filexfer/path_packets_test.go @@ -17,7 +17,7 @@ func TestLStatPacket(t *testing.T) { Path: path, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -62,7 +62,7 @@ func TestSetstatPacket(t *testing.T) { }, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -112,7 +112,7 @@ func TestRemovePacket(t *testing.T) { Path: path, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -157,7 +157,7 @@ func TestMkdirPacket(t *testing.T) { }, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -207,7 +207,7 @@ func TestRmdirPacket(t *testing.T) { Path: path, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -247,7 +247,7 @@ func TestRealPathPacket(t *testing.T) { Path: path, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -287,7 +287,7 @@ func TestStatPacket(t *testing.T) { Path: path, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -329,7 +329,7 @@ func TestRenamePacket(t *testing.T) { NewPath: newpath, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -374,7 +374,7 @@ func TestReadLinkPacket(t *testing.T) { Path: path, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -416,7 +416,7 @@ func TestSymlinkPacket(t *testing.T) { TargetPath: targetpath, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } diff --git a/encoding/ssh/filexfer/response_packets.go b/encoding/ssh/filexfer/response_packets.go index 7c4b116..831b443 100644 --- a/encoding/ssh/filexfer/response_packets.go +++ b/encoding/ssh/filexfer/response_packets.go @@ -10,17 +10,20 @@ type StatusPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *StatusPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - // uint32(error/status code) + string(error message) + string(language tag) - size := 4 + 4 + len(p.ErrorMessage) + 4 + len(p.LanguageTag) +func (p *StatusPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + // uint32(error/status code) + string(error message) + string(language tag) + size := 4 + 4 + len(p.ErrorMessage) + 4 + len(p.LanguageTag) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeStatus, reqid, size) + buf.StartPacket(PacketTypeStatus, reqid) + buf.AppendUint32(uint32(p.StatusCode)) + buf.AppendString(p.ErrorMessage) + buf.AppendString(p.LanguageTag) - b.AppendUint32(uint32(p.StatusCode)) - b.AppendString(p.ErrorMessage) - b.AppendString(p.LanguageTag) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -49,14 +52,17 @@ type HandlePacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *HandlePacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 + len(p.Handle) // string(handle) +func (p *HandlePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Handle) // string(handle) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeHandle, reqid, size) + buf.StartPacket(PacketTypeHandle, reqid) + buf.AppendString(p.Handle) - b.AppendString(p.Handle) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -75,14 +81,17 @@ type DataPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *DataPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 // uint32(len(data)); data content in payload +func (p *DataPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 // uint32(len(data)); data content in payload + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeData, reqid, size) + buf.StartPacket(PacketTypeData, reqid) + buf.AppendUint32(uint32(len(p.Data))) - b.AppendUint32(uint32(len(p.Data))) - - return b.Packet(p.Data) + return buf.Packet(p.Data) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -101,22 +110,26 @@ type NamePacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *NamePacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := 4 // uint32(len(entries)) +func (p *NamePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 // uint32(len(entries)) - for _, e := range p.Entries { - size += e.Len() + for _, e := range p.Entries { + size += e.Len() + } + + buf = NewMarshalBuffer(size) } - b := NewMarshalBuffer(PacketTypeName, reqid, size) - - b.AppendUint32(uint32(len(p.Entries))) + buf.StartPacket(PacketTypeName, reqid) + buf.AppendUint32(uint32(len(p.Entries))) for _, e := range p.Entries { - e.MarshalInto(b) + e.MarshalInto(buf) } - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. @@ -147,14 +160,17 @@ type AttrsPacket struct { } // MarshalPacket returns p as a two-part binary encoding of p. -func (p *AttrsPacket) MarshalPacket(reqid uint32) (header, payload []byte, err error) { - size := p.Attrs.Len() // ATTRS(attrs) +func (p *AttrsPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := p.Attrs.Len() // ATTRS(attrs) + buf = NewMarshalBuffer(size) + } - b := NewMarshalBuffer(PacketTypeAttrs, reqid, size) + buf.StartPacket(PacketTypeAttrs, reqid) + p.Attrs.MarshalInto(buf) - p.Attrs.MarshalInto(b) - - return b.Packet(payload) + return buf.Packet(payload) } // UnmarshalPacketBody unmarshals the packet body from the given Buffer. diff --git a/encoding/ssh/filexfer/response_packets_test.go b/encoding/ssh/filexfer/response_packets_test.go index e39d709..23e42bb 100644 --- a/encoding/ssh/filexfer/response_packets_test.go +++ b/encoding/ssh/filexfer/response_packets_test.go @@ -21,7 +21,7 @@ func TestStatusPacket(t *testing.T) { LanguageTag: languageTag, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -71,7 +71,7 @@ func TestHandlePacket(t *testing.T) { Handle: "somehandle", } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -112,7 +112,7 @@ func TestDataPacket(t *testing.T) { Data: payload, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -171,7 +171,7 @@ func TestNamePacket(t *testing.T) { }, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) } @@ -240,7 +240,7 @@ func TestAttrsPacket(t *testing.T) { }, } - data, err := ComposePacket(p.MarshalPacket(id)) + data, err := ComposePacket(p.MarshalPacket(id, nil)) if err != nil { t.Fatal("unexpected error:", err) }