mirror of https://github.com/pkg/sftp.git
Don't crash when the packet length is zero
This commit is contained in:
parent
7f43671909
commit
cb1556337d
|
@ -1,6 +1,7 @@
|
|||
package sftp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
|
@ -208,3 +209,22 @@ func TestUseFstatChecked(t *testing.T) {
|
|||
testFstatOption(t, UseFstat(true), true)
|
||||
testFstatOption(t, UseFstat(false), false)
|
||||
}
|
||||
|
||||
type sink struct{}
|
||||
|
||||
func (*sink) Close() error { return nil }
|
||||
func (*sink) Write(p []byte) (int, error) { return len(p), nil }
|
||||
|
||||
func TestClientZeroLengthPacket(t *testing.T) {
|
||||
// Packet length zero (never valid). This used to crash the client.
|
||||
packet := []byte{0, 0, 0, 0}
|
||||
|
||||
r := bytes.NewReader(packet)
|
||||
c, err := NewClientPipe(r, &sink{})
|
||||
if err == nil {
|
||||
t.Error("expected an error, got nil")
|
||||
}
|
||||
if c != nil {
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
// +build gofuzz
|
||||
|
||||
package sftp
|
||||
|
||||
import "bytes"
|
||||
|
||||
type sink struct{}
|
||||
|
||||
func (*sink) Close() error { return nil }
|
||||
func (*sink) Write(p []byte) (int, error) { return len(p), nil }
|
||||
|
||||
var devnull = &sink{}
|
||||
|
||||
// To run: go-fuzz-build && go-fuzz
|
||||
func Fuzz(data []byte) int {
|
||||
c, err := NewClientPipe(bytes.NewReader(data), devnull)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
c.Close()
|
||||
return 1
|
||||
}
|
|
@ -154,6 +154,10 @@ func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, e
|
|||
debug("recv packet %d bytes too long", length)
|
||||
return 0, nil, errLongPacket
|
||||
}
|
||||
if length == 0 {
|
||||
debug("recv packet of 0 bytes too short")
|
||||
return 0, nil, errShortPacket
|
||||
}
|
||||
if alloc == nil {
|
||||
b = make([]byte, length)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package sftp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
|
@ -13,6 +14,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -365,3 +367,31 @@ func TestStatNonExistent(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerWithBrokenClient(t *testing.T) {
|
||||
validInit := sp(sshFxInitPacket{Version: 3})
|
||||
brokenOpen := sp(sshFxpOpenPacket{Path: "foo"})
|
||||
brokenOpen = brokenOpen[:len(brokenOpen)-2]
|
||||
|
||||
for _, clientInput := range [][]byte{
|
||||
// Packet length zero (never valid). This used to crash the server.
|
||||
{0, 0, 0, 0},
|
||||
append(validInit, 0, 0, 0, 0),
|
||||
|
||||
// Client hangs up mid-packet.
|
||||
append(validInit, brokenOpen...),
|
||||
} {
|
||||
srv, err := NewServer(struct {
|
||||
io.Reader
|
||||
io.WriteCloser
|
||||
}{
|
||||
bytes.NewReader(clientInput),
|
||||
&sink{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = srv.Serve()
|
||||
assert.Error(t, err)
|
||||
srv.Close()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue