diff --git a/client.go b/client.go index 0d4d9a9..307a35e 100644 --- a/client.go +++ b/client.go @@ -158,6 +158,17 @@ func UseFstat(value bool) ClientOption { } } +// CopyStderrTo specifies a writer to which the standard error of the remote sftp-server command should be written. +// +// The writer passed in will not be automatically closed. +// It is the responsibility of the caller to coordinate closure of any writers. +func CopyStderrTo(wr io.Writer) ClientOption { + return func(c *Client) error { + c.stderrTo = wr + return nil + } +} + // Client represents an SFTP session on a *ssh.ClientConn SSH connection. // Multiple Clients can be active on a single SSH connection, and a Client // may be called concurrently from multiple Goroutines. @@ -166,6 +177,8 @@ func UseFstat(value bool) ClientOption { type Client struct { clientConn + stderrTo io.Writer + ext map[string]string // Extensions (name -> data). maxPacket int // max packet size read or written. @@ -186,9 +199,7 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) { if err != nil { return nil, err } - if err := s.RequestSubsystem("sftp"); err != nil { - return nil, err - } + pw, err := s.StdinPipe() if err != nil { return nil, err @@ -197,15 +208,27 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) { if err != nil { return nil, err } + perr, err := s.StderrPipe() + if err != nil { + return nil, err + } - return NewClientPipe(pr, pw, opts...) + if err := s.RequestSubsystem("sftp"); err != nil { + return nil, err + } + + return newClientPipe(pr, perr, pw, s.Wait, opts...) } // NewClientPipe creates a new SFTP client given a Reader and a WriteCloser. // This can be used for connecting to an SFTP server over TCP/TLS or by using // the system's ssh client program (e.g. via exec.Command). func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) { - sftp := &Client{ + return newClientPipe(rd, nil, wr, nil, opts...) +} + +func newClientPipe(rd, stderr io.Reader, wr io.WriteCloser, wait func() error, opts ...ClientOption) (*Client, error) { + c := &Client{ clientConn: clientConn{ conn: conn{ Reader: rd, @@ -213,6 +236,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie }, inflight: make(map[uint32]chan<- result), closed: make(chan struct{}), + wait: wait, }, ext: make(map[string]string), @@ -222,32 +246,50 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie } for _, opt := range opts { - if err := opt(sftp); err != nil { + if err := opt(c); err != nil { wr.Close() return nil, err } } - if err := sftp.sendInit(); err != nil { + if stderr != nil { + wr := io.Discard + if c.stderrTo != nil { + wr = c.stderrTo + } + + go func() { + // DO NOT close the writer! + // Programs may pass in `os.Stderr` to write the remote stderr to, + // and the program may continue after disconnect by reconnecting. + // But if we've closed their stderr, then we just messed everything up. + + if _, err := io.Copy(wr, stderr); err != nil { + debug("error copying stderr: %v", err) + } + }() + } + + if err := c.sendInit(); err != nil { wr.Close() return nil, fmt.Errorf("error sending init packet to server: %w", err) } - if err := sftp.recvVersion(); err != nil { + if err := c.recvVersion(); err != nil { wr.Close() return nil, fmt.Errorf("error receiving version packet from server: %w", err) } - sftp.clientConn.wg.Add(1) + c.clientConn.wg.Add(1) go func() { - defer sftp.clientConn.wg.Done() + defer c.clientConn.wg.Done() - if err := sftp.clientConn.recv(); err != nil { - sftp.clientConn.broadcastErr(err) + if err := c.clientConn.recv(); err != nil { + c.clientConn.broadcastErr(err) } }() - return sftp, nil + return c, nil } // Create creates the named file mode 0666 (before umask), truncating it if it diff --git a/conn.go b/conn.go index 93bc37b..e68a2bd 100644 --- a/conn.go +++ b/conn.go @@ -22,7 +22,7 @@ type conn struct { // For the client mode just pass 0. // It returns io.EOF if the connection is closed and // there are no more packets to read. -func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) { +func (c *conn) recvPacket(orderID uint32) (fxp, []byte, error) { return recvPacket(c, c.alloc, orderID) } @@ -43,6 +43,8 @@ type clientConn struct { conn wg sync.WaitGroup + wait func() error // if non-nil, call this during Wait() to get a possible remote status error. + sync.Mutex // protects inflight inflight map[uint32]chan<- result // outstanding requests @@ -55,6 +57,27 @@ type clientConn struct { // goroutines. func (c *clientConn) Wait() error { <-c.closed + + if c.wait == nil { + // Only return this error if c.wait won't return something more useful. + return c.err + } + + if err := c.wait(); err != nil { + + // TODO: when https://github.com/golang/go/issues/35025 is fixed, + // we can remove this if block entirely. + // Right now, it’s always going to return this, so it is not useful. + // But we have this code here so that as soon as the ssh library is updated, + // we can return a possibly more useful error. + if err.Error() == "ssh: session not started" { + return c.err + } + + return err + } + + // c.wait returned no error; so, let's return something maybe more useful. return c.err } @@ -119,7 +142,7 @@ func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) { // result captures the result of receiving the a packet from the server type result struct { - typ byte + typ fxp data []byte err error } @@ -129,7 +152,7 @@ type idmarshaler interface { encoding.BinaryMarshaler } -func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (byte, []byte, error) { +func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (fxp, []byte, error) { if cap(ch) < 1 { ch = make(chan result, 1) } diff --git a/packet.go b/packet.go index bfe6a3c..3836ab6 100644 --- a/packet.go +++ b/packet.go @@ -304,16 +304,22 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { return nil } -func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, error) { +func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (fxp, []byte, error) { var b []byte if alloc != nil { b = alloc.GetPage(orderID) } else { b = make([]byte, 4) } - if _, err := io.ReadFull(r, b[:4]); err != nil { - return 0, nil, err + + if n, err := io.ReadFull(r, b[:4]); err != nil { + if err == io.EOF { + return 0, nil, err + } + + return 0, nil, fmt.Errorf("error reading packet length: %d of 4: %w", n, err) } + length, _ := unmarshalUint32(b) if length > maxMsgLength { debug("recv packet %d bytes too long", length) @@ -323,24 +329,39 @@ func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, e debug("recv packet of 0 bytes too short") return 0, nil, errShortPacket } + if alloc == nil { b = make([]byte, length) } - if _, err := io.ReadFull(r, b[:length]); err != nil { + + n, err := io.ReadFull(r, b[:length]) + b = b[:n] + + if err != nil { + debug("recv packet error: %d of %d bytes: %x", n, length, b) + // ReadFull only returns EOF if it has read no bytes. // In this case, that means a partial packet, and thus unexpected. if err == io.EOF { err = io.ErrUnexpectedEOF } - debug("recv packet %d bytes: err %v", length, err) - return 0, nil, err + + if n == 0 { + return 0, nil, fmt.Errorf("error reading packet body: %d of %d: %w", n, length, err) + } + + return 0, nil, fmt.Errorf("error reading packet body: %d of %d: (%s) %w", n, length, fxp(b[0]), err) } + + typ, payload := fxp(b[0]), b[1:n] + if debugDumpRxPacketBytes { - debug("recv packet: %s %d bytes %x", fxp(b[0]), length, b[1:length]) + debug("recv packet: %s %d bytes %x", typ, length, payload) } else if debugDumpRxPacket { - debug("recv packet: %s %d bytes", fxp(b[0]), length) + debug("recv packet: %s %d bytes", typ, length) } - return b[0], b[1:length], nil + + return typ, payload, nil } type extensionPair struct { diff --git a/packet_test.go b/packet_test.go index 98455ab..59a1b21 100644 --- a/packet_test.go +++ b/packet_test.go @@ -468,7 +468,7 @@ func TestRecvPacket(t *testing.T) { var recvPacketTests = []struct { b []byte - want uint8 + want fxp body []byte wantErr error }{ diff --git a/request-server.go b/request-server.go index 11047e6..08c24d7 100644 --- a/request-server.go +++ b/request-server.go @@ -148,7 +148,7 @@ func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error { var err error var pkt requestPacket - var pktType uint8 + var pktType fxp var pktBytes []byte for { @@ -158,7 +158,7 @@ func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error { return err } - pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes}) + pkt, err = makePacket(rxPacket{pktType, pktBytes}) if err != nil { switch { case errors.Is(err, errUnknownExtendedPacket): diff --git a/server.go b/server.go index cd656d8..7735c42 100644 --- a/server.go +++ b/server.go @@ -390,7 +390,7 @@ func (svr *Server) Serve() error { var err error var pkt requestPacket - var pktType uint8 + var pktType fxp var pktBytes []byte for { pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID()) @@ -403,7 +403,7 @@ func (svr *Server) Serve() error { break } - pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes}) + pkt, err = makePacket(rxPacket{pktType, pktBytes}) if err != nil { switch { case errors.Is(err, errUnknownExtendedPacket): diff --git a/sftp.go b/sftp.go index 778c8f3..1e698bb 100644 --- a/sftp.go +++ b/sftp.go @@ -184,15 +184,15 @@ func (f fx) String() string { } type unexpectedPacketErr struct { - want, got uint8 + want, got fxp } func (u *unexpectedPacketErr) Error() string { - return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", fxp(u.want), fxp(u.got)) + return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", u.want, u.got) } -func unimplementedPacketErr(u uint8) error { - return fmt.Errorf("sftp: unimplemented packet type: got %v", fxp(u)) +func unimplementedPacketErr(u fxp) error { + return fmt.Errorf("sftp: unimplemented packet type: got %v", u) } type unexpectedIDErr struct{ want, got uint32 }