mirror of https://github.com/pkg/sftp.git
only check OpenFileWriter interface if Read flags is true
Improve memory handler and some test case Improve nil check for Close and TransferError interfaces
This commit is contained in:
parent
ea67d57ce5
commit
9284f1d6ac
|
@ -9,6 +9,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
@ -47,7 +48,7 @@ func (fs *root) Fileread(r *Request) (io.ReaderAt, error) {
|
|||
return file.ReaderAt()
|
||||
}
|
||||
|
||||
func (fs *root) Filewrite(r *Request) (io.WriterAt, error) {
|
||||
func (fs *root) getFileForWrite(r *Request) (*memFile, error) {
|
||||
if fs.mockErr != nil {
|
||||
return nil, fs.mockErr
|
||||
}
|
||||
|
@ -56,29 +57,7 @@ func (fs *root) Filewrite(r *Request) (io.WriterAt, error) {
|
|||
defer fs.filesLock.Unlock()
|
||||
file, err := fs.fetch(r.Filepath)
|
||||
if err == os.ErrNotExist {
|
||||
dir, err := fs.fetch(filepath.Dir(r.Filepath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !dir.isdir {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
file = newMemFile(r.Filepath, false)
|
||||
fs.files[r.Filepath] = file
|
||||
}
|
||||
return file.WriterAt()
|
||||
}
|
||||
|
||||
func (fs *root) OpenFile(r *Request) (WriterAtReaderAt, error) {
|
||||
if fs.mockErr != nil {
|
||||
return nil, fs.mockErr
|
||||
}
|
||||
_ = r.WithContext(r.Context()) // initialize context for deadlock testing
|
||||
fs.filesLock.Lock()
|
||||
defer fs.filesLock.Unlock()
|
||||
file, err := fs.fetch(r.Filepath)
|
||||
if err == os.ErrNotExist {
|
||||
dir, err := fs.fetch(filepath.Dir(r.Filepath))
|
||||
dir, err := fs.fetch(path.Dir(r.Filepath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -91,6 +70,18 @@ func (fs *root) OpenFile(r *Request) (WriterAtReaderAt, error) {
|
|||
return file, nil
|
||||
}
|
||||
|
||||
func (fs *root) Filewrite(r *Request) (io.WriterAt, error) {
|
||||
file, err := fs.getFileForWrite(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return file.WriterAt()
|
||||
}
|
||||
|
||||
func (fs *root) OpenFile(r *Request) (WriterAtReaderAt, error) {
|
||||
return fs.getFileForWrite(r)
|
||||
}
|
||||
|
||||
func (fs *root) Filecmd(r *Request) error {
|
||||
if fs.mockErr != nil {
|
||||
return fs.mockErr
|
||||
|
|
|
@ -80,7 +80,7 @@ func TestRequestSplitWrite(t *testing.T) {
|
|||
p := clientRequestServerPair(t)
|
||||
defer p.Close()
|
||||
w, err := p.cli.Create("/foo")
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
p.cli.maxPacket = 3 // force it to send in small chunks
|
||||
contents := "one two three four five six seven eight nine ten"
|
||||
w.Write([]byte(contents))
|
||||
|
|
29
request.go
29
request.go
|
@ -6,6 +6,7 @@ import (
|
|||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
|
@ -143,21 +144,21 @@ func (r *Request) close() error {
|
|||
r.state.RLock()
|
||||
wr := r.state.writerAt
|
||||
rd := r.state.readerAt
|
||||
writerReader := r.state.writerReaderAt
|
||||
rw := r.state.writerReaderAt
|
||||
r.state.RUnlock()
|
||||
|
||||
var err error
|
||||
|
||||
// Close errors on a Writer are far more likely to be the important one.
|
||||
// As they can be information that there was a loss of data.
|
||||
if c, ok := wr.(io.Closer); ok {
|
||||
if c, ok := wr.(io.Closer); ok && c != nil && !reflect.ValueOf(c).IsNil() {
|
||||
if err2 := c.Close(); err == nil {
|
||||
// update error if it is still nil
|
||||
err = err2
|
||||
}
|
||||
}
|
||||
|
||||
if c, ok := writerReader.(io.Closer); ok {
|
||||
if c, ok := rw.(io.Closer); ok && c != nil && !reflect.ValueOf(c).IsNil() {
|
||||
if err2 := c.Close(); err == nil {
|
||||
// update error if it is still nil
|
||||
err = err2
|
||||
|
@ -165,7 +166,7 @@ func (r *Request) close() error {
|
|||
}
|
||||
}
|
||||
|
||||
if c, ok := rd.(io.Closer); ok {
|
||||
if c, ok := rd.(io.Closer); ok && c != nil && !reflect.ValueOf(c).IsNil() {
|
||||
if err2 := c.Close(); err == nil {
|
||||
// update error if it is still nil
|
||||
err = err2
|
||||
|
@ -184,18 +185,18 @@ func (r *Request) transferError(err error) {
|
|||
r.state.RLock()
|
||||
wr := r.state.writerAt
|
||||
rd := r.state.readerAt
|
||||
writerReader := r.state.writerReaderAt
|
||||
rw := r.state.writerReaderAt
|
||||
r.state.RUnlock()
|
||||
|
||||
if t, ok := writerReader.(TransferError); ok {
|
||||
if t, ok := wr.(TransferError); ok && t != nil && !reflect.ValueOf(t).IsNil() {
|
||||
t.TransferError(err)
|
||||
}
|
||||
|
||||
if t, ok := wr.(TransferError); ok {
|
||||
if t, ok := rw.(TransferError); ok && t != nil && !reflect.ValueOf(t).IsNil() {
|
||||
t.TransferError(err)
|
||||
}
|
||||
|
||||
if t, ok := rd.(TransferError); ok {
|
||||
if t, ok := rd.(TransferError); ok && t != nil && !reflect.ValueOf(t).IsNil() {
|
||||
t.TransferError(err)
|
||||
}
|
||||
}
|
||||
|
@ -227,10 +228,14 @@ func (r *Request) open(h Handlers, pkt requestPacket) responsePacket {
|
|||
var err error
|
||||
switch {
|
||||
case flags.Write, flags.Append, flags.Creat, flags.Trunc:
|
||||
if openFileWriter, ok := h.FilePut.(OpenFileWriter); ok && flags.Read {
|
||||
r.Method = "Open"
|
||||
r.state.writerReaderAt, err = openFileWriter.OpenFile(r)
|
||||
} else {
|
||||
if flags.Read {
|
||||
openFileWriter, ok := h.FilePut.(OpenFileWriter)
|
||||
if ok {
|
||||
r.Method = "Open"
|
||||
r.state.writerReaderAt, err = openFileWriter.OpenFile(r)
|
||||
}
|
||||
}
|
||||
if r.Method == "" {
|
||||
r.Method = "Put"
|
||||
r.state.writerAt, err = h.FilePut.Filewrite(r)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue