mirror of https://github.com/pkg/sftp.git
only check OpenFileWriter interface if Read flags is true
Improve memory handler and some test case Improve nil check for Close and TransferError interfaces
This commit is contained in:
parent
ea67d57ce5
commit
9284f1d6ac
|
@ -9,6 +9,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -47,7 +48,7 @@ func (fs *root) Fileread(r *Request) (io.ReaderAt, error) {
|
||||||
return file.ReaderAt()
|
return file.ReaderAt()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fs *root) Filewrite(r *Request) (io.WriterAt, error) {
|
func (fs *root) getFileForWrite(r *Request) (*memFile, error) {
|
||||||
if fs.mockErr != nil {
|
if fs.mockErr != nil {
|
||||||
return nil, fs.mockErr
|
return nil, fs.mockErr
|
||||||
}
|
}
|
||||||
|
@ -56,29 +57,7 @@ func (fs *root) Filewrite(r *Request) (io.WriterAt, error) {
|
||||||
defer fs.filesLock.Unlock()
|
defer fs.filesLock.Unlock()
|
||||||
file, err := fs.fetch(r.Filepath)
|
file, err := fs.fetch(r.Filepath)
|
||||||
if err == os.ErrNotExist {
|
if err == os.ErrNotExist {
|
||||||
dir, err := fs.fetch(filepath.Dir(r.Filepath))
|
dir, err := fs.fetch(path.Dir(r.Filepath))
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if !dir.isdir {
|
|
||||||
return nil, os.ErrInvalid
|
|
||||||
}
|
|
||||||
file = newMemFile(r.Filepath, false)
|
|
||||||
fs.files[r.Filepath] = file
|
|
||||||
}
|
|
||||||
return file.WriterAt()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fs *root) OpenFile(r *Request) (WriterAtReaderAt, error) {
|
|
||||||
if fs.mockErr != nil {
|
|
||||||
return nil, fs.mockErr
|
|
||||||
}
|
|
||||||
_ = r.WithContext(r.Context()) // initialize context for deadlock testing
|
|
||||||
fs.filesLock.Lock()
|
|
||||||
defer fs.filesLock.Unlock()
|
|
||||||
file, err := fs.fetch(r.Filepath)
|
|
||||||
if err == os.ErrNotExist {
|
|
||||||
dir, err := fs.fetch(filepath.Dir(r.Filepath))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -91,6 +70,18 @@ func (fs *root) OpenFile(r *Request) (WriterAtReaderAt, error) {
|
||||||
return file, nil
|
return file, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (fs *root) Filewrite(r *Request) (io.WriterAt, error) {
|
||||||
|
file, err := fs.getFileForWrite(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return file.WriterAt()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *root) OpenFile(r *Request) (WriterAtReaderAt, error) {
|
||||||
|
return fs.getFileForWrite(r)
|
||||||
|
}
|
||||||
|
|
||||||
func (fs *root) Filecmd(r *Request) error {
|
func (fs *root) Filecmd(r *Request) error {
|
||||||
if fs.mockErr != nil {
|
if fs.mockErr != nil {
|
||||||
return fs.mockErr
|
return fs.mockErr
|
||||||
|
|
|
@ -80,7 +80,7 @@ func TestRequestSplitWrite(t *testing.T) {
|
||||||
p := clientRequestServerPair(t)
|
p := clientRequestServerPair(t)
|
||||||
defer p.Close()
|
defer p.Close()
|
||||||
w, err := p.cli.Create("/foo")
|
w, err := p.cli.Create("/foo")
|
||||||
assert.Nil(t, err)
|
require.Nil(t, err)
|
||||||
p.cli.maxPacket = 3 // force it to send in small chunks
|
p.cli.maxPacket = 3 // force it to send in small chunks
|
||||||
contents := "one two three four five six seven eight nine ten"
|
contents := "one two three four five six seven eight nine ten"
|
||||||
w.Write([]byte(contents))
|
w.Write([]byte(contents))
|
||||||
|
|
29
request.go
29
request.go
|
@ -6,6 +6,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
@ -143,21 +144,21 @@ func (r *Request) close() error {
|
||||||
r.state.RLock()
|
r.state.RLock()
|
||||||
wr := r.state.writerAt
|
wr := r.state.writerAt
|
||||||
rd := r.state.readerAt
|
rd := r.state.readerAt
|
||||||
writerReader := r.state.writerReaderAt
|
rw := r.state.writerReaderAt
|
||||||
r.state.RUnlock()
|
r.state.RUnlock()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Close errors on a Writer are far more likely to be the important one.
|
// Close errors on a Writer are far more likely to be the important one.
|
||||||
// As they can be information that there was a loss of data.
|
// As they can be information that there was a loss of data.
|
||||||
if c, ok := wr.(io.Closer); ok {
|
if c, ok := wr.(io.Closer); ok && c != nil && !reflect.ValueOf(c).IsNil() {
|
||||||
if err2 := c.Close(); err == nil {
|
if err2 := c.Close(); err == nil {
|
||||||
// update error if it is still nil
|
// update error if it is still nil
|
||||||
err = err2
|
err = err2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if c, ok := writerReader.(io.Closer); ok {
|
if c, ok := rw.(io.Closer); ok && c != nil && !reflect.ValueOf(c).IsNil() {
|
||||||
if err2 := c.Close(); err == nil {
|
if err2 := c.Close(); err == nil {
|
||||||
// update error if it is still nil
|
// update error if it is still nil
|
||||||
err = err2
|
err = err2
|
||||||
|
@ -165,7 +166,7 @@ func (r *Request) close() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if c, ok := rd.(io.Closer); ok {
|
if c, ok := rd.(io.Closer); ok && c != nil && !reflect.ValueOf(c).IsNil() {
|
||||||
if err2 := c.Close(); err == nil {
|
if err2 := c.Close(); err == nil {
|
||||||
// update error if it is still nil
|
// update error if it is still nil
|
||||||
err = err2
|
err = err2
|
||||||
|
@ -184,18 +185,18 @@ func (r *Request) transferError(err error) {
|
||||||
r.state.RLock()
|
r.state.RLock()
|
||||||
wr := r.state.writerAt
|
wr := r.state.writerAt
|
||||||
rd := r.state.readerAt
|
rd := r.state.readerAt
|
||||||
writerReader := r.state.writerReaderAt
|
rw := r.state.writerReaderAt
|
||||||
r.state.RUnlock()
|
r.state.RUnlock()
|
||||||
|
|
||||||
if t, ok := writerReader.(TransferError); ok {
|
if t, ok := wr.(TransferError); ok && t != nil && !reflect.ValueOf(t).IsNil() {
|
||||||
t.TransferError(err)
|
t.TransferError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if t, ok := wr.(TransferError); ok {
|
if t, ok := rw.(TransferError); ok && t != nil && !reflect.ValueOf(t).IsNil() {
|
||||||
t.TransferError(err)
|
t.TransferError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if t, ok := rd.(TransferError); ok {
|
if t, ok := rd.(TransferError); ok && t != nil && !reflect.ValueOf(t).IsNil() {
|
||||||
t.TransferError(err)
|
t.TransferError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -227,10 +228,14 @@ func (r *Request) open(h Handlers, pkt requestPacket) responsePacket {
|
||||||
var err error
|
var err error
|
||||||
switch {
|
switch {
|
||||||
case flags.Write, flags.Append, flags.Creat, flags.Trunc:
|
case flags.Write, flags.Append, flags.Creat, flags.Trunc:
|
||||||
if openFileWriter, ok := h.FilePut.(OpenFileWriter); ok && flags.Read {
|
if flags.Read {
|
||||||
r.Method = "Open"
|
openFileWriter, ok := h.FilePut.(OpenFileWriter)
|
||||||
r.state.writerReaderAt, err = openFileWriter.OpenFile(r)
|
if ok {
|
||||||
} else {
|
r.Method = "Open"
|
||||||
|
r.state.writerReaderAt, err = openFileWriter.OpenFile(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.Method == "" {
|
||||||
r.Method = "Put"
|
r.Method = "Put"
|
||||||
r.state.writerAt, err = h.FilePut.Filewrite(r)
|
r.state.writerAt, err = h.FilePut.Filewrite(r)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue