Merge pull request #632 from pkg/fix-ssh-client-use
CI / Run test cases (1.23, macos-latest) (push) Has been cancelled Details
CI / Run test cases (1.23, ubuntu-latest) (push) Has been cancelled Details
CI / Run test cases (1.24, ubuntu-latest) (push) Has been cancelled Details

Fix SSH subsystemrequest usage
This commit is contained in:
Cassondra Foesch 2025-07-13 18:13:11 +00:00 committed by GitHub
commit 53c62f1551
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 120 additions and 34 deletions

View File

@ -158,6 +158,17 @@ func UseFstat(value bool) ClientOption {
} }
} }
// CopyStderrTo specifies a writer to which the standard error of the remote sftp-server command should be written.
//
// The writer passed in will not be automatically closed.
// It is the responsibility of the caller to coordinate closure of any writers.
func CopyStderrTo(wr io.Writer) ClientOption {
return func(c *Client) error {
c.stderrTo = wr
return nil
}
}
// Client represents an SFTP session on a *ssh.ClientConn SSH connection. // Client represents an SFTP session on a *ssh.ClientConn SSH connection.
// Multiple Clients can be active on a single SSH connection, and a Client // Multiple Clients can be active on a single SSH connection, and a Client
// may be called concurrently from multiple Goroutines. // may be called concurrently from multiple Goroutines.
@ -166,6 +177,8 @@ func UseFstat(value bool) ClientOption {
type Client struct { type Client struct {
clientConn clientConn
stderrTo io.Writer
ext map[string]string // Extensions (name -> data). ext map[string]string // Extensions (name -> data).
maxPacket int // max packet size read or written. maxPacket int // max packet size read or written.
@ -186,9 +199,7 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := s.RequestSubsystem("sftp"); err != nil {
return nil, err
}
pw, err := s.StdinPipe() pw, err := s.StdinPipe()
if err != nil { if err != nil {
return nil, err return nil, err
@ -197,15 +208,27 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
perr, err := s.StderrPipe()
if err != nil {
return nil, err
}
return NewClientPipe(pr, pw, opts...) if err := s.RequestSubsystem("sftp"); err != nil {
return nil, err
}
return newClientPipe(pr, perr, pw, s.Wait, opts...)
} }
// NewClientPipe creates a new SFTP client given a Reader and a WriteCloser. // NewClientPipe creates a new SFTP client given a Reader and a WriteCloser.
// This can be used for connecting to an SFTP server over TCP/TLS or by using // This can be used for connecting to an SFTP server over TCP/TLS or by using
// the system's ssh client program (e.g. via exec.Command). // the system's ssh client program (e.g. via exec.Command).
func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) { func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) {
sftp := &Client{ return newClientPipe(rd, nil, wr, nil, opts...)
}
func newClientPipe(rd, stderr io.Reader, wr io.WriteCloser, wait func() error, opts ...ClientOption) (*Client, error) {
c := &Client{
clientConn: clientConn{ clientConn: clientConn{
conn: conn{ conn: conn{
Reader: rd, Reader: rd,
@ -213,6 +236,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
}, },
inflight: make(map[uint32]chan<- result), inflight: make(map[uint32]chan<- result),
closed: make(chan struct{}), closed: make(chan struct{}),
wait: wait,
}, },
ext: make(map[string]string), ext: make(map[string]string),
@ -222,32 +246,50 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
} }
for _, opt := range opts { for _, opt := range opts {
if err := opt(sftp); err != nil { if err := opt(c); err != nil {
wr.Close() wr.Close()
return nil, err return nil, err
} }
} }
if err := sftp.sendInit(); err != nil { if stderr != nil {
wr := io.Discard
if c.stderrTo != nil {
wr = c.stderrTo
}
go func() {
// DO NOT close the writer!
// Programs may pass in `os.Stderr` to write the remote stderr to,
// and the program may continue after disconnect by reconnecting.
// But if we've closed their stderr, then we just messed everything up.
if _, err := io.Copy(wr, stderr); err != nil {
debug("error copying stderr: %v", err)
}
}()
}
if err := c.sendInit(); err != nil {
wr.Close() wr.Close()
return nil, fmt.Errorf("error sending init packet to server: %w", err) return nil, fmt.Errorf("error sending init packet to server: %w", err)
} }
if err := sftp.recvVersion(); err != nil { if err := c.recvVersion(); err != nil {
wr.Close() wr.Close()
return nil, fmt.Errorf("error receiving version packet from server: %w", err) return nil, fmt.Errorf("error receiving version packet from server: %w", err)
} }
sftp.clientConn.wg.Add(1) c.clientConn.wg.Add(1)
go func() { go func() {
defer sftp.clientConn.wg.Done() defer c.clientConn.wg.Done()
if err := sftp.clientConn.recv(); err != nil { if err := c.clientConn.recv(); err != nil {
sftp.clientConn.broadcastErr(err) c.clientConn.broadcastErr(err)
} }
}() }()
return sftp, nil return c, nil
} }
// Create creates the named file mode 0666 (before umask), truncating it if it // Create creates the named file mode 0666 (before umask), truncating it if it

29
conn.go
View File

@ -22,7 +22,7 @@ type conn struct {
// For the client mode just pass 0. // For the client mode just pass 0.
// It returns io.EOF if the connection is closed and // It returns io.EOF if the connection is closed and
// there are no more packets to read. // there are no more packets to read.
func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) { func (c *conn) recvPacket(orderID uint32) (fxp, []byte, error) {
return recvPacket(c, c.alloc, orderID) return recvPacket(c, c.alloc, orderID)
} }
@ -43,6 +43,8 @@ type clientConn struct {
conn conn
wg sync.WaitGroup wg sync.WaitGroup
wait func() error // if non-nil, call this during Wait() to get a possible remote status error.
sync.Mutex // protects inflight sync.Mutex // protects inflight
inflight map[uint32]chan<- result // outstanding requests inflight map[uint32]chan<- result // outstanding requests
@ -55,6 +57,27 @@ type clientConn struct {
// goroutines. // goroutines.
func (c *clientConn) Wait() error { func (c *clientConn) Wait() error {
<-c.closed <-c.closed
if c.wait == nil {
// Only return this error if c.wait won't return something more useful.
return c.err
}
if err := c.wait(); err != nil {
// TODO: when https://github.com/golang/go/issues/35025 is fixed,
// we can remove this if block entirely.
// Right now, its always going to return this, so it is not useful.
// But we have this code here so that as soon as the ssh library is updated,
// we can return a possibly more useful error.
if err.Error() == "ssh: session not started" {
return c.err
}
return err
}
// c.wait returned no error; so, let's return something maybe more useful.
return c.err return c.err
} }
@ -119,7 +142,7 @@ func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) {
// result captures the result of receiving the a packet from the server // result captures the result of receiving the a packet from the server
type result struct { type result struct {
typ byte typ fxp
data []byte data []byte
err error err error
} }
@ -129,7 +152,7 @@ type idmarshaler interface {
encoding.BinaryMarshaler encoding.BinaryMarshaler
} }
func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (byte, []byte, error) { func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (fxp, []byte, error) {
if cap(ch) < 1 { if cap(ch) < 1 {
ch = make(chan result, 1) ch = make(chan result, 1)
} }

View File

@ -304,16 +304,22 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error {
return nil return nil
} }
func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, error) { func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (fxp, []byte, error) {
var b []byte var b []byte
if alloc != nil { if alloc != nil {
b = alloc.GetPage(orderID) b = alloc.GetPage(orderID)
} else { } else {
b = make([]byte, 4) b = make([]byte, 4)
} }
if _, err := io.ReadFull(r, b[:4]); err != nil {
if n, err := io.ReadFull(r, b[:4]); err != nil {
if err == io.EOF {
return 0, nil, err return 0, nil, err
} }
return 0, nil, fmt.Errorf("error reading packet length: %d of 4: %w", n, err)
}
length, _ := unmarshalUint32(b) length, _ := unmarshalUint32(b)
if length > maxMsgLength { if length > maxMsgLength {
debug("recv packet %d bytes too long", length) debug("recv packet %d bytes too long", length)
@ -323,24 +329,39 @@ func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, e
debug("recv packet of 0 bytes too short") debug("recv packet of 0 bytes too short")
return 0, nil, errShortPacket return 0, nil, errShortPacket
} }
if alloc == nil { if alloc == nil {
b = make([]byte, length) b = make([]byte, length)
} }
if _, err := io.ReadFull(r, b[:length]); err != nil {
n, err := io.ReadFull(r, b[:length])
b = b[:n]
if err != nil {
debug("recv packet error: %d of %d bytes: %x", n, length, b)
// ReadFull only returns EOF if it has read no bytes. // ReadFull only returns EOF if it has read no bytes.
// In this case, that means a partial packet, and thus unexpected. // In this case, that means a partial packet, and thus unexpected.
if err == io.EOF { if err == io.EOF {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
debug("recv packet %d bytes: err %v", length, err)
return 0, nil, err if n == 0 {
return 0, nil, fmt.Errorf("error reading packet body: %d of %d: %w", n, length, err)
} }
return 0, nil, fmt.Errorf("error reading packet body: %d of %d: (%s) %w", n, length, fxp(b[0]), err)
}
typ, payload := fxp(b[0]), b[1:n]
if debugDumpRxPacketBytes { if debugDumpRxPacketBytes {
debug("recv packet: %s %d bytes %x", fxp(b[0]), length, b[1:length]) debug("recv packet: %s %d bytes %x", typ, length, payload)
} else if debugDumpRxPacket { } else if debugDumpRxPacket {
debug("recv packet: %s %d bytes", fxp(b[0]), length) debug("recv packet: %s %d bytes", typ, length)
} }
return b[0], b[1:length], nil
return typ, payload, nil
} }
type extensionPair struct { type extensionPair struct {

View File

@ -468,7 +468,7 @@ func TestRecvPacket(t *testing.T) {
var recvPacketTests = []struct { var recvPacketTests = []struct {
b []byte b []byte
want uint8 want fxp
body []byte body []byte
wantErr error wantErr error
}{ }{

View File

@ -148,7 +148,7 @@ func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
var err error var err error
var pkt requestPacket var pkt requestPacket
var pktType uint8 var pktType fxp
var pktBytes []byte var pktBytes []byte
for { for {
@ -158,7 +158,7 @@ func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
return err return err
} }
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes}) pkt, err = makePacket(rxPacket{pktType, pktBytes})
if err != nil { if err != nil {
switch { switch {
case errors.Is(err, errUnknownExtendedPacket): case errors.Is(err, errUnknownExtendedPacket):

View File

@ -390,7 +390,7 @@ func (svr *Server) Serve() error {
var err error var err error
var pkt requestPacket var pkt requestPacket
var pktType uint8 var pktType fxp
var pktBytes []byte var pktBytes []byte
for { for {
pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID()) pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID())
@ -403,7 +403,7 @@ func (svr *Server) Serve() error {
break break
} }
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes}) pkt, err = makePacket(rxPacket{pktType, pktBytes})
if err != nil { if err != nil {
switch { switch {
case errors.Is(err, errUnknownExtendedPacket): case errors.Is(err, errUnknownExtendedPacket):

View File

@ -184,15 +184,15 @@ func (f fx) String() string {
} }
type unexpectedPacketErr struct { type unexpectedPacketErr struct {
want, got uint8 want, got fxp
} }
func (u *unexpectedPacketErr) Error() string { func (u *unexpectedPacketErr) Error() string {
return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", fxp(u.want), fxp(u.got)) return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", u.want, u.got)
} }
func unimplementedPacketErr(u uint8) error { func unimplementedPacketErr(u fxp) error {
return fmt.Errorf("sftp: unimplemented packet type: got %v", fxp(u)) return fmt.Errorf("sftp: unimplemented packet type: got %v", u)
} }
type unexpectedIDErr struct{ want, got uint32 } type unexpectedIDErr struct{ want, got uint32 }