Merge pull request #72 from mdlayher/master

client: make Client.{Lstat,Open,Stat} satisfy os.IsNotExist
This commit is contained in:
Dave Cheney 2016-01-04 12:14:20 +01:00
commit 9e66bf3ae2
3 changed files with 104 additions and 64 deletions

View File

@ -224,7 +224,7 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
}
case ssh_FXP_STATUS:
// TODO(dfc) scope warning!
err = eofOrErr(unmarshalStatus(id, data))
err = normaliseError(unmarshalStatus(id, data))
done = true
default:
return nil, unimplementedPacketErr(typ)
@ -280,7 +280,7 @@ func (c *Client) Stat(p string) (os.FileInfo, error) {
attr, _ := unmarshalAttrs(data)
return fileInfoFromStat(attr, path.Base(p)), nil
case ssh_FXP_STATUS:
return nil, unmarshalStatus(id, data)
return nil, normaliseError(unmarshalStatus(id, data))
default:
return nil, unimplementedPacketErr(typ)
}
@ -306,7 +306,7 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) {
attr, _ := unmarshalAttrs(data)
return fileInfoFromStat(attr, path.Base(p)), nil
case ssh_FXP_STATUS:
return nil, unmarshalStatus(id, data)
return nil, normaliseError(unmarshalStatus(id, data))
default:
return nil, unimplementedPacketErr(typ)
}
@ -354,7 +354,7 @@ func (c *Client) Symlink(oldname, newname string) error {
}
switch typ {
case ssh_FXP_STATUS:
return okOrErr(unmarshalStatus(id, data))
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
@ -374,7 +374,7 @@ func (c *Client) setstat(path string, flags uint32, attrs interface{}) error {
}
switch typ {
case ssh_FXP_STATUS:
return okOrErr(unmarshalStatus(id, data))
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
@ -446,7 +446,7 @@ func (c *Client) open(path string, pflags uint32) (*File, error) {
handle, _ := unmarshalString(data)
return &File{c: c, path: path, handle: handle}, nil
case ssh_FXP_STATUS:
return nil, unmarshalStatus(id, data)
return nil, normaliseError(unmarshalStatus(id, data))
default:
return nil, unimplementedPacketErr(typ)
}
@ -466,7 +466,7 @@ func (c *Client) close(handle string) error {
}
switch typ {
case ssh_FXP_STATUS:
return okOrErr(unmarshalStatus(id, data))
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
@ -567,7 +567,7 @@ func (c *Client) removeFile(path string) error {
}
switch typ {
case ssh_FXP_STATUS:
return okOrErr(unmarshalStatus(id, data))
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
@ -584,7 +584,7 @@ func (c *Client) removeDirectory(path string) error {
}
switch typ {
case ssh_FXP_STATUS:
return okOrErr(unmarshalStatus(id, data))
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
@ -603,7 +603,7 @@ func (c *Client) Rename(oldname, newname string) error {
}
switch typ {
case ssh_FXP_STATUS:
return okOrErr(unmarshalStatus(id, data))
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
@ -631,7 +631,7 @@ func (c *Client) realpath(path string) (string, error) {
filename, _ := unmarshalString(data) // ignore attributes
return filename, nil
case ssh_FXP_STATUS:
return "", okOrErr(unmarshalStatus(id, data))
return "", normaliseError(unmarshalStatus(id, data))
default:
return "", unimplementedPacketErr(typ)
}
@ -686,7 +686,7 @@ func (c *Client) Mkdir(path string) error {
}
switch typ {
case ssh_FXP_STATUS:
return okOrErr(unmarshalStatus(id, data))
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
@ -789,7 +789,10 @@ func (f *File) Read(b []byte) (int, error) {
switch res.typ {
case ssh_FXP_STATUS:
if firstErr.err == nil || req.offset < firstErr.offset {
firstErr = offsetErr{offset: req.offset, err: eofOrErr(unmarshalStatus(reqId, res.data))}
firstErr = offsetErr{
offset: req.offset,
err: normaliseError(unmarshalStatus(reqId, res.data)),
}
break
}
case ssh_FXP_DATA:
@ -887,7 +890,7 @@ func (f *File) WriteTo(w io.Writer) (int64, error) {
switch res.typ {
case ssh_FXP_STATUS:
if firstErr.err == nil || req.offset < firstErr.offset {
firstErr = offsetErr{offset: req.offset, err: eofOrErr(unmarshalStatus(reqId, res.data))}
firstErr = offsetErr{offset: req.offset, err: normaliseError(unmarshalStatus(reqId, res.data))}
break
}
case ssh_FXP_DATA:
@ -996,7 +999,7 @@ func (f *File) Write(b []byte) (int, error) {
switch res.typ {
case ssh_FXP_STATUS:
id, _ := unmarshalUint32(res.data)
err := okOrErr(unmarshalStatus(id, res.data))
err := normaliseError(unmarshalStatus(id, res.data))
if err != nil && firstErr == nil {
firstErr = err
break
@ -1062,7 +1065,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
switch res.typ {
case ssh_FXP_STATUS:
id, _ := unmarshalUint32(res.data)
err := okOrErr(unmarshalStatus(id, res.data))
err := normaliseError(unmarshalStatus(id, res.data))
if err != nil && firstErr == nil {
firstErr = err
break
@ -1135,20 +1138,25 @@ func min(a, b int) int {
return a
}
// okOrErr returns nil if Err.Code is SSH_FX_OK, otherwise it returns the error.
func okOrErr(err error) error {
if err, ok := err.(*StatusError); ok && err.Code == ssh_FX_OK {
return nil
}
return err
}
func eofOrErr(err error) error {
if err, ok := err.(*StatusError); ok && err.Code == ssh_FX_EOF {
// normaliseError normalises an error into a more standard form that can be
// checked against stdlib errors like io.EOF or os.ErrNotExist.
func normaliseError(err error) error {
switch err := err.(type) {
case *StatusError:
switch err.Code {
case ssh_FX_EOF:
return io.EOF
}
case ssh_FX_NO_SUCH_FILE:
return os.ErrNotExist
case ssh_FX_OK:
return nil
default:
return err
}
default:
return err
}
}
func unmarshalStatus(id uint32, data []byte) error {
sid, data := unmarshalUint32(data)

View File

@ -188,7 +188,7 @@ func TestClientLstat(t *testing.T) {
}
}
func TestClientLstatMissing(t *testing.T) {
func TestClientLstatIsNotExist(t *testing.T) {
sftp, cmd := testClient(t, READONLY, NO_DELAY)
defer cmd.Wait()
defer sftp.Close()
@ -199,9 +199,8 @@ func TestClientLstatMissing(t *testing.T) {
}
os.Remove(f.Name())
_, err = sftp.Lstat(f.Name())
if err1, ok := err.(*StatusError); !ok || err1.Code != ssh_FX_NO_SUCH_FILE {
t.Fatalf("Lstat: want: %v, got %#v", ssh_FX_NO_SUCH_FILE, err)
if _, err := sftp.Lstat(f.Name()); !os.IsNotExist(err) {
t.Errorf("os.IsNotExist(%v) = false, want true", err)
}
}
@ -243,6 +242,26 @@ func TestClientOpen(t *testing.T) {
}
}
func TestClientOpenIsNotExist(t *testing.T) {
sftp, cmd := testClient(t, READONLY, NO_DELAY)
defer cmd.Wait()
defer sftp.Close()
if _, err := sftp.Open("/doesnt/exist/"); !os.IsNotExist(err) {
t.Errorf("os.IsNotExist(%v) = false, want true", err)
}
}
func TestClientStatIsNotExist(t *testing.T) {
sftp, cmd := testClient(t, READONLY, NO_DELAY)
defer cmd.Wait()
defer sftp.Close()
if _, err := sftp.Stat("/doesnt/exist/"); !os.IsNotExist(err) {
t.Errorf("os.IsNotExist(%v) = false, want true", err)
}
}
const seekBytes = 128 * 1024
type seek struct {

View File

@ -1,6 +1,7 @@
package sftp
import (
"errors"
"io"
"os"
"testing"
@ -14,42 +15,54 @@ var _ fs.FileSystem = new(Client)
// assert that *File implements io.ReadWriteCloser
var _ io.ReadWriteCloser = new(File)
var ok = &StatusError{Code: ssh_FX_OK}
var eof = &StatusError{Code: ssh_FX_EOF}
var fail = &StatusError{Code: ssh_FX_FAILURE}
func TestNormaliseError(t *testing.T) {
var (
ok = &StatusError{Code: ssh_FX_OK}
eof = &StatusError{Code: ssh_FX_EOF}
fail = &StatusError{Code: ssh_FX_FAILURE}
noSuchFile = &StatusError{Code: ssh_FX_NO_SUCH_FILE}
foo = errors.New("foo")
)
var eofOrErrTests = []struct {
err, want error
var tests = []struct {
desc string
err error
want error
}{
{nil, nil},
{eof, io.EOF},
{ok, ok},
{io.EOF, io.EOF},
{
desc: "nil error",
},
{
desc: "not *StatusError",
err: foo,
want: foo,
},
{
desc: "*StatusError with ssh_FX_EOF",
err: eof,
want: io.EOF,
},
{
desc: "*StatusError with ssh_FX_NO_SUCH_FILE",
err: noSuchFile,
want: os.ErrNotExist,
},
{
desc: "*StatusError with ssh_FX_OK",
err: ok,
},
{
desc: "*StatusError with ssh_FX_FAILURE",
err: fail,
want: fail,
},
}
func TestEofOrErr(t *testing.T) {
for _, tt := range eofOrErrTests {
got := eofOrErr(tt.err)
for _, tt := range tests {
got := normaliseError(tt.err)
if got != tt.want {
t.Errorf("eofOrErr(%#v): want: %#v, got: %#v", tt.err, tt.want, got)
}
}
}
var okOrErrTests = []struct {
err, want error
}{
{nil, nil},
{eof, eof},
{ok, nil},
{io.EOF, io.EOF},
}
func TestOkOrErr(t *testing.T) {
for _, tt := range okOrErrTests {
got := okOrErr(tt.err)
if got != tt.want {
t.Errorf("okOrErr(%#v): want: %#v, got: %#v", tt.err, tt.want, got)
t.Errorf("normaliseError(%#v), test %q\n- want: %#v\n- got: %#v",
tt.err, tt.desc, tt.want, got)
}
}
}