| 
									
										
										
										
											2013-11-06 12:00:04 +08:00
										 |  |  | package sftp | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-11-07 14:43:06 +08:00
										 |  |  | import ( | 
					
						
							| 
									
										
										
										
											2020-10-30 16:36:20 +08:00
										 |  |  | 	"bytes" | 
					
						
							| 
									
										
										
										
											2016-01-04 03:54:19 +08:00
										 |  |  | 	"errors" | 
					
						
							| 
									
										
										
										
											2013-11-07 14:43:06 +08:00
										 |  |  | 	"io" | 
					
						
							| 
									
										
										
										
											2013-11-14 12:32:21 +08:00
										 |  |  | 	"os" | 
					
						
							| 
									
										
										
										
											2013-11-07 14:43:06 +08:00
										 |  |  | 	"testing" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-11-11 09:57:03 +08:00
										 |  |  | 	"github.com/kr/fs" | 
					
						
							| 
									
										
										
										
											2013-11-07 14:43:06 +08:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // assert that *Client implements fs.FileSystem
 | 
					
						
							|  |  |  | var _ fs.FileSystem = new(Client) | 
					
						
							| 
									
										
										
										
											2013-11-06 12:00:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2013-11-14 12:32:21 +08:00
										 |  |  | // assert that *File implements io.ReadWriteCloser
 | 
					
						
							|  |  |  | var _ io.ReadWriteCloser = new(File) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-01-04 03:54:19 +08:00
										 |  |  | func TestNormaliseError(t *testing.T) { | 
					
						
							|  |  |  | 	var ( | 
					
						
							| 
									
										
										
										
											2019-08-30 23:04:37 +08:00
										 |  |  | 		ok         = &StatusError{Code: sshFxOk} | 
					
						
							|  |  |  | 		eof        = &StatusError{Code: sshFxEOF} | 
					
						
							|  |  |  | 		fail       = &StatusError{Code: sshFxFailure} | 
					
						
							|  |  |  | 		noSuchFile = &StatusError{Code: sshFxNoSuchFile} | 
					
						
							| 
									
										
										
										
											2016-01-04 03:54:19 +08:00
										 |  |  | 		foo        = errors.New("foo") | 
					
						
							|  |  |  | 	) | 
					
						
							| 
									
										
										
										
											2013-11-06 12:00:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-01-04 03:54:19 +08:00
										 |  |  | 	var tests = []struct { | 
					
						
							|  |  |  | 		desc string | 
					
						
							|  |  |  | 		err  error | 
					
						
							|  |  |  | 		want error | 
					
						
							|  |  |  | 	}{ | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			desc: "nil error", | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			desc: "not *StatusError", | 
					
						
							|  |  |  | 			err:  foo, | 
					
						
							|  |  |  | 			want: foo, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			desc: "*StatusError with ssh_FX_EOF", | 
					
						
							|  |  |  | 			err:  eof, | 
					
						
							|  |  |  | 			want: io.EOF, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			desc: "*StatusError with ssh_FX_NO_SUCH_FILE", | 
					
						
							|  |  |  | 			err:  noSuchFile, | 
					
						
							|  |  |  | 			want: os.ErrNotExist, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			desc: "*StatusError with ssh_FX_OK", | 
					
						
							|  |  |  | 			err:  ok, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			desc: "*StatusError with ssh_FX_FAILURE", | 
					
						
							|  |  |  | 			err:  fail, | 
					
						
							|  |  |  | 			want: fail, | 
					
						
							|  |  |  | 		}, | 
					
						
							| 
									
										
										
										
											2013-11-06 12:00:04 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-01-04 03:54:19 +08:00
										 |  |  | 	for _, tt := range tests { | 
					
						
							|  |  |  | 		got := normaliseError(tt.err) | 
					
						
							| 
									
										
										
										
											2013-11-06 12:00:04 +08:00
										 |  |  | 		if got != tt.want { | 
					
						
							| 
									
										
										
										
											2016-01-04 03:54:19 +08:00
										 |  |  | 			t.Errorf("normaliseError(%#v), test %q\n- want: %#v\n-  got: %#v", | 
					
						
							|  |  |  | 				tt.err, tt.desc, tt.want, got) | 
					
						
							| 
									
										
										
										
											2013-11-06 12:00:04 +08:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2013-11-14 12:32:21 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | var flagsTests = []struct { | 
					
						
							|  |  |  | 	flags int | 
					
						
							|  |  |  | 	want  uint32 | 
					
						
							|  |  |  | }{ | 
					
						
							| 
									
										
										
										
											2019-08-30 23:04:37 +08:00
										 |  |  | 	{os.O_RDONLY, sshFxfRead}, | 
					
						
							|  |  |  | 	{os.O_WRONLY, sshFxfWrite}, | 
					
						
							|  |  |  | 	{os.O_RDWR, sshFxfRead | sshFxfWrite}, | 
					
						
							|  |  |  | 	{os.O_RDWR | os.O_CREATE | os.O_TRUNC, sshFxfRead | sshFxfWrite | sshFxfCreat | sshFxfTrunc}, | 
					
						
							|  |  |  | 	{os.O_WRONLY | os.O_APPEND, sshFxfWrite | sshFxfAppend}, | 
					
						
							| 
									
										
										
										
											2013-11-14 12:32:21 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestFlags(t *testing.T) { | 
					
						
							|  |  |  | 	for i, tt := range flagsTests { | 
					
						
							|  |  |  | 		got := flags(tt.flags) | 
					
						
							|  |  |  | 		if got != tt.want { | 
					
						
							|  |  |  | 			t.Errorf("test %v: flags(%x): want: %x, got: %x", i, tt.flags, tt.want, got) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2015-10-31 08:06:26 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-01-16 03:12:47 +08:00
										 |  |  | type packetSizeTest struct { | 
					
						
							|  |  |  | 	size  int | 
					
						
							|  |  |  | 	valid bool | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | var maxPacketCheckedTests = []packetSizeTest{ | 
					
						
							|  |  |  | 	{size: 0, valid: false}, | 
					
						
							|  |  |  | 	{size: 1, valid: true}, | 
					
						
							|  |  |  | 	{size: 32768, valid: true}, | 
					
						
							|  |  |  | 	{size: 32769, valid: false}, | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | var maxPacketUncheckedTests = []packetSizeTest{ | 
					
						
							|  |  |  | 	{size: 0, valid: false}, | 
					
						
							|  |  |  | 	{size: 1, valid: true}, | 
					
						
							|  |  |  | 	{size: 32768, valid: true}, | 
					
						
							|  |  |  | 	{size: 32769, valid: true}, | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestMaxPacketChecked(t *testing.T) { | 
					
						
							|  |  |  | 	for _, tt := range maxPacketCheckedTests { | 
					
						
							|  |  |  | 		testMaxPacketOption(t, MaxPacketChecked(tt.size), tt) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestMaxPacketUnchecked(t *testing.T) { | 
					
						
							|  |  |  | 	for _, tt := range maxPacketUncheckedTests { | 
					
						
							|  |  |  | 		testMaxPacketOption(t, MaxPacketUnchecked(tt.size), tt) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestMaxPacket(t *testing.T) { | 
					
						
							|  |  |  | 	for _, tt := range maxPacketCheckedTests { | 
					
						
							|  |  |  | 		testMaxPacketOption(t, MaxPacket(tt.size), tt) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func testMaxPacketOption(t *testing.T, o ClientOption, tt packetSizeTest) { | 
					
						
							|  |  |  | 	var c Client | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	err := o(&c) | 
					
						
							|  |  |  | 	if (err == nil) != tt.valid { | 
					
						
							|  |  |  | 		t.Errorf("MaxPacketChecked(%v)\n- want: %v\n- got: %v", tt.size, tt.valid, err == nil) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if c.maxPacket != tt.size && tt.valid { | 
					
						
							|  |  |  | 		t.Errorf("MaxPacketChecked(%v)\n- want: %v\n- got: %v", tt.size, tt.size, c.maxPacket) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2020-01-07 06:54:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | func testFstatOption(t *testing.T, o ClientOption, value bool) { | 
					
						
							|  |  |  | 	var c Client | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	err := o(&c) | 
					
						
							|  |  |  | 	if err == nil && c.useFstat != value { | 
					
						
							|  |  |  | 		t.Errorf("UseFStat(%v)\n- want: %v\n- got: %v", value, value, c.useFstat) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestUseFstatChecked(t *testing.T) { | 
					
						
							|  |  |  | 	testFstatOption(t, UseFstat(true), true) | 
					
						
							|  |  |  | 	testFstatOption(t, UseFstat(false), false) | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2020-10-30 16:36:20 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | type sink struct{} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (*sink) Close() error                { return nil } | 
					
						
							|  |  |  | func (*sink) Write(p []byte) (int, error) { return len(p), nil } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestClientZeroLengthPacket(t *testing.T) { | 
					
						
							|  |  |  | 	// Packet length zero (never valid). This used to crash the client.
 | 
					
						
							|  |  |  | 	packet := []byte{0, 0, 0, 0} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	r := bytes.NewReader(packet) | 
					
						
							|  |  |  | 	c, err := NewClientPipe(r, &sink{}) | 
					
						
							|  |  |  | 	if err == nil { | 
					
						
							|  |  |  | 		t.Error("expected an error, got nil") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if c != nil { | 
					
						
							|  |  |  | 		c.Close() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2021-03-12 02:56:45 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | func TestClientShortPacket(t *testing.T) { | 
					
						
							|  |  |  | 	// init packet too short.
 | 
					
						
							|  |  |  | 	packet := []byte{0, 0, 0, 1, 2} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	r := bytes.NewReader(packet) | 
					
						
							|  |  |  | 	_, err := NewClientPipe(r, &sink{}) | 
					
						
							|  |  |  | 	if !errors.Is(err, errShortPacket) { | 
					
						
							|  |  |  | 		t.Fatalf("expected error: %v, got: %v", errShortPacket, err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2021-03-16 00:53:09 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | // Issue #418: panic in clientConn.recv when the sid is incomplete.
 | 
					
						
							|  |  |  | func TestClientNoSid(t *testing.T) { | 
					
						
							|  |  |  | 	stream := new(bytes.Buffer) | 
					
						
							|  |  |  | 	sendPacket(stream, &sshFxVersionPacket{Version: sftpProtocolVersion}) | 
					
						
							|  |  |  | 	// Next packet has the sid cut short after two bytes.
 | 
					
						
							|  |  |  | 	stream.Write([]byte{0, 0, 0, 10, 0, 0}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	c, err := NewClientPipe(stream, &sink{}) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Fatal(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	_, err = c.Stat("anything") | 
					
						
							|  |  |  | 	if !errors.Is(err, ErrSSHFxConnectionLost) { | 
					
						
							|  |  |  | 		t.Fatal("expected ErrSSHFxConnectionLost, got", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } |