mirror of https://github.com/pkg/sftp.git
fix ssh subsystem request invocation
This commit is contained in:
parent
36ce9cfd5c
commit
18192955ef
78
client.go
78
client.go
|
@ -158,6 +158,17 @@ func UseFstat(value bool) ClientOption {
|
|||
}
|
||||
}
|
||||
|
||||
// UseStderr is used to indicate that you intend to read from the standard error of the remote sftp-server command.
|
||||
// This does not actually get or set the standard error,
|
||||
// instead this simply prevents the standard error from being discarded.
|
||||
// You will still need to call [Client.StderrPipe] to get the reader.
|
||||
func UseStderr() ClientOption {
|
||||
return func(c *Client) error {
|
||||
c.useStderr = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Client represents an SFTP session on a *ssh.ClientConn SSH connection.
|
||||
// Multiple Clients can be active on a single SSH connection, and a Client
|
||||
// may be called concurrently from multiple Goroutines.
|
||||
|
@ -166,6 +177,9 @@ func UseFstat(value bool) ClientOption {
|
|||
type Client struct {
|
||||
clientConn
|
||||
|
||||
stderr io.Reader
|
||||
useStderr bool
|
||||
|
||||
ext map[string]string // Extensions (name -> data).
|
||||
|
||||
maxPacket int // max packet size read or written.
|
||||
|
@ -186,9 +200,7 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.RequestSubsystem("sftp"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pw, err := s.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -197,15 +209,27 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
perr, err := s.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewClientPipe(pr, pw, opts...)
|
||||
if err := s.RequestSubsystem("sftp"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newClientPipe(pr, pw, perr, s.Wait, opts...)
|
||||
}
|
||||
|
||||
// NewClientPipe creates a new SFTP client given a Reader and a WriteCloser.
|
||||
// This can be used for connecting to an SFTP server over TCP/TLS or by using
|
||||
// the system's ssh client program (e.g. via exec.Command).
|
||||
func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) {
|
||||
sftp := &Client{
|
||||
return newClientPipe(rd, wr, nil, nil, opts...)
|
||||
}
|
||||
|
||||
func newClientPipe(rd io.Reader, wr io.WriteCloser, stderr io.Reader, wait func() error, opts ...ClientOption) (*Client, error) {
|
||||
c := &Client{
|
||||
clientConn: clientConn{
|
||||
conn: conn{
|
||||
Reader: rd,
|
||||
|
@ -213,6 +237,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
|
|||
},
|
||||
inflight: make(map[uint32]chan<- result),
|
||||
closed: make(chan struct{}),
|
||||
wait: wait,
|
||||
},
|
||||
|
||||
ext: make(map[string]string),
|
||||
|
@ -222,32 +247,59 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
|
|||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if err := opt(sftp); err != nil {
|
||||
if err := opt(c); err != nil {
|
||||
wr.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := sftp.sendInit(); err != nil {
|
||||
if stderr != nil {
|
||||
if !c.useStderr {
|
||||
go func() {
|
||||
_, err := io.Copy(io.Discard, stderr)
|
||||
if err != nil {
|
||||
debug("error discarding stderr: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
} else {
|
||||
// Only set c.stderr when we're not discarding it.
|
||||
c.stderr = stderr
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.sendInit(); err != nil {
|
||||
wr.Close()
|
||||
return nil, fmt.Errorf("error sending init packet to server: %w", err)
|
||||
}
|
||||
|
||||
if err := sftp.recvVersion(); err != nil {
|
||||
if err := c.recvVersion(); err != nil {
|
||||
wr.Close()
|
||||
return nil, fmt.Errorf("error receiving version packet from server: %w", err)
|
||||
}
|
||||
|
||||
sftp.clientConn.wg.Add(1)
|
||||
c.clientConn.wg.Add(1)
|
||||
go func() {
|
||||
defer sftp.clientConn.wg.Done()
|
||||
defer c.clientConn.wg.Done()
|
||||
|
||||
if err := sftp.clientConn.recv(); err != nil {
|
||||
sftp.clientConn.broadcastErr(err)
|
||||
if err := c.clientConn.recv(); err != nil {
|
||||
c.clientConn.broadcastErr(err)
|
||||
}
|
||||
}()
|
||||
|
||||
return sftp, nil
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// StderrPipe returns a reader for the standard error of the remote sftp-server command.
|
||||
// You must have passed in the `UseStderr` client option or the standard error will already be set up to be discarded.
|
||||
// An error returned here does not mean that the client is no longer useable,
|
||||
// it only means that you won't be able to read the standard error output from the remote command.
|
||||
func (c *Client) StderrPipe() (io.Reader, error) {
|
||||
if c.stderr == nil {
|
||||
return nil, fmt.Errorf("stderr not available")
|
||||
}
|
||||
|
||||
return c.stderr, nil
|
||||
}
|
||||
|
||||
// Create creates the named file mode 0666 (before umask), truncating it if it
|
||||
|
|
19
conn.go
19
conn.go
|
@ -43,6 +43,8 @@ type clientConn struct {
|
|||
conn
|
||||
wg sync.WaitGroup
|
||||
|
||||
wait func() error // if non-nil, call this during Wait() to get a possible remote status error.
|
||||
|
||||
sync.Mutex // protects inflight
|
||||
inflight map[uint32]chan<- result // outstanding requests
|
||||
|
||||
|
@ -55,6 +57,23 @@ type clientConn struct {
|
|||
// goroutines.
|
||||
func (c *clientConn) Wait() error {
|
||||
<-c.closed
|
||||
if c.wait != nil {
|
||||
if err := c.wait(); err != nil {
|
||||
|
||||
// TODO: when https://github.com/golang/go/issues/35025 is fixed,
|
||||
// we can remove this if block entirely.
|
||||
// Right now, it’s always going to return this, so it is not useful.
|
||||
// But we have this code here so that as soon as the ssh library is updated,
|
||||
// we can return a possibly more useful error.
|
||||
if err.Error() == "ssh: session not started" {
|
||||
return c.err
|
||||
}
|
||||
|
||||
// We intentionally override the c.err error here,
|
||||
// it will probably be io.UnexpectedEOF in this case anyways.
|
||||
return err
|
||||
}
|
||||
}
|
||||
return c.err
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue