diff --git a/client.go b/client.go index e21eccb..4899b89 100644 --- a/client.go +++ b/client.go @@ -57,7 +57,7 @@ func (c *Client) Close() error { return c.w.Close() } // it already exists. If successful, methods on the returned File can be // used for I/O; the associated file descriptor has mode O_RDWR. func (c *Client) Create(path string) (*File, error) { - return c.open(path, ssh_FXF_READ|ssh_FXF_WRITE|ssh_FXF_CREAT|ssh_FXF_TRUNC) + return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC)) } func (c *Client) sendInit() error { @@ -228,7 +228,14 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) { // returned file can be used for reading; the associated file descriptor // has mode O_RDONLY. func (c *Client) Open(path string) (*File, error) { - return c.open(path, ssh_FXF_READ) + return c.open(path, flags(os.O_RDONLY)) +} + +// OpenFile is the generalized open call; most users will use Open or +// Create instead. It opens the named file with specified flag (O_RDONLY +// etc.). If successful, methods on the returned File can be used for I/O. +func (c *Client) OpenFile(path string, f int) (*File, error) { + return c.open(path, flags(f)) } func (c *Client) open(path string, pflags uint32) (*File, error) { @@ -566,3 +573,31 @@ func unmarshalStatus(id uint32, data []byte) error { lang: lang, } } + +// flags converts the flags passed to OpenFile into ssh flags. +// Unsupported flags are ignored. +func flags(f int) uint32 { + var out uint32 + switch f & os.O_WRONLY { + case os.O_WRONLY: + out |= ssh_FXF_WRITE + case os.O_RDONLY: + out |= ssh_FXF_READ + } + if f&os.O_RDWR == os.O_RDWR { + out |= ssh_FXF_READ | ssh_FXF_WRITE + } + if f&os.O_APPEND == os.O_APPEND { + out |= ssh_FXF_APPEND + } + if f&os.O_CREATE == os.O_CREATE { + out |= ssh_FXF_CREAT + } + if f&os.O_TRUNC == os.O_TRUNC { + out |= ssh_FXF_TRUNC + } + if f&os.O_EXCL == os.O_EXCL { + out |= ssh_FXF_EXCL + } + return out +} diff --git a/client_integration_test.go b/client_integration_test.go index 69f2929..cc46f72 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -233,6 +233,25 @@ func TestClientCreate(t *testing.T) { defer f2.Close() } +func TestClientAppend(t *testing.T) { + sftp, cmd := testClient(t, READWRITE) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer f.Close() + defer os.Remove(f.Name()) + + f2, err := sftp.OpenFile(f.Name(), os.O_RDWR|os.O_APPEND) + if err != nil { + t.Fatal(err) + } + defer f2.Close() +} + func TestClientCreateFailed(t *testing.T) { sftp, cmd := testClient(t, READONLY) defer cmd.Wait() diff --git a/client_test.go b/client_test.go index bbdf65f..9ade6d1 100644 --- a/client_test.go +++ b/client_test.go @@ -2,6 +2,7 @@ package sftp import ( "io" + "os" "testing" "github.com/kr/fs" @@ -10,6 +11,9 @@ import ( // assert that *Client implements fs.FileSystem var _ fs.FileSystem = new(Client) +// assert that *File implements io.ReadWriteCloser +var _ io.ReadWriteCloser = new(File) + var ok = &StatusError{Code: ssh_FX_OK} var eof = &StatusError{Code: ssh_FX_EOF} var fail = &StatusError{Code: ssh_FX_FAILURE} @@ -49,3 +53,23 @@ func TestOkOrErr(t *testing.T) { } } } + +var flagsTests = []struct { + flags int + want uint32 +}{ + {os.O_RDONLY, ssh_FXF_READ}, + {os.O_WRONLY, ssh_FXF_WRITE}, + {os.O_RDWR, ssh_FXF_READ | ssh_FXF_WRITE}, + {os.O_RDWR | os.O_CREATE | os.O_TRUNC, ssh_FXF_READ | ssh_FXF_WRITE | ssh_FXF_CREAT | ssh_FXF_TRUNC}, + {os.O_WRONLY | os.O_APPEND, ssh_FXF_WRITE | ssh_FXF_APPEND}, +} + +func TestFlags(t *testing.T) { + for i, tt := range flagsTests { + got := flags(tt.flags) + if got != tt.want { + t.Errorf("test %v: flags(%x): want: %x, got: %x", i, tt.flags, tt.want, got) + } + } +} diff --git a/sftp.go b/sftp.go index ceff232..7f560fc 100644 --- a/sftp.go +++ b/sftp.go @@ -155,10 +155,8 @@ func (u *unexpectedPacketErr) Error() string { return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", fxp(u.want), fxp(u.got)) } -type unimplementedPacketErr uint8 - -func (u unimplementedPacketErr) Error() string { - return fmt.Sprintf("sftp: unimplemented packet type: got %v", fxp(u)) +func unimplementedPacketErr(u uint8) error { + return fmt.Errorf("sftp: unimplemented packet type: got %v", fxp(u)) } type unexpectedIdErr struct{ want, got uint32 }