mirror of https://github.com/pkg/sftp.git
				
				
				
			
		
			
				
	
	
		
			871 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			871 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Go
		
	
	
	
| package sftp
 | |
| 
 | |
| // sftp server integration tests
 | |
| // enable with -integration
 | |
| // example invokation (darwin): gofmt -w `find . -name \*.go` && (cd server_standalone/ ; go build -tags debug) && go test -tags debug github.com/pkg/sftp -integration -v -sftp /usr/libexec/sftp-server -run ServerCompareSubsystems
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"crypto/ecdsa"
 | |
| 	"crypto/elliptic"
 | |
| 	crand "crypto/rand"
 | |
| 	"crypto/x509"
 | |
| 	"encoding/hex"
 | |
| 	"encoding/pem"
 | |
| 	"flag"
 | |
| 	"fmt"
 | |
| 	"io/ioutil"
 | |
| 	"math/rand"
 | |
| 	"net"
 | |
| 	"os"
 | |
| 	"os/exec"
 | |
| 	"path"
 | |
| 	"path/filepath"
 | |
| 	"regexp"
 | |
| 	"runtime"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/kr/fs"
 | |
| 	"github.com/stretchr/testify/assert"
 | |
| 	"golang.org/x/crypto/ssh"
 | |
| )
 | |
| 
 | |
