diff --git a/client.go b/client.go index 11c36a8..f3a74e4 100644 --- a/client.go +++ b/client.go @@ -34,8 +34,9 @@ type clientConn struct { reqid atomic.Uint32 rd io.Reader - resPool *sync.WorkPool[result] + wait func() error + resPool *sync.WorkPool[result] bufPool *sync.SlicePool[[]byte, byte] pktPool *sync.Pool[sshfx.RawPacket] @@ -107,6 +108,26 @@ func (c *clientConn) getChan(reqid uint32) (chan<- result, bool) { func (c *clientConn) Wait() error { <-c.closed + + if c.wait == nil { + return c.err + } + + if err := c.wait(); err != nil { + + // TODO: when when https://github.com/golang/go/issues/35025 is fixed, + // we can remove this if block entirely. + // Right now, it's always going to return this, so it is not useful. + // But we have this code here so that as soon as the ssh library is updated, + // we can return a possibly more useful error. + if err.Error() == "ssh: session not started" { + return c.err + } + + return err + } + + // c.wait returned no error; so let's return something maybe more useful. return c.err } @@ -114,7 +135,8 @@ func (c *clientConn) disconnect(err error) { c.mu.Lock() defer c.mu.Unlock() - c.err = err + c.err = cmp.Or(c.err, err) + select { case <-c.closed: // already closed @@ -365,11 +387,20 @@ func WithMaxPacketLength(length int) ClientOption { } } +// CopyStderrTo specifies a writer to which the standard error of the remote sftp-server command should be written. +func CopyStderrTo(wr io.WriteCloser) ClientOption { + return func(cl *Client) error { + cl.stderrTo = wr + return nil + } +} + // Client represents an SFTP session on a *ssh.ClientConn SSH connection. // Multiple clients can be active on a single SSH connection, // and a client may be called concurrently from multiple goroutines. type Client struct { - conn clientConn + conn clientConn + stderrTo io.WriteCloser maxPacket uint32 maxDataLen int @@ -554,24 +585,30 @@ func NewClient(ctx context.Context, conn *ssh.Client, opts ...ClientOption) (*Cl return nil, err } + pw, err := s.StdinPipe() + if err != nil { + s.Close() + return nil, err + } + + pr, err := s.StdoutPipe() + if err != nil { + s.Close() + return nil, err + } + + perr, err := s.StderrPipe() + if err != nil { + s.Close() + return nil, err + } + if err := s.RequestSubsystem("sftp"); err != nil { s.Close() return nil, err } - w, err := s.StdinPipe() - if err != nil { - s.Close() - return nil, err - } - - r, err := s.StdoutPipe() - if err != nil { - s.Close() - return nil, err - } - - return NewClientPipe(ctx, r, w, opts...) + return newClientPipe(ctx, pr, perr, pw, s.Wait, opts) } // NewClientPipe creates a new SFTP client given a Reader and WriteCloser. @@ -579,10 +616,16 @@ func NewClient(ctx context.Context, conn *ssh.Client, opts ...ClientOption) (*Cl // // The given context is only used for the negotiation of init and version packets. func NewClientPipe(ctx context.Context, rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) { + return newClientPipe(ctx, rd, nil, wr, nil, opts) +} + +func newClientPipe(ctx context.Context, rd, stderr io.Reader, wr io.WriteCloser, wait func() error, opts []ClientOption) (*Client, error) { cl := &Client{ conn: clientConn{ - rd: rd, - wr: wr, + rd: rd, + wr: wr, + + wait: wait, closed: make(chan struct{}), }, @@ -597,6 +640,27 @@ func NewClientPipe(ctx context.Context, rd io.Reader, wr io.WriteCloser, opts .. } } + if stderr != nil { + wr := io.Discard + if cl.stderrTo != nil { + wr = cl.stderrTo + } + + go func() { + defer func() { + if closer, ok := wr.(io.Closer); ok { + if err := closer.Close(); err != nil { + cl.conn.disconnect(fmt.Errorf("error closing stderrTo: %w", err)) + } + } + }() + + if _, err := io.Copy(wr, stderr); err != nil { + cl.conn.disconnect(fmt.Errorf("error copying stderr: %w", err)) + } + }() + } + exts, err := cl.conn.handshake(ctx, cl.maxPacket) if err != nil { return nil, err @@ -605,7 +669,6 @@ func NewClientPipe(ctx context.Context, rd io.Reader, wr io.WriteCloser, opts .. cl.exts = exts cl.conn.resPool = sync.NewWorkPool[result](cl.maxInflight) - cl.conn.bufPool = sync.NewSlicePool[[]byte](cl.maxInflight, int(cl.maxPacket)) cl.conn.pktPool = sync.NewPool[sshfx.RawPacket](cl.maxInflight) diff --git a/localfs/localfs_integration_test.go b/localfs/localfs_integration_test.go index 38b7cfc..edc4159 100644 --- a/localfs/localfs_integration_test.go +++ b/localfs/localfs_integration_test.go @@ -213,6 +213,7 @@ func TestMain(m *testing.M) { func withOpenSSHImpl(m *testing.M) error { sftpServerLocations := []string{ + "/usr/libexec/ssh/sftp-server", "/usr/libexec/sftp-server", "/usr/lib/openssh/sftp-server", "/usr/lib/ssh/sftp-server",