mirror of https://github.com/pkg/sftp.git
Merge pull request #632 from pkg/fix-ssh-client-use
Fix SSH subsystemrequest usage
This commit is contained in:
commit
53c62f1551
68
client.go
68
client.go
|
@ -158,6 +158,17 @@ func UseFstat(value bool) ClientOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CopyStderrTo specifies a writer to which the standard error of the remote sftp-server command should be written.
|
||||||
|
//
|
||||||
|
// The writer passed in will not be automatically closed.
|
||||||
|
// It is the responsibility of the caller to coordinate closure of any writers.
|
||||||
|
func CopyStderrTo(wr io.Writer) ClientOption {
|
||||||
|
return func(c *Client) error {
|
||||||
|
c.stderrTo = wr
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Client represents an SFTP session on a *ssh.ClientConn SSH connection.
|
// Client represents an SFTP session on a *ssh.ClientConn SSH connection.
|
||||||
// Multiple Clients can be active on a single SSH connection, and a Client
|
// Multiple Clients can be active on a single SSH connection, and a Client
|
||||||
// may be called concurrently from multiple Goroutines.
|
// may be called concurrently from multiple Goroutines.
|
||||||
|
@ -166,6 +177,8 @@ func UseFstat(value bool) ClientOption {
|
||||||
type Client struct {
|
type Client struct {
|
||||||
clientConn
|
clientConn
|
||||||
|
|
||||||
|
stderrTo io.Writer
|
||||||
|
|
||||||
ext map[string]string // Extensions (name -> data).
|
ext map[string]string // Extensions (name -> data).
|
||||||
|
|
||||||
maxPacket int // max packet size read or written.
|
maxPacket int // max packet size read or written.
|
||||||
|
@ -186,9 +199,7 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := s.RequestSubsystem("sftp"); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
pw, err := s.StdinPipe()
|
pw, err := s.StdinPipe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -197,15 +208,27 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
perr, err := s.StderrPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return NewClientPipe(pr, pw, opts...)
|
if err := s.RequestSubsystem("sftp"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return newClientPipe(pr, perr, pw, s.Wait, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientPipe creates a new SFTP client given a Reader and a WriteCloser.
|
// NewClientPipe creates a new SFTP client given a Reader and a WriteCloser.
|
||||||
// This can be used for connecting to an SFTP server over TCP/TLS or by using
|
// This can be used for connecting to an SFTP server over TCP/TLS or by using
|
||||||
// the system's ssh client program (e.g. via exec.Command).
|
// the system's ssh client program (e.g. via exec.Command).
|
||||||
func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) {
|
func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) {
|
||||||
sftp := &Client{
|
return newClientPipe(rd, nil, wr, nil, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClientPipe(rd, stderr io.Reader, wr io.WriteCloser, wait func() error, opts ...ClientOption) (*Client, error) {
|
||||||
|
c := &Client{
|
||||||
clientConn: clientConn{
|
clientConn: clientConn{
|
||||||
conn: conn{
|
conn: conn{
|
||||||
Reader: rd,
|
Reader: rd,
|
||||||
|
@ -213,6 +236,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
|
||||||
},
|
},
|
||||||
inflight: make(map[uint32]chan<- result),
|
inflight: make(map[uint32]chan<- result),
|
||||||
closed: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
|
wait: wait,
|
||||||
},
|
},
|
||||||
|
|
||||||
ext: make(map[string]string),
|
ext: make(map[string]string),
|
||||||
|
@ -222,32 +246,50 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
if err := opt(sftp); err != nil {
|
if err := opt(c); err != nil {
|
||||||
wr.Close()
|
wr.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := sftp.sendInit(); err != nil {
|
if stderr != nil {
|
||||||
|
wr := io.Discard
|
||||||
|
if c.stderrTo != nil {
|
||||||
|
wr = c.stderrTo
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// DO NOT close the writer!
|
||||||
|
// Programs may pass in `os.Stderr` to write the remote stderr to,
|
||||||
|
// and the program may continue after disconnect by reconnecting.
|
||||||
|
// But if we've closed their stderr, then we just messed everything up.
|
||||||
|
|
||||||
|
if _, err := io.Copy(wr, stderr); err != nil {
|
||||||
|
debug("error copying stderr: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.sendInit(); err != nil {
|
||||||
wr.Close()
|
wr.Close()
|
||||||
return nil, fmt.Errorf("error sending init packet to server: %w", err)
|
return nil, fmt.Errorf("error sending init packet to server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := sftp.recvVersion(); err != nil {
|
if err := c.recvVersion(); err != nil {
|
||||||
wr.Close()
|
wr.Close()
|
||||||
return nil, fmt.Errorf("error receiving version packet from server: %w", err)
|
return nil, fmt.Errorf("error receiving version packet from server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sftp.clientConn.wg.Add(1)
|
c.clientConn.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer sftp.clientConn.wg.Done()
|
defer c.clientConn.wg.Done()
|
||||||
|
|
||||||
if err := sftp.clientConn.recv(); err != nil {
|
if err := c.clientConn.recv(); err != nil {
|
||||||
sftp.clientConn.broadcastErr(err)
|
c.clientConn.broadcastErr(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return sftp, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create creates the named file mode 0666 (before umask), truncating it if it
|
// Create creates the named file mode 0666 (before umask), truncating it if it
|
||||||
|
|
29
conn.go
29
conn.go
|
@ -22,7 +22,7 @@ type conn struct {
|
||||||
// For the client mode just pass 0.
|
// For the client mode just pass 0.
|
||||||
// It returns io.EOF if the connection is closed and
|
// It returns io.EOF if the connection is closed and
|
||||||
// there are no more packets to read.
|
// there are no more packets to read.
|
||||||
func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) {
|
func (c *conn) recvPacket(orderID uint32) (fxp, []byte, error) {
|
||||||
return recvPacket(c, c.alloc, orderID)
|
return recvPacket(c, c.alloc, orderID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,6 +43,8 @@ type clientConn struct {
|
||||||
conn
|
conn
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
wait func() error // if non-nil, call this during Wait() to get a possible remote status error.
|
||||||
|
|
||||||
sync.Mutex // protects inflight
|
sync.Mutex // protects inflight
|
||||||
inflight map[uint32]chan<- result // outstanding requests
|
inflight map[uint32]chan<- result // outstanding requests
|
||||||
|
|
||||||
|
@ -55,6 +57,27 @@ type clientConn struct {
|
||||||
// goroutines.
|
// goroutines.
|
||||||
func (c *clientConn) Wait() error {
|
func (c *clientConn) Wait() error {
|
||||||
<-c.closed
|
<-c.closed
|
||||||
|
|
||||||
|
if c.wait == nil {
|
||||||
|
// Only return this error if c.wait won't return something more useful.
|
||||||
|
return c.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.wait(); err != nil {
|
||||||
|
|
||||||
|
// TODO: when https://github.com/golang/go/issues/35025 is fixed,
|
||||||
|
// we can remove this if block entirely.
|
||||||
|
// Right now, it’s always going to return this, so it is not useful.
|
||||||
|
// But we have this code here so that as soon as the ssh library is updated,
|
||||||
|
// we can return a possibly more useful error.
|
||||||
|
if err.Error() == "ssh: session not started" {
|
||||||
|
return c.err
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// c.wait returned no error; so, let's return something maybe more useful.
|
||||||
return c.err
|
return c.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,7 +142,7 @@ func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) {
|
||||||
|
|
||||||
// result captures the result of receiving the a packet from the server
|
// result captures the result of receiving the a packet from the server
|
||||||
type result struct {
|
type result struct {
|
||||||
typ byte
|
typ fxp
|
||||||
data []byte
|
data []byte
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
@ -129,7 +152,7 @@ type idmarshaler interface {
|
||||||
encoding.BinaryMarshaler
|
encoding.BinaryMarshaler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (byte, []byte, error) {
|
func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (fxp, []byte, error) {
|
||||||
if cap(ch) < 1 {
|
if cap(ch) < 1 {
|
||||||
ch = make(chan result, 1)
|
ch = make(chan result, 1)
|
||||||
}
|
}
|
||||||
|
|
37
packet.go
37
packet.go
|
@ -304,16 +304,22 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, error) {
|
func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (fxp, []byte, error) {
|
||||||
var b []byte
|
var b []byte
|
||||||
if alloc != nil {
|
if alloc != nil {
|
||||||
b = alloc.GetPage(orderID)
|
b = alloc.GetPage(orderID)
|
||||||
} else {
|
} else {
|
||||||
b = make([]byte, 4)
|
b = make([]byte, 4)
|
||||||
}
|
}
|
||||||
if _, err := io.ReadFull(r, b[:4]); err != nil {
|
|
||||||
|
if n, err := io.ReadFull(r, b[:4]); err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
return 0, nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return 0, nil, fmt.Errorf("error reading packet length: %d of 4: %w", n, err)
|
||||||
|
}
|
||||||
|
|
||||||
length, _ := unmarshalUint32(b)
|
length, _ := unmarshalUint32(b)
|
||||||
if length > maxMsgLength {
|
if length > maxMsgLength {
|
||||||
debug("recv packet %d bytes too long", length)
|
debug("recv packet %d bytes too long", length)
|
||||||
|
@ -323,24 +329,39 @@ func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, e
|
||||||
debug("recv packet of 0 bytes too short")
|
debug("recv packet of 0 bytes too short")
|
||||||
return 0, nil, errShortPacket
|
return 0, nil, errShortPacket
|
||||||
}
|
}
|
||||||
|
|
||||||
if alloc == nil {
|
if alloc == nil {
|
||||||
b = make([]byte, length)
|
b = make([]byte, length)
|
||||||
}
|
}
|
||||||
if _, err := io.ReadFull(r, b[:length]); err != nil {
|
|
||||||
|
n, err := io.ReadFull(r, b[:length])
|
||||||
|
b = b[:n]
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
debug("recv packet error: %d of %d bytes: %x", n, length, b)
|
||||||
|
|
||||||
// ReadFull only returns EOF if it has read no bytes.
|
// ReadFull only returns EOF if it has read no bytes.
|
||||||
// In this case, that means a partial packet, and thus unexpected.
|
// In this case, that means a partial packet, and thus unexpected.
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
err = io.ErrUnexpectedEOF
|
err = io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
debug("recv packet %d bytes: err %v", length, err)
|
|
||||||
return 0, nil, err
|
if n == 0 {
|
||||||
|
return 0, nil, fmt.Errorf("error reading packet body: %d of %d: %w", n, length, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return 0, nil, fmt.Errorf("error reading packet body: %d of %d: (%s) %w", n, length, fxp(b[0]), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
typ, payload := fxp(b[0]), b[1:n]
|
||||||
|
|
||||||
if debugDumpRxPacketBytes {
|
if debugDumpRxPacketBytes {
|
||||||
debug("recv packet: %s %d bytes %x", fxp(b[0]), length, b[1:length])
|
debug("recv packet: %s %d bytes %x", typ, length, payload)
|
||||||
} else if debugDumpRxPacket {
|
} else if debugDumpRxPacket {
|
||||||
debug("recv packet: %s %d bytes", fxp(b[0]), length)
|
debug("recv packet: %s %d bytes", typ, length)
|
||||||
}
|
}
|
||||||
return b[0], b[1:length], nil
|
|
||||||
|
return typ, payload, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type extensionPair struct {
|
type extensionPair struct {
|
||||||
|
|
|
@ -468,7 +468,7 @@ func TestRecvPacket(t *testing.T) {
|
||||||
var recvPacketTests = []struct {
|
var recvPacketTests = []struct {
|
||||||
b []byte
|
b []byte
|
||||||
|
|
||||||
want uint8
|
want fxp
|
||||||
body []byte
|
body []byte
|
||||||
wantErr error
|
wantErr error
|
||||||
}{
|
}{
|
||||||
|
|
|
@ -148,7 +148,7 @@ func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
var pkt requestPacket
|
var pkt requestPacket
|
||||||
var pktType uint8
|
var pktType fxp
|
||||||
var pktBytes []byte
|
var pktBytes []byte
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
@ -158,7 +158,7 @@ func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
|
pkt, err = makePacket(rxPacket{pktType, pktBytes})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, errUnknownExtendedPacket):
|
case errors.Is(err, errUnknownExtendedPacket):
|
||||||
|
|
|
@ -390,7 +390,7 @@ func (svr *Server) Serve() error {
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
var pkt requestPacket
|
var pkt requestPacket
|
||||||
var pktType uint8
|
var pktType fxp
|
||||||
var pktBytes []byte
|
var pktBytes []byte
|
||||||
for {
|
for {
|
||||||
pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID())
|
pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID())
|
||||||
|
@ -403,7 +403,7 @@ func (svr *Server) Serve() error {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
|
pkt, err = makePacket(rxPacket{pktType, pktBytes})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, errUnknownExtendedPacket):
|
case errors.Is(err, errUnknownExtendedPacket):
|
||||||
|
|
8
sftp.go
8
sftp.go
|
@ -184,15 +184,15 @@ func (f fx) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
type unexpectedPacketErr struct {
|
type unexpectedPacketErr struct {
|
||||||
want, got uint8
|
want, got fxp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *unexpectedPacketErr) Error() string {
|
func (u *unexpectedPacketErr) Error() string {
|
||||||
return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", fxp(u.want), fxp(u.got))
|
return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", u.want, u.got)
|
||||||
}
|
}
|
||||||
|
|
||||||
func unimplementedPacketErr(u uint8) error {
|
func unimplementedPacketErr(u fxp) error {
|
||||||
return fmt.Errorf("sftp: unimplemented packet type: got %v", fxp(u))
|
return fmt.Errorf("sftp: unimplemented packet type: got %v", u)
|
||||||
}
|
}
|
||||||
|
|
||||||
type unexpectedIDErr struct{ want, got uint32 }
|
type unexpectedIDErr struct{ want, got uint32 }
|
||||||
|
|
Loading…
Reference in New Issue