Improve benchmarks and errors

This commit is contained in:
Cassondra Foesch 2021-03-17 11:03:24 +00:00
parent 460ad57385
commit f1e28f8a88
9 changed files with 470 additions and 253 deletions

View File

@ -43,10 +43,10 @@ type ClientOption func(*Client) error
func MaxPacketChecked(size int) ClientOption { func MaxPacketChecked(size int) ClientOption {
return func(c *Client) error { return func(c *Client) error {
if size < 1 { if size < 1 {
return errors.Errorf("size must be greater or equal to 1") return errors.New("size must be greater or equal to 1")
} }
if size > 32768 { if size > 32768 {
return errors.Errorf("sizes larger than 32KB might not work with all servers") return errors.New("sizes larger than 32KB might not work with all servers")
} }
c.maxPacket = size c.maxPacket = size
return nil return nil
@ -65,7 +65,7 @@ func MaxPacketChecked(size int) ClientOption {
func MaxPacketUnchecked(size int) ClientOption { func MaxPacketUnchecked(size int) ClientOption {
return func(c *Client) error { return func(c *Client) error {
if size < 1 { if size < 1 {
return errors.Errorf("size must be greater or equal to 1") return errors.New("size must be greater or equal to 1")
} }
c.maxPacket = size c.maxPacket = size
return nil return nil
@ -90,7 +90,7 @@ func MaxPacket(size int) ClientOption {
func MaxConcurrentRequestsPerFile(n int) ClientOption { func MaxConcurrentRequestsPerFile(n int) ClientOption {
return func(c *Client) error { return func(c *Client) error {
if n < 1 { if n < 1 {
return errors.Errorf("n must be greater or equal to 1") return errors.New("n must be greater or equal to 1")
} }
c.maxConcurrentRequests = n c.maxConcurrentRequests = n
return nil return nil

View File

@ -6,7 +6,6 @@ package sftp
import ( import (
"bytes" "bytes"
"crypto/sha1" "crypto/sha1"
"encoding"
"errors" "errors"
"io" "io"
"io/ioutil" "io/ioutil"
@ -1490,26 +1489,48 @@ func TestClientReadFrom(t *testing.T) {
var errFakeNet = errors.New("Fake network issue") var errFakeNet = errors.New("Fake network issue")
func TestClientReadFromDeadlock(t *testing.T) { func TestClientReadFromDeadlock(t *testing.T) {
clientWriteDeadlock(t, 1, func(f *File) { for i := 0; i < 5; i++ {
clientWriteDeadlock(t, i, func(f *File) {
b := make([]byte, 32768*4) b := make([]byte, 32768*4)
content := bytes.NewReader(b) content := bytes.NewReader(b)
_, err := f.ReadFrom(content) _, err := f.ReadFrom(content)
if err != errFakeNet { if !errors.Is(err, errFakeNet) {
t.Fatal("Didn't recieve correct error:", err) t.Fatal("Didn't recieve correct error:", err)
} }
}) })
}
} }
// Write has exact same problem // Write has exact same problem
func TestClientWriteDeadlock(t *testing.T) { func TestClientWriteDeadlock(t *testing.T) {
clientWriteDeadlock(t, 1, func(f *File) { for i := 0; i < 5; i++ {
clientWriteDeadlock(t, i, func(f *File) {
b := make([]byte, 32768*4) b := make([]byte, 32768*4)
_, err := f.Write(b) _, err := f.Write(b)
if err != errFakeNet { if !errors.Is(err, errFakeNet) {
t.Fatal("Didn't recieve correct error:", err) t.Fatal("Didn't recieve correct error:", err)
} }
}) })
}
}
type timeBombWriter struct {
count int
w io.WriteCloser
}
func (w *timeBombWriter) Write(b []byte) (int, error) {
if w.count < 1 {
return 0, errFakeNet
}
w.count--
return w.w.Write(b)
}
func (w *timeBombWriter) Close() error {
return w.w.Close()
} }
// shared body for both previous tests // shared body for both previous tests
@ -1534,20 +1555,13 @@ func clientWriteDeadlock(t *testing.T, N int, badfunc func(*File)) {
} }
defer w.Close() defer w.Close()
// Override sendPacket with failing version // Override the clienConn Writer with a failing version
// Replicates network error/drop part way through (after 1 good packet) // Replicates network error/drop part way through (after N good writes)
count := 0 wrap := sftp.clientConn.conn.WriteCloser
sendPacketTest := func(w io.Writer, m encoding.BinaryMarshaler) error { sftp.clientConn.conn.WriteCloser = &timeBombWriter{
count++ count: N,
if count > N { w: wrap,
return errFakeNet
} }
return sendPacket(w, m)
}
sftp.clientConn.conn.sendPacketTest = sendPacketTest
defer func() {
sftp.clientConn.conn.sendPacketTest = nil
}()
// this locked (before the fix) // this locked (before the fix)
badfunc(w) badfunc(w)
@ -1555,27 +1569,31 @@ func clientWriteDeadlock(t *testing.T, N int, badfunc func(*File)) {
// Read/WriteTo has this issue as well // Read/WriteTo has this issue as well
func TestClientReadDeadlock(t *testing.T) { func TestClientReadDeadlock(t *testing.T) {
clientReadDeadlock(t, 1, func(f *File) { for i := 0; i < 5; i++ {
clientReadDeadlock(t, i, func(f *File) {
b := make([]byte, 32768*4) b := make([]byte, 32768*4)
_, err := f.Read(b) _, err := f.Read(b)
if err != errFakeNet { if !errors.Is(err, errFakeNet) {
t.Fatal("Didn't recieve correct error:", err) t.Fatal("Didn't recieve correct error:", err)
} }
}) })
}
} }
func TestClientWriteToDeadlock(t *testing.T) { func TestClientWriteToDeadlock(t *testing.T) {
clientReadDeadlock(t, 2, func(f *File) { for i := 0; i < 5; i++ {
clientReadDeadlock(t, i, func(f *File) {
b := make([]byte, 32768*4) b := make([]byte, 32768*4)
buf := bytes.NewBuffer(b) buf := bytes.NewBuffer(b)
_, err := f.WriteTo(buf) _, err := f.WriteTo(buf)
if err != errFakeNet { if !errors.Is(err, errFakeNet) {
t.Fatal("Didn't recieve correct error:", err) t.Fatal("Didn't recieve correct error:", err)
} }
}) })
}
} }
func clientReadDeadlock(t *testing.T, N int, badfunc func(*File)) { func clientReadDeadlock(t *testing.T, N int, badfunc func(*File)) {
@ -1611,21 +1629,13 @@ func clientReadDeadlock(t *testing.T, N int, badfunc func(*File)) {
} }
defer r.Close() defer r.Close()
// Override sendPacket with failing version // Override the clienConn Writer with a failing version
// Replicates network error/drop part way through (after 1 good packet) // Replicates network error/drop part way through (after N good writes)
count := 0 wrap := sftp.clientConn.conn.WriteCloser
sendPacketTest := func(w io.Writer, m encoding.BinaryMarshaler) error { sftp.clientConn.conn.WriteCloser = &timeBombWriter{
count++ count: N,
if count > N { w: wrap,
return errFakeNet
} }
return sendPacket(w, m)
}
sftp.clientConn.conn.sendPacketTest = sendPacketTest
defer func() {
sftp.clientConn.conn.sendPacketTest = nil
}()
// this locked (before the fix) // this locked (before the fix)
badfunc(r) badfunc(r)
@ -2444,6 +2454,28 @@ func BenchmarkReadFrom4MiBDelay150Msec(b *testing.B) {
benchmarkReadFrom(b, 4*1024*1024, 150*time.Millisecond) benchmarkReadFrom(b, 4*1024*1024, 150*time.Millisecond)
} }
// writeToBuffer implements the relevant parts of bytes.Buffer,
// but does not release its internal buffer when Reset.
//
// Release its internal memory when Reset is good for avoiding memory leaks,
// but not great for memory benchmarks, as this fills up a lot of irrelevant allocations.
type writeToBuffer struct {
b []byte
}
func (w *writeToBuffer) Len() int {
return len(w.b)
}
func (w *writeToBuffer) Reset() {
w.b = w.b[:0]
}
func (w *writeToBuffer) Write(b []byte) (int, error) {
w.b = append(w.b, b...)
return len(b), nil
}
func benchmarkWriteTo(b *testing.B, bufsize int, delay time.Duration) { func benchmarkWriteTo(b *testing.B, bufsize int, delay time.Duration) {
size := 10*1024*1024 + 123 // ~10MiB size := 10*1024*1024 + 123 // ~10MiB
@ -2466,7 +2498,9 @@ func benchmarkWriteTo(b *testing.B, bufsize int, delay time.Duration) {
b.ResetTimer() b.ResetTimer()
b.SetBytes(int64(size)) b.SetBytes(int64(size))
buf := new(bytes.Buffer) buf := &writeToBuffer{
b: make([]byte, 0, size),
}
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
buf.Reset() buf.Reset()

View File

@ -16,8 +16,6 @@ type conn struct {
// this is the same allocator used in packet manager // this is the same allocator used in packet manager
alloc *allocator alloc *allocator
sync.Mutex // used to serialise writes to sendPacket sync.Mutex // used to serialise writes to sendPacket
// sendPacketTest is needed to replicate packet issues in testing
sendPacketTest func(w io.Writer, m encoding.BinaryMarshaler) error
} }
// the orderID is used in server mode if the allocator is enabled. // the orderID is used in server mode if the allocator is enabled.
@ -29,9 +27,7 @@ func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) {
func (c *conn) sendPacket(m encoding.BinaryMarshaler) error { func (c *conn) sendPacket(m encoding.BinaryMarshaler) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
if c.sendPacketTest != nil {
return c.sendPacketTest(c, m)
}
return sendPacket(c, m) return sendPacket(c, m)
} }
@ -91,7 +87,7 @@ func (c *clientConn) recv() error {
// This is an unexpected occurrence. Send the error // This is an unexpected occurrence. Send the error
// back to all listeners so that they terminate // back to all listeners so that they terminate
// gracefully. // gracefully.
return errors.Errorf("sid not found: %v", sid) return errors.Errorf("sid not found: %d", sid)
} }
ch <- result{typ: typ, data: data} ch <- result{typ: typ, data: data}

View File

@ -131,7 +131,7 @@ func ExampleClient_Mkdir_parents() {
fi, err = client.Stat(parents) fi, err = client.Stat(parents)
if err == nil { if err == nil {
if !fi.IsDir() { if !fi.IsDir() {
return fmt.Errorf("File exists: %s", parents) return fmt.Errorf("file exists: %s", parents)
} }
} }
} }

3
go.sum
View File

@ -15,9 +15,12 @@ golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWP
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 h1:myAQVi0cGEoqQVR5POX+8RR2mrocKqNN1hmeMqhX27k=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221 h1:/ZHdbVpdR/jk3g30/d4yUL0JU9kksj8+F/bnQUVLGDM=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -133,7 +133,7 @@ func marshalPacket(m encoding.BinaryMarshaler) (header, payload []byte, err erro
func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error {
header, payload, err := marshalPacket(m) header, payload, err := marshalPacket(m)
if err != nil { if err != nil {
return errors.Errorf("binary marshaller failed: %v", err) return errors.Wrap(err, "binary marshaller failed")
} }
length := len(header) + len(payload) - 4 // subtract the uint32(length) from the start length := len(header) + len(payload) - 4 // subtract the uint32(length) from the start
@ -146,12 +146,12 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error {
binary.BigEndian.PutUint32(header[:4], uint32(length)) binary.BigEndian.PutUint32(header[:4], uint32(length))
if _, err := w.Write(header); err != nil { if _, err := w.Write(header); err != nil {
return errors.Errorf("failed to send packet: %v", err) return errors.Wrap(err, "failed to send packet")
} }
if len(payload) > 0 { if len(payload) > 0 {
if _, err := w.Write(payload); err != nil { if _, err := w.Write(payload); err != nil {
return errors.Errorf("failed to send packet payload: %v", err) return errors.Wrap(err, "failed to send packet payload")
} }
} }

View File

@ -3,120 +3,174 @@ package sftp
import ( import (
"bytes" "bytes"
"encoding" "encoding"
"errors"
"io/ioutil"
"os" "os"
"testing" "testing"
) )
var marshalUint32Tests = []struct { func TestMarshalUint32(t *testing.T) {
var tests = []struct {
v uint32 v uint32
want []byte want []byte
}{ }{
{1, []byte{0, 0, 0, 1}}, {0, []byte{0, 0, 0, 0}},
{256, []byte{0, 0, 1, 0}}, {42, []byte{0, 0, 0, 42}},
{42 << 8, []byte{0, 0, 42, 0}},
{42 << 16, []byte{0, 42, 0, 0}},
{42 << 24, []byte{42, 0, 0, 0}},
{^uint32(0), []byte{255, 255, 255, 255}}, {^uint32(0), []byte{255, 255, 255, 255}},
} }
func TestMarshalUint32(t *testing.T) { for _, tt := range tests {
for _, tt := range marshalUint32Tests {
got := marshalUint32(nil, tt.v) got := marshalUint32(nil, tt.v)
if !bytes.Equal(tt.want, got) { if !bytes.Equal(tt.want, got) {
t.Errorf("marshalUint32(%d): want %v, got %v", tt.v, tt.want, got) t.Errorf("marshalUint32(%d) = %#v, want %#v", tt.v, got, tt.want)
} }
} }
} }
var marshalUint64Tests = []struct {
v uint64
want []byte
}{
{1, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}},
{256, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0}},
{^uint64(0), []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
{1 << 32, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0}},
}
func TestMarshalUint64(t *testing.T) { func TestMarshalUint64(t *testing.T) {
for _, tt := range marshalUint64Tests { var tests = []struct {
v uint64
want []byte
}{
{0, []byte{0, 0, 0, 0, 0, 0, 0, 0}},
{42, []byte{0, 0, 0, 0, 0, 0, 0, 42}},
{42 << 8, []byte{0, 0, 0, 0, 0, 0, 42, 0}},
{42 << 16, []byte{0, 0, 0, 0, 0, 42, 0, 0}},
{42 << 24, []byte{0, 0, 0, 0, 42, 0, 0, 0}},
{42 << 32, []byte{0, 0, 0, 42, 0, 0, 0, 0}},
{42 << 40, []byte{0, 0, 42, 0, 0, 0, 0, 0}},
{42 << 48, []byte{0, 42, 0, 0, 0, 0, 0, 0}},
{42 << 56, []byte{42, 0, 0, 0, 0, 0, 0, 0}},
{^uint64(0), []byte{255, 255, 255, 255, 255, 255, 255, 255}},
}
for _, tt := range tests {
got := marshalUint64(nil, tt.v) got := marshalUint64(nil, tt.v)
if !bytes.Equal(tt.want, got) { if !bytes.Equal(tt.want, got) {
t.Errorf("marshalUint64(%d): want %#v, got %#v", tt.v, tt.want, got) t.Errorf("marshalUint64(%d) = %#v, want %#v", tt.v, got, tt.want)
} }
} }
} }
var marshalStringTests = []struct {
v string
want []byte
}{
{"", []byte{0, 0, 0, 0}},
{"/foo", []byte{0x0, 0x0, 0x0, 0x4, 0x2f, 0x66, 0x6f, 0x6f}},
}
func TestMarshalString(t *testing.T) { func TestMarshalString(t *testing.T) {
for _, tt := range marshalStringTests { var tests = []struct {
v string
want []byte
}{
{"", []byte{0, 0, 0, 0}},
{"/", []byte{0x0, 0x0, 0x0, 0x01, '/'}},
{"/foo", []byte{0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o'}},
{"\x00bar", []byte{0x0, 0x0, 0x0, 0x4, 0, 'b', 'a', 'r'}},
{"b\x00ar", []byte{0x0, 0x0, 0x0, 0x4, 'b', 0, 'a', 'r'}},
{"ba\x00r", []byte{0x0, 0x0, 0x0, 0x4, 'b', 'a', 0, 'r'}},
{"bar\x00", []byte{0x0, 0x0, 0x0, 0x4, 'b', 'a', 'r', 0}},
}
for _, tt := range tests {
got := marshalString(nil, tt.v) got := marshalString(nil, tt.v)
if !bytes.Equal(tt.want, got) { if !bytes.Equal(tt.want, got) {
t.Errorf("marshalString(%q): want %#v, got %#v", tt.v, tt.want, got) t.Errorf("marshalString(%q) = %#v, want %#v", tt.v, got, tt.want)
} }
} }
} }
var marshalTests = []struct {
v interface{}
want []byte
}{
{uint8(1), []byte{1}},
{byte(1), []byte{1}},
{uint32(1), []byte{0, 0, 0, 1}},
{uint64(1), []byte{0, 0, 0, 0, 0, 0, 0, 1}},
{"foo", []byte{0x0, 0x0, 0x0, 0x3, 0x66, 0x6f, 0x6f}},
{[]uint32{1, 2, 3, 4}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x4}},
}
func TestMarshal(t *testing.T) { func TestMarshal(t *testing.T) {
for _, tt := range marshalTests { type Struct struct {
X, Y, Z uint32
}
var tests = []struct {
v interface{}
want []byte
}{
{uint8(42), []byte{42}},
{uint32(42 << 8), []byte{0, 0, 42, 0}},
{uint64(42 << 32), []byte{0, 0, 0, 42, 0, 0, 0, 0}},
{"foo", []byte{0x0, 0x0, 0x0, 0x3, 'f', 'o', 'o'}},
{Struct{1, 2, 3}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3}},
{[]uint32{1, 2, 3}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3}},
}
for _, tt := range tests {
got := marshal(nil, tt.v) got := marshal(nil, tt.v)
if !bytes.Equal(tt.want, got) { if !bytes.Equal(tt.want, got) {
t.Errorf("marshal(%v): want %#v, got %#v", tt.v, tt.want, got) t.Errorf("marshal(%#v) = %#v, want %#v", tt.v, got, tt.want)
} }
} }
} }
var unmarshalUint32Tests = []struct {
b []byte
want uint32
rest []byte
}{
{[]byte{0, 0, 0, 0}, 0, nil},
{[]byte{0, 0, 1, 0}, 256, nil},
{[]byte{255, 0, 0, 255}, 4278190335, nil},
}
func TestUnmarshalUint32(t *testing.T) { func TestUnmarshalUint32(t *testing.T) {
for _, tt := range unmarshalUint32Tests { testBuffer := []byte{
got, rest := unmarshalUint32(tt.b) 0, 0, 0, 0,
if got != tt.want || !bytes.Equal(rest, tt.rest) { 0, 0, 0, 42,
t.Errorf("unmarshalUint32(%v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) 0, 0, 42, 0,
0, 42, 0, 0,
42, 0, 0, 0,
255, 0, 0, 254,
} }
}
}
var unmarshalUint64Tests = []struct { var wants = []uint32{
b []byte 0,
want uint64 42,
rest []byte 42 << 8,
}{ 42 << 16,
{[]byte{0, 0, 0, 0, 0, 0, 0, 0}, 0, nil}, 42 << 24,
{[]byte{0, 0, 0, 0, 0, 0, 1, 0}, 256, nil}, 255<<24 | 254,
{[]byte{255, 0, 0, 0, 0, 0, 0, 255}, 18374686479671623935, nil}, }
var i int
for len(testBuffer) > 0 {
got, rest := unmarshalUint32(testBuffer)
if got != wants[i] {
t.Fatalf("unmarshalUint32(%#v) = %d, want %d", testBuffer[:4], got, wants[i])
}
i++
testBuffer = rest
}
} }
func TestUnmarshalUint64(t *testing.T) { func TestUnmarshalUint64(t *testing.T) {
for _, tt := range unmarshalUint64Tests { testBuffer := []byte{
got, rest := unmarshalUint64(tt.b) 0, 0, 0, 0, 0, 0, 0, 0,
if got != tt.want || !bytes.Equal(rest, tt.rest) { 0, 0, 0, 0, 0, 0, 0, 42,
t.Errorf("unmarshalUint64(%v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) 0, 0, 0, 0, 0, 0, 42, 0,
0, 0, 0, 0, 0, 42, 0, 0,
0, 0, 0, 0, 42, 0, 0, 0,
0, 0, 0, 42, 0, 0, 0, 0,
0, 0, 42, 0, 0, 0, 0, 0,
0, 42, 0, 0, 0, 0, 0, 0,
42, 0, 0, 0, 0, 0, 0, 0,
255, 0, 0, 0, 0, 0, 0, 254,
} }
var wants = []uint64{
0,
42,
42 << 8,
42 << 16,
42 << 24,
42 << 32,
42 << 40,
42 << 48,
42 << 56,
255<<56 | 254,
}
var i int
for len(testBuffer) > 0 {
got, rest := unmarshalUint64(testBuffer)
if got != wants[i] {
t.Fatalf("unmarshalUint64(%#v) = %d, want %d", testBuffer[:8], got, wants[i])
}
i++
testBuffer = rest
} }
} }
@ -130,85 +184,193 @@ var unmarshalStringTests = []struct {
} }
func TestUnmarshalString(t *testing.T) { func TestUnmarshalString(t *testing.T) {
for _, tt := range unmarshalStringTests { testBuffer := []byte{
got, rest := unmarshalString(tt.b) 0, 0, 0, 0,
if got != tt.want || !bytes.Equal(rest, tt.rest) { 0, 0, 0, 1, '/',
t.Errorf("unmarshalUint64(%v): want %q, %#v, got %q, %#v", tt.b, tt.want, tt.rest, got, rest) 0, 0, 0, 4, '/', 'f', 'o', 'o',
0, 0, 0, 4, 0, 'b', 'a', 'r',
0, 0, 0, 4, 'b', 0, 'a', 'r',
0, 0, 0, 4, 'b', 'a', 0, 'r',
0, 0, 0, 4, 'b', 'a', 'r', 0,
} }
var wants = []string{
"",
"/",
"/foo",
"\x00bar",
"b\x00ar",
"ba\x00r",
"bar\x00",
}
var i int
for len(testBuffer) > 0 {
got, rest := unmarshalString(testBuffer)
if got != wants[i] {
t.Fatalf("unmarshalUint64(%#v...) = %q, want %q", testBuffer[:4], got, wants[i])
}
i++
testBuffer = rest
} }
} }
var sendPacketTests = []struct { type nopCloserBuffer struct {
p encoding.BinaryMarshaler bytes.Buffer
want []byte }
}{
{&sshFxInitPacket{
Version: 3,
Extensions: []extensionPair{
{"posix-rename@openssh.com", "1"},
},
}, []byte{0x0, 0x0, 0x0, 0x26, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}},
{&sshFxpOpenPacket{ func (*nopCloserBuffer) Close() error {
ID: 1, return nil
Path: "/foo",
Pflags: flags(os.O_RDONLY),
}, []byte{0x0, 0x0, 0x0, 0x15, 0x3, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x4, 0x2f, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0}},
{&sshFxpWritePacket{
ID: 124,
Handle: "foo",
Offset: 13,
Length: uint32(len([]byte("bar"))),
Data: []byte("bar"),
}, []byte{0x0, 0x0, 0x0, 0x1b, 0x6, 0x0, 0x0, 0x0, 0x7c, 0x0, 0x0, 0x0, 0x3, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd, 0x0, 0x0, 0x0, 0x3, 0x62, 0x61, 0x72}},
{&sshFxpSetstatPacket{
ID: 31,
Path: "/bar",
Flags: flags(os.O_WRONLY),
Attrs: struct {
UID uint32
GID uint32
}{1000, 100},
}, []byte{0x0, 0x0, 0x0, 0x19, 0x9, 0x0, 0x0, 0x0, 0x1f, 0x0, 0x0, 0x0, 0x4, 0x2f, 0x62, 0x61, 0x72, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x3, 0xe8, 0x0, 0x0, 0x0, 0x64}},
} }
func TestSendPacket(t *testing.T) { func TestSendPacket(t *testing.T) {
for _, tt := range sendPacketTests { var tests = []struct {
var w bytes.Buffer packet encoding.BinaryMarshaler
sendPacket(&w, tt.p) want []byte
if got := w.Bytes(); !bytes.Equal(tt.want, got) { }{
t.Errorf("sendPacket(%v): want %#v, got %#v", tt.p, tt.want, got) {
} packet: &sshFxInitPacket{
}
}
func sp(p encoding.BinaryMarshaler) []byte {
var w bytes.Buffer
sendPacket(&w, p)
return w.Bytes()
}
var recvPacketTests = []struct {
b []byte
want uint8
rest []byte
}{
{sp(&sshFxInitPacket{
Version: 3, Version: 3,
Extensions: []extensionPair{ Extensions: []extensionPair{
{"posix-rename@openssh.com", "1"}, {"posix-rename@openssh.com", "1"},
}, },
}), sshFxpInit, []byte{0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}}, },
want: []byte{
0x0, 0x0, 0x0, 0x26,
0x1,
0x0, 0x0, 0x0, 0x3,
0x0, 0x0, 0x0, 0x18,
'p', 'o', 's', 'i', 'x', '-', 'r', 'e', 'n', 'a', 'm', 'e', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
0x0, 0x0, 0x0, 0x1,
'1',
},
},
{
packet: &sshFxpOpenPacket{
ID: 1,
Path: "/foo",
Pflags: flags(os.O_RDONLY),
},
want: []byte{
0x0, 0x0, 0x0, 0x15,
0x3,
0x0, 0x0, 0x0, 0x1,
0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o',
0x0, 0x0, 0x0, 0x1,
0x0, 0x0, 0x0, 0x0,
},
},
{
packet: &sshFxpWritePacket{
ID: 124,
Handle: "foo",
Offset: 13,
Length: uint32(len("bar")),
Data: []byte("bar"),
},
want: []byte{
0x0, 0x0, 0x0, 0x1b,
0x6,
0x0, 0x0, 0x0, 0x7c,
0x0, 0x0, 0x0, 0x3, 'f', 'o', 'o',
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd,
0x0, 0x0, 0x0, 0x3, 'b', 'a', 'r',
},
},
{
packet: &sshFxpSetstatPacket{
ID: 31,
Path: "/bar",
Flags: sshFileXferAttrUIDGID,
Attrs: struct {
UID uint32
GID uint32
}{
UID: 1000,
GID: 100,
},
},
want: []byte{
0x0, 0x0, 0x0, 0x19,
0x9,
0x0, 0x0, 0x0, 0x1f,
0x0, 0x0, 0x0, 0x4, '/', 'b', 'a', 'r',
0x0, 0x0, 0x0, 0x2,
0x0, 0x0, 0x3, 0xe8,
0x0, 0x0, 0x0, 0x64,
},
},
}
for _, tt := range tests {
b := new(bytes.Buffer)
sendPacket(b, tt.packet)
if got := b.Bytes(); !bytes.Equal(tt.want, got) {
t.Errorf("sendPacket(%v): got %x want %x", tt.packet, tt.want, got)
}
}
}
func sp(data encoding.BinaryMarshaler) []byte {
b := new(bytes.Buffer)
sendPacket(b, data)
return b.Bytes()
} }
func TestRecvPacket(t *testing.T) { func TestRecvPacket(t *testing.T) {
var recvPacketTests = []struct {
b []byte
want uint8
body []byte
wantErr error
}{
{
b: sp(&sshFxInitPacket{
Version: 3,
Extensions: []extensionPair{
{"posix-rename@openssh.com", "1"},
},
}),
want: sshFxpInit,
body: []byte{
0x0, 0x0, 0x0, 0x3,
0x0, 0x0, 0x0, 0x18,
'p', 'o', 's', 'i', 'x', '-', 'r', 'e', 'n', 'a', 'm', 'e', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
0x0, 0x0, 0x0, 0x01,
'1',
},
},
{
b: []byte{
0x0, 0x0, 0x0, 0x0,
},
wantErr: errShortPacket,
},
}
for _, tt := range recvPacketTests { for _, tt := range recvPacketTests {
r := bytes.NewReader(tt.b) r := bytes.NewReader(tt.b)
got, rest, _ := recvPacket(r, nil, 0)
if got != tt.want || !bytes.Equal(rest, tt.rest) { got, body, err := recvPacket(r, nil, 0)
t.Errorf("recvPacket(%#v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) if tt.wantErr == nil {
if err != nil {
t.Fatalf("recvPacket(%#v): unexpected error: %v", tt.b, err)
}
} else {
if !errors.Is(err, tt.wantErr) {
t.Fatalf("recvPacket(%#v) = %v, want %v", tt.b, err, tt.wantErr)
}
}
if got != tt.want {
t.Errorf("recvPacket(%#v) = %#v, want %#v", tt.b, got, tt.want)
}
if !bytes.Equal(body, tt.body) {
t.Errorf("recvPacket(%#v) = %#v, want %#v", tt.b, body, tt.body)
} }
} }
} }
@ -297,49 +459,49 @@ func TestSSHFxpOpenPackethasPflags(t *testing.T) {
} }
} }
func BenchmarkMarshalInit(b *testing.B) { func benchMarshal(b *testing.B, packet encoding.BinaryMarshaler) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
sp(&sshFxInitPacket{ sendPacket(ioutil.Discard, packet)
}
}
func BenchmarkMarshalInit(b *testing.B) {
benchMarshal(b, &sshFxInitPacket{
Version: 3, Version: 3,
Extensions: []extensionPair{ Extensions: []extensionPair{
{"posix-rename@openssh.com", "1"}, {"posix-rename@openssh.com", "1"},
}, },
}) })
}
} }
func BenchmarkMarshalOpen(b *testing.B) { func BenchmarkMarshalOpen(b *testing.B) {
for i := 0; i < b.N; i++ { benchMarshal(b, &sshFxpOpenPacket{
sp(&sshFxpOpenPacket{
ID: 1, ID: 1,
Path: "/home/test/some/random/path", Path: "/home/test/some/random/path",
Pflags: flags(os.O_RDONLY), Pflags: flags(os.O_RDONLY),
}) })
}
} }
func BenchmarkMarshalWriteWorstCase(b *testing.B) { func BenchmarkMarshalWriteWorstCase(b *testing.B) {
data := make([]byte, 32*1024) data := make([]byte, 32*1024)
for i := 0; i < b.N; i++ {
sp(&sshFxpWritePacket{ benchMarshal(b, &sshFxpWritePacket{
ID: 1, ID: 1,
Handle: "someopaquehandle", Handle: "someopaquehandle",
Offset: 0, Offset: 0,
Length: uint32(len(data)), Length: uint32(len(data)),
Data: data, Data: data,
}) })
}
} }
func BenchmarkMarshalWrite1k(b *testing.B) { func BenchmarkMarshalWrite1k(b *testing.B) {
data := make([]byte, 1024) data := make([]byte, 1025)
for i := 0; i < b.N; i++ {
sp(&sshFxpWritePacket{ benchMarshal(b, &sshFxpWritePacket{
ID: 1, ID: 1,
Handle: "someopaquehandle", Handle: "someopaquehandle",
Offset: 0, Offset: 0,
Length: uint32(len(data)), Length: uint32(len(data)),
Data: data, Data: data,
}) })
}
} }

View File

@ -408,16 +408,16 @@ func testServer(t *testing.T, useSubsystem bool, readonly bool) (net.Listener, s
func makeDummyKey() (string, error) { func makeDummyKey() (string, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader) priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader)
if err != nil { if err != nil {
return "", fmt.Errorf("cannot generate key: %v", err) return "", fmt.Errorf("cannot generate key: %w", err)
} }
der, err := x509.MarshalECPrivateKey(priv) der, err := x509.MarshalECPrivateKey(priv)
if err != nil { if err != nil {
return "", fmt.Errorf("cannot marshal key: %v", err) return "", fmt.Errorf("cannot marshal key: %w", err)
} }
block := &pem.Block{Type: "EC PRIVATE KEY", Bytes: der} block := &pem.Block{Type: "EC PRIVATE KEY", Bytes: der}
f, err := ioutil.TempFile("", "sftp-test-key-") f, err := ioutil.TempFile("", "sftp-test-key-")
if err != nil { if err != nil {
return "", fmt.Errorf("cannot create temp file: %v", err) return "", fmt.Errorf("cannot create temp file: %w", err)
} }
defer func() { defer func() {
if f != nil { if f != nil {
@ -426,16 +426,34 @@ func makeDummyKey() (string, error) {
} }
}() }()
if err := pem.Encode(f, block); err != nil { if err := pem.Encode(f, block); err != nil {
return "", fmt.Errorf("cannot write key: %v", err) return "", fmt.Errorf("cannot write key: %w", err)
} }
if err := f.Close(); err != nil { if err := f.Close(); err != nil {
return "", fmt.Errorf("error closing key file: %v", err) return "", fmt.Errorf("error closing key file: %w", err)
} }
path := f.Name() path := f.Name()
f = nil f = nil
return path, nil return path, nil
} }
type execError struct {
path string
stderr string
err error
}
func (e *execError) Error() string {
return fmt.Sprintf("%s: %v: %s", e.path, e.err, e.stderr)
}
func (e *execError) Unwrap() error {
return e.err
}
func (e *execError) Cause() error {
return e.err
}
func runSftpClient(t *testing.T, script string, path string, host string, port int) (string, error) { func runSftpClient(t *testing.T, script string, path string, host string, port int) (string, error) {
// if sftp client binary is unavailable, skip test // if sftp client binary is unavailable, skip test
if _, err := os.Stat(*testSftpClientBin); err != nil { if _, err := os.Stat(*testSftpClientBin); err != nil {
@ -471,7 +489,11 @@ func runSftpClient(t *testing.T, script string, path string, host string, port i
} }
err = cmd.Wait() err = cmd.Wait()
if err != nil { if err != nil {
err = fmt.Errorf("%v: %s", err, stderr.String()) err = &execError{
path: cmd.Path,
stderr: stderr.String(),
err: err,
}
} }
return stdout.String(), err return stdout.String(), err
} }

View File

@ -200,15 +200,15 @@ func unimplementedPacketErr(u uint8) error {
type unexpectedIDErr struct{ want, got uint32 } type unexpectedIDErr struct{ want, got uint32 }
func (u *unexpectedIDErr) Error() string { func (u *unexpectedIDErr) Error() string {
return fmt.Sprintf("sftp: unexpected id: want %v, got %v", u.want, u.got) return fmt.Sprintf("sftp: unexpected id: want %d, got %d", u.want, u.got)
} }
func unimplementedSeekWhence(whence int) error { func unimplementedSeekWhence(whence int) error {
return errors.Errorf("sftp: unimplemented seek whence %v", whence) return errors.Errorf("sftp: unimplemented seek whence %d", whence)
} }
func unexpectedCount(want, got uint32) error { func unexpectedCount(want, got uint32) error {
return errors.Errorf("sftp: unexpected count: want %v, got %v", want, got) return errors.Errorf("sftp: unexpected count: want %d, got %d", want, got)
} }
type unexpectedVersionErr struct{ want, got uint32 } type unexpectedVersionErr struct{ want, got uint32 }
@ -239,7 +239,7 @@ func getSupportedExtensionByName(extensionName string) (sshExtensionPair, error)
return supportedExtension, nil return supportedExtension, nil
} }
} }
return sshExtensionPair{}, fmt.Errorf("Unsupported extension: %v", extensionName) return sshExtensionPair{}, fmt.Errorf("unsupported extension: %s", extensionName)
} }
// SetSFTPExtensions allows to customize the supported server extensions. // SetSFTPExtensions allows to customize the supported server extensions.