mirror of https://github.com/pkg/sftp.git
				
				
				
			
		
			
				
	
	
		
			192 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			192 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Go
		
	
	
	
| package sftp
 | |
| 
 | |
| import (
 | |
| 	"encoding"
 | |
| 	"io"
 | |
| 	"sync"
 | |
| 
 | |
| 	"github.com/pkg/errors"
 | |
| )
 | |
| 
 | |
| // conn implements a bidirectional channel on which client and server
 | |
| // connections are multiplexed.
 | |
| type conn struct {
 | |
| 	io.Reader
 | |
| 	io.WriteCloser
 | |
| 	// this is the same allocator used in packet manager
 | |
| 	alloc      *allocator
 | |
| 	sync.Mutex // used to serialise writes to sendPacket
 | |
| 	// sendPacketTest is needed to replicate packet issues in testing
 | |
| 	sendPacketTest func(w io.Writer, m encoding.BinaryMarshaler) error
 | |
| }
 | |
| 
 | |
| // the orderID is used in server mode if the allocator is enabled.
 | |
| // For the client mode just pass 0
 | |
| func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) {
 | |
| 	return recvPacket(c, c.alloc, orderID)
 | |
| }
 | |
| 
 | |
| func (c *conn) sendPacket(m encoding.BinaryMarshaler) error {
 | |
| 	c.Lock()
 | |
| 	defer c.Unlock()
 | |
| 	if c.sendPacketTest != nil {
 | |
| 		return c.sendPacketTest(c, m)
 | |
| 	}
 | |
| 	return sendPacket(c, m)
 | |
| }
 | |
| 
 | |
| func (c *conn) Close() error {
 | |
| 	c.Lock()
 | |
| 	defer c.Unlock()
 | |
| 	return c.WriteCloser.Close()
 | |
| }
 | |
| 
 | |
| type clientConn struct {
 | |
| 	conn
 | |
| 	wg sync.WaitGroup
 | |
| 
 | |
| 	sync.Mutex                          // protects inflight
 | |
| 	inflight   map[uint32]chan<- result // outstanding requests
 | |
| 
 | |
| 	closed chan struct{}
 | |
| 	err    error
 | |
| }
 | |
| 
 | |
| // Wait blocks until the conn has shut down, and return the error
 | |
| // causing the shutdown. It can be called concurrently from multiple
 | |
| // goroutines.
 | |
| func (c *clientConn) Wait() error {
 | |
| 	<-c.closed
 | |
| 	return c.err
 | |
| }
 | |
| 
 | |
| // Close closes the SFTP session.
 | |
| func (c *clientConn) Close() error {
 | |
| 	defer c.wg.Wait()
 | |
| 	return c.conn.Close()
 | |
| }
 | |
| 
 | |
| func (c *clientConn) loop() {
 | |
| 	defer c.wg.Done()
 | |
| 	err := c.recv()
 | |
| 	if err != nil {
 | |
| 		c.broadcastErr(err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // recv continuously reads from the server and forwards responses to the
 | |
| // appropriate channel.
 | |
| func (c *clientConn) recv() error {
 | |
| 	defer c.conn.Close()
 | |
| 
 | |
| 	for {
 | |
| 		typ, data, err := c.recvPacket(0)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		sid, _ := unmarshalUint32(data)
 | |
| 
 | |
| 		ch, ok := c.getChannel(sid)
 | |
| 		if !ok {
 | |
| 			// This is an unexpected occurrence. Send the error
 | |
| 			// back to all listeners so that they terminate
 | |
| 			// gracefully.
 | |
| 			return errors.Errorf("sid not found: %v", sid)
 | |
| 		}
 | |
| 
 | |
| 		ch <- result{typ: typ, data: data}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (c *clientConn) putChannel(ch chan<- result, sid uint32) bool {
 | |
| 	c.Lock()
 | |
| 	defer c.Unlock()
 | |
| 
 | |
| 	select {
 | |
| 	case <-c.closed:
 | |
| 		// already closed with broadcastErr, return error on chan.
 | |
| 		ch <- result{err: ErrSSHFxConnectionLost}
 | |
| 		return false
 | |
| 	default:
 | |
| 	}
 | |
| 
 | |
| 	c.inflight[sid] = ch
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) {
 | |
| 	c.Lock()
 | |
| 	defer c.Unlock()
 | |
| 
 | |
| 	ch, ok := c.inflight[sid]
 | |
| 	delete(c.inflight, sid)
 | |
| 
 | |
| 	return ch, ok
 | |
| }
 | |
| 
 | |
| // result captures the result of receiving the a packet from the server
 | |
| type result struct {
 | |
| 	typ  byte
 | |
| 	data []byte
 | |
| 	err  error
 | |
| }
 | |
| 
 | |
| type idmarshaler interface {
 | |
| 	id() uint32
 | |
| 	encoding.BinaryMarshaler
 | |
| }
 | |
| 
 | |
| func (c *clientConn) sendPacket(ch chan result, p idmarshaler) (byte, []byte, error) {
 | |
| 	if cap(ch) < 1 {
 | |
| 		ch = make(chan result, 1)
 | |
| 	}
 | |
| 
 | |
| 	c.dispatchRequest(ch, p)
 | |
| 	s := <-ch
 | |
| 	return s.typ, s.data, s.err
 | |
| }
 | |
| 
 | |
| // dispatchRequest should ideally only be called by race-detection tests outside of this file,
 | |
| // where you have to ensure two packets are in flight sequentially after each other.
 | |
| func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) {
 | |
| 	sid := p.id()
 | |
| 
 | |
| 	if !c.putChannel(ch, sid) {
 | |
| 		// already closed.
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	if err := c.conn.sendPacket(p); err != nil {
 | |
| 		if ch, ok := c.getChannel(sid); ok {
 | |
| 			ch <- result{err: err}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // broadcastErr sends an error to all goroutines waiting for a response.
 | |
| func (c *clientConn) broadcastErr(err error) {
 | |
| 	c.Lock()
 | |
| 	defer c.Unlock()
 | |
| 
 | |
| 	bcastRes := result{err: ErrSSHFxConnectionLost}
 | |
| 	for sid, ch := range c.inflight {
 | |
| 		ch <- bcastRes
 | |
| 
 | |
| 		// Replace the chan in inflight,
 | |
| 		// we have hijacked this chan,
 | |
| 		// and this guarantees always-only-once sending.
 | |
| 		c.inflight[sid] = make(chan<- result, 1)
 | |
| 	}
 | |
| 
 | |
| 	c.err = err
 | |
| 	close(c.closed)
 | |
| }
 | |
| 
 | |
| type serverConn struct {
 | |
| 	conn
 | |
| }
 | |
| 
 | |
| func (s *serverConn) sendError(id uint32, err error) error {
 | |
| 	return s.sendPacket(statusFromError(id, err))
 | |
| }
 |