| func TestMain(m *testing.M) {
 | |
| 	sftpClientLocation, _ := exec.LookPath("sftp")
 | |
| 	testSftpClientBin = flag.String("sftp_client", sftpClientLocation, "location of the sftp client binary")
 | |
| 
 | |
| 	lookSFTPServer := []string{
 | |
| 		"/usr/libexec/sftp-server",
 | |
| 		"/usr/lib/openssh/sftp-server",
 | |
| 		"/usr/lib/ssh/sftp-server",
 | |
| 	}
 | |
| 	sftpServer, _ := exec.LookPath("sftp-server")
 | |
| 	if len(sftpServer) == 0 {
 | |
| 		for _, location := range lookSFTPServer {
 | |
| 			if _, err := os.Stat(location); err == nil {
 | |
| 				sftpServer = location
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	testSftp = flag.String("sftp", sftpServer, "location of the sftp server binary")
 | |
| 	flag.Parse()
 | |
| 
 | |
| 	os.Exit(m.Run())
 | |
| }
 | |
| 
 | |
| func skipIfWindows(t testing.TB) {
 | |
| 	if runtime.GOOS == "windows" {
 | |
| 		t.Skip("skipping test on windows")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| var testServerImpl = flag.Bool("testserver", false, "perform integration tests against sftp package server instance")
 | |
| var testIntegration = flag.Bool("integration", false, "perform integration tests against sftp server process")
 | |
| var testAllocator = flag.Bool("allocator", false, "perform tests using the allocator")
 | |
| var testSftp *string
 | |
| 
 | |
| var testSftpClientBin *string
 | |
| var sshServerDebugStream = ioutil.Discard
 | |
| var sftpServerDebugStream = ioutil.Discard
 | |
| var sftpClientDebugStream = ioutil.Discard
 | |
| 
 | |
| const (
 | |
| 	GolangSFTP  = true
 | |
| 	OpenSSHSFTP = false
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	hostPrivateKeySigner ssh.Signer
 | |
| 	privKey              = []byte(`
 | |
| -----BEGIN RSA PRIVATE KEY-----
 | |
| MIIEowIBAAKCAQEArhp7SqFnXVZAgWREL9Ogs+miy4IU/m0vmdkoK6M97G9NX/Pj
 | |
| wf8I/3/ynxmcArbt8Rc4JgkjT2uxx/NqR0yN42N1PjO5Czu0dms1PSqcKIJdeUBV
 | |
| 7gdrKSm9Co4d2vwfQp5mg47eG4w63pz7Drk9+VIyi9YiYH4bve7WnGDswn4ycvYZ
 | |
| slV5kKnjlfCdPig+g5P7yQYud0cDWVwyA0+kxvL6H3Ip+Fu8rLDZn4/P1WlFAIuc
 | |
| PAf4uEKDGGmC2URowi5eesYR7f6GN/HnBs2776laNlAVXZUmYTUfOGagwLsEkx8x
 | |
| XdNqntfbs2MOOoK+myJrNtcB9pCrM0H6um19uQIDAQABAoIBABkWr9WdVKvalgkP
 | |
| TdQmhu3mKRNyd1wCl+1voZ5IM9Ayac/98UAvZDiNU4Uhx52MhtVLJ0gz4Oa8+i16
 | |
| IkKMAZZW6ro/8dZwkBzQbieWUFJ2Fso2PyvB3etcnGU8/Yhk9IxBDzy+BbuqhYE2
 | |
| 1ebVQtz+v1HvVZzaD11bYYm/Xd7Y28QREVfFen30Q/v3dv7dOteDE/RgDS8Czz7w
 | |
| jMW32Q8JL5grz7zPkMK39BLXsTcSYcaasT2ParROhGJZDmbgd3l33zKCVc1zcj9B
 | |
| SA47QljGd09Tys958WWHgtj2o7bp9v1Ufs4LnyKgzrB80WX1ovaSQKvd5THTLchO
 | |
| kLIhUAECgYEA2doGXy9wMBmTn/hjiVvggR1aKiBwUpnB87Hn5xCMgoECVhFZlT6l
 | |
| WmZe7R2klbtG1aYlw+y+uzHhoVDAJW9AUSV8qoDUwbRXvBVlp+In5wIqJ+VjfivK
 | |
| zgIfzomL5NvDz37cvPmzqIeySTowEfbQyq7CUQSoDtE9H97E2wWZhDkCgYEAzJdJ
 | |
| k+NSFoTkHhfD3L0xCDHpRV3gvaOeew8524fVtVUq53X8m91ng4AX1r74dCUYwwiF
 | |
| gqTtSSJfx2iH1xKnNq28M9uKg7wOrCKrRqNPnYUO3LehZEC7rwUr26z4iJDHjjoB
 | |
| uBcS7nw0LJ+0Zeg1IF+aIdZGV3MrAKnrzWPixYECgYBsffX6ZWebrMEmQ89eUtFF
 | |
| u9ZxcGI/4K8ErC7vlgBD5ffB4TYZ627xzFWuBLs4jmHCeNIJ9tct5rOVYN+wRO1k
 | |
| /CRPzYUnSqb+1jEgILL6istvvv+DkE+ZtNkeRMXUndWwel94BWsBnUKe0UmrSJ3G
 | |
| sq23J3iCmJW2T3z+DpXbkQKBgQCK+LUVDNPE0i42NsRnm+fDfkvLP7Kafpr3Umdl
 | |
| tMY474o+QYn+wg0/aPJIf9463rwMNyyhirBX/k57IIktUdFdtfPicd2MEGETElWv
 | |
| nN1GzYxD50Rs2f/jKisZhEwqT9YNyV9DkgDdGGdEbJNYqbv0qpwDIg8T9foe8E1p
 | |
| bdErgQKBgAt290I3L316cdxIQTkJh1DlScN/unFffITwu127WMr28Jt3mq3cZpuM
 | |
| Aecey/eEKCj+Rlas5NDYKsB18QIuAw+qqWyq0LAKLiAvP1965Rkc4PLScl3MgJtO
 | |
| QYa37FK0p8NcDeUuF86zXBVutwS5nJLchHhKfd590ks57OROtm29
 | |
| -----END RSA PRIVATE KEY-----
 | |
| `)
 | |
| )
 | |
| 
 | |
| func init() {
 | |
| 	var err error
 | |
| 	hostPrivateKeySigner, err = ssh.ParsePrivateKey(privKey)
 | |
| 	if err != nil {
 | |
| 		panic(err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func keyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
 | |
| 	permissions := &ssh.Permissions{
 | |
| 		CriticalOptions: map[string]string{},
 | |
| 		Extensions:      map[string]string{},
 | |
| 	}
 | |
| 	return permissions, nil
 | |
| }
 | |
| 
 | |
| func pwAuth(conn ssh.ConnMetadata, pw []byte) (*ssh.Permissions, error) {
 | |
| 	permissions := &ssh.Permissions{
 | |
| 		CriticalOptions: map[string]string{},
 | |
| 		Extensions:      map[string]string{},
 | |
| 	}
 | |
| 	return permissions, nil
 | |
| }
 | |
| 
 | |
| func basicServerConfig() *ssh.ServerConfig {
 | |
| 	config := ssh.ServerConfig{
 | |
| 		Config: ssh.Config{
 | |
| 			MACs: []string{"hmac-sha1"},
 | |
| 		},
 | |
| 		PasswordCallback:  pwAuth,
 | |
| 		PublicKeyCallback: keyAuth,
 | |
| 	}
 | |
| 	config.AddHostKey(hostPrivateKeySigner)
 | |
| 	return &config
 | |
| }
 | |
| 
 | |
| type sshServer struct {
 | |
| 	useSubsystem bool
 | |
| 	conn         net.Conn
 | |
| 	config       *ssh.ServerConfig
 | |
| 	sshConn      *ssh.ServerConn
 | |
| 	newChans     <-chan ssh.NewChannel
 | |
| 	newReqs      <-chan *ssh.Request
 | |
| }
 | |
| 
 | |
| func sshServerFromConn(conn net.Conn, useSubsystem bool, config *ssh.ServerConfig) (*sshServer, error) {
 | |
| 	// From a standard TCP connection to an encrypted SSH connection
 | |
| 	sshConn, newChans, newReqs, err := ssh.NewServerConn(conn, config)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	svr := &sshServer{useSubsystem, conn, config, sshConn, newChans, newReqs}
 | |
| 	svr.listenChannels()
 | |
| 	return svr, nil
 | |
| }
 | |
| 
 | |
| func (svr *sshServer) Wait() error {
 | |
| 	return svr.sshConn.Wait()
 | |
| }
 | |
| 
 | |
| func (svr *sshServer) Close() error {
 | |
| 	return svr.sshConn.Close()
 | |
| }
 | |
| 
 | |
| func (svr *sshServer) listenChannels() {
 | |
| 	go func() {
 | |
| 		for chanReq := range svr.newChans {
 | |
| 			go svr.handleChanReq(chanReq)
 | |
| 		}
 | |
| 	}()
 | |
| 	go func() {
 | |
| 		for req := range svr.newReqs {
 | |
| 			go svr.handleReq(req)
 | |
| 		}
 | |
| 	}()
 | |
| }
 | |
| 
 | |
| func (svr *sshServer) handleReq(req *ssh.Request) {
 | |
| 	switch req.Type {
 | |
| 	default:
 | |
| 		rejectRequest(req)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type sshChannelServer struct {
 | |
| 	svr     *sshServer
 | |
| 	chanReq ssh.NewChannel
 | |
| 	ch      ssh.Channel
 | |
| 	newReqs <-chan *ssh.Request
 | |
| }
 | |
| 
 | |
| type sshSessionChannelServer struct {
 | |
| 	*sshChannelServer
 | |
| 	env []string
 | |
| }
 | |
| 
 | |
| func (svr *sshServer) handleChanReq(chanReq ssh.NewChannel) {
 | |
| 	fmt.Fprintf(sshServerDebugStream, "channel request: %v, extra: '%v'\n", chanReq.ChannelType(), hex.EncodeToString(chanReq.ExtraData()))
 | |
| 	switch chanReq.ChannelType() {
 | |
| 	case "session":
 | |
| 		if ch, reqs, err := chanReq.Accept(); err != nil {
 | |
| 			fmt.Fprintf(sshServerDebugStream, "fail to accept channel request: %v\n", err)
 | |
| 			chanReq.Reject(ssh.ResourceShortage, "channel accept failure")
 | |
| 		} else {
 | |
| 			chsvr := &sshSessionChannelServer{
 | |
| 				sshChannelServer: &sshChannelServer{svr, chanReq, ch, reqs},
 | |
| 				env:              append([]string{}, os.Environ()...),
 | |
| 			}
 | |
| 			chsvr.handle()
 | |
| 		}
 | |
| 	default:
 | |
| 		chanReq.Reject(ssh.UnknownChannelType, "channel type is not a session")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (chsvr *sshSessionChannelServer) handle() {
 | |
| 	// should maybe do something here...
 | |
| 	go chsvr.handleReqs()
 | |
| }
 | |
| 
 | |
| func (chsvr *sshSessionChannelServer) handleReqs() {
 | |
| 	for req := range chsvr.newReqs {
 | |
| 		chsvr.handleReq(req)
 | |
| 	}
 | |
| 	fmt.Fprintf(sshServerDebugStream, "ssh server session channel complete\n")
 | |
| }
 | |
| 
 | |
| func (chsvr *sshSessionChannelServer) handleReq(req *ssh.Request) {
 | |
| 	switch req.Type {
 | |
| 	case "env":
 | |
| 		chsvr.handleEnv(req)
 | |
| 	case "subsystem":
 | |
| 		chsvr.handleSubsystem(req)
 | |
| 	default:
 | |
| 		rejectRequest(req)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func rejectRequest(req *ssh.Request) error {
 | |
| 	fmt.Fprintf(sshServerDebugStream, "ssh rejecting request, type: %s\n", req.Type)
 | |
| 	err := req.Reply(false, []byte{})
 | |
| 	if err != nil {
 | |
| 		fmt.Fprintf(sshServerDebugStream, "ssh request reply had error: %v\n", err)
 | |
| 	}
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func rejectRequestUnmarshalError(req *ssh.Request, s interface{}, err error) error {
 | |
| 	fmt.Fprintf(sshServerDebugStream, "ssh request unmarshaling error, type '%T': %v\n", s, err)
 | |
| 	rejectRequest(req)
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| // env request form:
 | |
| type sshEnvRequest struct {
 | |
| 	Envvar string
 | |
| 	Value  string
 | |
| }
 | |
| 
 | |
| func (chsvr *sshSessionChannelServer) handleEnv(req *ssh.Request) error {
 | |
| 	envReq := &sshEnvRequest{}
 | |
| 	if err := ssh.Unmarshal(req.Payload, envReq); err != nil {
 | |
| 		return rejectRequestUnmarshalError(req, envReq, err)
 | |
| 	}
 | |
| 	req.Reply(true, nil)
 | |
| 
 | |
| 	found := false
 | |
| 	for i, envstr := range chsvr.env {
 | |
| 		if strings.HasPrefix(envstr, envReq.Envvar+"=") {
 | |
| 			found = true
 | |
| 			chsvr.env[i] = envReq.Envvar + "=" + envReq.Value
 | |
| 		}
 | |
| 	}
 | |
| 	if !found {
 | |
| 		chsvr.env = append(chsvr.env, envReq.Envvar+"="+envReq.Value)
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // Payload: int: command size, string: command
 | |
| type sshSubsystemRequest struct {
 | |
| 	Name string
 | |
| }
 | |
| 
 | |
| type sshSubsystemExitStatus struct {
 | |
| 	Status uint32
 | |
| }
 | |
| 
 | |
| func (chsvr *sshSessionChannelServer) handleSubsystem(req *ssh.Request) error {
 | |
| 	defer func() {
 | |
| 		err1 := chsvr.ch.CloseWrite()
 | |
| 		err2 := chsvr.ch.Close()
 | |
| 		fmt.Fprintf(sshServerDebugStream, "ssh server subsystem request complete, err: %v %v\n", err1, err2)
 | |
| 	}()
 | |
| 
 | |
| 	subsystemReq := &sshSubsystemRequest{}
 | |
| 	if err := ssh.Unmarshal(req.Payload, subsystemReq); err != nil {
 | |
| 		return rejectRequestUnmarshalError(req, subsystemReq, err)
 | |
| 	}
 | |
| 
 | |
| 	// reply to the ssh client
 | |
| 
 | |
| 	// no idea if this is actually correct spec-wise.
 | |
| 	// just enough for an sftp server to start.
 | |
| 	if subsystemReq.Name != "sftp" {
 | |
| 		return req.Reply(false, nil)
 | |
| 	}
 | |
| 
 | |
| 	req.Reply(true, nil)
 | |
| 
 | |
| 	if !chsvr.svr.useSubsystem {
 | |
| 		// use the openssh sftp server backend; this is to test the ssh code, not the sftp code,
 | |
| 		// or is used for comparison between our sftp subsystem and the openssh sftp subsystem
 | |
| 		cmd := exec.Command(*testSftp, "-e", "-l", "DEBUG") // log to stderr
 | |
| 		cmd.Stdin = chsvr.ch
 | |
| 		cmd.Stdout = chsvr.ch
 | |
| 		cmd.Stderr = sftpServerDebugStream
 | |
| 		if err := cmd.Start(); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		return cmd.Wait()
 | |
| 	}
 | |
| 
 | |
| 	sftpServer, err := NewServer(
 | |
| 		chsvr.ch,
 | |
| 		WithDebug(sftpServerDebugStream),
 | |
| 	)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// wait for the session to close
 | |
| 	runErr := sftpServer.Serve()
 | |
| 	exitStatus := uint32(1)
 | |
| 	if runErr == nil {
 | |
| 		exitStatus = uint32(0)
 | |
| 	}
 | |
| 
 | |
| 	_, exitStatusErr := chsvr.ch.SendRequest("exit-status", false, ssh.Marshal(sshSubsystemExitStatus{exitStatus}))
 | |
| 	return exitStatusErr
 | |
| }
 | |
| 
 | |
| // starts an ssh server to test. returns: host string and port
 | |
| func testServer(t *testing.T, useSubsystem bool, readonly bool) (net.Listener, string, int) {
 | |
| 	if !*testIntegration {
 | |
| 		t.Skip("skipping integration test")
 | |
| 	}
 | |
| 
 | |
| 	listener, err := net.Listen("tcp", "127.0.0.1:0")
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	host, portStr, err := net.SplitHostPort(listener.Addr().String())
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	port, err := strconv.Atoi(portStr)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	go func() {
 | |
| 		for {
 | |
| 			conn, err := listener.Accept()
 | |
| 			if err != nil {
 | |
| 				fmt.Fprintf(sshServerDebugStream, "ssh server socket closed: %v\n", err)
 | |
| 				break
 | |
| 			}
 | |
| 
 | |
| 			go func() {
 | |
| 				defer conn.Close()
 | |
| 				sshSvr, err := sshServerFromConn(conn, useSubsystem, basicServerConfig())
 | |
| 				if err != nil {
 | |
| 					t.Error(err)
 | |
| 					return
 | |
| 				}
 | |
| 				err = sshSvr.Wait()
 | |
| 				fmt.Fprintf(sshServerDebugStream, "ssh server finished, err: %v\n", err)
 | |
| 			}()
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	return listener, host, port
 | |
| }
 | |
| 
 | |
| func makeDummyKey() (string, error) {
 | |
| 	priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader)
 | |
| 	if err != nil {
 | |
| 		return "", fmt.Errorf("cannot generate key: %v", err)
 | |
| 	}
 | |
| 	der, err := x509.MarshalECPrivateKey(priv)
 | |
| 	if err != nil {
 | |
| 		return "", fmt.Errorf("cannot marshal key: %v", err)
 | |
| 	}
 | |
| 	block := &pem.Block{Type: "EC PRIVATE KEY", Bytes: der}
 | |
| 	f, err := ioutil.TempFile("", "sftp-test-key-")
 | |
| 	if err != nil {
 | |
| 		return "", fmt.Errorf("cannot create temp file: %v", err)
 | |
| 	}
 | |
| 	defer func() {
 | |
| 		if f != nil {
 | |
| 			_ = f.Close()
 | |
| 			_ = os.Remove(f.Name())
 | |
| 		}
 | |
| 	}()
 | |
| 	if err := pem.Encode(f, block); err != nil {
 | |
| 		return "", fmt.Errorf("cannot write key: %v", err)
 | |
| 	}
 | |
| 	if err := f.Close(); err != nil {
 | |
| 		return "", fmt.Errorf("error closing key file: %v", err)
 | |
| 	}
 | |
| 	path := f.Name()
 | |
| 	f = nil
 | |
| 	return path, nil
 | |
| }
 | |
| 
 | |
| func runSftpClient(t *testing.T, script string, path string, host string, port int) (string, error) {
 | |
| 	// if sftp client binary is unavailable, skip test
 | |
| 	if _, err := os.Stat(*testSftpClientBin); err != nil {
 | |
| 		t.Skip("sftp client binary unavailable")
 | |
| 	}
 | |
| 
 | |
| 	// make a dummy key so we don't rely on ssh-agent
 | |
| 	dummyKey, err := makeDummyKey()
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	defer os.Remove(dummyKey)
 | |
| 
 | |
| 	args := []string{
 | |
| 		// "-vvvv",
 | |
| 		"-b", "-",
 | |
| 		"-o", "StrictHostKeyChecking=no",
 | |
| 		"-o", "LogLevel=ERROR",
 | |
| 		"-o", "UserKnownHostsFile /dev/null",
 | |
| 		// do not trigger ssh-agent prompting
 | |
| 		"-o", "IdentityFile=" + dummyKey,
 | |
| 		"-o", "IdentitiesOnly=yes",
 | |
| 		"-P", fmt.Sprintf("%d", port), fmt.Sprintf("%s:%s", host, path),
 | |
| 	}
 | |
| 	cmd := exec.Command(*testSftpClientBin, args...)
 | |
| 	var stdout bytes.Buffer
 | |
| 	var stderr bytes.Buffer
 | |
| 	cmd.Stdin = bytes.NewBufferString(script)
 | |
| 	cmd.Stdout = &stdout
 | |
| 	cmd.Stderr = &stderr
 | |
| 	if err := cmd.Start(); err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	err = cmd.Wait()
 | |
| 	if err != nil {
 | |
| 		err = fmt.Errorf("%v: %s", err, stderr.String())
 | |
| 	}
 | |
| 	return stdout.String(), err
 | |
| }
 | |
| 
 | |
| // assert.Eventually seems to have a data rate on macOS with go 1.14 so replace it with this simpler function
 | |
| func waitForCondition(t *testing.T, condition func() bool) {
 | |
| 	start := time.Now()
 | |
| 	tick := 10 * time.Millisecond
 | |
| 	waitFor := 100 * time.Millisecond
 | |
| 	for !condition() {
 | |
| 		time.Sleep(tick)
 | |
| 		if time.Since(start) > waitFor {
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| 	assert.True(t, condition())
 | |
| }
 | |
| 
 | |
| func checkAllocatorBeforeServerClose(t *testing.T, alloc *allocator) {
 | |
| 	if alloc != nil {
 | |
| 		// before closing the server we are, generally, waiting for new packets in recvPacket and we have a page allocated.
 | |
| 		// Sometime the sendPacket returns some milliseconds after the client receives the response, and so we have 2
 | |
| 		// allocated pages here, so wait some milliseconds. To avoid crashes we must be sure to not release the pages
 | |
| 		// too soon.
 | |
| 		waitForCondition(t, func() bool { return alloc.countUsedPages() <= 1 })
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func checkAllocatorAfterServerClose(t *testing.T, alloc *allocator) {
 | |
| 	if alloc != nil {
 | |
| 		// wait for the server cleanup
 | |
| 		waitForCondition(t, func() bool { return alloc.countUsedPages() == 0 })
 | |
| 		waitForCondition(t, func() bool { return alloc.countAvailablePages() == 0 })
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestServerCompareSubsystems(t *testing.T) {
 | |
| 	listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
 | |
| 	listenerOp, hostOp, portOp := testServer(t, OpenSSHSFTP, READONLY)
 | |
| 	defer listenerGo.Close()
 | |
| 	defer listenerOp.Close()
 | |
| 
 | |
| 	script := `
 | |
| ls /
 | |
| ls -l /
 | |
| ls /dev/
 | |
| ls -l /dev/
 | |
| ls -l /etc/
 | |
| ls -l /bin/
 | |
| ls -l /usr/bin/
 | |
| `
 | |
| 	outputGo, err := runSftpClient(t, script, "/", hostGo, portGo)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	outputOp, err := runSftpClient(t, script, "/", hostOp, portOp)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	newlineRegex := regexp.MustCompile(`\r*\n`)
 | |
| 	spaceRegex := regexp.MustCompile(`\s+`)
 | |
| 	outputGoLines := newlineRegex.Split(outputGo, -1)
 | |
| 	outputOpLines := newlineRegex.Split(outputOp, -1)
 | |
| 
 | |
| 	for i, goLine := range outputGoLines {
 | |
| 		if i > len(outputOpLines) {
 | |
| 			t.Fatalf("output line count differs")
 | |
| 		}
 | |
| 		opLine := outputOpLines[i]
 | |
| 		bad := false
 | |
| 		if goLine != opLine {
 | |
| 			goWords := spaceRegex.Split(goLine, -1)
 | |
| 			opWords := spaceRegex.Split(opLine, -1)
 | |
| 			// some fields are allowed to be different..
 | |
| 			// words[2] and [3] as these are users & groups
 | |
| 			// words[1] as the link count for directories like proc is unstable
 | |
| 			// during testing as processes are created/destroyed.
 | |
| 			// words[7] as timestamp on dirs can very for things like /tmp
 | |
| 			for j, goWord := range goWords {
 | |
| 				if j > len(opWords) {
 | |
| 					bad = true
 | |
| 				}
 | |
| 				opWord := opWords[j]
 | |
| 				if goWord != opWord && j != 1 && j != 2 && j != 3 && j != 7 {
 | |
| 					bad = true
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if bad {
 | |
| 			t.Errorf("outputs differ, go:\n%v\nopenssh:\n%v\n", goLine, opLine)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| var rng = rand.New(rand.NewSource(time.Now().Unix()))
 | |
| 
 | |
| func randData(length int) []byte {
 | |
| 	data := make([]byte, length)
 | |
| 	for i := 0; i < length; i++ {
 | |
| 		data[i] = byte(rng.Uint32())
 | |
| 	}
 | |
| 	return data
 | |
| }
 | |
| 
 | |
| func randName() string {
 | |
| 	return "sftp." + hex.EncodeToString(randData(16))
 | |
| }
 | |
| 
 | |
| func TestServerMkdirRmdir(t *testing.T) {
 | |
| 	listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
 | |
| 	defer listenerGo.Close()
 | |
| 
 | |
| 	tmpDir := "/tmp/" + randName()
 | |
| 	defer os.RemoveAll(tmpDir)
 | |
| 
 | |
| 	// mkdir remote
 | |
| 	if _, err := runSftpClient(t, "mkdir "+tmpDir, "/", hostGo, portGo); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	// directory should now exist
 | |
| 	if _, err := os.Stat(tmpDir); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	// now remove the directory
 | |
| 	if _, err := runSftpClient(t, "rmdir "+tmpDir, "/", hostGo, portGo); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	if _, err := os.Stat(tmpDir); err == nil {
 | |
| 		t.Fatal("should have error after deleting the directory")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestServerLink(t *testing.T) {
 | |
| 	skipIfWindows(t) // No hard links on windows.
 | |
| 	listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
 | |
| 	defer listenerGo.Close()
 | |
| 
 | |
| 	tmpFileLocalData := randData(999)
 | |
| 
 | |
| 	linkdest := "/tmp/" + randName()
 | |
| 	defer os.RemoveAll(linkdest)
 | |
| 	if err := ioutil.WriteFile(linkdest, tmpFileLocalData, 0644); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	link := "/tmp/" + randName()
 | |
| 	defer os.RemoveAll(link)
 | |
| 
 | |
| 	// now create a hard link within the new directory
 | |
| 	if output, err := runSftpClient(t, fmt.Sprintf("ln %s %s", linkdest, link), "/", hostGo, portGo); err != nil {
 | |
| 		t.Fatalf("failed: %v %v", err, string(output))
 | |
| 	}
 | |
| 
 | |
| 	// file should now exist and be the same size as linkdest
 | |
| 	if stat, err := os.Lstat(link); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	} else if int(stat.Size()) != len(tmpFileLocalData) {
 | |
| 		t.Fatalf("wrong size: %v", len(tmpFileLocalData))
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestServerSymlink(t *testing.T) {
 | |
| 	skipIfWindows(t) // No symlinks on windows.
 | |
| 	listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
 | |
| 	defer listenerGo.Close()
 | |
| 
 | |
| 	link := "/tmp/" + randName()
 | |
| 	defer os.RemoveAll(link)
 | |
| 
 | |
| 	// now create a symbolic link within the new directory
 | |
| 	if output, err := runSftpClient(t, "symlink /bin/sh "+link, "/", hostGo, portGo); err != nil {
 | |
| 		t.Fatalf("failed: %v %v", err, string(output))
 | |
| 	}
 | |
| 
 | |
| 	// symlink should now exist
 | |
| 	if stat, err := os.Lstat(link); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	} else if (stat.Mode() & os.ModeSymlink) != os.ModeSymlink {
 | |
| 		t.Fatalf("is not a symlink: %v", stat.Mode())
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestServerPut(t *testing.T) {
 | |
| 	listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
 | |
| 	defer listenerGo.Close()
 | |
| 
 | |
| 	tmpFileLocal := "/tmp/" + randName()
 | |
| 	tmpFileRemote := "/tmp/" + randName()
 | |
| 	defer os.RemoveAll(tmpFileLocal)
 | |
| 	defer os.RemoveAll(tmpFileRemote)
 | |
| 
 | |
| 	t.Logf("put: local %v remote %v", tmpFileLocal, tmpFileRemote)
 | |
| 
 | |
| 	// create a file with random contents. This will be the local file pushed to the server
 | |
| 	tmpFileLocalData := randData(10 * 1024 * 1024)
 | |
| 	if err := ioutil.WriteFile(tmpFileLocal, tmpFileLocalData, 0644); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	// sftp the file to the server
 | |
| 	if output, err := runSftpClient(t, "put "+tmpFileLocal+" "+tmpFileRemote, "/", hostGo, portGo); err != nil {
 | |
| 		t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
 | |
| 	}
 | |
| 
 | |
| 	// tmpFile2 should now exist, with the same contents
 | |
| 	if tmpFileRemoteData, err := ioutil.ReadFile(tmpFileRemote); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	} else if string(tmpFileLocalData) != string(tmpFileRemoteData) {
 | |
| 		t.Fatal("contents of file incorrect after put")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestServerResume(t *testing.T) {
 | |
| 	listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
 | |
| 	defer listenerGo.Close()
 | |
| 
 | |
| 	tmpFileLocal := "/tmp/" + randName()
 | |
| 	tmpFileRemote := "/tmp/" + randName()
 | |
| 	defer os.RemoveAll(tmpFileLocal)
 | |
| 	defer os.RemoveAll(tmpFileRemote)
 | |
| 
 | |
| 	t.Logf("put: local %v remote %v", tmpFileLocal, tmpFileRemote)
 | |
| 
 | |
| 	// create a local file with random contents to be pushed to the server
 | |
| 	tmpFileLocalData := randData(2 * 1024 * 1024)
 | |
| 	// only write half the data to simulate a split upload
 | |
| 	half := 1024 * 1024
 | |
| 	err := ioutil.WriteFile(tmpFileLocal, tmpFileLocalData[:half], 0644)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	// sftp the first half of the file to the server
 | |
| 	output, err := runSftpClient(t, "put "+tmpFileLocal+" "+tmpFileRemote,
 | |
| 		"/", hostGo, portGo)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
 | |
| 	}
 | |
| 
 | |
| 	// write the full file out
 | |
| 	err = ioutil.WriteFile(tmpFileLocal, tmpFileLocalData, 0644)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	// re-sftp the full file with the append flag set
 | |
| 	output, err = runSftpClient(t, "put -a "+tmpFileLocal+" "+tmpFileRemote,
 | |
| 		"/", hostGo, portGo)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
 | |
| 	}
 | |
| 
 | |
| 	// tmpFileRemote should now exist, with the same contents
 | |
| 	if tmpFileRemoteData, err := ioutil.ReadFile(tmpFileRemote); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	} else if string(tmpFileLocalData) != string(tmpFileRemoteData) {
 | |
| 		t.Fatal("contents of file incorrect after put")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestServerGet(t *testing.T) {
 | |
| 	listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
 | |
| 	defer listenerGo.Close()
 | |
| 
 | |
| 	tmpFileLocal := "/tmp/" + randName()
 | |
| 	tmpFileRemote := "/tmp/" + randName()
 | |
| 	defer os.RemoveAll(tmpFileLocal)
 | |
| 	defer os.RemoveAll(tmpFileRemote)
 | |
| 
 | |
| 	t.Logf("get: local %v remote %v", tmpFileLocal, tmpFileRemote)
 | |
| 
 | |
| 	// create a file with random contents. This will be the remote file pulled from the server
 | |
| 	tmpFileRemoteData := randData(10 * 1024 * 1024)
 | |
| 	if err := ioutil.WriteFile(tmpFileRemote, tmpFileRemoteData, 0644); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	// sftp the file to the server
 | |
| 	if output, err := runSftpClient(t, "get "+tmpFileRemote+" "+tmpFileLocal, "/", hostGo, portGo); err != nil {
 | |
| 		t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
 | |
| 	}
 | |
| 
 | |
| 	// tmpFile2 should now exist, with the same contents
 | |
| 	if tmpFileLocalData, err := ioutil.ReadFile(tmpFileLocal); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	} else if string(tmpFileLocalData) != string(tmpFileRemoteData) {
 | |
| 		t.Fatal("contents of file incorrect after put")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func compareDirectoriesRecursive(t *testing.T, aroot, broot string) {
 | |
| 	walker := fs.Walk(aroot)
 | |
| 	for walker.Step() {
 | |
| 		if err := walker.Err(); err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 		// find paths
 | |
| 		aPath := walker.Path()
 | |
| 		aRel, err := filepath.Rel(aroot, aPath)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("could not find relative path for %v: %v", aPath, err)
 | |
| 		}
 | |
| 		bPath := path.Join(broot, aRel)
 | |
| 
 | |
| 		if aRel == "." {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		//t.Logf("comparing: %v a: %v b %v", aRel, aPath, bPath)
 | |
| 
 | |
| 		// if a is a link, the sftp recursive copy won't have copied it. ignore
 | |
| 		aLink, err := os.Lstat(aPath)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("could not lstat %v: %v", aPath, err)
 | |
| 		}
 | |
| 		if aLink.Mode()&os.ModeSymlink != 0 {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		// stat the files
 | |
| 		aFile, err := os.Stat(aPath)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("could not stat %v: %v", aPath, err)
 | |
| 		}
 | |
| 		bFile, err := os.Stat(bPath)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("could not stat %v: %v", bPath, err)
 | |
| 		}
 | |
| 
 | |
| 		// compare stats, with some leniency for the timestamp
 | |
| 		if aFile.Mode() != bFile.Mode() {
 | |
| 			t.Fatalf("modes different for %v: %v vs %v", aRel, aFile.Mode(), bFile.Mode())
 | |
| 		}
 | |
| 		if !aFile.IsDir() {
 | |
| 			if aFile.Size() != bFile.Size() {
 | |
| 				t.Fatalf("sizes different for %v: %v vs %v", aRel, aFile.Size(), bFile.Size())
 | |
| 			}
 | |
| 		}
 | |
| 		timeDiff := aFile.ModTime().Sub(bFile.ModTime())
 | |
| 		if timeDiff > time.Second || timeDiff < -time.Second {
 | |
| 			t.Fatalf("mtimes different for %v: %v vs %v", aRel, aFile.ModTime(), bFile.ModTime())
 | |
| 		}
 | |
| 
 | |
| 		// compare contents
 | |
| 		if !aFile.IsDir() {
 | |
| 			if aContents, err := ioutil.ReadFile(aPath); err != nil {
 | |
| 				t.Fatal(err)
 | |
| 			} else if bContents, err := ioutil.ReadFile(bPath); err != nil {
 | |
| 				t.Fatal(err)
 | |
| 			} else if string(aContents) != string(bContents) {
 | |
| 				t.Fatalf("contents different for %v", aRel)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestServerPutRecursive(t *testing.T) {
 | |
| 	listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
 | |
| 	defer listenerGo.Close()
 | |
| 
 | |
| 	dirLocal, err := os.Getwd()
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	tmpDirRemote := "/tmp/" + randName()
 | |
| 	defer os.RemoveAll(tmpDirRemote)
 | |
| 
 | |
| 	t.Logf("put recursive: local %v remote %v", dirLocal, tmpDirRemote)
 | |
| 
 | |
| 	// push this directory (source code etc) recursively to the server
 | |
| 	if output, err := runSftpClient(t, "mkdir "+tmpDirRemote+"\r\nput -r -P "+dirLocal+"/ "+tmpDirRemote+"/", "/", hostGo, portGo); err != nil {
 | |
| 		t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
 | |
| 	}
 | |
| 
 | |
| 	compareDirectoriesRecursive(t, dirLocal, path.Join(tmpDirRemote, path.Base(dirLocal)))
 | |
| }
 | |
| 
 | |
| func TestServerGetRecursive(t *testing.T) {
 | |
| 	listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
 | |
| 	defer listenerGo.Close()
 | |
| 
 | |
| 	dirRemote, err := os.Getwd()
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	tmpDirLocal := "/tmp/" + randName()
 | |
| 	defer os.RemoveAll(tmpDirLocal)
 | |
| 
 | |
| 	t.Logf("get recursive: local %v remote %v", tmpDirLocal, dirRemote)
 | |
| 
 | |
| 	// pull this directory (source code etc) recursively from the server
 | |
| 	if output, err := runSftpClient(t, "lmkdir "+tmpDirLocal+"\r\nget -r -P "+dirRemote+"/ "+tmpDirLocal+"/", "/", hostGo, portGo); err != nil {
 | |
| 		t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
 | |
| 	}
 | |
| 
 | |
| 	compareDirectoriesRecursive(t, dirRemote, path.Join(tmpDirLocal, path.Base(dirRemote)))
 | |
| }
 |