mirror of https://github.com/pkg/sftp.git
231 lines
4.8 KiB
Go
231 lines
4.8 KiB
Go
package sftp
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"io"
|
|
"os"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/kr/fs"
|
|
)
|
|
|
|
// assert that *Client implements fs.FileSystem
|
|
var _ fs.FileSystem = new(Client)
|
|
|
|
// assert that *File implements io.ReadWriteCloser
|
|
var _ io.ReadWriteCloser = new(File)
|
|
|
|
func TestNormaliseError(t *testing.T) {
|
|
var (
|
|
ok = &StatusError{Code: sshFxOk}
|
|
eof = &StatusError{Code: sshFxEOF}
|
|
fail = &StatusError{Code: sshFxFailure}
|
|
noSuchFile = &StatusError{Code: sshFxNoSuchFile}
|
|
foo = errors.New("foo")
|
|
)
|
|
|
|
var tests = []struct {
|
|
desc string
|
|
err error
|
|
want error
|
|
}{
|
|
{
|
|
desc: "nil error",
|
|
},
|
|
{
|
|
desc: "not *StatusError",
|
|
err: foo,
|
|
want: foo,
|
|
},
|
|
{
|
|
desc: "*StatusError with ssh_FX_EOF",
|
|
err: eof,
|
|
want: io.EOF,
|
|
},
|
|
{
|
|
desc: "*StatusError with ssh_FX_NO_SUCH_FILE",
|
|
err: noSuchFile,
|
|
want: os.ErrNotExist,
|
|
},
|
|
{
|
|
desc: "*StatusError with ssh_FX_OK",
|
|
err: ok,
|
|
},
|
|
{
|
|
desc: "*StatusError with ssh_FX_FAILURE",
|
|
err: fail,
|
|
want: fail,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
got := normaliseError(tt.err)
|
|
if got != tt.want {
|
|
t.Errorf("normaliseError(%#v), test %q\n- want: %#v\n- got: %#v",
|
|
tt.err, tt.desc, tt.want, got)
|
|
}
|
|
}
|
|
}
|
|
|
|
var flagsTests = []struct {
|
|
flags int
|
|
want uint32
|
|
}{
|
|
{os.O_RDONLY, sshFxfRead},
|
|
{os.O_WRONLY, sshFxfWrite},
|
|
{os.O_RDWR, sshFxfRead | sshFxfWrite},
|
|
{os.O_RDWR | os.O_CREATE | os.O_TRUNC, sshFxfRead | sshFxfWrite | sshFxfCreat | sshFxfTrunc},
|
|
{os.O_WRONLY | os.O_APPEND, sshFxfWrite | sshFxfAppend},
|
|
}
|
|
|
|
func TestFlags(t *testing.T) {
|
|
for i, tt := range flagsTests {
|
|
got := flags(tt.flags)
|
|
if got != tt.want {
|
|
t.Errorf("test %v: flags(%x): want: %x, got: %x", i, tt.flags, tt.want, got)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestUnmarshalStatus(t *testing.T) {
|
|
requestID := uint32(1)
|
|
|
|
id := marshalUint32([]byte{}, requestID)
|
|
idCode := marshalUint32(id, sshFxFailure)
|
|
idCodeMsg := marshalString(idCode, "err msg")
|
|
idCodeMsgLang := marshalString(idCodeMsg, "lang tag")
|
|
|
|
var tests = []struct {
|
|
desc string
|
|
reqID uint32
|
|
status []byte
|
|
want error
|
|
}{
|
|
{
|
|
desc: "well-formed status",
|
|
reqID: 1,
|
|
status: idCodeMsgLang,
|
|
want: &StatusError{
|
|
Code: sshFxFailure,
|
|
msg: "err msg",
|
|
lang: "lang tag",
|
|
},
|
|
},
|
|
{
|
|
desc: "missing error message and language tag",
|
|
reqID: 1,
|
|
status: idCode,
|
|
want: &StatusError{
|
|
Code: sshFxFailure,
|
|
},
|
|
},
|
|
{
|
|
desc: "missing language tag",
|
|
reqID: 1,
|
|
status: idCodeMsg,
|
|
want: &StatusError{
|
|
Code: sshFxFailure,
|
|
msg: "err msg",
|
|
},
|
|
},
|
|
{
|
|
desc: "request identifier mismatch",
|
|
reqID: 2,
|
|
status: idCodeMsgLang,
|
|
want: &unexpectedIDErr{2, requestID},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
got := unmarshalStatus(tt.reqID, tt.status)
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("unmarshalStatus(%v, %v), test %q\n- want: %#v\n- got: %#v",
|
|
requestID, tt.status, tt.desc, tt.want, got)
|
|
}
|
|
}
|
|
}
|
|
|
|
type packetSizeTest struct {
|
|
size int
|
|
valid bool
|
|
}
|
|
|
|
var maxPacketCheckedTests = []packetSizeTest{
|
|
{size: 0, valid: false},
|
|
{size: 1, valid: true},
|
|
{size: 32768, valid: true},
|
|
{size: 32769, valid: false},
|
|
}
|
|
|
|
var maxPacketUncheckedTests = []packetSizeTest{
|
|
{size: 0, valid: false},
|
|
{size: 1, valid: true},
|
|
{size: 32768, valid: true},
|
|
{size: 32769, valid: true},
|
|
}
|
|
|
|
func TestMaxPacketChecked(t *testing.T) {
|
|
for _, tt := range maxPacketCheckedTests {
|
|
testMaxPacketOption(t, MaxPacketChecked(tt.size), tt)
|
|
}
|
|
}
|
|
|
|
func TestMaxPacketUnchecked(t *testing.T) {
|
|
for _, tt := range maxPacketUncheckedTests {
|
|
testMaxPacketOption(t, MaxPacketUnchecked(tt.size), tt)
|
|
}
|
|
}
|
|
|
|
func TestMaxPacket(t *testing.T) {
|
|
for _, tt := range maxPacketCheckedTests {
|
|
testMaxPacketOption(t, MaxPacket(tt.size), tt)
|
|
}
|
|
}
|
|
|
|
func testMaxPacketOption(t *testing.T, o ClientOption, tt packetSizeTest) {
|
|
var c Client
|
|
|
|
err := o(&c)
|
|
if (err == nil) != tt.valid {
|
|
t.Errorf("MaxPacketChecked(%v)\n- want: %v\n- got: %v", tt.size, tt.valid, err == nil)
|
|
}
|
|
if c.maxPacket != tt.size && tt.valid {
|
|
t.Errorf("MaxPacketChecked(%v)\n- want: %v\n- got: %v", tt.size, tt.size, c.maxPacket)
|
|
}
|
|
}
|
|
|
|
func testFstatOption(t *testing.T, o ClientOption, value bool) {
|
|
var c Client
|
|
|
|
err := o(&c)
|
|
if err == nil && c.useFstat != value {
|
|
t.Errorf("UseFStat(%v)\n- want: %v\n- got: %v", value, value, c.useFstat)
|
|
}
|
|
}
|
|
|
|
func TestUseFstatChecked(t *testing.T) {
|
|
testFstatOption(t, UseFstat(true), true)
|
|
testFstatOption(t, UseFstat(false), false)
|
|
}
|
|
|
|
type sink struct{}
|
|
|
|
func (*sink) Close() error { return nil }
|
|
func (*sink) Write(p []byte) (int, error) { return len(p), nil }
|
|
|
|
func TestClientZeroLengthPacket(t *testing.T) {
|
|
// Packet length zero (never valid). This used to crash the client.
|
|
packet := []byte{0, 0, 0, 0}
|
|
|
|
r := bytes.NewReader(packet)
|
|
c, err := NewClientPipe(r, &sink{})
|
|
if err == nil {
|
|
t.Error("expected an error, got nil")
|
|
}
|
|
if c != nil {
|
|
c.Close()
|
|
}
|
|
}
|