mirror of https://github.com/pkg/sftp.git
WriteTo better, but not best, version
This commit is contained in:
parent
29c556e3a6
commit
64bc1f82e3
314
client.go
314
client.go
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
|
|
@ -995,7 +996,7 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) {
|
|||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
ch := make(chan result, 1) // reuse channel per mapper.
|
||||
ch := make(chan result, 1) // reusable channel per mapper.
|
||||
|
||||
for packet := range workCh {
|
||||
n, err := f.readChunkAt(ch, packet.b, packet.off)
|
||||
|
|
@ -1015,9 +1016,9 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) {
|
|||
}()
|
||||
|
||||
// collect all the results into a relevant return: the earliest offset to return an error.
|
||||
firstErr := rwErr{-1, nil}
|
||||
firstErr := rwErr{math.MaxInt64, nil}
|
||||
for rwErr := range errCh {
|
||||
if firstErr.off < 0 || rwErr.off < firstErr.off {
|
||||
if rwErr.off <= firstErr.off {
|
||||
firstErr = rwErr
|
||||
}
|
||||
|
||||
|
|
@ -1038,6 +1039,38 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) {
|
|||
return len(b), nil
|
||||
}
|
||||
|
||||
// writeToSimple implements WriteTo, but works sequentially with no parallelism.
|
||||
func (f *File) writeToSequential(w io.Writer) (written int64, err error) {
|
||||
b := make([]byte, f.c.maxPacket)
|
||||
ch := make(chan result, 1) // reusable channel
|
||||
|
||||
for { // Still, do this in a loop, just to be sure we read everything to io.EOF.
|
||||
n, err := f.readChunkAt(ch, b, int64(f.offset))
|
||||
if n < 0 {
|
||||
panic("sftp.File: returned negative count from readChunkAt")
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
f.offset += uint64(n)
|
||||
|
||||
m, err2 := w.Write(b[:n])
|
||||
written += int64(m)
|
||||
|
||||
if err == nil {
|
||||
err = err2
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return written, nil // return nil explicitly.
|
||||
}
|
||||
|
||||
return written, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WriteTo writes the file to w. The return value is the number of bytes
|
||||
// written. Any error encountered during the write is also returned.
|
||||
//
|
||||
|
|
@ -1060,130 +1093,96 @@ func (f *File) WriteTo(w io.Writer) (int64, error) {
|
|||
fileSize = uint64(fi.Size())
|
||||
}
|
||||
|
||||
inFlight := 0
|
||||
desiredInFlight := 1
|
||||
offset := f.offset
|
||||
writeOffset := offset
|
||||
// see comment on same line in Read() above
|
||||
ch := make(chan result, f.c.maxConcurrentRequests+1)
|
||||
type inflightRead struct {
|
||||
b []byte
|
||||
offset uint64
|
||||
}
|
||||
reqs := make(map[uint32]inflightRead)
|
||||
pendingWrites := make(map[uint64][]byte)
|
||||
type offsetErr struct {
|
||||
offset uint64
|
||||
err error
|
||||
}
|
||||
var firstErr offsetErr
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
sendReq := func(b []byte, offset uint64) {
|
||||
reqID := f.c.nextID()
|
||||
f.c.dispatchRequest(ch, sshFxpReadPacket{
|
||||
ID: reqID,
|
||||
Handle: f.handle,
|
||||
Offset: offset,
|
||||
Len: uint32(len(b)),
|
||||
})
|
||||
inFlight++
|
||||
reqs[reqID] = inflightRead{b: b, offset: offset}
|
||||
if fileSize <= uint64(f.c.maxPacket) {
|
||||
// We should be able to handle this in one Read.
|
||||
return f.writeToSequential(w)
|
||||
}
|
||||
|
||||
var copied int64
|
||||
for firstErr.err == nil || inFlight > 0 {
|
||||
if firstErr.err == nil {
|
||||
for inFlight+len(pendingWrites) < desiredInFlight {
|
||||
b := make([]byte, f.c.maxPacket)
|
||||
sendReq(b, offset)
|
||||
offset += uint64(f.c.maxPacket)
|
||||
if offset > fileSize {
|
||||
desiredInFlight = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
concurrency := int(fileSize/uint64(f.c.maxPacket) + 1) // bad guess, but better than no guess
|
||||
if concurrency > f.c.maxConcurrentRequests {
|
||||
concurrency = f.c.maxConcurrentRequests
|
||||
}
|
||||
|
||||
if inFlight == 0 {
|
||||
if firstErr.err == nil && len(pendingWrites) > 0 {
|
||||
return copied, ErrInternalInconsistency
|
||||
// if the writing Reduce phase has ended, then all work unconditionally needs to be thrown out.
|
||||
cancel := make(chan struct{})
|
||||
defer close(cancel)
|
||||
|
||||
type writeWork struct {
|
||||
b []byte
|
||||
n int
|
||||
next chan writeWork
|
||||
}
|
||||
writeCh := make(chan writeWork, 1)
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
// We should be able to handle this in one Read.
|
||||
ch := make(chan result, 1) // reusable channel
|
||||
|
||||
cur := writeCh
|
||||
|
||||
for { // Still, do this in a loop, just to be sure we read everything to io.EOF.
|
||||
b := make([]byte, f.c.maxPacket)
|
||||
|
||||
n, err := f.readChunkAt(ch, b, int64(f.offset))
|
||||
if n < 0 {
|
||||
panic("sftp.File: returned negative count from readChunkAt")
|
||||
}
|
||||
break
|
||||
}
|
||||
res := <-ch
|
||||
inFlight--
|
||||
if res.err != nil {
|
||||
firstErr = offsetErr{offset: 0, err: res.err}
|
||||
continue
|
||||
}
|
||||
reqID, data := unmarshalUint32(res.data)
|
||||
req, ok := reqs[reqID]
|
||||
if !ok {
|
||||
firstErr = offsetErr{offset: 0, err: errors.Errorf("sid: %v not found", reqID)}
|
||||
continue
|
||||
}
|
||||
delete(reqs, reqID)
|
||||
switch res.typ {
|
||||
case sshFxpStatus:
|
||||
if firstErr.err == nil || req.offset < firstErr.offset {
|
||||
firstErr = offsetErr{offset: req.offset, err: normaliseError(unmarshalStatus(reqID, res.data))}
|
||||
|
||||
if n > 0 {
|
||||
f.offset += uint64(n)
|
||||
|
||||
next := make(chan writeWork, 1)
|
||||
cur <- writeWork{
|
||||
b: b,
|
||||
n: n,
|
||||
next: next,
|
||||
}
|
||||
cur = next
|
||||
}
|
||||
case sshFxpData:
|
||||
l, data := unmarshalUint32(data)
|
||||
if req.offset == writeOffset {
|
||||
nbytes, err := w.Write(data)
|
||||
copied += int64(nbytes)
|
||||
if err != nil {
|
||||
// We will never receive another DATA with offset==writeOffset, so
|
||||
// the loop will drain inFlight and then exit.
|
||||
firstErr = offsetErr{offset: req.offset + uint64(nbytes), err: err}
|
||||
break
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
// Do not send io.EOF in this codepath,
|
||||
// it could erroneously end writes early.
|
||||
close(cur)
|
||||
return
|
||||
}
|
||||
if nbytes < int(l) {
|
||||
firstErr = offsetErr{offset: req.offset + uint64(nbytes), err: io.ErrShortWrite}
|
||||
break
|
||||
}
|
||||
switch {
|
||||
case offset > fileSize:
|
||||
desiredInFlight = 1
|
||||
case desiredInFlight < f.c.maxConcurrentRequests:
|
||||
desiredInFlight++
|
||||
}
|
||||
writeOffset += uint64(nbytes)
|
||||
for {
|
||||
pendingData, ok := pendingWrites[writeOffset]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
// Give go a chance to free the memory.
|
||||
delete(pendingWrites, writeOffset)
|
||||
nbytes, err := w.Write(pendingData)
|
||||
// Do not move writeOffset on error so subsequent iterations won't trigger
|
||||
// any writes.
|
||||
if err != nil {
|
||||
firstErr = offsetErr{offset: writeOffset + uint64(nbytes), err: err}
|
||||
break
|
||||
}
|
||||
if nbytes < len(pendingData) {
|
||||
firstErr = offsetErr{offset: writeOffset + uint64(nbytes), err: io.ErrShortWrite}
|
||||
break
|
||||
}
|
||||
writeOffset += uint64(nbytes)
|
||||
}
|
||||
} else {
|
||||
// Don't write the data yet because
|
||||
// this response came in out of order
|
||||
// and we need to wait for responses
|
||||
// for earlier segments of the file.
|
||||
pendingWrites[req.offset] = data
|
||||
|
||||
// Do not close(cur) in this codepath,
|
||||
// it could be erroneously interpreted as the EOF signal.
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
default:
|
||||
firstErr = offsetErr{offset: 0, err: unimplementedPacketErr(res.typ)}
|
||||
}
|
||||
}()
|
||||
|
||||
var written int64
|
||||
|
||||
cur := writeCh
|
||||
for {
|
||||
select {
|
||||
case packet, ok := <-cur:
|
||||
if !ok {
|
||||
return written, nil
|
||||
}
|
||||
|
||||
n, err := w.Write(packet.b[:packet.n])
|
||||
written += int64(n)
|
||||
if err != nil {
|
||||
close(cancel)
|
||||
return written, err
|
||||
}
|
||||
|
||||
cur = packet.next
|
||||
|
||||
case err := <-errCh:
|
||||
return written, err
|
||||
}
|
||||
}
|
||||
if firstErr.err != io.EOF {
|
||||
return copied, firstErr.err
|
||||
}
|
||||
return copied, nil
|
||||
}
|
||||
|
||||
// Stat returns the FileInfo structure describing file. If there is an
|
||||
|
|
@ -1303,7 +1302,7 @@ func (f *File) WriteAt(b []byte, off int64) (int, error) {
|
|||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
ch := make(chan result, 1) // reuse channel per mapper.
|
||||
ch := make(chan result, 1) // reusable channel per mapper.
|
||||
|
||||
for packet := range workCh {
|
||||
n, err := f.writeChunkAt(ch, packet.b, packet.off)
|
||||
|
|
@ -1322,9 +1321,9 @@ func (f *File) WriteAt(b []byte, off int64) (int, error) {
|
|||
}()
|
||||
|
||||
// collect all the results into a relevant return: the earliest offset to return an error.
|
||||
firstErr := rwErr{-1, nil}
|
||||
firstErr := rwErr{math.MaxInt64, nil}
|
||||
for rwErr := range errCh {
|
||||
if firstErr.off < 0 || rwErr.off < firstErr.off {
|
||||
if rwErr.off <= firstErr.off {
|
||||
firstErr = rwErr
|
||||
}
|
||||
|
||||
|
|
@ -1344,6 +1343,38 @@ func (f *File) WriteAt(b []byte, off int64) (int, error) {
|
|||
return len(b), nil
|
||||
}
|
||||
|
||||
// readFromSimple implements WriteTo, but works sequentially with no parallelism.
|
||||
func (f *File) readFromSequential(r io.Reader) (read int64, err error) {
|
||||
b := make([]byte, f.c.maxPacket)
|
||||
ch := make(chan result, 1) // reusable channel
|
||||
|
||||
for { // Still, do this in a loop, just to be sure we read everything to io.EOF.
|
||||
n, err := r.Read(b)
|
||||
if n < 0 {
|
||||
panic("sftp.File: reader returned negative count from Read")
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
read += int64(n)
|
||||
|
||||
m, err2 := f.writeChunkAt(ch, b[:n], int64(f.offset))
|
||||
f.offset += uint64(m)
|
||||
|
||||
if err == nil {
|
||||
err = err2
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return read, nil // return nil explicitly.
|
||||
}
|
||||
|
||||
return read, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReadFrom reads data from r until EOF and writes it to the file. The return
|
||||
// value is the number of bytes read. Any error except io.EOF encountered
|
||||
// during the read is also returned.
|
||||
|
|
@ -1371,34 +1402,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
|
|||
|
||||
if remain <= int64(f.c.maxPacket) {
|
||||
// Try and spot cases where we likely will only need to read and write once.
|
||||
b := make([]byte, f.c.maxPacket)
|
||||
ch := make(chan result, 1) // reuse channel
|
||||
|
||||
for { // Still, do this in a loop, just to be sure we read everything to io.EOF.
|
||||
n, err := r.Read(b)
|
||||
if n < 0 {
|
||||
panic("sftp.File: reader returned negative count from Read")
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
read += int64(n)
|
||||
|
||||
m, err2 := f.writeChunkAt(ch, b[:n], int64(f.offset))
|
||||
f.offset += uint64(m)
|
||||
|
||||
if err == nil {
|
||||
err = err2
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return read, nil // return nil explicitly.
|
||||
}
|
||||
|
||||
return read, err
|
||||
}
|
||||
}
|
||||
return f.readFromSequential(r)
|
||||
}
|
||||
|
||||
// Split the write into multiple maxPacket sized concurrent writes
|
||||
|
|
@ -1428,11 +1432,9 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
|
|||
}
|
||||
|
||||
// Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets.
|
||||
go func() {
|
||||
go func(offset int64) {
|
||||
defer close(workCh)
|
||||
|
||||
offset := int64(f.offset)
|
||||
|
||||
for {
|
||||
b := pool.Get().([]byte)
|
||||
n, err := r.Read(b)
|
||||
|
|
@ -1457,7 +1459,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
|
|||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}(int64(f.offset))
|
||||
|
||||
concurrency := int(remain/int64(f.c.maxPacket) + 1) // bad guess, but better than no guess
|
||||
if concurrency > f.c.maxConcurrentRequests {
|
||||
|
|
@ -1471,7 +1473,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
|
|||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
ch := make(chan result, 1) // reuse channel per mapper.
|
||||
ch := make(chan result, 1) // reusable channel per mapper.
|
||||
|
||||
for packet := range workCh {
|
||||
n, err := f.writeChunkAt(ch, packet.b[:packet.n], packet.off)
|
||||
|
|
@ -1491,16 +1493,16 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
|
|||
}()
|
||||
|
||||
// collect all the results into a relevant return: the earliest offset to return an error.
|
||||
firstErr := rwErr{-1, nil}
|
||||
firstErr := rwErr{math.MaxInt64, nil}
|
||||
for rwErr := range errCh {
|
||||
if firstErr.off < 0 || rwErr.off < firstErr.off {
|
||||
if rwErr.off <= firstErr.off {
|
||||
firstErr = rwErr
|
||||
}
|
||||
|
||||
select {
|
||||
case <-cancel:
|
||||
default:
|
||||
// stop any more work from being distributed. (Just in case.)
|
||||
// stop any more work from being distributed.
|
||||
close(cancel)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue