diff --git a/client.go b/client.go index 0d4d9a9..430344e 100644 --- a/client.go +++ b/client.go @@ -158,6 +158,17 @@ func UseFstat(value bool) ClientOption { } } +// UseStderr is used to indicate that you intend to read from the standard error of the remote sftp-server command. +// This does not actually get or set the standard error, +// instead this simply prevents the standard error from being discarded. +// You will still need to call [Client.StderrPipe] to get the reader. +func UseStderr() ClientOption { + return func(c *Client) error { + c.useStderr = true + return nil + } +} + // Client represents an SFTP session on a *ssh.ClientConn SSH connection. // Multiple Clients can be active on a single SSH connection, and a Client // may be called concurrently from multiple Goroutines. @@ -166,6 +177,9 @@ func UseFstat(value bool) ClientOption { type Client struct { clientConn + stderr io.Reader + useStderr bool + ext map[string]string // Extensions (name -> data). maxPacket int // max packet size read or written. @@ -186,9 +200,7 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) { if err != nil { return nil, err } - if err := s.RequestSubsystem("sftp"); err != nil { - return nil, err - } + pw, err := s.StdinPipe() if err != nil { return nil, err @@ -197,15 +209,27 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) { if err != nil { return nil, err } + perr, err := s.StderrPipe() + if err != nil { + return nil, err + } - return NewClientPipe(pr, pw, opts...) + if err := s.RequestSubsystem("sftp"); err != nil { + return nil, err + } + + return newClientPipe(pr, pw, perr, s.Wait, opts...) } // NewClientPipe creates a new SFTP client given a Reader and a WriteCloser. // This can be used for connecting to an SFTP server over TCP/TLS or by using // the system's ssh client program (e.g. via exec.Command). func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) { - sftp := &Client{ + return newClientPipe(rd, wr, nil, nil, opts...) +} + +func newClientPipe(rd io.Reader, wr io.WriteCloser, stderr io.Reader, wait func() error, opts ...ClientOption) (*Client, error) { + c := &Client{ clientConn: clientConn{ conn: conn{ Reader: rd, @@ -213,6 +237,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie }, inflight: make(map[uint32]chan<- result), closed: make(chan struct{}), + wait: wait, }, ext: make(map[string]string), @@ -222,32 +247,59 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie } for _, opt := range opts { - if err := opt(sftp); err != nil { + if err := opt(c); err != nil { wr.Close() return nil, err } } - if err := sftp.sendInit(); err != nil { + if stderr != nil { + if !c.useStderr { + go func() { + _, err := io.Copy(io.Discard, stderr) + if err != nil { + debug("error discarding stderr: %v", err) + } + }() + + } else { + // Only set c.stderr when we're not discarding it. + c.stderr = stderr + } + } + + if err := c.sendInit(); err != nil { wr.Close() return nil, fmt.Errorf("error sending init packet to server: %w", err) } - if err := sftp.recvVersion(); err != nil { + if err := c.recvVersion(); err != nil { wr.Close() return nil, fmt.Errorf("error receiving version packet from server: %w", err) } - sftp.clientConn.wg.Add(1) + c.clientConn.wg.Add(1) go func() { - defer sftp.clientConn.wg.Done() + defer c.clientConn.wg.Done() - if err := sftp.clientConn.recv(); err != nil { - sftp.clientConn.broadcastErr(err) + if err := c.clientConn.recv(); err != nil { + c.clientConn.broadcastErr(err) } }() - return sftp, nil + return c, nil +} + +// StderrPipe returns a reader for the standard error of the remote sftp-server command. +// You must have passed in the `UseStderr` client option or the standard error will already be set up to be discarded. +// An error returned here does not mean that the client is no longer useable, +// it only means that you won't be able to read the standard error output from the remote command. +func (c *Client) StderrPipe() (io.Reader, error) { + if c.stderr == nil { + return nil, fmt.Errorf("stderr not available") + } + + return c.stderr, nil } // Create creates the named file mode 0666 (before umask), truncating it if it diff --git a/conn.go b/conn.go index 93bc37b..a51da89 100644 --- a/conn.go +++ b/conn.go @@ -43,6 +43,8 @@ type clientConn struct { conn wg sync.WaitGroup + wait func() error // if non-nil, call this during Wait() to get a possible remote status error. + sync.Mutex // protects inflight inflight map[uint32]chan<- result // outstanding requests @@ -55,6 +57,23 @@ type clientConn struct { // goroutines. func (c *clientConn) Wait() error { <-c.closed + if c.wait != nil { + if err := c.wait(); err != nil { + + // TODO: when https://github.com/golang/go/issues/35025 is fixed, + // we can remove this if block entirely. + // Right now, it’s always going to return this, so it is not useful. + // But we have this code here so that as soon as the ssh library is updated, + // we can return a possibly more useful error. + if err.Error() == "ssh: session not started" { + return c.err + } + + // We intentionally override the c.err error here, + // it will probably be io.UnexpectedEOF in this case anyways. + return err + } + } return c.err }