mirror of https://github.com/pkg/sftp.git
Use channel to implement a simple way to wait
This commit is contained in:
parent
3a53acc96b
commit
5cd7f324f9
|
@ -6,7 +6,6 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
@ -123,7 +122,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
|
||||||
WriteCloser: wr,
|
WriteCloser: wr,
|
||||||
},
|
},
|
||||||
inflight: make(map[uint32]chan<- result),
|
inflight: make(map[uint32]chan<- result),
|
||||||
errCond: sync.NewCond(new(sync.Mutex)),
|
closed: make(chan struct{}),
|
||||||
},
|
},
|
||||||
maxPacket: 1 << 15,
|
maxPacket: 1 << 15,
|
||||||
maxConcurrentRequests: 64,
|
maxConcurrentRequests: 64,
|
||||||
|
|
16
conn.go
16
conn.go
|
@ -37,19 +37,15 @@ type clientConn struct {
|
||||||
sync.Mutex // protects inflight
|
sync.Mutex // protects inflight
|
||||||
inflight map[uint32]chan<- result // outstanding requests
|
inflight map[uint32]chan<- result // outstanding requests
|
||||||
|
|
||||||
errCond *sync.Cond
|
closed chan struct{}
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait blocks until the conn has shut down, and return the error
|
// Wait blocks until the conn has shut down, and return the error
|
||||||
// causing the shutdown. It can be called concurrently from multiple
|
// causing the shutdown. It can be called concurrently from multiple
|
||||||
// goroutines.
|
// goroutines.
|
||||||
func (c *clientConn) Wait() error {
|
func (c *clientConn) Wait() error {
|
||||||
c.errCond.L.Lock()
|
<-c.closed
|
||||||
defer c.errCond.L.Unlock()
|
|
||||||
for c.err == nil {
|
|
||||||
c.errCond.Wait()
|
|
||||||
}
|
|
||||||
return c.err
|
return c.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,10 +60,6 @@ func (c *clientConn) loop() {
|
||||||
err := c.recv()
|
err := c.recv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.broadcastErr(err)
|
c.broadcastErr(err)
|
||||||
c.errCond.L.Lock()
|
|
||||||
c.err = err
|
|
||||||
c.errCond.Broadcast()
|
|
||||||
c.errCond.L.Unlock()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,6 +133,8 @@ func (c *clientConn) broadcastErr(err error) {
|
||||||
for _, ch := range listeners {
|
for _, ch := range listeners {
|
||||||
ch <- result{err: err}
|
ch <- result{err: err}
|
||||||
}
|
}
|
||||||
|
c.err = err
|
||||||
|
close(c.closed)
|
||||||
}
|
}
|
||||||
|
|
||||||
type serverConn struct {
|
type serverConn struct {
|
||||||
|
|
Loading…
Reference in New Issue