sftp/localfs/localfs_integration_test.go

927 lines
16 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//go:build integration
// +build integration
package localfs
import (
"bytes"
"cmp"
"context"
"errors"
"flag"
"fmt"
"io"
"io/fs"
"os"
"os/exec"
"os/user"
"path"
"path/filepath"
"runtime"
"slices"
"sync"
"testing"
"time"
goCmp "github.com/google/go-cmp/cmp"
sftp "github.com/pkg/sftp/v2"
sshfx "github.com/pkg/sftp/v2/encoding/ssh/filexfer"
)
var (
myUID string
myGID string
)
type Pipe struct {
io.Reader
io.WriteCloser
}
func toRemotePath(p string) string {
p = filepath.ToSlash(p)
if !path.IsAbs(p) {
return "/" + p
}
return p
}
var srv = &sftp.Server{
Handler: handler,
}
var cl *sftp.Client
var (
cwd string
clData []byte
clInfo fs.FileInfo
)
func teeToFile(wr io.WriteCloser, filename string) (io.WriteCloser, io.Closer) {
f, err := os.Create(filename)
if err != nil {
panic(err)
}
prd, pwr := io.Pipe()
go func() {
buf := make([]byte, 32*1<<10)
for {
n, err := prd.Read(buf)
if err != nil {
wr.Close()
if !errors.Is(err, io.EOF) {
panic(err)
}
return
}
if _, err := f.Write(buf[:n]); err != nil {
panic(err)
}
if _, err := wr.Write(buf[:n]); err != nil {
panic(err)
}
}
}()
return pwr, f
}
type delayedWrite struct {
c <-chan time.Time
b []byte
}
type delayedWriter struct {
delay time.Duration
ch chan *delayedWrite
mu sync.Mutex
closing chan struct{}
closed chan struct{}
}
func newDelayedWriter(wr io.WriteCloser, delay time.Duration) io.WriteCloser {
dw := &delayedWriter{
delay: delay,
ch: make(chan *delayedWrite, 128),
closing: make(chan struct{}),
closed: make(chan struct{}),
}
go func() {
defer func() {
wr.Close()
close(dw.closed)
}()
for msg := range dw.ch {
select {
case <-msg.c:
n, err := wr.Write(msg.b)
if n < len(msg.b) {
err = cmp.Or(err, io.ErrShortWrite)
}
if err != nil {
panic(fmt.Sprintf("delayedWriter: write error: %v", err))
}
case <-dw.closing:
for range dw.ch {
// Drain and throw away all outstanding writes.
}
return
}
}
}()
return dw
}
func (dw *delayedWriter) Write(b []byte) (int, error) {
write := &delayedWrite{
c: time.After(dw.delay),
b: slices.Clone(b),
}
select {
case <-dw.closing:
return 0, fs.ErrClosed
case dw.ch <- write:
// No need to queue up a background send.
default:
go func() {
select {
case <-dw.closing:
case dw.ch <- write:
}
}()
}
return len(b), nil
}
func (dw *delayedWriter) Close() error {
dw.mu.Lock()
defer dw.mu.Unlock()
select {
case <-dw.closing:
default:
close(dw.closing)
close(dw.ch)
}
<-dw.closed // wait for writer to finish up.
return nil
}
var (
logOutbound = flag.Bool("log-outbound", false, "if set client.log and server.log records all outbound bytes")
testServerImpl = flag.Bool("testserver", true, "perform tests against the sftp package server instance")
debugLevel = flag.String("l", "QUIET", "sftp-server debug level")
delay = flag.Duration("delay", 0, "the amount of delay in the client writer")
)
func TestMain(m *testing.M) {
flag.Parse()
var err error
switch {
case *testServerImpl:
err = withServerImpl(m)
default:
err = withOpenSSHImpl(m)
}
if err != nil {
fmt.Fprintln(os.Stdout, "FAIL:", err)
os.Exit(1)
}
}
func withOpenSSHImpl(m *testing.M) error {
sftpServerLocations := []string{
"/usr/libexec/ssh/sftp-server",
"/usr/libexec/sftp-server",
"/usr/lib/openssh/sftp-server",
"/usr/lib/ssh/sftp-server",
`C:\Program Files\Git\usr\lib\ssh\sftp-server.exe`,
}
sftpServer, _ := exec.LookPath("sftp-server")
if sftpServer == "" {
for _, loc := range sftpServerLocations {
if _, err := os.Stat(loc); err == nil {
sftpServer = loc
break
}
}
if sftpServer == "" {
return fmt.Errorf("could not find sftp-server")
}
}
cmd := exec.Command(sftpServer, "-e", "-l", *debugLevel) // log to stderr
cmd.Stderr = os.Stdout
pw, err := cmd.StdinPipe()
if err != nil {
return err
}
pr, err := cmd.StdoutPipe()
if err != nil {
return err
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("could not start sftp-server: %w", err)
}
return runClient(m, pr, pw)
}
func withServerImpl(m *testing.M) error {
srvConn, clConn := new(Pipe), new(Pipe)
var err error
srvConn.Reader, clConn.WriteCloser, err = os.Pipe()
if err != nil {
return err
}
clConn.Reader, srvConn.WriteCloser, err = os.Pipe()
if err != nil {
return err
}
if *logOutbound {
var clCloser io.Closer
clConn.WriteCloser, clCloser = teeToFile(clConn.WriteCloser, "client.log")
defer clCloser.Close()
var srvCloser io.Closer
srvConn.WriteCloser, srvCloser = teeToFile(srvConn.WriteCloser, "server.log")
defer srvCloser.Close()
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if err := srv.Serve(srvConn); err != nil {
panic(err)
}
}()
defer func() {
if err := srv.GracefulStop(); err != nil {
panic(err)
}
wg.Wait()
}()
return runClient(m, clConn.Reader, clConn.WriteCloser)
}
func runClient(m *testing.M, rd io.Reader, wr io.WriteCloser) error {
if *delay != 0 {
wr = newDelayedWriter(wr, *delay)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if runtime.GOOS != "windows" {
me, err := user.Current()
if err != nil {
return err
}
myUID = me.Uid
myGID = me.Gid
}
var err error
cl, err = sftp.NewClientPipe(ctx, rd, wr)
if err != nil {
return err
}
defer cl.Close()
defer func() {
cl.ReportPoolMetrics(os.Stdout)
}()
cwd, err = cl.RealPath(".")
if err != nil {
return err
}
// Read "client.go" to test against various ways to read a file.
clData, err = os.ReadFile("../client.go")
if err != nil {
return err
}
clInfo, err = os.Stat("../client.go")
if err != nil {
return err
}
retCode := m.Run()
if retCode != 0 {
return fmt.Errorf("ret code: %d", retCode)
}
return nil
}
func TestMkdir(t *testing.T) {
dir := t.TempDir()
mkdirTarget := filepath.Join(dir, "test-mkdir")
t.Log("mkdir:", mkdirTarget)
if err := cl.Mkdir(toRemotePath(mkdirTarget), 0755); err != nil {
t.Fatal(err)
}
fi, err := os.Stat(mkdirTarget)
if err != nil {
t.Fatal(err)
}
if !fi.IsDir() {
t.Errorf("expected %q to be a directory", mkdirTarget)
}
if err := cl.Mkdir(toRemotePath(mkdirTarget), 0755); err == nil {
t.Errorf("expected to error")
}
if err := cl.Remove(toRemotePath(mkdirTarget)); err != nil {
t.Fatal(err)
}
_, err = os.Stat(mkdirTarget)
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("expected %q to not exist", mkdirTarget)
}
if err := cl.Remove(toRemotePath(mkdirTarget)); err == nil {
t.Errorf("expected to error")
}
}
func TestOpenReadWrite(t *testing.T) {
dir := t.TempDir()
rdwrTarget := filepath.Join(dir, "test-rdwr")
t.Log("openfile:", rdwrTarget)
f, err := cl.OpenFile(toRemotePath(rdwrTarget), os.O_RDWR|os.O_CREATE, 0644)
if err != nil {
t.Fatal(err)
}
defer func() {
t.Log("close:", rdwrTarget)
err := f.Close()
if err != nil && !errors.Is(err, fs.ErrClosed) {
t.Error(err)
}
}()
n, err := f.Write([]byte("foo\n"))
if err != nil {
t.Fatal(err)
}
if n != 4 {
t.Errorf("Write() = %d, expected 5", n)
}
if _, err := f.Seek(0, os.SEEK_SET); err != nil {
t.Fatal(err)
}
data := make([]byte, 42)
n, err = f.Read(data)
if err != nil && !errors.Is(err, io.EOF) {
t.Fatal(err)
}
data = data[:n]
if n != 4 {
t.Errorf("Read() = %d, expected 4", n)
}
if string(data) != "foo\n" {
t.Errorf("Read() = %#v, but expected %#v", string(data), "foo\n")
}
fi, err := f.Stat()
if err != nil {
t.Fatal(err)
}
sys := fi.Sys()
attrs, ok := sys.(*sshfx.Attributes)
if !ok {
t.Fatalf("Stat() = %T, wrong type", sys)
}
if attrs.Size != 4 {
t.Errorf("Stat().Size = %d, but expected 4", attrs.Size)
}
if got := fmt.Sprint(attrs.UID); myUID != "" && got != myUID {
t.Errorf("Stat().UID = %s, but expected %s", got, myUID)
}
if got := fmt.Sprint(attrs.GID); myGID != "" && got != myGID {
t.Errorf("Stat().GID = %s, but expected %s", got, myGID)
}
permExpect := sshfx.FileMode(0644)
if runtime.GOOS == "windows" {
permExpect = 0666
}
if expect := sshfx.ModeRegular | permExpect; attrs.Permissions != expect {
t.Errorf("Stat().Permissions = %#o, but expected %#o", attrs.Permissions, expect)
}
if attrs.ATime == 0 {
t.Errorf("Stat().ATime shouldnt be zero")
}
if attrs.MTime == 0 {
t.Errorf("Stat().MTime shouldnt be zero")
}
switch runtime.GOOS {
case "windows":
if err := f.Close(); err != nil {
t.Error(err)
}
fallthrough
default:
// outside of windows you can delete an open file.
if err := cl.Remove(rdwrTarget); err != nil {
t.Error(err)
}
}
if _, err := os.Stat(rdwrTarget); !errors.Is(err, fs.ErrNotExist) {
t.Errorf("expected %q to not exist", rdwrTarget)
}
}
func TestRename(t *testing.T) {
dir := t.TempDir()
renFrom := filepath.Join(dir, "test-ren-from")
renTo := filepath.Join(dir, "test-ren-to")
t.Log(renFrom, "->", renTo)
if err := cl.WriteFile(renFrom, []byte("foo"), 0644); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(renFrom); err != nil {
t.Fatal(err)
}
if err := cl.Rename(renFrom, renTo); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(renFrom); !errors.Is(err, fs.ErrNotExist) {
t.Errorf("expected %q to not exist", renFrom)
}
if _, err := os.Stat(renTo); err != nil {
t.Error(err)
}
}
func TestReadDirRoot(t *testing.T) {
fis, err := cl.ReadDir("/")
if err != nil {
t.Fatal(err)
}
if len(fis) == 0 {
t.Errorf("ReadDir() []fs.FileInfo shouldnt be empty")
}
for _, fi := range fis {
switch fi := fi.(type) {
case *sshfx.NameEntry:
t.Log(fi.Longname)
case fs.FileInfo:
t.Log(sftp.FormatLongname(fi, handler))
}
}
}
func TestCWD(t *testing.T) {
t.Log("readdir:", cwd)
fis, err := cl.ReadDir(cwd)
if err != nil {
t.Fatal(err)
}
for _, fi := range fis {
switch fi := fi.(type) {
case *sshfx.NameEntry:
t.Log(fi.Longname)
case fs.FileInfo:
t.Log(sftp.FormatLongname(fi, handler))
}
}
}
func TestOneBigRead(t *testing.T) {
f, err := cl.Open("../client.go")
if err != nil {
t.Fatal(err)
}
defer func() {
t.Log("close: ../client.go")
if err := f.Close(); err != nil {
t.Error(err)
}
}()
fi, err := f.Stat()
if err != nil {
t.Fatal(err)
}
b := make([]byte, fi.Size()*1024)
n, err := f.Read(b)
if err != nil && !errors.Is(err, io.EOF) {
t.Fatal(err)
}
b = b[:n]
if !compareClientData(t, "f.Read", "read", b, n) {
return
}
}
func TestReadAll(t *testing.T) {
f, err := cl.Open("../client.go")
if err != nil {
t.Fatal(err)
}
defer func() {
t.Log("close: ../client.go")
if err := f.Close(); err != nil {
t.Error(err)
}
}()
b, err := io.ReadAll(f)
if err != nil {
t.Fatal(err)
}
if !compareClientData(t, "f.Read", "read", b, len(b)) {
return
}
}
func TestWriteTo(t *testing.T) {
f, err := cl.Open("../client.go")
if err != nil {
t.Fatal(err)
}
defer func() {
t.Log("close: ../client.go")
if err := f.Close(); err != nil {
t.Error(err)
}
}()
fi, err := f.Stat()
if err != nil {
t.Fatal(err)
}
buf := new(bytes.Buffer)
buf.Grow(int(fi.Size() + 1024))
n, err := f.WriteTo(buf)
if err != nil {
t.Fatal(err)
}
if !compareClientData(t, "f.WriteTo", "writeto", buf.Bytes(), n) {
return
}
}
func TestCopy(t *testing.T) {
f, err := cl.Open("../client.go")
if err != nil {
t.Fatal(err)
}
defer func() {
t.Log("close: ../client.go")
if err := f.Close(); err != nil {
t.Error(err)
}
}()
fi, err := f.Stat()
if err != nil {
t.Fatal(err)
}
buf := new(bytes.Buffer)
buf.Grow(int(fi.Size() + 1024))
n, err := io.Copy(buf, f)
if err != nil {
t.Fatal(err)
}
if !compareClientData(t, "io.Copy", "copy", buf.Bytes(), n) {
return
}
}
func TestReadFile(t *testing.T) {
b, err := cl.ReadFile("../client.go")
if err != nil {
t.Fatal(err)
}
if !compareClientData(t, "cl.ReadFile", "readfile", b, len(b)) {
return
}
}
func compareClientData[SZ int | int64](t *testing.T, op, logfile string, b []byte, n SZ) bool {
t.Helper()
if int64(n) != clInfo.Size() {
t.Errorf("%s() = %d bytes, expected %d", op, n, clInfo.Size())
}
if diff := goCmp.Diff(clData, b); diff != "" {
t.Error("os.ReadFile differs from", op)
if testing.Verbose() {
t.Error(diff)
}
os.WriteFile(fmt.Sprintf("client-%s.out", logfile), b, 0644)
return false
}
return true
}
func TestReadFrom(t *testing.T) {
dir := t.TempDir()
readFromTarget := filepath.Join(dir, "client-readfrom.go")
f, err := cl.Create(toRemotePath(readFromTarget))
if err != nil {
t.Fatal(err)
}
defer func() {
t.Log("close: client-readfrom.go")
if err := f.Close(); err != nil {
t.Error(err)
}
}()
n, err := f.ReadFrom(bytes.NewReader(clData))
if err != nil {
t.Fatal(err)
}
b, err := os.ReadFile(readFromTarget)
if err != nil {
t.Fatal(err)
}
if int64(len(b)) != n {
t.Errorf("ReadFile() = %d, but ReadFrom() = %d", len(b), n)
}
if !compareClientData(t, "f.ReadFrom", "readfrom", b, n) {
return
}
}
func TestStatVFS(t *testing.T) {
if !*testServerImpl {
t.Skip("not testing against localfs server implementation")
}
if _, ok := any(handler).(sftp.StatVFSServerHandler); !ok {
t.Skip("handler does not implement statvfs")
}
dir := t.TempDir()
targetNotExist := filepath.Join(dir, "statvfs-does-not-exist")
_, err := cl.StatVFS(toRemotePath(targetNotExist))
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("unexpected error, got %v, should be fs.ErrNotFound", err)
}
resp, err := cl.StatVFS(toRemotePath(dir))
if err != nil {
t.Fatal(err)
}
t.Logf("%+v", resp)
}
var benchBuf []byte
func benchHelperWriteTo(b *testing.B, length int) {
b.Helper()
dir := b.TempDir()
target := filepath.Join(dir, "bench-writeto")
remote := toRemotePath(target)
if err := os.WriteFile(target, nil, 0644); err != nil {
b.Fatal(err)
}
defer os.Remove(remote)
if err := os.Truncate(target, int64(length)); err != nil {
b.Fatal(err)
}
buf := bytes.NewBuffer(benchBuf)
if grow := length - buf.Cap(); grow > 0 {
buf.Grow(grow)
}
defer func() {
benchBuf = buf.Bytes()
}()
b.ResetTimer()
b.SetBytes(int64(length))
for range b.N {
buf.Reset()
f, err := cl.Open(remote)
if err != nil {
b.Fatal(err)
}
n, err := f.WriteTo(buf)
if err != nil {
b.Fatal(err)
}
if n != int64(length) {
b.Fatal(io.ErrShortWrite)
}
f.Close()
if buf.Len() != length {
b.Fatalf("wrong buffer size: got %d, want %d", buf.Len(), length)
}
}
}
func BenchmarkWriteTo1KiB(b *testing.B) {
benchHelperWriteTo(b, 1<<10)
}
func BenchmarkWriteTo1MiB(b *testing.B) {
benchHelperWriteTo(b, 1<<20)
}
func BenchmarkWriteTo4MiB(b *testing.B) {
benchHelperWriteTo(b, 1<<22)
}
func BenchmarkWriteTo10MiB(b *testing.B) {
benchHelperWriteTo(b, 10*(1<<20)+123)
}
func BenchmarkWriteTo64MiB(b *testing.B) {
benchHelperWriteTo(b, 1<<26)
}
func benchHelperReadFrom(b *testing.B, length int) {
b.Helper()
dir := b.TempDir()
target := filepath.Join(dir, "bench-readfrom")
remote := toRemotePath(target)
if err := os.WriteFile(target, nil, 0644); err != nil {
b.Fatal(err)
}
defer os.Remove(remote)
data := benchBuf[:cap(benchBuf)]
if len(data) < length {
data = make([]byte, length)
}
data = data[:length]
defer func() {
benchBuf = data
}()
buf := new(bytes.Reader)
b.ResetTimer()
b.SetBytes(int64(length))
for range b.N {
buf.Reset(data)
f, err := cl.Create(remote)
if err != nil {
b.Fatal(err)
}
n, err := f.ReadFrom(buf)
if err != nil {
b.Fatal(err)
}
if n != int64(length) {
b.Fatal(io.ErrShortWrite)
}
f.Close()
if buf.Len() != 0 {
b.Fatalf("wrong buffer size: got %d, want %d", buf.Len(), 0)
}
}
}
func BenchmarkReadFrom1KiB(b *testing.B) {
benchHelperReadFrom(b, 1<<10)
}
func BenchmarkReadFrom1MiB(b *testing.B) {
benchHelperReadFrom(b, 1<<20)
}
func BenchmarkReadFrom4MiB(b *testing.B) {
benchHelperReadFrom(b, 1<<22)
}
func BenchmarkReadFrom10MiB(b *testing.B) {
benchHelperReadFrom(b, 10*(1<<20)+123)
}
func BenchmarkReadFrom64MiB(b *testing.B) {
benchHelperReadFrom(b, 1<<26)
}