mirror of https://github.com/pkg/sftp.git
Merge pull request #392 from drakkan/transfereof
request-server: don't return EOF if there is an unexpected error
This commit is contained in:
commit
7230c61342
|
@ -167,6 +167,9 @@ func (rs *RequestServer) Serve() error {
|
||||||
// make sure all open requests are properly closed
|
// make sure all open requests are properly closed
|
||||||
// (eg. possible on dropped connections, client crashes, etc.)
|
// (eg. possible on dropped connections, client crashes, etc.)
|
||||||
for handle, req := range rs.openRequests {
|
for handle, req := range rs.openRequests {
|
||||||
|
if err == io.EOF {
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
req.transferError(err)
|
req.transferError(err)
|
||||||
|
|
||||||
delete(rs.openRequests, handle)
|
delete(rs.openRequests, handle)
|
||||||
|
|
|
@ -19,6 +19,7 @@ var _ = fmt.Print
|
||||||
type csPair struct {
|
type csPair struct {
|
||||||
cli *Client
|
cli *Client
|
||||||
svr *RequestServer
|
svr *RequestServer
|
||||||
|
svrResult chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
// these must be closed in order, else client.Close will hang
|
// these must be closed in order, else client.Close will hang
|
||||||
|
@ -39,6 +40,9 @@ func clientRequestServerPair(t *testing.T) *csPair {
|
||||||
skipIfPlan9(t)
|
skipIfPlan9(t)
|
||||||
ready := make(chan bool)
|
ready := make(chan bool)
|
||||||
os.Remove(sock) // either this or signal handling
|
os.Remove(sock) // either this or signal handling
|
||||||
|
pair := &csPair{
|
||||||
|
svrResult: make(chan error, 1),
|
||||||
|
}
|
||||||
var server *RequestServer
|
var server *RequestServer
|
||||||
go func() {
|
go func() {
|
||||||
l, err := net.Listen("unix", sock)
|
l, err := net.Listen("unix", sock)
|
||||||
|
@ -55,7 +59,8 @@ func clientRequestServerPair(t *testing.T) *csPair {
|
||||||
options = append(options, WithRSAllocator())
|
options = append(options, WithRSAllocator())
|
||||||
}
|
}
|
||||||
server = NewRequestServer(fd, handlers, options...)
|
server = NewRequestServer(fd, handlers, options...)
|
||||||
server.Serve()
|
err = server.Serve()
|
||||||
|
pair.svrResult <- err
|
||||||
}()
|
}()
|
||||||
<-ready
|
<-ready
|
||||||
defer os.Remove(sock)
|
defer os.Remove(sock)
|
||||||
|
@ -65,7 +70,9 @@ func clientRequestServerPair(t *testing.T) *csPair {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%+v\n", err)
|
t.Fatalf("%+v\n", err)
|
||||||
}
|
}
|
||||||
return &csPair{client, server}
|
pair.svr = server
|
||||||
|
pair.cli = client
|
||||||
|
return pair
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkRequestServerAllocator(t *testing.T, p *csPair) {
|
func checkRequestServerAllocator(t *testing.T, p *csPair) {
|
||||||
|
@ -718,6 +725,34 @@ func TestRequestReaddir(t *testing.T) {
|
||||||
checkRequestServerAllocator(t, p)
|
checkRequestServerAllocator(t, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCleanDisconnect(t *testing.T) {
|
||||||
|
p := clientRequestServerPair(t)
|
||||||
|
defer p.Close()
|
||||||
|
|
||||||
|
err := p.cli.conn.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
// server must return io.EOF after a clean client close
|
||||||
|
// with no pending open requests
|
||||||
|
err = <-p.svrResult
|
||||||
|
require.EqualError(t, err, io.EOF.Error())
|
||||||
|
checkRequestServerAllocator(t, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUncleanDisconnect(t *testing.T) {
|
||||||
|
p := clientRequestServerPair(t)
|
||||||
|
defer p.Close()
|
||||||
|
|
||||||
|
foo := NewRequest("", "foo")
|
||||||
|
p.svr.nextRequest(foo)
|
||||||
|
err := p.cli.conn.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
// the foo request above is still open after the client disconnects
|
||||||
|
// so the server will convert io.EOF to io.ErrUnexpectedEOF
|
||||||
|
err = <-p.svrResult
|
||||||
|
require.EqualError(t, err, io.ErrUnexpectedEOF.Error())
|
||||||
|
checkRequestServerAllocator(t, p)
|
||||||
|
}
|
||||||
|
|
||||||
func TestCleanPath(t *testing.T) {
|
func TestCleanPath(t *testing.T) {
|
||||||
assert.Equal(t, "/", cleanPath("/"))
|
assert.Equal(t, "/", cleanPath("/"))
|
||||||
assert.Equal(t, "/", cleanPath("."))
|
assert.Equal(t, "/", cleanPath("."))
|
||||||
|
|
Loading…
Reference in New Issue