mirror of https://github.com/pkg/sftp.git
Improve benchmarks and errors
This commit is contained in:
parent
460ad57385
commit
f1e28f8a88
|
@ -43,10 +43,10 @@ type ClientOption func(*Client) error
|
|||
func MaxPacketChecked(size int) ClientOption {
|
||||
return func(c *Client) error {
|
||||
if size < 1 {
|
||||
return errors.Errorf("size must be greater or equal to 1")
|
||||
return errors.New("size must be greater or equal to 1")
|
||||
}
|
||||
if size > 32768 {
|
||||
return errors.Errorf("sizes larger than 32KB might not work with all servers")
|
||||
return errors.New("sizes larger than 32KB might not work with all servers")
|
||||
}
|
||||
c.maxPacket = size
|
||||
return nil
|
||||
|
@ -65,7 +65,7 @@ func MaxPacketChecked(size int) ClientOption {
|
|||
func MaxPacketUnchecked(size int) ClientOption {
|
||||
return func(c *Client) error {
|
||||
if size < 1 {
|
||||
return errors.Errorf("size must be greater or equal to 1")
|
||||
return errors.New("size must be greater or equal to 1")
|
||||
}
|
||||
c.maxPacket = size
|
||||
return nil
|
||||
|
@ -90,7 +90,7 @@ func MaxPacket(size int) ClientOption {
|
|||
func MaxConcurrentRequestsPerFile(n int) ClientOption {
|
||||
return func(c *Client) error {
|
||||
if n < 1 {
|
||||
return errors.Errorf("n must be greater or equal to 1")
|
||||
return errors.New("n must be greater or equal to 1")
|
||||
}
|
||||
c.maxConcurrentRequests = n
|
||||
return nil
|
||||
|
|
|
@ -6,7 +6,6 @@ package sftp
|
|||
import (
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"encoding"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
@ -1490,27 +1489,49 @@ func TestClientReadFrom(t *testing.T) {
|
|||
var errFakeNet = errors.New("Fake network issue")
|
||||
|
||||
func TestClientReadFromDeadlock(t *testing.T) {
|
||||
clientWriteDeadlock(t, 1, func(f *File) {
|
||||
for i := 0; i < 5; i++ {
|
||||
clientWriteDeadlock(t, i, func(f *File) {
|
||||
b := make([]byte, 32768*4)
|
||||
content := bytes.NewReader(b)
|
||||
_, err := f.ReadFrom(content)
|
||||
if err != errFakeNet {
|
||||
if !errors.Is(err, errFakeNet) {
|
||||
t.Fatal("Didn't recieve correct error:", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Write has exact same problem
|
||||
func TestClientWriteDeadlock(t *testing.T) {
|
||||
clientWriteDeadlock(t, 1, func(f *File) {
|
||||
for i := 0; i < 5; i++ {
|
||||
clientWriteDeadlock(t, i, func(f *File) {
|
||||
b := make([]byte, 32768*4)
|
||||
|
||||
_, err := f.Write(b)
|
||||
if err != errFakeNet {
|
||||
if !errors.Is(err, errFakeNet) {
|
||||
t.Fatal("Didn't recieve correct error:", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type timeBombWriter struct {
|
||||
count int
|
||||
w io.WriteCloser
|
||||
}
|
||||
|
||||
func (w *timeBombWriter) Write(b []byte) (int, error) {
|
||||
if w.count < 1 {
|
||||
return 0, errFakeNet
|
||||
}
|
||||
|
||||
w.count--
|
||||
return w.w.Write(b)
|
||||
}
|
||||
|
||||
func (w *timeBombWriter) Close() error {
|
||||
return w.w.Close()
|
||||
}
|
||||
|
||||
// shared body for both previous tests
|
||||
func clientWriteDeadlock(t *testing.T, N int, badfunc func(*File)) {
|
||||
|
@ -1534,20 +1555,13 @@ func clientWriteDeadlock(t *testing.T, N int, badfunc func(*File)) {
|
|||
}
|
||||
defer w.Close()
|
||||
|
||||
// Override sendPacket with failing version
|
||||
// Replicates network error/drop part way through (after 1 good packet)
|
||||
count := 0
|
||||
sendPacketTest := func(w io.Writer, m encoding.BinaryMarshaler) error {
|
||||
count++
|
||||
if count > N {
|
||||
return errFakeNet
|
||||
// Override the clienConn Writer with a failing version
|
||||
// Replicates network error/drop part way through (after N good writes)
|
||||
wrap := sftp.clientConn.conn.WriteCloser
|
||||
sftp.clientConn.conn.WriteCloser = &timeBombWriter{
|
||||
count: N,
|
||||
w: wrap,
|
||||
}
|
||||
return sendPacket(w, m)
|
||||
}
|
||||
sftp.clientConn.conn.sendPacketTest = sendPacketTest
|
||||
defer func() {
|
||||
sftp.clientConn.conn.sendPacketTest = nil
|
||||
}()
|
||||
|
||||
// this locked (before the fix)
|
||||
badfunc(w)
|
||||
|
@ -1555,28 +1569,32 @@ func clientWriteDeadlock(t *testing.T, N int, badfunc func(*File)) {
|
|||
|
||||
// Read/WriteTo has this issue as well
|
||||
func TestClientReadDeadlock(t *testing.T) {
|
||||
clientReadDeadlock(t, 1, func(f *File) {
|
||||
for i := 0; i < 5; i++ {
|
||||
clientReadDeadlock(t, i, func(f *File) {
|
||||
b := make([]byte, 32768*4)
|
||||
|
||||
_, err := f.Read(b)
|
||||
if err != errFakeNet {
|
||||
if !errors.Is(err, errFakeNet) {
|
||||
t.Fatal("Didn't recieve correct error:", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientWriteToDeadlock(t *testing.T) {
|
||||
clientReadDeadlock(t, 2, func(f *File) {
|
||||
for i := 0; i < 5; i++ {
|
||||
clientReadDeadlock(t, i, func(f *File) {
|
||||
b := make([]byte, 32768*4)
|
||||
|
||||
buf := bytes.NewBuffer(b)
|
||||
|
||||
_, err := f.WriteTo(buf)
|
||||
if err != errFakeNet {
|
||||
if !errors.Is(err, errFakeNet) {
|
||||
t.Fatal("Didn't recieve correct error:", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func clientReadDeadlock(t *testing.T, N int, badfunc func(*File)) {
|
||||
if !*testServerImpl {
|
||||
|
@ -1611,21 +1629,13 @@ func clientReadDeadlock(t *testing.T, N int, badfunc func(*File)) {
|
|||
}
|
||||
defer r.Close()
|
||||
|
||||
// Override sendPacket with failing version
|
||||
// Replicates network error/drop part way through (after 1 good packet)
|
||||
count := 0
|
||||
sendPacketTest := func(w io.Writer, m encoding.BinaryMarshaler) error {
|
||||
count++
|
||||
if count > N {
|
||||
return errFakeNet
|
||||
// Override the clienConn Writer with a failing version
|
||||
// Replicates network error/drop part way through (after N good writes)
|
||||
wrap := sftp.clientConn.conn.WriteCloser
|
||||
sftp.clientConn.conn.WriteCloser = &timeBombWriter{
|
||||
count: N,
|
||||
w: wrap,
|
||||
}
|
||||
return sendPacket(w, m)
|
||||
}
|
||||
|
||||
sftp.clientConn.conn.sendPacketTest = sendPacketTest
|
||||
defer func() {
|
||||
sftp.clientConn.conn.sendPacketTest = nil
|
||||
}()
|
||||
|
||||
// this locked (before the fix)
|
||||
badfunc(r)
|
||||
|
@ -2444,6 +2454,28 @@ func BenchmarkReadFrom4MiBDelay150Msec(b *testing.B) {
|
|||
benchmarkReadFrom(b, 4*1024*1024, 150*time.Millisecond)
|
||||
}
|
||||
|
||||
// writeToBuffer implements the relevant parts of bytes.Buffer,
|
||||
// but does not release its internal buffer when Reset.
|
||||
//
|
||||
// Release its internal memory when Reset is good for avoiding memory leaks,
|
||||
// but not great for memory benchmarks, as this fills up a lot of irrelevant allocations.
|
||||
type writeToBuffer struct {
|
||||
b []byte
|
||||
}
|
||||
|
||||
func (w *writeToBuffer) Len() int {
|
||||
return len(w.b)
|
||||
}
|
||||
|
||||
func (w *writeToBuffer) Reset() {
|
||||
w.b = w.b[:0]
|
||||
}
|
||||
|
||||
func (w *writeToBuffer) Write(b []byte) (int, error) {
|
||||
w.b = append(w.b, b...)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func benchmarkWriteTo(b *testing.B, bufsize int, delay time.Duration) {
|
||||
size := 10*1024*1024 + 123 // ~10MiB
|
||||
|
||||
|
@ -2466,7 +2498,9 @@ func benchmarkWriteTo(b *testing.B, bufsize int, delay time.Duration) {
|
|||
b.ResetTimer()
|
||||
b.SetBytes(int64(size))
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
buf := &writeToBuffer{
|
||||
b: make([]byte, 0, size),
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf.Reset()
|
||||
|
|
8
conn.go
8
conn.go
|
@ -16,8 +16,6 @@ type conn struct {
|
|||
// this is the same allocator used in packet manager
|
||||
alloc *allocator
|
||||
sync.Mutex // used to serialise writes to sendPacket
|
||||
// sendPacketTest is needed to replicate packet issues in testing
|
||||
sendPacketTest func(w io.Writer, m encoding.BinaryMarshaler) error
|
||||
}
|
||||
|
||||
// the orderID is used in server mode if the allocator is enabled.
|
||||
|
@ -29,9 +27,7 @@ func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) {
|
|||
func (c *conn) sendPacket(m encoding.BinaryMarshaler) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.sendPacketTest != nil {
|
||||
return c.sendPacketTest(c, m)
|
||||
}
|
||||
|
||||
return sendPacket(c, m)
|
||||
}
|
||||
|
||||
|
@ -91,7 +87,7 @@ func (c *clientConn) recv() error {
|
|||
// This is an unexpected occurrence. Send the error
|
||||
// back to all listeners so that they terminate
|
||||
// gracefully.
|
||||
return errors.Errorf("sid not found: %v", sid)
|
||||
return errors.Errorf("sid not found: %d", sid)
|
||||
}
|
||||
|
||||
ch <- result{typ: typ, data: data}
|
||||
|
|
|
@ -131,7 +131,7 @@ func ExampleClient_Mkdir_parents() {
|
|||
fi, err = client.Stat(parents)
|
||||
if err == nil {
|
||||
if !fi.IsDir() {
|
||||
return fmt.Errorf("File exists: %s", parents)
|
||||
return fmt.Errorf("file exists: %s", parents)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
3
go.sum
3
go.sum
|
@ -15,9 +15,12 @@ golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWP
|
|||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 h1:myAQVi0cGEoqQVR5POX+8RR2mrocKqNN1hmeMqhX27k=
|
||||
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221 h1:/ZHdbVpdR/jk3g30/d4yUL0JU9kksj8+F/bnQUVLGDM=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
|
|
@ -133,7 +133,7 @@ func marshalPacket(m encoding.BinaryMarshaler) (header, payload []byte, err erro
|
|||
func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error {
|
||||
header, payload, err := marshalPacket(m)
|
||||
if err != nil {
|
||||
return errors.Errorf("binary marshaller failed: %v", err)
|
||||
return errors.Wrap(err, "binary marshaller failed")
|
||||
}
|
||||
|
||||
length := len(header) + len(payload) - 4 // subtract the uint32(length) from the start
|
||||
|
@ -146,12 +146,12 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error {
|
|||
binary.BigEndian.PutUint32(header[:4], uint32(length))
|
||||
|
||||
if _, err := w.Write(header); err != nil {
|
||||
return errors.Errorf("failed to send packet: %v", err)
|
||||
return errors.Wrap(err, "failed to send packet")
|
||||
}
|
||||
|
||||
if len(payload) > 0 {
|
||||
if _, err := w.Write(payload); err != nil {
|
||||
return errors.Errorf("failed to send packet payload: %v", err)
|
||||
return errors.Wrap(err, "failed to send packet payload")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
442
packet_test.go
442
packet_test.go
|
@ -3,120 +3,174 @@ package sftp
|
|||
import (
|
||||
"bytes"
|
||||
"encoding"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var marshalUint32Tests = []struct {
|
||||
func TestMarshalUint32(t *testing.T) {
|
||||
var tests = []struct {
|
||||
v uint32
|
||||
want []byte
|
||||
}{
|
||||
{1, []byte{0, 0, 0, 1}},
|
||||
{256, []byte{0, 0, 1, 0}},
|
||||
{0, []byte{0, 0, 0, 0}},
|
||||
{42, []byte{0, 0, 0, 42}},
|
||||
{42 << 8, []byte{0, 0, 42, 0}},
|
||||
{42 << 16, []byte{0, 42, 0, 0}},
|
||||
{42 << 24, []byte{42, 0, 0, 0}},
|
||||
{^uint32(0), []byte{255, 255, 255, 255}},
|
||||
}
|
||||
|
||||
func TestMarshalUint32(t *testing.T) {
|
||||
for _, tt := range marshalUint32Tests {
|
||||
for _, tt := range tests {
|
||||
got := marshalUint32(nil, tt.v)
|
||||
if !bytes.Equal(tt.want, got) {
|
||||
t.Errorf("marshalUint32(%d): want %v, got %v", tt.v, tt.want, got)
|
||||
t.Errorf("marshalUint32(%d) = %#v, want %#v", tt.v, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var marshalUint64Tests = []struct {
|
||||
v uint64
|
||||
want []byte
|
||||
}{
|
||||
{1, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}},
|
||||
{256, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0}},
|
||||
{^uint64(0), []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
|
||||
{1 << 32, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0}},
|
||||
}
|
||||
|
||||
func TestMarshalUint64(t *testing.T) {
|
||||
for _, tt := range marshalUint64Tests {
|
||||
var tests = []struct {
|
||||
v uint64
|
||||
want []byte
|
||||
}{
|
||||
{0, []byte{0, 0, 0, 0, 0, 0, 0, 0}},
|
||||
{42, []byte{0, 0, 0, 0, 0, 0, 0, 42}},
|
||||
{42 << 8, []byte{0, 0, 0, 0, 0, 0, 42, 0}},
|
||||
{42 << 16, []byte{0, 0, 0, 0, 0, 42, 0, 0}},
|
||||
{42 << 24, []byte{0, 0, 0, 0, 42, 0, 0, 0}},
|
||||
{42 << 32, []byte{0, 0, 0, 42, 0, 0, 0, 0}},
|
||||
{42 << 40, []byte{0, 0, 42, 0, 0, 0, 0, 0}},
|
||||
{42 << 48, []byte{0, 42, 0, 0, 0, 0, 0, 0}},
|
||||
{42 << 56, []byte{42, 0, 0, 0, 0, 0, 0, 0}},
|
||||
{^uint64(0), []byte{255, 255, 255, 255, 255, 255, 255, 255}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := marshalUint64(nil, tt.v)
|
||||
if !bytes.Equal(tt.want, got) {
|
||||
t.Errorf("marshalUint64(%d): want %#v, got %#v", tt.v, tt.want, got)
|
||||
t.Errorf("marshalUint64(%d) = %#v, want %#v", tt.v, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var marshalStringTests = []struct {
|
||||
func TestMarshalString(t *testing.T) {
|
||||
var tests = []struct {
|
||||
v string
|
||||
want []byte
|
||||
}{
|
||||
{"", []byte{0, 0, 0, 0}},
|
||||
{"/foo", []byte{0x0, 0x0, 0x0, 0x4, 0x2f, 0x66, 0x6f, 0x6f}},
|
||||
{"/", []byte{0x0, 0x0, 0x0, 0x01, '/'}},
|
||||
{"/foo", []byte{0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o'}},
|
||||
{"\x00bar", []byte{0x0, 0x0, 0x0, 0x4, 0, 'b', 'a', 'r'}},
|
||||
{"b\x00ar", []byte{0x0, 0x0, 0x0, 0x4, 'b', 0, 'a', 'r'}},
|
||||
{"ba\x00r", []byte{0x0, 0x0, 0x0, 0x4, 'b', 'a', 0, 'r'}},
|
||||
{"bar\x00", []byte{0x0, 0x0, 0x0, 0x4, 'b', 'a', 'r', 0}},
|
||||
}
|
||||
|
||||
func TestMarshalString(t *testing.T) {
|
||||
for _, tt := range marshalStringTests {
|
||||
for _, tt := range tests {
|
||||
got := marshalString(nil, tt.v)
|
||||
if !bytes.Equal(tt.want, got) {
|
||||
t.Errorf("marshalString(%q): want %#v, got %#v", tt.v, tt.want, got)
|
||||
t.Errorf("marshalString(%q) = %#v, want %#v", tt.v, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var marshalTests = []struct {
|
||||
v interface{}
|
||||
want []byte
|
||||
}{
|
||||
{uint8(1), []byte{1}},
|
||||
{byte(1), []byte{1}},
|
||||
{uint32(1), []byte{0, 0, 0, 1}},
|
||||
{uint64(1), []byte{0, 0, 0, 0, 0, 0, 0, 1}},
|
||||
{"foo", []byte{0x0, 0x0, 0x0, 0x3, 0x66, 0x6f, 0x6f}},
|
||||
{[]uint32{1, 2, 3, 4}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x4}},
|
||||
}
|
||||
|
||||
func TestMarshal(t *testing.T) {
|
||||
for _, tt := range marshalTests {
|
||||
got := marshal(nil, tt.v)
|
||||
if !bytes.Equal(tt.want, got) {
|
||||
t.Errorf("marshal(%v): want %#v, got %#v", tt.v, tt.want, got)
|
||||
}
|
||||
}
|
||||
type Struct struct {
|
||||
X, Y, Z uint32
|
||||
}
|
||||
|
||||
var unmarshalUint32Tests = []struct {
|
||||
b []byte
|
||||
want uint32
|
||||
rest []byte
|
||||
var tests = []struct {
|
||||
v interface{}
|
||||
want []byte
|
||||
}{
|
||||
{[]byte{0, 0, 0, 0}, 0, nil},
|
||||
{[]byte{0, 0, 1, 0}, 256, nil},
|
||||
{[]byte{255, 0, 0, 255}, 4278190335, nil},
|
||||
{uint8(42), []byte{42}},
|
||||
{uint32(42 << 8), []byte{0, 0, 42, 0}},
|
||||
{uint64(42 << 32), []byte{0, 0, 0, 42, 0, 0, 0, 0}},
|
||||
{"foo", []byte{0x0, 0x0, 0x0, 0x3, 'f', 'o', 'o'}},
|
||||
{Struct{1, 2, 3}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3}},
|
||||
{[]uint32{1, 2, 3}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := marshal(nil, tt.v)
|
||||
if !bytes.Equal(tt.want, got) {
|
||||
t.Errorf("marshal(%#v) = %#v, want %#v", tt.v, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalUint32(t *testing.T) {
|
||||
for _, tt := range unmarshalUint32Tests {
|
||||
got, rest := unmarshalUint32(tt.b)
|
||||
if got != tt.want || !bytes.Equal(rest, tt.rest) {
|
||||
t.Errorf("unmarshalUint32(%v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest)
|
||||
}
|
||||
}
|
||||
testBuffer := []byte{
|
||||
0, 0, 0, 0,
|
||||
0, 0, 0, 42,
|
||||
0, 0, 42, 0,
|
||||
0, 42, 0, 0,
|
||||
42, 0, 0, 0,
|
||||
255, 0, 0, 254,
|
||||
}
|
||||
|
||||
var unmarshalUint64Tests = []struct {
|
||||
b []byte
|
||||
want uint64
|
||||
rest []byte
|
||||
}{
|
||||
{[]byte{0, 0, 0, 0, 0, 0, 0, 0}, 0, nil},
|
||||
{[]byte{0, 0, 0, 0, 0, 0, 1, 0}, 256, nil},
|
||||
{[]byte{255, 0, 0, 0, 0, 0, 0, 255}, 18374686479671623935, nil},
|
||||
var wants = []uint32{
|
||||
0,
|
||||
42,
|
||||
42 << 8,
|
||||
42 << 16,
|
||||
42 << 24,
|
||||
255<<24 | 254,
|
||||
}
|
||||
|
||||
var i int
|
||||
for len(testBuffer) > 0 {
|
||||
got, rest := unmarshalUint32(testBuffer)
|
||||
|
||||
if got != wants[i] {
|
||||
t.Fatalf("unmarshalUint32(%#v) = %d, want %d", testBuffer[:4], got, wants[i])
|
||||
}
|
||||
|
||||
i++
|
||||
testBuffer = rest
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalUint64(t *testing.T) {
|
||||
for _, tt := range unmarshalUint64Tests {
|
||||
got, rest := unmarshalUint64(tt.b)
|
||||
if got != tt.want || !bytes.Equal(rest, tt.rest) {
|
||||
t.Errorf("unmarshalUint64(%v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest)
|
||||
testBuffer := []byte{
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 42,
|
||||
0, 0, 0, 0, 0, 0, 42, 0,
|
||||
0, 0, 0, 0, 0, 42, 0, 0,
|
||||
0, 0, 0, 0, 42, 0, 0, 0,
|
||||
0, 0, 0, 42, 0, 0, 0, 0,
|
||||
0, 0, 42, 0, 0, 0, 0, 0,
|
||||
0, 42, 0, 0, 0, 0, 0, 0,
|
||||
42, 0, 0, 0, 0, 0, 0, 0,
|
||||
255, 0, 0, 0, 0, 0, 0, 254,
|
||||
}
|
||||
|
||||
var wants = []uint64{
|
||||
0,
|
||||
42,
|
||||
42 << 8,
|
||||
42 << 16,
|
||||
42 << 24,
|
||||
42 << 32,
|
||||
42 << 40,
|
||||
42 << 48,
|
||||
42 << 56,
|
||||
255<<56 | 254,
|
||||
}
|
||||
|
||||
var i int
|
||||
for len(testBuffer) > 0 {
|
||||
got, rest := unmarshalUint64(testBuffer)
|
||||
|
||||
if got != wants[i] {
|
||||
t.Fatalf("unmarshalUint64(%#v) = %d, want %d", testBuffer[:8], got, wants[i])
|
||||
}
|
||||
|
||||
i++
|
||||
testBuffer = rest
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -130,85 +184,193 @@ var unmarshalStringTests = []struct {
|
|||
}
|
||||
|
||||
func TestUnmarshalString(t *testing.T) {
|
||||
for _, tt := range unmarshalStringTests {
|
||||
got, rest := unmarshalString(tt.b)
|
||||
if got != tt.want || !bytes.Equal(rest, tt.rest) {
|
||||
t.Errorf("unmarshalUint64(%v): want %q, %#v, got %q, %#v", tt.b, tt.want, tt.rest, got, rest)
|
||||
}
|
||||
}
|
||||
testBuffer := []byte{
|
||||
0, 0, 0, 0,
|
||||
0, 0, 0, 1, '/',
|
||||
0, 0, 0, 4, '/', 'f', 'o', 'o',
|
||||
0, 0, 0, 4, 0, 'b', 'a', 'r',
|
||||
0, 0, 0, 4, 'b', 0, 'a', 'r',
|
||||
0, 0, 0, 4, 'b', 'a', 0, 'r',
|
||||
0, 0, 0, 4, 'b', 'a', 'r', 0,
|
||||
}
|
||||
|
||||
var sendPacketTests = []struct {
|
||||
p encoding.BinaryMarshaler
|
||||
want []byte
|
||||
}{
|
||||
{&sshFxInitPacket{
|
||||
Version: 3,
|
||||
Extensions: []extensionPair{
|
||||
{"posix-rename@openssh.com", "1"},
|
||||
},
|
||||
}, []byte{0x0, 0x0, 0x0, 0x26, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}},
|
||||
var wants = []string{
|
||||
"",
|
||||
"/",
|
||||
"/foo",
|
||||
"\x00bar",
|
||||
"b\x00ar",
|
||||
"ba\x00r",
|
||||
"bar\x00",
|
||||
}
|
||||
|
||||
{&sshFxpOpenPacket{
|
||||
ID: 1,
|
||||
Path: "/foo",
|
||||
Pflags: flags(os.O_RDONLY),
|
||||
}, []byte{0x0, 0x0, 0x0, 0x15, 0x3, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x4, 0x2f, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0}},
|
||||
var i int
|
||||
for len(testBuffer) > 0 {
|
||||
got, rest := unmarshalString(testBuffer)
|
||||
|
||||
{&sshFxpWritePacket{
|
||||
ID: 124,
|
||||
Handle: "foo",
|
||||
Offset: 13,
|
||||
Length: uint32(len([]byte("bar"))),
|
||||
Data: []byte("bar"),
|
||||
}, []byte{0x0, 0x0, 0x0, 0x1b, 0x6, 0x0, 0x0, 0x0, 0x7c, 0x0, 0x0, 0x0, 0x3, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd, 0x0, 0x0, 0x0, 0x3, 0x62, 0x61, 0x72}},
|
||||
if got != wants[i] {
|
||||
t.Fatalf("unmarshalUint64(%#v...) = %q, want %q", testBuffer[:4], got, wants[i])
|
||||
}
|
||||
|
||||
{&sshFxpSetstatPacket{
|
||||
ID: 31,
|
||||
Path: "/bar",
|
||||
Flags: flags(os.O_WRONLY),
|
||||
Attrs: struct {
|
||||
UID uint32
|
||||
GID uint32
|
||||
}{1000, 100},
|
||||
}, []byte{0x0, 0x0, 0x0, 0x19, 0x9, 0x0, 0x0, 0x0, 0x1f, 0x0, 0x0, 0x0, 0x4, 0x2f, 0x62, 0x61, 0x72, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x3, 0xe8, 0x0, 0x0, 0x0, 0x64}},
|
||||
i++
|
||||
testBuffer = rest
|
||||
}
|
||||
}
|
||||
|
||||
type nopCloserBuffer struct {
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (*nopCloserBuffer) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSendPacket(t *testing.T) {
|
||||
for _, tt := range sendPacketTests {
|
||||
var w bytes.Buffer
|
||||
sendPacket(&w, tt.p)
|
||||
if got := w.Bytes(); !bytes.Equal(tt.want, got) {
|
||||
t.Errorf("sendPacket(%v): want %#v, got %#v", tt.p, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sp(p encoding.BinaryMarshaler) []byte {
|
||||
var w bytes.Buffer
|
||||
sendPacket(&w, p)
|
||||
return w.Bytes()
|
||||
}
|
||||
|
||||
var recvPacketTests = []struct {
|
||||
b []byte
|
||||
want uint8
|
||||
rest []byte
|
||||
var tests = []struct {
|
||||
packet encoding.BinaryMarshaler
|
||||
want []byte
|
||||
}{
|
||||
{sp(&sshFxInitPacket{
|
||||
{
|
||||
packet: &sshFxInitPacket{
|
||||
Version: 3,
|
||||
Extensions: []extensionPair{
|
||||
{"posix-rename@openssh.com", "1"},
|
||||
},
|
||||
}), sshFxpInit, []byte{0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}},
|
||||
},
|
||||
want: []byte{
|
||||
0x0, 0x0, 0x0, 0x26,
|
||||
0x1,
|
||||
0x0, 0x0, 0x0, 0x3,
|
||||
0x0, 0x0, 0x0, 0x18,
|
||||
'p', 'o', 's', 'i', 'x', '-', 'r', 'e', 'n', 'a', 'm', 'e', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
|
||||
0x0, 0x0, 0x0, 0x1,
|
||||
'1',
|
||||
},
|
||||
},
|
||||
{
|
||||
packet: &sshFxpOpenPacket{
|
||||
ID: 1,
|
||||
Path: "/foo",
|
||||
Pflags: flags(os.O_RDONLY),
|
||||
},
|
||||
want: []byte{
|
||||
0x0, 0x0, 0x0, 0x15,
|
||||
0x3,
|
||||
0x0, 0x0, 0x0, 0x1,
|
||||
0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o',
|
||||
0x0, 0x0, 0x0, 0x1,
|
||||
0x0, 0x0, 0x0, 0x0,
|
||||
},
|
||||
},
|
||||
{
|
||||
packet: &sshFxpWritePacket{
|
||||
ID: 124,
|
||||
Handle: "foo",
|
||||
Offset: 13,
|
||||
Length: uint32(len("bar")),
|
||||
Data: []byte("bar"),
|
||||
},
|
||||
want: []byte{
|
||||
0x0, 0x0, 0x0, 0x1b,
|
||||
0x6,
|
||||
0x0, 0x0, 0x0, 0x7c,
|
||||
0x0, 0x0, 0x0, 0x3, 'f', 'o', 'o',
|
||||
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd,
|
||||
0x0, 0x0, 0x0, 0x3, 'b', 'a', 'r',
|
||||
},
|
||||
},
|
||||
{
|
||||
packet: &sshFxpSetstatPacket{
|
||||
ID: 31,
|
||||
Path: "/bar",
|
||||
Flags: sshFileXferAttrUIDGID,
|
||||
Attrs: struct {
|
||||
UID uint32
|
||||
GID uint32
|
||||
}{
|
||||
UID: 1000,
|
||||
GID: 100,
|
||||
},
|
||||
},
|
||||
want: []byte{
|
||||
0x0, 0x0, 0x0, 0x19,
|
||||
0x9,
|
||||
0x0, 0x0, 0x0, 0x1f,
|
||||
0x0, 0x0, 0x0, 0x4, '/', 'b', 'a', 'r',
|
||||
0x0, 0x0, 0x0, 0x2,
|
||||
0x0, 0x0, 0x3, 0xe8,
|
||||
0x0, 0x0, 0x0, 0x64,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
b := new(bytes.Buffer)
|
||||
sendPacket(b, tt.packet)
|
||||
if got := b.Bytes(); !bytes.Equal(tt.want, got) {
|
||||
t.Errorf("sendPacket(%v): got %x want %x", tt.packet, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sp(data encoding.BinaryMarshaler) []byte {
|
||||
b := new(bytes.Buffer)
|
||||
sendPacket(b, data)
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func TestRecvPacket(t *testing.T) {
|
||||
var recvPacketTests = []struct {
|
||||
b []byte
|
||||
|
||||
want uint8
|
||||
body []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
b: sp(&sshFxInitPacket{
|
||||
Version: 3,
|
||||
Extensions: []extensionPair{
|
||||
{"posix-rename@openssh.com", "1"},
|
||||
},
|
||||
}),
|
||||
want: sshFxpInit,
|
||||
body: []byte{
|
||||
0x0, 0x0, 0x0, 0x3,
|
||||
0x0, 0x0, 0x0, 0x18,
|
||||
'p', 'o', 's', 'i', 'x', '-', 'r', 'e', 'n', 'a', 'm', 'e', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
|
||||
0x0, 0x0, 0x0, 0x01,
|
||||
'1',
|
||||
},
|
||||
},
|
||||
{
|
||||
b: []byte{
|
||||
0x0, 0x0, 0x0, 0x0,
|
||||
},
|
||||
wantErr: errShortPacket,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range recvPacketTests {
|
||||
r := bytes.NewReader(tt.b)
|
||||
got, rest, _ := recvPacket(r, nil, 0)
|
||||
if got != tt.want || !bytes.Equal(rest, tt.rest) {
|
||||
t.Errorf("recvPacket(%#v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest)
|
||||
|
||||
got, body, err := recvPacket(r, nil, 0)
|
||||
if tt.wantErr == nil {
|
||||
if err != nil {
|
||||
t.Fatalf("recvPacket(%#v): unexpected error: %v", tt.b, err)
|
||||
}
|
||||
} else {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("recvPacket(%#v) = %v, want %v", tt.b, err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("recvPacket(%#v) = %#v, want %#v", tt.b, got, tt.want)
|
||||
}
|
||||
|
||||
if !bytes.Equal(body, tt.body) {
|
||||
t.Errorf("recvPacket(%#v) = %#v, want %#v", tt.b, body, tt.body)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -297,31 +459,33 @@ func TestSSHFxpOpenPackethasPflags(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func BenchmarkMarshalInit(b *testing.B) {
|
||||
func benchMarshal(b *testing.B, packet encoding.BinaryMarshaler) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
sp(&sshFxInitPacket{
|
||||
sendPacket(ioutil.Discard, packet)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMarshalInit(b *testing.B) {
|
||||
benchMarshal(b, &sshFxInitPacket{
|
||||
Version: 3,
|
||||
Extensions: []extensionPair{
|
||||
{"posix-rename@openssh.com", "1"},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMarshalOpen(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
sp(&sshFxpOpenPacket{
|
||||
benchMarshal(b, &sshFxpOpenPacket{
|
||||
ID: 1,
|
||||
Path: "/home/test/some/random/path",
|
||||
Pflags: flags(os.O_RDONLY),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMarshalWriteWorstCase(b *testing.B) {
|
||||
data := make([]byte, 32*1024)
|
||||
for i := 0; i < b.N; i++ {
|
||||
sp(&sshFxpWritePacket{
|
||||
|
||||
benchMarshal(b, &sshFxpWritePacket{
|
||||
ID: 1,
|
||||
Handle: "someopaquehandle",
|
||||
Offset: 0,
|
||||
|
@ -329,12 +493,11 @@ func BenchmarkMarshalWriteWorstCase(b *testing.B) {
|
|||
Data: data,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMarshalWrite1k(b *testing.B) {
|
||||
data := make([]byte, 1024)
|
||||
for i := 0; i < b.N; i++ {
|
||||
sp(&sshFxpWritePacket{
|
||||
data := make([]byte, 1025)
|
||||
|
||||
benchMarshal(b, &sshFxpWritePacket{
|
||||
ID: 1,
|
||||
Handle: "someopaquehandle",
|
||||
Offset: 0,
|
||||
|
@ -342,4 +505,3 @@ func BenchmarkMarshalWrite1k(b *testing.B) {
|
|||
Data: data,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -408,16 +408,16 @@ func testServer(t *testing.T, useSubsystem bool, readonly bool) (net.Listener, s
|
|||
func makeDummyKey() (string, error) {
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot generate key: %v", err)
|
||||
return "", fmt.Errorf("cannot generate key: %w", err)
|
||||
}
|
||||
der, err := x509.MarshalECPrivateKey(priv)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot marshal key: %v", err)
|
||||
return "", fmt.Errorf("cannot marshal key: %w", err)
|
||||
}
|
||||
block := &pem.Block{Type: "EC PRIVATE KEY", Bytes: der}
|
||||
f, err := ioutil.TempFile("", "sftp-test-key-")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot create temp file: %v", err)
|
||||
return "", fmt.Errorf("cannot create temp file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if f != nil {
|
||||
|
@ -426,16 +426,34 @@ func makeDummyKey() (string, error) {
|
|||
}
|
||||
}()
|
||||
if err := pem.Encode(f, block); err != nil {
|
||||
return "", fmt.Errorf("cannot write key: %v", err)
|
||||
return "", fmt.Errorf("cannot write key: %w", err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return "", fmt.Errorf("error closing key file: %v", err)
|
||||
return "", fmt.Errorf("error closing key file: %w", err)
|
||||
}
|
||||
path := f.Name()
|
||||
f = nil
|
||||
return path, nil
|
||||
}
|
||||
|
||||
type execError struct {
|
||||
path string
|
||||
stderr string
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *execError) Error() string {
|
||||
return fmt.Sprintf("%s: %v: %s", e.path, e.err, e.stderr)
|
||||
}
|
||||
|
||||
func (e *execError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
func (e *execError) Cause() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
func runSftpClient(t *testing.T, script string, path string, host string, port int) (string, error) {
|
||||
// if sftp client binary is unavailable, skip test
|
||||
if _, err := os.Stat(*testSftpClientBin); err != nil {
|
||||
|
@ -471,7 +489,11 @@ func runSftpClient(t *testing.T, script string, path string, host string, port i
|
|||
}
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%v: %s", err, stderr.String())
|
||||
err = &execError{
|
||||
path: cmd.Path,
|
||||
stderr: stderr.String(),
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
return stdout.String(), err
|
||||
}
|
||||
|
|
8
sftp.go
8
sftp.go
|
@ -200,15 +200,15 @@ func unimplementedPacketErr(u uint8) error {
|
|||
type unexpectedIDErr struct{ want, got uint32 }
|
||||
|
||||
func (u *unexpectedIDErr) Error() string {
|
||||
return fmt.Sprintf("sftp: unexpected id: want %v, got %v", u.want, u.got)
|
||||
return fmt.Sprintf("sftp: unexpected id: want %d, got %d", u.want, u.got)
|
||||
}
|
||||
|
||||
func unimplementedSeekWhence(whence int) error {
|
||||
return errors.Errorf("sftp: unimplemented seek whence %v", whence)
|
||||
return errors.Errorf("sftp: unimplemented seek whence %d", whence)
|
||||
}
|
||||
|
||||
func unexpectedCount(want, got uint32) error {
|
||||
return errors.Errorf("sftp: unexpected count: want %v, got %v", want, got)
|
||||
return errors.Errorf("sftp: unexpected count: want %d, got %d", want, got)
|
||||
}
|
||||
|
||||
type unexpectedVersionErr struct{ want, got uint32 }
|
||||
|
@ -239,7 +239,7 @@ func getSupportedExtensionByName(extensionName string) (sshExtensionPair, error)
|
|||
return supportedExtension, nil
|
||||
}
|
||||
}
|
||||
return sshExtensionPair{}, fmt.Errorf("Unsupported extension: %v", extensionName)
|
||||
return sshExtensionPair{}, fmt.Errorf("unsupported extension: %s", extensionName)
|
||||
}
|
||||
|
||||
// SetSFTPExtensions allows to customize the supported server extensions.
|
||||
|
|
Loading…
Reference in New Issue