| 
									
										
										
										
											2016-06-15 09:50:02 +08:00
										 |  |  | package sftp | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"encoding" | 
					
						
							|  |  |  | 	"io" | 
					
						
							|  |  |  | 	"sync" | 
					
						
							| 
									
										
										
										
											2016-06-15 16:07:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/pkg/errors" | 
					
						
							| 
									
										
										
										
											2016-06-15 09:50:02 +08:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // conn implements a bidirectional channel on which client and server
 | 
					
						
							| 
									
										
										
										
											2016-06-15 16:07:14 +08:00
										 |  |  | // connections are multiplexed.
 | 
					
						
							| 
									
										
										
										
											2016-06-15 09:50:02 +08:00
										 |  |  | type conn struct { | 
					
						
							| 
									
										
										
										
											2016-06-15 18:19:51 +08:00
										 |  |  | 	io.Reader | 
					
						
							|  |  |  | 	io.WriteCloser | 
					
						
							| 
									
										
										
										
											2020-03-18 16:36:07 +08:00
										 |  |  | 	// this is the same allocator used in packet manager
 | 
					
						
							|  |  |  | 	alloc      *allocator | 
					
						
							| 
									
										
										
										
											2016-06-15 09:50:02 +08:00
										 |  |  | 	sync.Mutex // used to serialise writes to sendPacket
 | 
					
						
							| 
									
										
										
										
											2017-02-13 15:18:54 +08:00
										 |  |  | 	// sendPacketTest is needed to replicate packet issues in testing
 | 
					
						
							|  |  |  | 	sendPacketTest func(w io.Writer, m encoding.BinaryMarshaler) error | 
					
						
							| 
									
										
										
										
											2016-06-15 09:50:02 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-18 16:36:07 +08:00
										 |  |  | // the orderID is used in server mode if the allocator is enabled.
 | 
					
						
							|  |  |  | // For the client mode just pass 0
 | 
					
						
							|  |  |  | func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) { | 
					
						
							|  |  |  | 	return recvPacket(c, c.alloc, orderID) | 
					
						
							| 
									
										
										
										
											2016-06-15 09:50:02 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (c *conn) sendPacket(m encoding.BinaryMarshaler) error { | 
					
						
							|  |  |  | 	c.Lock() | 
					
						
							|  |  |  | 	defer c.Unlock() | 
					
						
							| 
									
										
										
										
											2017-02-13 15:18:54 +08:00
										 |  |  | 	if c.sendPacketTest != nil { | 
					
						
							|  |  |  | 		return c.sendPacketTest(c, m) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2016-06-15 09:50:02 +08:00
										 |  |  | 	return sendPacket(c, m) | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2016-06-15 16:07:14 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-08 07:05:46 +08:00
										 |  |  | func (c *conn) Close() error { | 
					
						
							|  |  |  | 	c.Lock() | 
					
						
							|  |  |  | 	defer c.Unlock() | 
					
						
							|  |  |  | 	return c.WriteCloser.Close() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-06-15 16:07:14 +08:00
										 |  |  | type clientConn struct { | 
					
						
							|  |  |  | 	conn | 
					
						
							| 
									
										
										
										
											2016-06-15 16:23:51 +08:00
										 |  |  | 	wg         sync.WaitGroup | 
					
						
							| 
									
										
										
										
											2016-06-15 16:07:14 +08:00
										 |  |  | 	sync.Mutex                          // protects inflight
 | 
					
						
							|  |  |  | 	inflight   map[uint32]chan<- result // outstanding requests
 | 
					
						
							| 
									
										
										
										
											2018-12-05 14:47:03 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-12-05 16:30:09 +08:00
										 |  |  | 	closed chan struct{} | 
					
						
							|  |  |  | 	err    error | 
					
						
							| 
									
										
										
										
											2018-12-03 17:46:45 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Wait blocks until the conn has shut down, and return the error
 | 
					
						
							| 
									
										
										
										
											2018-12-05 14:47:03 +08:00
										 |  |  | // causing the shutdown. It can be called concurrently from multiple
 | 
					
						
							|  |  |  | // goroutines.
 | 
					
						
							| 
									
										
										
										
											2018-12-03 17:46:45 +08:00
										 |  |  | func (c *clientConn) Wait() error { | 
					
						
							| 
									
										
										
										
											2018-12-05 16:30:09 +08:00
										 |  |  | 	<-c.closed | 
					
						
							| 
									
										
										
										
											2018-12-05 14:47:03 +08:00
										 |  |  | 	return c.err | 
					
						
							| 
									
										
										
										
											2016-06-15 16:07:14 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-06-15 16:23:51 +08:00
										 |  |  | // Close closes the SFTP session.
 | 
					
						
							|  |  |  | func (c *clientConn) Close() error { | 
					
						
							| 
									
										
										
										
											2016-06-15 18:04:25 +08:00
										 |  |  | 	defer c.wg.Wait() | 
					
						
							|  |  |  | 	return c.conn.Close() | 
					
						
							| 
									
										
										
										
											2016-06-15 16:23:51 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (c *clientConn) loop() { | 
					
						
							|  |  |  | 	defer c.wg.Done() | 
					
						
							| 
									
										
										
										
											2016-06-15 16:07:14 +08:00
										 |  |  | 	err := c.recv() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		c.broadcastErr(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // recv continuously reads from the server and forwards responses to the
 | 
					
						
							|  |  |  | // appropriate channel.
 | 
					
						
							|  |  |  | func (c *clientConn) recv() error { | 
					
						
							| 
									
										
										
										
											2017-02-01 08:24:31 +08:00
										 |  |  | 	defer func() { | 
					
						
							|  |  |  | 		c.conn.Close() | 
					
						
							|  |  |  | 	}() | 
					
						
							| 
									
										
										
										
											2016-06-15 16:07:14 +08:00
										 |  |  | 	for { | 
					
						
							| 
									
										
										
										
											2020-03-18 16:36:07 +08:00
										 |  |  | 		typ, data, err := c.recvPacket(0) | 
					
						
							| 
									
										
										
										
											2016-06-15 16:07:14 +08:00
										 |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		sid, _ := unmarshalUint32(data) | 
					
						
							|  |  |  | 		c.Lock() | 
					
						
							|  |  |  | 		ch, ok := c.inflight[sid] | 
					
						
							|  |  |  | 		delete(c.inflight, sid) | 
					
						
							|  |  |  | 		c.Unlock() | 
					
						
							|  |  |  | 		if !ok { | 
					
						
							|  |  |  | 			// This is an unexpected occurrence. Send the error
 | 
					
						
							|  |  |  | 			// back to all listeners so that they terminate
 | 
					
						
							|  |  |  | 			// gracefully.
 | 
					
						
							|  |  |  | 			return errors.Errorf("sid: %v not fond", sid) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		ch <- result{typ: typ, data: data} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-06-15 16:30:05 +08:00
										 |  |  | // result captures the result of receiving the a packet from the server
 | 
					
						
							|  |  |  | type result struct { | 
					
						
							|  |  |  | 	typ  byte | 
					
						
							|  |  |  | 	data []byte | 
					
						
							|  |  |  | 	err  error | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type idmarshaler interface { | 
					
						
							|  |  |  | 	id() uint32 | 
					
						
							|  |  |  | 	encoding.BinaryMarshaler | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-06-15 16:57:05 +08:00
										 |  |  | func (c *clientConn) sendPacket(p idmarshaler) (byte, []byte, error) { | 
					
						
							| 
									
										
										
										
											2018-02-07 08:43:44 +08:00
										 |  |  | 	ch := make(chan result, 2) | 
					
						
							| 
									
										
										
										
											2016-06-15 16:30:05 +08:00
										 |  |  | 	c.dispatchRequest(ch, p) | 
					
						
							|  |  |  | 	s := <-ch | 
					
						
							|  |  |  | 	return s.typ, s.data, s.err | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-06-15 16:07:14 +08:00
										 |  |  | func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) { | 
					
						
							|  |  |  | 	c.Lock() | 
					
						
							|  |  |  | 	c.inflight[p.id()] = ch | 
					
						
							| 
									
										
										
										
											2017-04-05 17:53:22 +08:00
										 |  |  | 	c.Unlock() | 
					
						
							| 
									
										
										
										
											2016-06-15 16:57:05 +08:00
										 |  |  | 	if err := c.conn.sendPacket(p); err != nil { | 
					
						
							| 
									
										
										
										
											2017-04-05 17:53:22 +08:00
										 |  |  | 		c.Lock() | 
					
						
							| 
									
										
										
										
											2016-06-15 16:07:14 +08:00
										 |  |  | 		delete(c.inflight, p.id()) | 
					
						
							| 
									
										
										
										
											2017-04-05 17:53:22 +08:00
										 |  |  | 		c.Unlock() | 
					
						
							| 
									
										
										
										
											2016-06-15 16:07:14 +08:00
										 |  |  | 		ch <- result{err: err} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // broadcastErr sends an error to all goroutines waiting for a response.
 | 
					
						
							|  |  |  | func (c *clientConn) broadcastErr(err error) { | 
					
						
							|  |  |  | 	c.Lock() | 
					
						
							|  |  |  | 	listeners := make([]chan<- result, 0, len(c.inflight)) | 
					
						
							|  |  |  | 	for _, ch := range c.inflight { | 
					
						
							|  |  |  | 		listeners = append(listeners, ch) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	c.Unlock() | 
					
						
							|  |  |  | 	for _, ch := range listeners { | 
					
						
							|  |  |  | 		ch <- result{err: err} | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2018-12-05 16:30:09 +08:00
										 |  |  | 	c.err = err | 
					
						
							|  |  |  | 	close(c.closed) | 
					
						
							| 
									
										
										
										
											2016-06-15 16:07:14 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2016-06-15 19:08:29 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | type serverConn struct { | 
					
						
							|  |  |  | 	conn | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-03-14 08:52:53 +08:00
										 |  |  | func (s *serverConn) sendError(p ider, err error) error { | 
					
						
							| 
									
										
										
										
											2016-06-15 19:08:29 +08:00
										 |  |  | 	return s.sendPacket(statusFromError(p, err)) | 
					
						
							|  |  |  | } |