mirror of https://github.com/pkg/sftp.git
writeToSequential: improve tests for write errors
This commit is contained in:
parent
65f24bcee4
commit
c7fdf5e5c6
|
@ -1176,11 +1176,11 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) {
|
|||
if n > 0 {
|
||||
f.offset += int64(n)
|
||||
|
||||
m, wErr := w.Write(b[:n])
|
||||
m, err := w.Write(b[:n])
|
||||
written += int64(m)
|
||||
|
||||
if wErr != nil {
|
||||
return written, wErr
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1249,55 +1249,51 @@ func TestClientReadSequential(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// this writer requires maxPacket = 3 and always returns an error for the second write call
|
||||
type lastChunkErrSequentialWriter struct {
|
||||
expected int
|
||||
written int
|
||||
writtenReturn int
|
||||
counter int
|
||||
}
|
||||
|
||||
func (w *lastChunkErrSequentialWriter) Write(b []byte) (int, error) {
|
||||
chunkSize := len(b)
|
||||
w.written += chunkSize
|
||||
if w.written == w.expected {
|
||||
return w.writtenReturn, errors.New("test error")
|
||||
w.counter++
|
||||
if w.counter == 1 {
|
||||
if len(b) != 3 {
|
||||
return 0, errors.New("this writer requires maxPacket = 3, please set MaxPacketChecked(3)")
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
return chunkSize, nil
|
||||
return 1, errors.New("this writer fails after the first write")
|
||||
}
|
||||
|
||||
func TestClientWriteSequential_WriterErr(t *testing.T) {
|
||||
sftp, cmd := testClient(t, READONLY, NODELAY)
|
||||
func TestClientWriteSequentialWriterErr(t *testing.T) {
|
||||
client, cmd := testClient(t, READONLY, NODELAY, MaxPacketChecked(3))
|
||||
defer cmd.Wait()
|
||||
defer sftp.Close()
|
||||
defer client.Close()
|
||||
|
||||
d, err := ioutil.TempDir("", "sftptest-writesequential-writeerr")
|
||||
require.NoError(t, err)
|
||||
|
||||
defer os.RemoveAll(d)
|
||||
|
||||
var (
|
||||
content = []byte("hello world")
|
||||
shortWrite = 2
|
||||
)
|
||||
w := lastChunkErrSequentialWriter{
|
||||
expected: len(content),
|
||||
writtenReturn: shortWrite,
|
||||
}
|
||||
|
||||
f, err := ioutil.TempFile(d, "write-sequential-writeerr-test")
|
||||
require.NoError(t, err)
|
||||
fname := f.Name()
|
||||
n, err := f.Write(content)
|
||||
_, err = f.Write([]byte("12345"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, len(content))
|
||||
require.NoError(t, f.Close())
|
||||
|
||||
sftpFile, err := sftp.Open(fname)
|
||||
sftpFile, err := client.Open(fname)
|
||||
require.NoError(t, err)
|
||||
defer sftpFile.Close()
|
||||
|
||||
gotWritten, gotErr := sftpFile.writeToSequential(&w)
|
||||
require.NotErrorIs(t, io.EOF, gotErr)
|
||||
require.Equal(t, int64(shortWrite), gotWritten)
|
||||
w := &lastChunkErrSequentialWriter{}
|
||||
written, err := sftpFile.writeToSequential(w)
|
||||
assert.Error(t, err)
|
||||
expected := int64(4)
|
||||
if written != expected {
|
||||
t.Errorf("sftpFile.Write() = %d, but expected %d", written, expected)
|
||||
}
|
||||
assert.Equal(t, 2, w.counter)
|
||||
}
|
||||
|
||||
func TestClientReadDir(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue