2016-06-15 09:50:02 +08:00
|
|
|
|
package sftp
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"encoding"
|
|
|
|
|
|
"io"
|
|
|
|
|
|
"sync"
|
2021-04-21 21:36:45 +08:00
|
|
|
|
"sync/atomic"
|
2016-06-15 16:07:14 +08:00
|
|
|
|
|
|
|
|
|
|
"github.com/pkg/errors"
|
2021-04-24 20:28:08 +08:00
|
|
|
|
sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer"
|
2016-06-15 09:50:02 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// conn implements a bidirectional channel on which client and server
|
2016-06-15 16:07:14 +08:00
|
|
|
|
// connections are multiplexed.
|
2016-06-15 09:50:02 +08:00
|
|
|
|
type conn struct {
|
2016-06-15 18:19:51 +08:00
|
|
|
|
io.Reader
|
2021-04-21 21:36:45 +08:00
|
|
|
|
|
2020-03-18 16:36:07 +08:00
|
|
|
|
// this is the same allocator used in packet manager
|
2021-04-21 21:36:45 +08:00
|
|
|
|
alloc *allocator
|
|
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
sync.Mutex // used to serialise writes, and closes.
|
|
|
|
|
|
io.Writer
|
|
|
|
|
|
io.Closer
|
2016-06-15 09:50:02 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2020-03-18 16:36:07 +08:00
|
|
|
|
// the orderID is used in server mode if the allocator is enabled.
|
|
|
|
|
|
// For the client mode just pass 0
|
|
|
|
|
|
func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) {
|
|
|
|
|
|
return recvPacket(c, c.alloc, orderID)
|
2016-06-15 09:50:02 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
func (c *conn) writeBinary(m encoding.BinaryMarshaler) error {
|
2016-06-15 09:50:02 +08:00
|
|
|
|
c.Lock()
|
|
|
|
|
|
defer c.Unlock()
|
2021-03-17 19:03:24 +08:00
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
return sendPacket(c.Writer, m)
|
2021-04-21 21:36:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
func (c *conn) writePacket(id uint32, p sshfx.PacketMarshaller, b []byte) error {
|
|
|
|
|
|
header, payload, err := p.MarshalPacket(id, b)
|
2021-04-21 21:36:45 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return errors.WithStack(err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
c.Lock()
|
|
|
|
|
|
defer c.Unlock()
|
|
|
|
|
|
|
|
|
|
|
|
if _, err := c.Write(header); err != nil {
|
|
|
|
|
|
return errors.WithStack(err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if len(payload) > 0 {
|
|
|
|
|
|
if _, err := c.Write(payload); err != nil {
|
|
|
|
|
|
return errors.WithStack(err)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return nil
|
2016-06-15 09:50:02 +08:00
|
|
|
|
}
|
2016-06-15 16:07:14 +08:00
|
|
|
|
|
2020-03-08 07:05:46 +08:00
|
|
|
|
func (c *conn) Close() error {
|
|
|
|
|
|
c.Lock()
|
|
|
|
|
|
defer c.Unlock()
|
2021-04-21 21:36:45 +08:00
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
return c.Closer.Close()
|
2020-03-08 07:05:46 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2016-06-15 16:07:14 +08:00
|
|
|
|
type clientConn struct {
|
2021-04-24 20:28:08 +08:00
|
|
|
|
*conn
|
2020-10-29 06:20:19 +08:00
|
|
|
|
wg sync.WaitGroup
|
|
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
nextid uint32
|
|
|
|
|
|
resPool resChanPool
|
|
|
|
|
|
bufPool *bufPool
|
2021-04-21 21:36:45 +08:00
|
|
|
|
|
2016-06-15 16:07:14 +08:00
|
|
|
|
sync.Mutex // protects inflight
|
|
|
|
|
|
inflight map[uint32]chan<- result // outstanding requests
|
2018-12-05 14:47:03 +08:00
|
|
|
|
|
2018-12-05 16:30:09 +08:00
|
|
|
|
closed chan struct{}
|
|
|
|
|
|
err error
|
2018-12-03 17:46:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
func newClientConn(rd io.Reader, wr io.WriteCloser) *clientConn {
|
|
|
|
|
|
return &clientConn{
|
|
|
|
|
|
conn: &conn{
|
|
|
|
|
|
Reader: rd,
|
|
|
|
|
|
Writer: wr,
|
|
|
|
|
|
Closer: wr,
|
|
|
|
|
|
},
|
|
|
|
|
|
inflight: make(map[uint32]chan<- result),
|
|
|
|
|
|
closed: make(chan struct{}),
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2021-04-21 21:36:45 +08:00
|
|
|
|
// returns the next value of c.nextid
|
|
|
|
|
|
func (c *clientConn) nextID() uint32 {
|
|
|
|
|
|
return atomic.AddUint32(&c.nextid, 1)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2018-12-03 17:46:45 +08:00
|
|
|
|
// Wait blocks until the conn has shut down, and return the error
|
2018-12-05 14:47:03 +08:00
|
|
|
|
// causing the shutdown. It can be called concurrently from multiple
|
|
|
|
|
|
// goroutines.
|
2018-12-03 17:46:45 +08:00
|
|
|
|
func (c *clientConn) Wait() error {
|
2018-12-05 16:30:09 +08:00
|
|
|
|
<-c.closed
|
2018-12-05 14:47:03 +08:00
|
|
|
|
return c.err
|
2016-06-15 16:07:14 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2016-06-15 16:23:51 +08:00
|
|
|
|
// Close closes the SFTP session.
|
|
|
|
|
|
func (c *clientConn) Close() error {
|
2016-06-15 18:04:25 +08:00
|
|
|
|
defer c.wg.Wait()
|
|
|
|
|
|
return c.conn.Close()
|
2016-06-15 16:23:51 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *clientConn) loop() {
|
|
|
|
|
|
defer c.wg.Done()
|
2016-06-15 16:07:14 +08:00
|
|
|
|
err := c.recv()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
c.broadcastErr(err)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2021-04-21 21:36:45 +08:00
|
|
|
|
// result captures the result of receiving the a packet from the server
|
|
|
|
|
|
type result struct {
|
2021-04-24 20:28:08 +08:00
|
|
|
|
pkt sshfx.RawPacket
|
|
|
|
|
|
buf []byte // return it after you’re done with it.
|
|
|
|
|
|
err error
|
2021-04-21 21:36:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2016-06-15 16:07:14 +08:00
|
|
|
|
// recv continuously reads from the server and forwards responses to the
|
|
|
|
|
|
// appropriate channel.
|
|
|
|
|
|
func (c *clientConn) recv() error {
|
2020-10-29 06:20:19 +08:00
|
|
|
|
defer c.conn.Close()
|
|
|
|
|
|
|
2016-06-15 16:07:14 +08:00
|
|
|
|
for {
|
2021-04-24 20:28:08 +08:00
|
|
|
|
var pkt sshfx.RawPacket
|
|
|
|
|
|
buf := c.bufPool.Get()
|
2021-04-21 21:36:45 +08:00
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
if err := pkt.ReadFrom(c.conn.Reader, buf, 64*1024); err != nil {
|
2021-03-16 00:53:09 +08:00
|
|
|
|
return err
|
|
|
|
|
|
}
|
2020-10-29 06:20:19 +08:00
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
ch, ok := c.getChannel(pkt.RequestID)
|
2016-06-15 16:07:14 +08:00
|
|
|
|
if !ok {
|
|
|
|
|
|
// This is an unexpected occurrence. Send the error
|
|
|
|
|
|
// back to all listeners so that they terminate
|
|
|
|
|
|
// gracefully.
|
2021-04-24 20:28:08 +08:00
|
|
|
|
return errors.Errorf("sid not found: %d", pkt.RequestID)
|
2016-06-15 16:07:14 +08:00
|
|
|
|
}
|
2020-10-29 06:20:19 +08:00
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
ch <- result{pkt: pkt, buf: buf}
|
2016-06-15 16:07:14 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2020-10-29 06:20:19 +08:00
|
|
|
|
func (c *clientConn) putChannel(ch chan<- result, sid uint32) bool {
|
|
|
|
|
|
c.Lock()
|
|
|
|
|
|
defer c.Unlock()
|
|
|
|
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
|
case <-c.closed:
|
|
|
|
|
|
// already closed with broadcastErr, return error on chan.
|
2020-12-08 23:59:24 +08:00
|
|
|
|
ch <- result{err: ErrSSHFxConnectionLost}
|
2020-10-29 06:20:19 +08:00
|
|
|
|
return false
|
|
|
|
|
|
default:
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
c.inflight[sid] = ch
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) {
|
|
|
|
|
|
c.Lock()
|
|
|
|
|
|
defer c.Unlock()
|
|
|
|
|
|
|
|
|
|
|
|
ch, ok := c.inflight[sid]
|
|
|
|
|
|
delete(c.inflight, sid)
|
|
|
|
|
|
|
|
|
|
|
|
return ch, ok
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2016-06-15 16:30:05 +08:00
|
|
|
|
type idmarshaler interface {
|
|
|
|
|
|
id() uint32
|
|
|
|
|
|
encoding.BinaryMarshaler
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
func (c *clientConn) sendPacket(req sshfx.PacketMarshaller, resp sshfx.Packet) error {
|
|
|
|
|
|
id := c.nextID()
|
2020-10-29 06:20:19 +08:00
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
ch := c.resPool.Get()
|
|
|
|
|
|
defer c.resPool.Put(ch)
|
2016-06-15 16:30:05 +08:00
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
c.dispatchPacket(ch, id, req)
|
2021-04-21 21:36:45 +08:00
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
r := <-ch
|
|
|
|
|
|
if r.err != nil {
|
|
|
|
|
|
// sendPacket should never return an io.EOF except through a StatusError.
|
|
|
|
|
|
if errors.Is(r.err, io.EOF) {
|
|
|
|
|
|
return ErrSSHFxConnectionLost
|
|
|
|
|
|
}
|
2021-04-21 21:36:45 +08:00
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
return r.err
|
2021-04-21 21:36:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
// Because DataPacket shall not alias r.pkt.Buffer,
|
|
|
|
|
|
// we are safe to return this buffer to the pool in all cases.
|
|
|
|
|
|
defer c.bufPool.Put(r.buf)
|
|
|
|
|
|
|
|
|
|
|
|
if r.pkt.RequestID != id {
|
2021-04-21 21:36:45 +08:00
|
|
|
|
return &unexpectedIDErr{
|
|
|
|
|
|
want: id,
|
2021-04-24 20:28:08 +08:00
|
|
|
|
got: r.pkt.RequestID,
|
2021-04-21 21:36:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
if r.pkt.PacketType == sshfx.PacketTypeStatus {
|
|
|
|
|
|
var status sshfx.StatusPacket
|
2021-04-21 21:36:45 +08:00
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
if err := status.UnmarshalPacketBody(&r.pkt.Data); err != nil {
|
2021-04-21 21:36:45 +08:00
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return &StatusError{
|
|
|
|
|
|
Code: uint32(status.StatusCode),
|
|
|
|
|
|
msg: status.ErrorMessage,
|
|
|
|
|
|
lang: status.LanguageTag,
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if resp == nil {
|
|
|
|
|
|
return &unexpectedPacketErr{
|
2021-04-24 20:28:08 +08:00
|
|
|
|
want: uint8(sshfx.PacketTypeStatus),
|
|
|
|
|
|
got: uint8(r.pkt.PacketType),
|
2021-04-21 21:36:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
if r.pkt.PacketType != resp.Type() {
|
2021-04-21 21:36:45 +08:00
|
|
|
|
return &unexpectedPacketErr{
|
|
|
|
|
|
want: uint8(resp.Type()),
|
2021-04-24 20:28:08 +08:00
|
|
|
|
got: uint8(r.pkt.PacketType),
|
2021-04-21 21:36:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
return resp.UnmarshalPacketBody(&r.pkt.Data)
|
2021-04-21 21:36:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
func (c *clientConn) dispatchPacket(ch chan<- result, id uint32, req sshfx.PacketMarshaller) {
|
2021-04-21 21:36:45 +08:00
|
|
|
|
if !c.putChannel(ch, id) {
|
|
|
|
|
|
// already closed.
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
buf := c.bufPool.Get()
|
|
|
|
|
|
defer c.bufPool.Put(buf)
|
2020-10-29 06:20:19 +08:00
|
|
|
|
|
2021-04-24 20:28:08 +08:00
|
|
|
|
if err := c.conn.writePacket(id, req, buf); err != nil {
|
|
|
|
|
|
if ch, ok := c.getChannel(id); ok {
|
2020-10-29 06:20:19 +08:00
|
|
|
|
ch <- result{err: err}
|
|
|
|
|
|
}
|
2016-06-15 16:07:14 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// broadcastErr sends an error to all goroutines waiting for a response.
|
|
|
|
|
|
func (c *clientConn) broadcastErr(err error) {
|
|
|
|
|
|
c.Lock()
|
2020-10-29 06:20:19 +08:00
|
|
|
|
defer c.Unlock()
|
|
|
|
|
|
|
2020-12-08 23:59:24 +08:00
|
|
|
|
bcastRes := result{err: ErrSSHFxConnectionLost}
|
2020-10-29 06:20:19 +08:00
|
|
|
|
for sid, ch := range c.inflight {
|
2020-11-02 21:02:47 +08:00
|
|
|
|
ch <- bcastRes
|
2020-10-29 06:20:19 +08:00
|
|
|
|
|
|
|
|
|
|
// Replace the chan in inflight,
|
|
|
|
|
|
// we have hijacked this chan,
|
|
|
|
|
|
// and this guarantees always-only-once sending.
|
|
|
|
|
|
c.inflight[sid] = make(chan<- result, 1)
|
2016-06-15 16:07:14 +08:00
|
|
|
|
}
|
2020-10-29 06:20:19 +08:00
|
|
|
|
|
2018-12-05 16:30:09 +08:00
|
|
|
|
c.err = err
|
|
|
|
|
|
close(c.closed)
|
2016-06-15 16:07:14 +08:00
|
|
|
|
}
|
2016-06-15 19:08:29 +08:00
|
|
|
|
|
|
|
|
|
|
type serverConn struct {
|
|
|
|
|
|
conn
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2021-02-22 20:00:27 +08:00
|
|
|
|
func (s *serverConn) sendError(id uint32, err error) error {
|
2021-04-24 20:28:08 +08:00
|
|
|
|
return s.writeBinary(statusFromError(id, err))
|
2016-06-15 19:08:29 +08:00
|
|
|
|
}
|