Compare commits

...

4 Commits

Author SHA1 Message Date
Cassondra Foesch c7176b3c6e rework recv debug messages to be context-added errors 2025-05-30 14:00:15 +00:00
Cassondra Foesch d9ce3caa72 convert uses of uint8 instead of fxp to fxp 2025-05-30 13:30:46 +00:00
Cassondra Foesch 9ae47f4170 better debug info 2025-05-30 13:23:15 +00:00
Cassondra Foesch 8a0fc6568b DO NOT close the CopyStderrTo writer 2025-05-30 12:50:06 +00:00
7 changed files with 54 additions and 33 deletions

View File

@ -159,7 +159,10 @@ func UseFstat(value bool) ClientOption {
}
// CopyStderrTo specifies a writer to which the standard error of the remote sftp-server command should be written.
func CopyStderrTo(wr io.WriteCloser) ClientOption {
//
// The writer passed in will not be automatically closed.
// It is the responsibility of the caller to coordinate closure of any writers.
func CopyStderrTo(wr io.Writer) ClientOption {
return func(c *Client) error {
c.stderrTo = wr
return nil
@ -174,7 +177,7 @@ func CopyStderrTo(wr io.WriteCloser) ClientOption {
type Client struct {
clientConn
stderrTo io.WriteCloser
stderrTo io.Writer
ext map[string]string // Extensions (name -> data).
@ -214,17 +217,17 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
return nil, err
}
return newClientPipe(pr, pw, perr, s.Wait, opts...)
return newClientPipe(pr, perr, pw, s.Wait, opts...)
}
// NewClientPipe creates a new SFTP client given a Reader and a WriteCloser.
// This can be used for connecting to an SFTP server over TCP/TLS or by using
// the system's ssh client program (e.g. via exec.Command).
func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) {
return newClientPipe(rd, wr, nil, nil, opts...)
return newClientPipe(rd, nil, wr, nil, opts...)
}
func newClientPipe(rd io.Reader, wr io.WriteCloser, stderr io.Reader, wait func() error, opts ...ClientOption) (*Client, error) {
func newClientPipe(rd, stderr io.Reader, wr io.WriteCloser, wait func() error, opts ...ClientOption) (*Client, error) {
c := &Client{
clientConn: clientConn{
conn: conn{
@ -256,13 +259,10 @@ func newClientPipe(rd io.Reader, wr io.WriteCloser, stderr io.Reader, wait func(
}
go func() {
defer func() {
if closer, ok := wr.(io.Closer); ok {
if err := closer.Close(); err != nil {
debug("error closing stderrTo: %v", err)
}
}
}()
// DO NOT close the writer!
// Programs may pass in `os.Stderr` to write the remote stderr to,
// and the program may continue after disconnect by reconnecting.
// But if we've closed their stderr, then we just messed everything up.
if _, err := io.Copy(wr, stderr); err != nil {
debug("error copying stderr: %v", err)

View File

@ -22,7 +22,7 @@ type conn struct {
// For the client mode just pass 0.
// It returns io.EOF if the connection is closed and
// there are no more packets to read.
func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) {
func (c *conn) recvPacket(orderID uint32) (fxp, []byte, error) {
return recvPacket(c, c.alloc, orderID)
}
@ -142,7 +142,7 @@ func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) {
// result captures the result of receiving the a packet from the server
type result struct {
typ byte
typ fxp
data []byte
err error
}
@ -152,7 +152,7 @@ type idmarshaler interface {
encoding.BinaryMarshaler
}
func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (byte, []byte, error) {
func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (fxp, []byte, error) {
if cap(ch) < 1 {
ch = make(chan result, 1)
}

View File

@ -304,16 +304,22 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error {
return nil
}
func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, error) {
func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (fxp, []byte, error) {
var b []byte
if alloc != nil {
b = alloc.GetPage(orderID)
} else {
b = make([]byte, 4)
}
if _, err := io.ReadFull(r, b[:4]); err != nil {
return 0, nil, err
if n, err := io.ReadFull(r, b[:4]); err != nil {
if err == io.EOF {
return 0, nil, err
}
return 0, nil, fmt.Errorf("error reading packet length: %d of 4: %w", n, err)
}
length, _ := unmarshalUint32(b)
if length > maxMsgLength {
debug("recv packet %d bytes too long", length)
@ -323,24 +329,39 @@ func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, e
debug("recv packet of 0 bytes too short")
return 0, nil, errShortPacket
}
if alloc == nil {
b = make([]byte, length)
}
if _, err := io.ReadFull(r, b[:length]); err != nil {
n, err := io.ReadFull(r, b[:length])
b = b[:n]
if err != nil {
debug("recv packet error: %d of %d bytes: %x", n, length, b)
// ReadFull only returns EOF if it has read no bytes.
// In this case, that means a partial packet, and thus unexpected.
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
debug("recv packet %d bytes: err %v", length, err)
return 0, nil, err
if n == 0 {
return 0, nil, fmt.Errorf("error reading packet body: %d of %d: %w", n, length, err)
}
return 0, nil, fmt.Errorf("error reading packet body: %d of %d: (%s) %w", n, length, fxp(b[0]), err)
}
typ, payload := fxp(b[0]), b[1:n]
if debugDumpRxPacketBytes {
debug("recv packet: %s %d bytes %x", fxp(b[0]), length, b[1:length])
debug("recv packet: %s %d bytes %x", typ, length, payload)
} else if debugDumpRxPacket {
debug("recv packet: %s %d bytes", fxp(b[0]), length)
debug("recv packet: %s %d bytes", typ, length)
}
return b[0], b[1:length], nil
return typ, payload, nil
}
type extensionPair struct {

View File

@ -468,7 +468,7 @@ func TestRecvPacket(t *testing.T) {
var recvPacketTests = []struct {
b []byte
want uint8
want fxp
body []byte
wantErr error
}{

View File

@ -148,7 +148,7 @@ func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
var err error
var pkt requestPacket
var pktType uint8
var pktType fxp
var pktBytes []byte
for {
@ -158,7 +158,7 @@ func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
return err
}
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
pkt, err = makePacket(rxPacket{pktType, pktBytes})
if err != nil {
switch {
case errors.Is(err, errUnknownExtendedPacket):

View File

@ -390,7 +390,7 @@ func (svr *Server) Serve() error {
var err error
var pkt requestPacket
var pktType uint8
var pktType fxp
var pktBytes []byte
for {
pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID())
@ -403,7 +403,7 @@ func (svr *Server) Serve() error {
break
}
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
pkt, err = makePacket(rxPacket{pktType, pktBytes})
if err != nil {
switch {
case errors.Is(err, errUnknownExtendedPacket):

View File

@ -184,15 +184,15 @@ func (f fx) String() string {
}
type unexpectedPacketErr struct {
want, got uint8
want, got fxp
}
func (u *unexpectedPacketErr) Error() string {
return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", fxp(u.want), fxp(u.got))
return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", u.want, u.got)
}
func unimplementedPacketErr(u uint8) error {
return fmt.Errorf("sftp: unimplemented packet type: got %v", fxp(u))
func unimplementedPacketErr(u fxp) error {
return fmt.Errorf("sftp: unimplemented packet type: got %v", u)
}
type unexpectedIDErr struct{ want, got uint32 }