mirror of https://github.com/pkg/sftp.git
symlink loop testing
This commit is contained in:
parent
d696bdb2ff
commit
c46216738b
|
@ -16,6 +16,10 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
const maxSymlinkFollows = 5
|
||||
|
||||
var errTooManySymlinks = errors.New("too many symbolic links")
|
||||
|
||||
// InMemHandler returns a Hanlders object with the test handlers.
|
||||
func InMemHandler() Handlers {
|
||||
root := &root{
|
||||
|
@ -87,6 +91,7 @@ func (fs *root) openfile(pathname string, flags uint32) (*memFile, error) {
|
|||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
var count int
|
||||
// You can create files through dangling symlinks.
|
||||
link, err := fs.lfetch(pathname)
|
||||
for err == nil && link.symlink != "" {
|
||||
|
@ -95,6 +100,10 @@ func (fs *root) openfile(pathname string, flags uint32) (*memFile, error) {
|
|||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
if count++; count > maxSymlinkFollows {
|
||||
return nil, errTooManySymlinks
|
||||
}
|
||||
|
||||
pathname = link.symlink
|
||||
link, err = fs.lfetch(pathname)
|
||||
}
|
||||
|
@ -485,18 +494,11 @@ func (fs *root) lfetch(path string) (*memFile, error) {
|
|||
func (fs *root) canonName(pathname string) (string, error) {
|
||||
dirname, filename := path.Dir(pathname), path.Base(pathname)
|
||||
|
||||
dir, err := fs.lfetch(dirname)
|
||||
dir, err := fs.fetch(dirname)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for dir.symlink != "" {
|
||||
dir, err = fs.lfetch(dir.symlink)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
if !dir.IsDir() {
|
||||
return "", syscall.ENOTDIR
|
||||
}
|
||||
|
@ -521,7 +523,12 @@ func (fs *root) fetch(path string) (*memFile, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var count int
|
||||
for file.symlink != "" {
|
||||
if count++; count > maxSymlinkFollows {
|
||||
return nil, errTooManySymlinks
|
||||
}
|
||||
|
||||
file, err = fs.lfetch(file.symlink)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -562,6 +563,44 @@ func TestRequestSymlink(t *testing.T) {
|
|||
checkRequestServerAllocator(t, p)
|
||||
}
|
||||
|
||||
func TestRequestSymlinkLoop(t *testing.T) {
|
||||
p := clientRequestServerPair(t)
|
||||
defer p.Close()
|
||||
|
||||
err := p.cli.Symlink("/foo", "/bar")
|
||||
require.NoError(t, err)
|
||||
err = p.cli.Symlink("/bar", "/baz")
|
||||
require.NoError(t, err)
|
||||
err = p.cli.Symlink("/baz", "/foo")
|
||||
require.NoError(t, err)
|
||||
|
||||
// test should fail if we reach this point
|
||||
timer := time.NewTimer(1 * time.Second)
|
||||
defer timer.Stop()
|
||||
|
||||
var content []byte
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
content, err = getTestFile(p.cli, "/bar")
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.Fatal("symlink loop following timed out")
|
||||
return // just to let the compiler be absolutely sure
|
||||
|
||||
case <-done:
|
||||
}
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Len(t, content, 0)
|
||||
|
||||
checkRequestServerAllocator(t, p)
|
||||
}
|
||||
|
||||
func TestRequestSymlinkDanglingFiles(t *testing.T) {
|
||||
p := clientRequestServerPair(t)
|
||||
defer p.Close()
|
||||
|
|
Loading…
Reference in New Issue