From 64bc1f82e3deb974b3d906cb09e290beb64ac247 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 22 Jan 2021 16:45:57 +0000 Subject: [PATCH] WriteTo better, but not best, version --- client.go | 314 +++++++++++++++++++++++++++--------------------------- 1 file changed, 158 insertions(+), 156 deletions(-) diff --git a/client.go b/client.go index e37779d..c2d2618 100644 --- a/client.go +++ b/client.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "io" + "math" "os" "path" "sync" @@ -995,7 +996,7 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) { go func() { defer wg.Done() - ch := make(chan result, 1) // reuse channel per mapper. + ch := make(chan result, 1) // reusable channel per mapper. for packet := range workCh { n, err := f.readChunkAt(ch, packet.b, packet.off) @@ -1015,9 +1016,9 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) { }() // collect all the results into a relevant return: the earliest offset to return an error. - firstErr := rwErr{-1, nil} + firstErr := rwErr{math.MaxInt64, nil} for rwErr := range errCh { - if firstErr.off < 0 || rwErr.off < firstErr.off { + if rwErr.off <= firstErr.off { firstErr = rwErr } @@ -1038,6 +1039,38 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) { return len(b), nil } +// writeToSimple implements WriteTo, but works sequentially with no parallelism. +func (f *File) writeToSequential(w io.Writer) (written int64, err error) { + b := make([]byte, f.c.maxPacket) + ch := make(chan result, 1) // reusable channel + + for { // Still, do this in a loop, just to be sure we read everything to io.EOF. + n, err := f.readChunkAt(ch, b, int64(f.offset)) + if n < 0 { + panic("sftp.File: returned negative count from readChunkAt") + } + + if n > 0 { + f.offset += uint64(n) + + m, err2 := w.Write(b[:n]) + written += int64(m) + + if err == nil { + err = err2 + } + } + + if err != nil { + if err == io.EOF { + return written, nil // return nil explicitly. + } + + return written, err + } + } +} + // WriteTo writes the file to w. The return value is the number of bytes // written. Any error encountered during the write is also returned. // @@ -1060,130 +1093,96 @@ func (f *File) WriteTo(w io.Writer) (int64, error) { fileSize = uint64(fi.Size()) } - inFlight := 0 - desiredInFlight := 1 - offset := f.offset - writeOffset := offset - // see comment on same line in Read() above - ch := make(chan result, f.c.maxConcurrentRequests+1) - type inflightRead struct { - b []byte - offset uint64 - } - reqs := make(map[uint32]inflightRead) - pendingWrites := make(map[uint64][]byte) - type offsetErr struct { - offset uint64 - err error - } - var firstErr offsetErr + f.mu.Lock() + defer f.mu.Unlock() - sendReq := func(b []byte, offset uint64) { - reqID := f.c.nextID() - f.c.dispatchRequest(ch, sshFxpReadPacket{ - ID: reqID, - Handle: f.handle, - Offset: offset, - Len: uint32(len(b)), - }) - inFlight++ - reqs[reqID] = inflightRead{b: b, offset: offset} + if fileSize <= uint64(f.c.maxPacket) { + // We should be able to handle this in one Read. + return f.writeToSequential(w) } - var copied int64 - for firstErr.err == nil || inFlight > 0 { - if firstErr.err == nil { - for inFlight+len(pendingWrites) < desiredInFlight { - b := make([]byte, f.c.maxPacket) - sendReq(b, offset) - offset += uint64(f.c.maxPacket) - if offset > fileSize { - desiredInFlight = 1 - } - } - } + concurrency := int(fileSize/uint64(f.c.maxPacket) + 1) // bad guess, but better than no guess + if concurrency > f.c.maxConcurrentRequests { + concurrency = f.c.maxConcurrentRequests + } - if inFlight == 0 { - if firstErr.err == nil && len(pendingWrites) > 0 { - return copied, ErrInternalInconsistency + // if the writing Reduce phase has ended, then all work unconditionally needs to be thrown out. + cancel := make(chan struct{}) + defer close(cancel) + + type writeWork struct { + b []byte + n int + next chan writeWork + } + writeCh := make(chan writeWork, 1) + errCh := make(chan error, 1) + + go func() { + // We should be able to handle this in one Read. + ch := make(chan result, 1) // reusable channel + + cur := writeCh + + for { // Still, do this in a loop, just to be sure we read everything to io.EOF. + b := make([]byte, f.c.maxPacket) + + n, err := f.readChunkAt(ch, b, int64(f.offset)) + if n < 0 { + panic("sftp.File: returned negative count from readChunkAt") } - break - } - res := <-ch - inFlight-- - if res.err != nil { - firstErr = offsetErr{offset: 0, err: res.err} - continue - } - reqID, data := unmarshalUint32(res.data) - req, ok := reqs[reqID] - if !ok { - firstErr = offsetErr{offset: 0, err: errors.Errorf("sid: %v not found", reqID)} - continue - } - delete(reqs, reqID) - switch res.typ { - case sshFxpStatus: - if firstErr.err == nil || req.offset < firstErr.offset { - firstErr = offsetErr{offset: req.offset, err: normaliseError(unmarshalStatus(reqID, res.data))} + + if n > 0 { + f.offset += uint64(n) + + next := make(chan writeWork, 1) + cur <- writeWork{ + b: b, + n: n, + next: next, + } + cur = next } - case sshFxpData: - l, data := unmarshalUint32(data) - if req.offset == writeOffset { - nbytes, err := w.Write(data) - copied += int64(nbytes) - if err != nil { - // We will never receive another DATA with offset==writeOffset, so - // the loop will drain inFlight and then exit. - firstErr = offsetErr{offset: req.offset + uint64(nbytes), err: err} - break + + if err != nil { + if err == io.EOF { + // Do not send io.EOF in this codepath, + // it could erroneously end writes early. + close(cur) + return } - if nbytes < int(l) { - firstErr = offsetErr{offset: req.offset + uint64(nbytes), err: io.ErrShortWrite} - break - } - switch { - case offset > fileSize: - desiredInFlight = 1 - case desiredInFlight < f.c.maxConcurrentRequests: - desiredInFlight++ - } - writeOffset += uint64(nbytes) - for { - pendingData, ok := pendingWrites[writeOffset] - if !ok { - break - } - // Give go a chance to free the memory. - delete(pendingWrites, writeOffset) - nbytes, err := w.Write(pendingData) - // Do not move writeOffset on error so subsequent iterations won't trigger - // any writes. - if err != nil { - firstErr = offsetErr{offset: writeOffset + uint64(nbytes), err: err} - break - } - if nbytes < len(pendingData) { - firstErr = offsetErr{offset: writeOffset + uint64(nbytes), err: io.ErrShortWrite} - break - } - writeOffset += uint64(nbytes) - } - } else { - // Don't write the data yet because - // this response came in out of order - // and we need to wait for responses - // for earlier segments of the file. - pendingWrites[req.offset] = data + + // Do not close(cur) in this codepath, + // it could be erroneously interpreted as the EOF signal. + errCh <- err + return } - default: - firstErr = offsetErr{offset: 0, err: unimplementedPacketErr(res.typ)} + } + }() + + var written int64 + + cur := writeCh + for { + select { + case packet, ok := <-cur: + if !ok { + return written, nil + } + + n, err := w.Write(packet.b[:packet.n]) + written += int64(n) + if err != nil { + close(cancel) + return written, err + } + + cur = packet.next + + case err := <-errCh: + return written, err } } - if firstErr.err != io.EOF { - return copied, firstErr.err - } - return copied, nil } // Stat returns the FileInfo structure describing file. If there is an @@ -1303,7 +1302,7 @@ func (f *File) WriteAt(b []byte, off int64) (int, error) { go func() { defer wg.Done() - ch := make(chan result, 1) // reuse channel per mapper. + ch := make(chan result, 1) // reusable channel per mapper. for packet := range workCh { n, err := f.writeChunkAt(ch, packet.b, packet.off) @@ -1322,9 +1321,9 @@ func (f *File) WriteAt(b []byte, off int64) (int, error) { }() // collect all the results into a relevant return: the earliest offset to return an error. - firstErr := rwErr{-1, nil} + firstErr := rwErr{math.MaxInt64, nil} for rwErr := range errCh { - if firstErr.off < 0 || rwErr.off < firstErr.off { + if rwErr.off <= firstErr.off { firstErr = rwErr } @@ -1344,6 +1343,38 @@ func (f *File) WriteAt(b []byte, off int64) (int, error) { return len(b), nil } +// readFromSimple implements WriteTo, but works sequentially with no parallelism. +func (f *File) readFromSequential(r io.Reader) (read int64, err error) { + b := make([]byte, f.c.maxPacket) + ch := make(chan result, 1) // reusable channel + + for { // Still, do this in a loop, just to be sure we read everything to io.EOF. + n, err := r.Read(b) + if n < 0 { + panic("sftp.File: reader returned negative count from Read") + } + + if n > 0 { + read += int64(n) + + m, err2 := f.writeChunkAt(ch, b[:n], int64(f.offset)) + f.offset += uint64(m) + + if err == nil { + err = err2 + } + } + + if err != nil { + if err == io.EOF { + return read, nil // return nil explicitly. + } + + return read, err + } + } +} + // ReadFrom reads data from r until EOF and writes it to the file. The return // value is the number of bytes read. Any error except io.EOF encountered // during the read is also returned. @@ -1371,34 +1402,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { if remain <= int64(f.c.maxPacket) { // Try and spot cases where we likely will only need to read and write once. - b := make([]byte, f.c.maxPacket) - ch := make(chan result, 1) // reuse channel - - for { // Still, do this in a loop, just to be sure we read everything to io.EOF. - n, err := r.Read(b) - if n < 0 { - panic("sftp.File: reader returned negative count from Read") - } - - if n > 0 { - read += int64(n) - - m, err2 := f.writeChunkAt(ch, b[:n], int64(f.offset)) - f.offset += uint64(m) - - if err == nil { - err = err2 - } - } - - if err != nil { - if err == io.EOF { - return read, nil // return nil explicitly. - } - - return read, err - } - } + return f.readFromSequential(r) } // Split the write into multiple maxPacket sized concurrent writes @@ -1428,11 +1432,9 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { } // Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets. - go func() { + go func(offset int64) { defer close(workCh) - offset := int64(f.offset) - for { b := pool.Get().([]byte) n, err := r.Read(b) @@ -1457,7 +1459,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { return } } - }() + }(int64(f.offset)) concurrency := int(remain/int64(f.c.maxPacket) + 1) // bad guess, but better than no guess if concurrency > f.c.maxConcurrentRequests { @@ -1471,7 +1473,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { go func() { defer wg.Done() - ch := make(chan result, 1) // reuse channel per mapper. + ch := make(chan result, 1) // reusable channel per mapper. for packet := range workCh { n, err := f.writeChunkAt(ch, packet.b[:packet.n], packet.off) @@ -1491,16 +1493,16 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { }() // collect all the results into a relevant return: the earliest offset to return an error. - firstErr := rwErr{-1, nil} + firstErr := rwErr{math.MaxInt64, nil} for rwErr := range errCh { - if firstErr.off < 0 || rwErr.off < firstErr.off { + if rwErr.off <= firstErr.off { firstErr = rwErr } select { case <-cancel: default: - // stop any more work from being distributed. (Just in case.) + // stop any more work from being distributed. close(cancel) } }