From 18192955ef8b3b9553923e85ba7edeac63929857 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 23 May 2025 03:04:57 +0000 Subject: [PATCH 1/7] fix ssh subsystem request invocation --- client.go | 78 +++++++++++++++++++++++++++++++++++++++++++++---------- conn.go | 19 ++++++++++++++ 2 files changed, 84 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 0d4d9a9..430344e 100644 --- a/client.go +++ b/client.go @@ -158,6 +158,17 @@ func UseFstat(value bool) ClientOption { } } +// UseStderr is used to indicate that you intend to read from the standard error of the remote sftp-server command. +// This does not actually get or set the standard error, +// instead this simply prevents the standard error from being discarded. +// You will still need to call [Client.StderrPipe] to get the reader. +func UseStderr() ClientOption { + return func(c *Client) error { + c.useStderr = true + return nil + } +} + // Client represents an SFTP session on a *ssh.ClientConn SSH connection. // Multiple Clients can be active on a single SSH connection, and a Client // may be called concurrently from multiple Goroutines. @@ -166,6 +177,9 @@ func UseFstat(value bool) ClientOption { type Client struct { clientConn + stderr io.Reader + useStderr bool + ext map[string]string // Extensions (name -> data). maxPacket int // max packet size read or written. @@ -186,9 +200,7 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) { if err != nil { return nil, err } - if err := s.RequestSubsystem("sftp"); err != nil { - return nil, err - } + pw, err := s.StdinPipe() if err != nil { return nil, err @@ -197,15 +209,27 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) { if err != nil { return nil, err } + perr, err := s.StderrPipe() + if err != nil { + return nil, err + } - return NewClientPipe(pr, pw, opts...) + if err := s.RequestSubsystem("sftp"); err != nil { + return nil, err + } + + return newClientPipe(pr, pw, perr, s.Wait, opts...) } // NewClientPipe creates a new SFTP client given a Reader and a WriteCloser. // This can be used for connecting to an SFTP server over TCP/TLS or by using // the system's ssh client program (e.g. via exec.Command). func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) { - sftp := &Client{ + return newClientPipe(rd, wr, nil, nil, opts...) +} + +func newClientPipe(rd io.Reader, wr io.WriteCloser, stderr io.Reader, wait func() error, opts ...ClientOption) (*Client, error) { + c := &Client{ clientConn: clientConn{ conn: conn{ Reader: rd, @@ -213,6 +237,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie }, inflight: make(map[uint32]chan<- result), closed: make(chan struct{}), + wait: wait, }, ext: make(map[string]string), @@ -222,32 +247,59 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie } for _, opt := range opts { - if err := opt(sftp); err != nil { + if err := opt(c); err != nil { wr.Close() return nil, err } } - if err := sftp.sendInit(); err != nil { + if stderr != nil { + if !c.useStderr { + go func() { + _, err := io.Copy(io.Discard, stderr) + if err != nil { + debug("error discarding stderr: %v", err) + } + }() + + } else { + // Only set c.stderr when we're not discarding it. + c.stderr = stderr + } + } + + if err := c.sendInit(); err != nil { wr.Close() return nil, fmt.Errorf("error sending init packet to server: %w", err) } - if err := sftp.recvVersion(); err != nil { + if err := c.recvVersion(); err != nil { wr.Close() return nil, fmt.Errorf("error receiving version packet from server: %w", err) } - sftp.clientConn.wg.Add(1) + c.clientConn.wg.Add(1) go func() { - defer sftp.clientConn.wg.Done() + defer c.clientConn.wg.Done() - if err := sftp.clientConn.recv(); err != nil { - sftp.clientConn.broadcastErr(err) + if err := c.clientConn.recv(); err != nil { + c.clientConn.broadcastErr(err) } }() - return sftp, nil + return c, nil +} + +// StderrPipe returns a reader for the standard error of the remote sftp-server command. +// You must have passed in the `UseStderr` client option or the standard error will already be set up to be discarded. +// An error returned here does not mean that the client is no longer useable, +// it only means that you won't be able to read the standard error output from the remote command. +func (c *Client) StderrPipe() (io.Reader, error) { + if c.stderr == nil { + return nil, fmt.Errorf("stderr not available") + } + + return c.stderr, nil } // Create creates the named file mode 0666 (before umask), truncating it if it diff --git a/conn.go b/conn.go index 93bc37b..a51da89 100644 --- a/conn.go +++ b/conn.go @@ -43,6 +43,8 @@ type clientConn struct { conn wg sync.WaitGroup + wait func() error // if non-nil, call this during Wait() to get a possible remote status error. + sync.Mutex // protects inflight inflight map[uint32]chan<- result // outstanding requests @@ -55,6 +57,23 @@ type clientConn struct { // goroutines. func (c *clientConn) Wait() error { <-c.closed + if c.wait != nil { + if err := c.wait(); err != nil { + + // TODO: when https://github.com/golang/go/issues/35025 is fixed, + // we can remove this if block entirely. + // Right now, it’s always going to return this, so it is not useful. + // But we have this code here so that as soon as the ssh library is updated, + // we can return a possibly more useful error. + if err.Error() == "ssh: session not started" { + return c.err + } + + // We intentionally override the c.err error here, + // it will probably be io.UnexpectedEOF in this case anyways. + return err + } + } return c.err } From 32bfbbb6c0bcf82bb0452bff2c2d0bd5c93a40b8 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 23 May 2025 10:18:29 +0000 Subject: [PATCH 2/7] I think I prefer this API design --- client.go | 48 +++++++++++++++++++----------------------------- 1 file changed, 19 insertions(+), 29 deletions(-) diff --git a/client.go b/client.go index 430344e..dd8d72c 100644 --- a/client.go +++ b/client.go @@ -158,13 +158,10 @@ func UseFstat(value bool) ClientOption { } } -// UseStderr is used to indicate that you intend to read from the standard error of the remote sftp-server command. -// This does not actually get or set the standard error, -// instead this simply prevents the standard error from being discarded. -// You will still need to call [Client.StderrPipe] to get the reader. -func UseStderr() ClientOption { +// CopyStderrTo specifies a writer to which the standard error of the remote sftp-server command should be written. +func CopyStderrTo(wr io.WriteCloser) ClientOption { return func(c *Client) error { - c.useStderr = true + c.stderrTo = wr return nil } } @@ -177,8 +174,7 @@ func UseStderr() ClientOption { type Client struct { clientConn - stderr io.Reader - useStderr bool + stderrTo io.WriteCloser ext map[string]string // Extensions (name -> data). @@ -254,18 +250,24 @@ func newClientPipe(rd io.Reader, wr io.WriteCloser, stderr io.Reader, wait func( } if stderr != nil { - if !c.useStderr { - go func() { - _, err := io.Copy(io.Discard, stderr) - if err != nil { - debug("error discarding stderr: %v", err) + wr := io.Discard + if c.stderrTo != nil { + wr = c.stderrTo + } + + go func() { + defer func() { + if closer, ok := wr.(io.Closer); ok { + if err := closer.Close(); err != nil { + debug("error closing stderrTo: %v", err) + } } }() - } else { - // Only set c.stderr when we're not discarding it. - c.stderr = stderr - } + if _, err := io.Copy(wr, stderr); err != nil { + debug("error copying stderr: %v", err) + } + }() } if err := c.sendInit(); err != nil { @@ -290,18 +292,6 @@ func newClientPipe(rd io.Reader, wr io.WriteCloser, stderr io.Reader, wait func( return c, nil } -// StderrPipe returns a reader for the standard error of the remote sftp-server command. -// You must have passed in the `UseStderr` client option or the standard error will already be set up to be discarded. -// An error returned here does not mean that the client is no longer useable, -// it only means that you won't be able to read the standard error output from the remote command. -func (c *Client) StderrPipe() (io.Reader, error) { - if c.stderr == nil { - return nil, fmt.Errorf("stderr not available") - } - - return c.stderr, nil -} - // Create creates the named file mode 0666 (before umask), truncating it if it // already exists. If successful, methods on the returned File can be used for // I/O; the associated file descriptor has mode O_RDWR. If you need more From f1b135a6f5c7d0e273eadc2e7ebc7513d55b2b13 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 23 May 2025 10:25:23 +0000 Subject: [PATCH 3/7] invert if-blocks to reduce indention levels --- conn.go | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/conn.go b/conn.go index a51da89..911e330 100644 --- a/conn.go +++ b/conn.go @@ -57,23 +57,27 @@ type clientConn struct { // goroutines. func (c *clientConn) Wait() error { <-c.closed - if c.wait != nil { - if err := c.wait(); err != nil { - // TODO: when https://github.com/golang/go/issues/35025 is fixed, - // we can remove this if block entirely. - // Right now, it’s always going to return this, so it is not useful. - // But we have this code here so that as soon as the ssh library is updated, - // we can return a possibly more useful error. - if err.Error() == "ssh: session not started" { - return c.err - } - - // We intentionally override the c.err error here, - // it will probably be io.UnexpectedEOF in this case anyways. - return err - } + if c.wait == nil { + // Only return this error if c.wait won't return something more useful. + return c.err } + + if err := c.wait(); err != nil { + + // TODO: when https://github.com/golang/go/issues/35025 is fixed, + // we can remove this if block entirely. + // Right now, it’s always going to return this, so it is not useful. + // But we have this code here so that as soon as the ssh library is updated, + // we can return a possibly more useful error. + if err.Error() == "ssh: session not started" { + return c.err + } + + return err + } + + // c.wait returned no error; so, let's return something maybe more useful. return c.err } From 8a0fc6568b76bf34d030a4ac5a6e15c50d337124 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 30 May 2025 12:50:06 +0000 Subject: [PATCH 4/7] DO NOT close the CopyStderrTo writer --- client.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index dd8d72c..307a35e 100644 --- a/client.go +++ b/client.go @@ -159,7 +159,10 @@ func UseFstat(value bool) ClientOption { } // CopyStderrTo specifies a writer to which the standard error of the remote sftp-server command should be written. -func CopyStderrTo(wr io.WriteCloser) ClientOption { +// +// The writer passed in will not be automatically closed. +// It is the responsibility of the caller to coordinate closure of any writers. +func CopyStderrTo(wr io.Writer) ClientOption { return func(c *Client) error { c.stderrTo = wr return nil @@ -174,7 +177,7 @@ func CopyStderrTo(wr io.WriteCloser) ClientOption { type Client struct { clientConn - stderrTo io.WriteCloser + stderrTo io.Writer ext map[string]string // Extensions (name -> data). @@ -214,17 +217,17 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) { return nil, err } - return newClientPipe(pr, pw, perr, s.Wait, opts...) + return newClientPipe(pr, perr, pw, s.Wait, opts...) } // NewClientPipe creates a new SFTP client given a Reader and a WriteCloser. // This can be used for connecting to an SFTP server over TCP/TLS or by using // the system's ssh client program (e.g. via exec.Command). func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) { - return newClientPipe(rd, wr, nil, nil, opts...) + return newClientPipe(rd, nil, wr, nil, opts...) } -func newClientPipe(rd io.Reader, wr io.WriteCloser, stderr io.Reader, wait func() error, opts ...ClientOption) (*Client, error) { +func newClientPipe(rd, stderr io.Reader, wr io.WriteCloser, wait func() error, opts ...ClientOption) (*Client, error) { c := &Client{ clientConn: clientConn{ conn: conn{ @@ -256,13 +259,10 @@ func newClientPipe(rd io.Reader, wr io.WriteCloser, stderr io.Reader, wait func( } go func() { - defer func() { - if closer, ok := wr.(io.Closer); ok { - if err := closer.Close(); err != nil { - debug("error closing stderrTo: %v", err) - } - } - }() + // DO NOT close the writer! + // Programs may pass in `os.Stderr` to write the remote stderr to, + // and the program may continue after disconnect by reconnecting. + // But if we've closed their stderr, then we just messed everything up. if _, err := io.Copy(wr, stderr); err != nil { debug("error copying stderr: %v", err) From 9ae47f41703d52f79a148606705c430f4989d1f5 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 30 May 2025 13:23:15 +0000 Subject: [PATCH 5/7] better debug info --- packet.go | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/packet.go b/packet.go index bfe6a3c..480d0e7 100644 --- a/packet.go +++ b/packet.go @@ -311,9 +311,16 @@ func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, e } else { b = make([]byte, 4) } - if _, err := io.ReadFull(r, b[:4]); err != nil { + + if n, err := io.ReadFull(r, b[:4]); err != nil { + debug("recv length %d of %d bytes: err %v", n, 4, err) + if n > 0 { + debug("recv length error: bytes %x", b[:n]) + } + return 0, nil, err } + length, _ := unmarshalUint32(b) if length > maxMsgLength { debug("recv packet %d bytes too long", length) @@ -323,24 +330,38 @@ func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, e debug("recv packet of 0 bytes too short") return 0, nil, errShortPacket } + if alloc == nil { b = make([]byte, length) } - if _, err := io.ReadFull(r, b[:length]); err != nil { + + if n, err := io.ReadFull(r, b[:length]); err != nil { + // Log this error message _before_ we potentially override it. + debug("recv packet %d of %d bytes: err %v", n, length, err) + // ReadFull only returns EOF if it has read no bytes. // In this case, that means a partial packet, and thus unexpected. if err == io.EOF { err = io.ErrUnexpectedEOF } - debug("recv packet %d bytes: err %v", length, err) + + if n > 0 { + n := min(32, n) // limit bytes dump to 32-bytes. + debug("recv packet error: bytes %x", b[:n]) + } + return 0, nil, err } + + typ, payload := fxp(b[0]), b[1:length] + if debugDumpRxPacketBytes { - debug("recv packet: %s %d bytes %x", fxp(b[0]), length, b[1:length]) + debug("recv packet: %s %d bytes %x", typ, length, payload) } else if debugDumpRxPacket { - debug("recv packet: %s %d bytes", fxp(b[0]), length) + debug("recv packet: %s %d bytes", typ, length) } - return b[0], b[1:length], nil + + return typ, payload, nil } type extensionPair struct { From d9ce3caa7238eac4673a96c43428c38cd69e014c Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 30 May 2025 13:29:21 +0000 Subject: [PATCH 6/7] convert uses of uint8 instead of fxp to fxp --- conn.go | 6 +++--- packet.go | 2 +- packet_test.go | 2 +- request-server.go | 4 ++-- server.go | 4 ++-- sftp.go | 8 ++++---- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/conn.go b/conn.go index 911e330..e68a2bd 100644 --- a/conn.go +++ b/conn.go @@ -22,7 +22,7 @@ type conn struct { // For the client mode just pass 0. // It returns io.EOF if the connection is closed and // there are no more packets to read. -func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) { +func (c *conn) recvPacket(orderID uint32) (fxp, []byte, error) { return recvPacket(c, c.alloc, orderID) } @@ -142,7 +142,7 @@ func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) { // result captures the result of receiving the a packet from the server type result struct { - typ byte + typ fxp data []byte err error } @@ -152,7 +152,7 @@ type idmarshaler interface { encoding.BinaryMarshaler } -func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (byte, []byte, error) { +func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (fxp, []byte, error) { if cap(ch) < 1 { ch = make(chan result, 1) } diff --git a/packet.go b/packet.go index 480d0e7..93fb8a4 100644 --- a/packet.go +++ b/packet.go @@ -304,7 +304,7 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { return nil } -func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, error) { +func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (fxp, []byte, error) { var b []byte if alloc != nil { b = alloc.GetPage(orderID) diff --git a/packet_test.go b/packet_test.go index 98455ab..59a1b21 100644 --- a/packet_test.go +++ b/packet_test.go @@ -468,7 +468,7 @@ func TestRecvPacket(t *testing.T) { var recvPacketTests = []struct { b []byte - want uint8 + want fxp body []byte wantErr error }{ diff --git a/request-server.go b/request-server.go index 11047e6..08c24d7 100644 --- a/request-server.go +++ b/request-server.go @@ -148,7 +148,7 @@ func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error { var err error var pkt requestPacket - var pktType uint8 + var pktType fxp var pktBytes []byte for { @@ -158,7 +158,7 @@ func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error { return err } - pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes}) + pkt, err = makePacket(rxPacket{pktType, pktBytes}) if err != nil { switch { case errors.Is(err, errUnknownExtendedPacket): diff --git a/server.go b/server.go index cd656d8..7735c42 100644 --- a/server.go +++ b/server.go @@ -390,7 +390,7 @@ func (svr *Server) Serve() error { var err error var pkt requestPacket - var pktType uint8 + var pktType fxp var pktBytes []byte for { pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID()) @@ -403,7 +403,7 @@ func (svr *Server) Serve() error { break } - pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes}) + pkt, err = makePacket(rxPacket{pktType, pktBytes}) if err != nil { switch { case errors.Is(err, errUnknownExtendedPacket): diff --git a/sftp.go b/sftp.go index 778c8f3..1e698bb 100644 --- a/sftp.go +++ b/sftp.go @@ -184,15 +184,15 @@ func (f fx) String() string { } type unexpectedPacketErr struct { - want, got uint8 + want, got fxp } func (u *unexpectedPacketErr) Error() string { - return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", fxp(u.want), fxp(u.got)) + return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", u.want, u.got) } -func unimplementedPacketErr(u uint8) error { - return fmt.Errorf("sftp: unimplemented packet type: got %v", fxp(u)) +func unimplementedPacketErr(u fxp) error { + return fmt.Errorf("sftp: unimplemented packet type: got %v", u) } type unexpectedIDErr struct{ want, got uint32 } From c7176b3c6ee097f642f9c9a2dfb65a8df2ad5ad2 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 30 May 2025 14:00:15 +0000 Subject: [PATCH 7/7] rework recv debug messages to be context-added errors --- packet.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/packet.go b/packet.go index 93fb8a4..3836ab6 100644 --- a/packet.go +++ b/packet.go @@ -313,12 +313,11 @@ func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (fxp, []byte, err } if n, err := io.ReadFull(r, b[:4]); err != nil { - debug("recv length %d of %d bytes: err %v", n, 4, err) - if n > 0 { - debug("recv length error: bytes %x", b[:n]) + if err == io.EOF { + return 0, nil, err } - return 0, nil, err + return 0, nil, fmt.Errorf("error reading packet length: %d of 4: %w", n, err) } length, _ := unmarshalUint32(b) @@ -335,9 +334,11 @@ func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (fxp, []byte, err b = make([]byte, length) } - if n, err := io.ReadFull(r, b[:length]); err != nil { - // Log this error message _before_ we potentially override it. - debug("recv packet %d of %d bytes: err %v", n, length, err) + n, err := io.ReadFull(r, b[:length]) + b = b[:n] + + if err != nil { + debug("recv packet error: %d of %d bytes: %x", n, length, b) // ReadFull only returns EOF if it has read no bytes. // In this case, that means a partial packet, and thus unexpected. @@ -345,15 +346,14 @@ func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (fxp, []byte, err err = io.ErrUnexpectedEOF } - if n > 0 { - n := min(32, n) // limit bytes dump to 32-bytes. - debug("recv packet error: bytes %x", b[:n]) + if n == 0 { + return 0, nil, fmt.Errorf("error reading packet body: %d of %d: %w", n, length, err) } - return 0, nil, err + return 0, nil, fmt.Errorf("error reading packet body: %d of %d: (%s) %w", n, length, fxp(b[0]), err) } - typ, payload := fxp(b[0]), b[1:length] + typ, payload := fxp(b[0]), b[1:n] if debugDumpRxPacketBytes { debug("recv packet: %s %d bytes %x", typ, length, payload)