Compare commits

...

3 Commits

Author SHA1 Message Date
Cassondra Foesch f1b135a6f5 invert if-blocks to reduce indention levels 2025-05-23 10:25:23 +00:00
Cassondra Foesch 32bfbbb6c0 I think I prefer this API design 2025-05-23 10:18:29 +00:00
Cassondra Foesch 18192955ef fix ssh subsystem request invocation 2025-05-23 09:21:31 +00:00
2 changed files with 78 additions and 13 deletions

View File

@ -158,6 +158,14 @@ func UseFstat(value bool) ClientOption {
} }
} }
// CopyStderrTo specifies a writer to which the standard error of the remote sftp-server command should be written.
func CopyStderrTo(wr io.WriteCloser) ClientOption {
return func(c *Client) error {
c.stderrTo = wr
return nil
}
}
// Client represents an SFTP session on a *ssh.ClientConn SSH connection. // Client represents an SFTP session on a *ssh.ClientConn SSH connection.
// Multiple Clients can be active on a single SSH connection, and a Client // Multiple Clients can be active on a single SSH connection, and a Client
// may be called concurrently from multiple Goroutines. // may be called concurrently from multiple Goroutines.
@ -166,6 +174,8 @@ func UseFstat(value bool) ClientOption {
type Client struct { type Client struct {
clientConn clientConn
stderrTo io.WriteCloser
ext map[string]string // Extensions (name -> data). ext map[string]string // Extensions (name -> data).
maxPacket int // max packet size read or written. maxPacket int // max packet size read or written.
@ -186,9 +196,7 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := s.RequestSubsystem("sftp"); err != nil {
return nil, err
}
pw, err := s.StdinPipe() pw, err := s.StdinPipe()
if err != nil { if err != nil {
return nil, err return nil, err
@ -197,15 +205,27 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
perr, err := s.StderrPipe()
if err != nil {
return nil, err
}
return NewClientPipe(pr, pw, opts...) if err := s.RequestSubsystem("sftp"); err != nil {
return nil, err
}
return newClientPipe(pr, pw, perr, s.Wait, opts...)
} }
// NewClientPipe creates a new SFTP client given a Reader and a WriteCloser. // NewClientPipe creates a new SFTP client given a Reader and a WriteCloser.
// This can be used for connecting to an SFTP server over TCP/TLS or by using // This can be used for connecting to an SFTP server over TCP/TLS or by using
// the system's ssh client program (e.g. via exec.Command). // the system's ssh client program (e.g. via exec.Command).
func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) { func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) {
sftp := &Client{ return newClientPipe(rd, wr, nil, nil, opts...)
}
func newClientPipe(rd io.Reader, wr io.WriteCloser, stderr io.Reader, wait func() error, opts ...ClientOption) (*Client, error) {
c := &Client{
clientConn: clientConn{ clientConn: clientConn{
conn: conn{ conn: conn{
Reader: rd, Reader: rd,
@ -213,6 +233,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
}, },
inflight: make(map[uint32]chan<- result), inflight: make(map[uint32]chan<- result),
closed: make(chan struct{}), closed: make(chan struct{}),
wait: wait,
}, },
ext: make(map[string]string), ext: make(map[string]string),
@ -222,32 +243,53 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
} }
for _, opt := range opts { for _, opt := range opts {
if err := opt(sftp); err != nil { if err := opt(c); err != nil {
wr.Close() wr.Close()
return nil, err return nil, err
} }
} }
if err := sftp.sendInit(); err != nil { if stderr != nil {
wr := io.Discard
if c.stderrTo != nil {
wr = c.stderrTo
}
go func() {
defer func() {
if closer, ok := wr.(io.Closer); ok {
if err := closer.Close(); err != nil {
debug("error closing stderrTo: %v", err)
}
}
}()
if _, err := io.Copy(wr, stderr); err != nil {
debug("error copying stderr: %v", err)
}
}()
}
if err := c.sendInit(); err != nil {
wr.Close() wr.Close()
return nil, fmt.Errorf("error sending init packet to server: %w", err) return nil, fmt.Errorf("error sending init packet to server: %w", err)
} }
if err := sftp.recvVersion(); err != nil { if err := c.recvVersion(); err != nil {
wr.Close() wr.Close()
return nil, fmt.Errorf("error receiving version packet from server: %w", err) return nil, fmt.Errorf("error receiving version packet from server: %w", err)
} }
sftp.clientConn.wg.Add(1) c.clientConn.wg.Add(1)
go func() { go func() {
defer sftp.clientConn.wg.Done() defer c.clientConn.wg.Done()
if err := sftp.clientConn.recv(); err != nil { if err := c.clientConn.recv(); err != nil {
sftp.clientConn.broadcastErr(err) c.clientConn.broadcastErr(err)
} }
}() }()
return sftp, nil return c, nil
} }
// Create creates the named file mode 0666 (before umask), truncating it if it // Create creates the named file mode 0666 (before umask), truncating it if it

23
conn.go
View File

@ -43,6 +43,8 @@ type clientConn struct {
conn conn
wg sync.WaitGroup wg sync.WaitGroup
wait func() error // if non-nil, call this during Wait() to get a possible remote status error.
sync.Mutex // protects inflight sync.Mutex // protects inflight
inflight map[uint32]chan<- result // outstanding requests inflight map[uint32]chan<- result // outstanding requests
@ -55,6 +57,27 @@ type clientConn struct {
// goroutines. // goroutines.
func (c *clientConn) Wait() error { func (c *clientConn) Wait() error {
<-c.closed <-c.closed
if c.wait == nil {
// Only return this error if c.wait won't return something more useful.
return c.err
}
if err := c.wait(); err != nil {
// TODO: when https://github.com/golang/go/issues/35025 is fixed,
// we can remove this if block entirely.
// Right now, its always going to return this, so it is not useful.
// But we have this code here so that as soon as the ssh library is updated,
// we can return a possibly more useful error.
if err.Error() == "ssh: session not started" {
return c.err
}
return err
}
// c.wait returned no error; so, let's return something maybe more useful.
return c.err return c.err
} }