mirror of https://github.com/pkg/sftp.git
fix ssh subsystem request invocation
This commit is contained in:
parent
b71b525cc4
commit
7cfa3d4785
95
client.go
95
client.go
|
@ -34,8 +34,9 @@ type clientConn struct {
|
|||
reqid atomic.Uint32
|
||||
rd io.Reader
|
||||
|
||||
resPool *sync.WorkPool[result]
|
||||
wait func() error
|
||||
|
||||
resPool *sync.WorkPool[result]
|
||||
bufPool *sync.SlicePool[[]byte, byte]
|
||||
pktPool *sync.Pool[sshfx.RawPacket]
|
||||
|
||||
|
@ -107,6 +108,26 @@ func (c *clientConn) getChan(reqid uint32) (chan<- result, bool) {
|
|||
|
||||
func (c *clientConn) Wait() error {
|
||||
<-c.closed
|
||||
|
||||
if c.wait == nil {
|
||||
return c.err
|
||||
}
|
||||
|
||||
if err := c.wait(); err != nil {
|
||||
|
||||
// TODO: when when https://github.com/golang/go/issues/35025 is fixed,
|
||||
// we can remove this if block entirely.
|
||||
// Right now, it's always going to return this, so it is not useful.
|
||||
// But we have this code here so that as soon as the ssh library is updated,
|
||||
// we can return a possibly more useful error.
|
||||
if err.Error() == "ssh: session not started" {
|
||||
return c.err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// c.wait returned no error; so let's return something maybe more useful.
|
||||
return c.err
|
||||
}
|
||||
|
||||
|
@ -114,7 +135,8 @@ func (c *clientConn) disconnect(err error) {
|
|||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.err = err
|
||||
c.err = cmp.Or(c.err, err)
|
||||
|
||||
select {
|
||||
case <-c.closed:
|
||||
// already closed
|
||||
|
@ -365,11 +387,20 @@ func WithMaxPacketLength(length int) ClientOption {
|
|||
}
|
||||
}
|
||||
|
||||
// CopyStderrTo specifies a writer to which the standard error of the remote sftp-server command should be written.
|
||||
func CopyStderrTo(wr io.WriteCloser) ClientOption {
|
||||
return func(cl *Client) error {
|
||||
cl.stderrTo = wr
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Client represents an SFTP session on a *ssh.ClientConn SSH connection.
|
||||
// Multiple clients can be active on a single SSH connection,
|
||||
// and a client may be called concurrently from multiple goroutines.
|
||||
type Client struct {
|
||||
conn clientConn
|
||||
stderrTo io.WriteCloser
|
||||
|
||||
maxPacket uint32
|
||||
maxDataLen int
|
||||
|
@ -554,24 +585,30 @@ func NewClient(ctx context.Context, conn *ssh.Client, opts ...ClientOption) (*Cl
|
|||
return nil, err
|
||||
}
|
||||
|
||||
pw, err := s.StdinPipe()
|
||||
if err != nil {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pr, err := s.StdoutPipe()
|
||||
if err != nil {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
perr, err := s.StderrPipe()
|
||||
if err != nil {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.RequestSubsystem("sftp"); err != nil {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
w, err := s.StdinPipe()
|
||||
if err != nil {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, err := s.StdoutPipe()
|
||||
if err != nil {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewClientPipe(ctx, r, w, opts...)
|
||||
return newClientPipe(ctx, pr, perr, pw, s.Wait, opts)
|
||||
}
|
||||
|
||||
// NewClientPipe creates a new SFTP client given a Reader and WriteCloser.
|
||||
|
@ -579,10 +616,16 @@ func NewClient(ctx context.Context, conn *ssh.Client, opts ...ClientOption) (*Cl
|
|||
//
|
||||
// The given context is only used for the negotiation of init and version packets.
|
||||
func NewClientPipe(ctx context.Context, rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) {
|
||||
return newClientPipe(ctx, rd, nil, wr, nil, opts)
|
||||
}
|
||||
|
||||
func newClientPipe(ctx context.Context, rd, stderr io.Reader, wr io.WriteCloser, wait func() error, opts []ClientOption) (*Client, error) {
|
||||
cl := &Client{
|
||||
conn: clientConn{
|
||||
rd: rd,
|
||||
wr: wr,
|
||||
|
||||
wait: wait,
|
||||
closed: make(chan struct{}),
|
||||
},
|
||||
|
||||
|
@ -597,6 +640,27 @@ func NewClientPipe(ctx context.Context, rd io.Reader, wr io.WriteCloser, opts ..
|
|||
}
|
||||
}
|
||||
|
||||
if stderr != nil {
|
||||
wr := io.Discard
|
||||
if cl.stderrTo != nil {
|
||||
wr = cl.stderrTo
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if closer, ok := wr.(io.Closer); ok {
|
||||
if err := closer.Close(); err != nil {
|
||||
cl.conn.disconnect(fmt.Errorf("error closing stderrTo: %w", err))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(wr, stderr); err != nil {
|
||||
cl.conn.disconnect(fmt.Errorf("error copying stderr: %w", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
exts, err := cl.conn.handshake(ctx, cl.maxPacket)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -605,7 +669,6 @@ func NewClientPipe(ctx context.Context, rd io.Reader, wr io.WriteCloser, opts ..
|
|||
cl.exts = exts
|
||||
|
||||
cl.conn.resPool = sync.NewWorkPool[result](cl.maxInflight)
|
||||
|
||||
cl.conn.bufPool = sync.NewSlicePool[[]byte](cl.maxInflight, int(cl.maxPacket))
|
||||
cl.conn.pktPool = sync.NewPool[sshfx.RawPacket](cl.maxInflight)
|
||||
|
||||
|
|
|
@ -213,6 +213,7 @@ func TestMain(m *testing.M) {
|
|||
|
||||
func withOpenSSHImpl(m *testing.M) error {
|
||||
sftpServerLocations := []string{
|
||||
"/usr/libexec/ssh/sftp-server",
|
||||
"/usr/libexec/sftp-server",
|
||||
"/usr/lib/openssh/sftp-server",
|
||||
"/usr/lib/ssh/sftp-server",
|
||||
|
|
Loading…
Reference in New Issue