WriteTo better, but not best, version

This commit is contained in:
Cassondra Foesch 2021-01-22 16:45:57 +00:00
parent 29c556e3a6
commit 64bc1f82e3
1 changed files with 158 additions and 156 deletions

314
client.go
View File

@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"io"
"math"
"os"
"path"
"sync"
@ -995,7 +996,7 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) {
go func() {
defer wg.Done()
ch := make(chan result, 1) // reuse channel per mapper.
ch := make(chan result, 1) // reusable channel per mapper.
for packet := range workCh {
n, err := f.readChunkAt(ch, packet.b, packet.off)
@ -1015,9 +1016,9 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) {
}()
// collect all the results into a relevant return: the earliest offset to return an error.
firstErr := rwErr{-1, nil}
firstErr := rwErr{math.MaxInt64, nil}
for rwErr := range errCh {
if firstErr.off < 0 || rwErr.off < firstErr.off {
if rwErr.off <= firstErr.off {
firstErr = rwErr
}
@ -1038,6 +1039,38 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) {
return len(b), nil
}
// writeToSimple implements WriteTo, but works sequentially with no parallelism.
func (f *File) writeToSequential(w io.Writer) (written int64, err error) {
b := make([]byte, f.c.maxPacket)
ch := make(chan result, 1) // reusable channel
for { // Still, do this in a loop, just to be sure we read everything to io.EOF.
n, err := f.readChunkAt(ch, b, int64(f.offset))
if n < 0 {
panic("sftp.File: returned negative count from readChunkAt")
}
if n > 0 {
f.offset += uint64(n)
m, err2 := w.Write(b[:n])
written += int64(m)
if err == nil {
err = err2
}
}
if err != nil {
if err == io.EOF {
return written, nil // return nil explicitly.
}
return written, err
}
}
}
// WriteTo writes the file to w. The return value is the number of bytes
// written. Any error encountered during the write is also returned.
//
@ -1060,130 +1093,96 @@ func (f *File) WriteTo(w io.Writer) (int64, error) {
fileSize = uint64(fi.Size())
}
inFlight := 0
desiredInFlight := 1
offset := f.offset
writeOffset := offset
// see comment on same line in Read() above
ch := make(chan result, f.c.maxConcurrentRequests+1)
type inflightRead struct {
b []byte
offset uint64
}
reqs := make(map[uint32]inflightRead)
pendingWrites := make(map[uint64][]byte)
type offsetErr struct {
offset uint64
err error
}
var firstErr offsetErr
f.mu.Lock()
defer f.mu.Unlock()
sendReq := func(b []byte, offset uint64) {
reqID := f.c.nextID()
f.c.dispatchRequest(ch, sshFxpReadPacket{
ID: reqID,
Handle: f.handle,
Offset: offset,
Len: uint32(len(b)),
})
inFlight++
reqs[reqID] = inflightRead{b: b, offset: offset}
if fileSize <= uint64(f.c.maxPacket) {
// We should be able to handle this in one Read.
return f.writeToSequential(w)
}
var copied int64
for firstErr.err == nil || inFlight > 0 {
if firstErr.err == nil {
for inFlight+len(pendingWrites) < desiredInFlight {
b := make([]byte, f.c.maxPacket)
sendReq(b, offset)
offset += uint64(f.c.maxPacket)
if offset > fileSize {
desiredInFlight = 1
}
}
}
concurrency := int(fileSize/uint64(f.c.maxPacket) + 1) // bad guess, but better than no guess
if concurrency > f.c.maxConcurrentRequests {
concurrency = f.c.maxConcurrentRequests
}
if inFlight == 0 {
if firstErr.err == nil && len(pendingWrites) > 0 {
return copied, ErrInternalInconsistency
// if the writing Reduce phase has ended, then all work unconditionally needs to be thrown out.
cancel := make(chan struct{})
defer close(cancel)
type writeWork struct {
b []byte
n int
next chan writeWork
}
writeCh := make(chan writeWork, 1)
errCh := make(chan error, 1)
go func() {
// We should be able to handle this in one Read.
ch := make(chan result, 1) // reusable channel
cur := writeCh
for { // Still, do this in a loop, just to be sure we read everything to io.EOF.
b := make([]byte, f.c.maxPacket)
n, err := f.readChunkAt(ch, b, int64(f.offset))
if n < 0 {
panic("sftp.File: returned negative count from readChunkAt")
}
break
}
res := <-ch
inFlight--
if res.err != nil {
firstErr = offsetErr{offset: 0, err: res.err}
continue
}
reqID, data := unmarshalUint32(res.data)
req, ok := reqs[reqID]
if !ok {
firstErr = offsetErr{offset: 0, err: errors.Errorf("sid: %v not found", reqID)}
continue
}
delete(reqs, reqID)
switch res.typ {
case sshFxpStatus:
if firstErr.err == nil || req.offset < firstErr.offset {
firstErr = offsetErr{offset: req.offset, err: normaliseError(unmarshalStatus(reqID, res.data))}
if n > 0 {
f.offset += uint64(n)
next := make(chan writeWork, 1)
cur <- writeWork{
b: b,
n: n,
next: next,
}
cur = next
}
case sshFxpData:
l, data := unmarshalUint32(data)
if req.offset == writeOffset {
nbytes, err := w.Write(data)
copied += int64(nbytes)
if err != nil {
// We will never receive another DATA with offset==writeOffset, so
// the loop will drain inFlight and then exit.
firstErr = offsetErr{offset: req.offset + uint64(nbytes), err: err}
break
if err != nil {
if err == io.EOF {
// Do not send io.EOF in this codepath,
// it could erroneously end writes early.
close(cur)
return
}
if nbytes < int(l) {
firstErr = offsetErr{offset: req.offset + uint64(nbytes), err: io.ErrShortWrite}
break
}
switch {
case offset > fileSize:
desiredInFlight = 1
case desiredInFlight < f.c.maxConcurrentRequests:
desiredInFlight++
}
writeOffset += uint64(nbytes)
for {
pendingData, ok := pendingWrites[writeOffset]
if !ok {
break
}
// Give go a chance to free the memory.
delete(pendingWrites, writeOffset)
nbytes, err := w.Write(pendingData)
// Do not move writeOffset on error so subsequent iterations won't trigger
// any writes.
if err != nil {
firstErr = offsetErr{offset: writeOffset + uint64(nbytes), err: err}
break
}
if nbytes < len(pendingData) {
firstErr = offsetErr{offset: writeOffset + uint64(nbytes), err: io.ErrShortWrite}
break
}
writeOffset += uint64(nbytes)
}
} else {
// Don't write the data yet because
// this response came in out of order
// and we need to wait for responses
// for earlier segments of the file.
pendingWrites[req.offset] = data
// Do not close(cur) in this codepath,
// it could be erroneously interpreted as the EOF signal.
errCh <- err
return
}
default:
firstErr = offsetErr{offset: 0, err: unimplementedPacketErr(res.typ)}
}
}()
var written int64
cur := writeCh
for {
select {
case packet, ok := <-cur:
if !ok {
return written, nil
}
n, err := w.Write(packet.b[:packet.n])
written += int64(n)
if err != nil {
close(cancel)
return written, err
}
cur = packet.next
case err := <-errCh:
return written, err
}
}
if firstErr.err != io.EOF {
return copied, firstErr.err
}
return copied, nil
}
// Stat returns the FileInfo structure describing file. If there is an
@ -1303,7 +1302,7 @@ func (f *File) WriteAt(b []byte, off int64) (int, error) {
go func() {
defer wg.Done()
ch := make(chan result, 1) // reuse channel per mapper.
ch := make(chan result, 1) // reusable channel per mapper.
for packet := range workCh {
n, err := f.writeChunkAt(ch, packet.b, packet.off)
@ -1322,9 +1321,9 @@ func (f *File) WriteAt(b []byte, off int64) (int, error) {
}()
// collect all the results into a relevant return: the earliest offset to return an error.
firstErr := rwErr{-1, nil}
firstErr := rwErr{math.MaxInt64, nil}
for rwErr := range errCh {
if firstErr.off < 0 || rwErr.off < firstErr.off {
if rwErr.off <= firstErr.off {
firstErr = rwErr
}
@ -1344,6 +1343,38 @@ func (f *File) WriteAt(b []byte, off int64) (int, error) {
return len(b), nil
}
// readFromSimple implements WriteTo, but works sequentially with no parallelism.
func (f *File) readFromSequential(r io.Reader) (read int64, err error) {
b := make([]byte, f.c.maxPacket)
ch := make(chan result, 1) // reusable channel
for { // Still, do this in a loop, just to be sure we read everything to io.EOF.
n, err := r.Read(b)
if n < 0 {
panic("sftp.File: reader returned negative count from Read")
}
if n > 0 {
read += int64(n)
m, err2 := f.writeChunkAt(ch, b[:n], int64(f.offset))
f.offset += uint64(m)
if err == nil {
err = err2
}
}
if err != nil {
if err == io.EOF {
return read, nil // return nil explicitly.
}
return read, err
}
}
}
// ReadFrom reads data from r until EOF and writes it to the file. The return
// value is the number of bytes read. Any error except io.EOF encountered
// during the read is also returned.
@ -1371,34 +1402,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
if remain <= int64(f.c.maxPacket) {
// Try and spot cases where we likely will only need to read and write once.
b := make([]byte, f.c.maxPacket)
ch := make(chan result, 1) // reuse channel
for { // Still, do this in a loop, just to be sure we read everything to io.EOF.
n, err := r.Read(b)
if n < 0 {
panic("sftp.File: reader returned negative count from Read")
}
if n > 0 {
read += int64(n)
m, err2 := f.writeChunkAt(ch, b[:n], int64(f.offset))
f.offset += uint64(m)
if err == nil {
err = err2
}
}
if err != nil {
if err == io.EOF {
return read, nil // return nil explicitly.
}
return read, err
}
}
return f.readFromSequential(r)
}
// Split the write into multiple maxPacket sized concurrent writes
@ -1428,11 +1432,9 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
}
// Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets.
go func() {
go func(offset int64) {
defer close(workCh)
offset := int64(f.offset)
for {
b := pool.Get().([]byte)
n, err := r.Read(b)
@ -1457,7 +1459,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
return
}
}
}()
}(int64(f.offset))
concurrency := int(remain/int64(f.c.maxPacket) + 1) // bad guess, but better than no guess
if concurrency > f.c.maxConcurrentRequests {
@ -1471,7 +1473,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
go func() {
defer wg.Done()
ch := make(chan result, 1) // reuse channel per mapper.
ch := make(chan result, 1) // reusable channel per mapper.
for packet := range workCh {
n, err := f.writeChunkAt(ch, packet.b[:packet.n], packet.off)
@ -1491,16 +1493,16 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
}()
// collect all the results into a relevant return: the earliest offset to return an error.
firstErr := rwErr{-1, nil}
firstErr := rwErr{math.MaxInt64, nil}
for rwErr := range errCh {
if firstErr.off < 0 || rwErr.off < firstErr.off {
if rwErr.off <= firstErr.off {
firstErr = rwErr
}
select {
case <-cancel:
default:
// stop any more work from being distributed. (Just in case.)
// stop any more work from being distributed.
close(cancel)
}
}