fix ssh subsystem request invocation

This commit is contained in:
Cassondra Foesch 2025-05-23 17:34:44 +00:00
parent b71b525cc4
commit 7cfa3d4785
2 changed files with 83 additions and 19 deletions

101
client.go
View File

@ -34,8 +34,9 @@ type clientConn struct {
reqid atomic.Uint32
rd io.Reader
resPool *sync.WorkPool[result]
wait func() error
resPool *sync.WorkPool[result]
bufPool *sync.SlicePool[[]byte, byte]
pktPool *sync.Pool[sshfx.RawPacket]
@ -107,6 +108,26 @@ func (c *clientConn) getChan(reqid uint32) (chan<- result, bool) {
func (c *clientConn) Wait() error {
<-c.closed
if c.wait == nil {
return c.err
}
if err := c.wait(); err != nil {
// TODO: when when https://github.com/golang/go/issues/35025 is fixed,
// we can remove this if block entirely.
// Right now, it's always going to return this, so it is not useful.
// But we have this code here so that as soon as the ssh library is updated,
// we can return a possibly more useful error.
if err.Error() == "ssh: session not started" {
return c.err
}
return err
}
// c.wait returned no error; so let's return something maybe more useful.
return c.err
}
@ -114,7 +135,8 @@ func (c *clientConn) disconnect(err error) {
c.mu.Lock()
defer c.mu.Unlock()
c.err = err
c.err = cmp.Or(c.err, err)
select {
case <-c.closed:
// already closed
@ -365,11 +387,20 @@ func WithMaxPacketLength(length int) ClientOption {
}
}
// CopyStderrTo specifies a writer to which the standard error of the remote sftp-server command should be written.
func CopyStderrTo(wr io.WriteCloser) ClientOption {
return func(cl *Client) error {
cl.stderrTo = wr
return nil
}
}
// Client represents an SFTP session on a *ssh.ClientConn SSH connection.
// Multiple clients can be active on a single SSH connection,
// and a client may be called concurrently from multiple goroutines.
type Client struct {
conn clientConn
conn clientConn
stderrTo io.WriteCloser
maxPacket uint32
maxDataLen int
@ -554,24 +585,30 @@ func NewClient(ctx context.Context, conn *ssh.Client, opts ...ClientOption) (*Cl
return nil, err
}
pw, err := s.StdinPipe()
if err != nil {
s.Close()
return nil, err
}
pr, err := s.StdoutPipe()
if err != nil {
s.Close()
return nil, err
}
perr, err := s.StderrPipe()
if err != nil {
s.Close()
return nil, err
}
if err := s.RequestSubsystem("sftp"); err != nil {
s.Close()
return nil, err
}
w, err := s.StdinPipe()
if err != nil {
s.Close()
return nil, err
}
r, err := s.StdoutPipe()
if err != nil {
s.Close()
return nil, err
}
return NewClientPipe(ctx, r, w, opts...)
return newClientPipe(ctx, pr, perr, pw, s.Wait, opts)
}
// NewClientPipe creates a new SFTP client given a Reader and WriteCloser.
@ -579,10 +616,16 @@ func NewClient(ctx context.Context, conn *ssh.Client, opts ...ClientOption) (*Cl
//
// The given context is only used for the negotiation of init and version packets.
func NewClientPipe(ctx context.Context, rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) {
return newClientPipe(ctx, rd, nil, wr, nil, opts)
}
func newClientPipe(ctx context.Context, rd, stderr io.Reader, wr io.WriteCloser, wait func() error, opts []ClientOption) (*Client, error) {
cl := &Client{
conn: clientConn{
rd: rd,
wr: wr,
rd: rd,
wr: wr,
wait: wait,
closed: make(chan struct{}),
},
@ -597,6 +640,27 @@ func NewClientPipe(ctx context.Context, rd io.Reader, wr io.WriteCloser, opts ..
}
}
if stderr != nil {
wr := io.Discard
if cl.stderrTo != nil {
wr = cl.stderrTo
}
go func() {
defer func() {
if closer, ok := wr.(io.Closer); ok {
if err := closer.Close(); err != nil {
cl.conn.disconnect(fmt.Errorf("error closing stderrTo: %w", err))
}
}
}()
if _, err := io.Copy(wr, stderr); err != nil {
cl.conn.disconnect(fmt.Errorf("error copying stderr: %w", err))
}
}()
}
exts, err := cl.conn.handshake(ctx, cl.maxPacket)
if err != nil {
return nil, err
@ -605,7 +669,6 @@ func NewClientPipe(ctx context.Context, rd io.Reader, wr io.WriteCloser, opts ..
cl.exts = exts
cl.conn.resPool = sync.NewWorkPool[result](cl.maxInflight)
cl.conn.bufPool = sync.NewSlicePool[[]byte](cl.maxInflight, int(cl.maxPacket))
cl.conn.pktPool = sync.NewPool[sshfx.RawPacket](cl.maxInflight)

View File

@ -213,6 +213,7 @@ func TestMain(m *testing.M) {
func withOpenSSHImpl(m *testing.M) error {
sftpServerLocations := []string{
"/usr/libexec/ssh/sftp-server",
"/usr/libexec/sftp-server",
"/usr/lib/openssh/sftp-server",
"/usr/lib/ssh/sftp-server",