mirror of https://github.com/minio/minio.git
				
				
				
			return an error in CopyAligned upon premature EOF (#18110)
add a unit-test to capture this corner case
This commit is contained in:
		
							parent
							
								
									cdeab19673
								
							
						
					
					
						commit
						d9f1df01eb
					
				| 
						 | 
					@ -22,6 +22,7 @@ package ioutil
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
| 
						 | 
					@ -348,6 +349,10 @@ const DirectioAlignSize = 4096
 | 
				
			||||||
// input writer *os.File not a generic io.Writer. Make sure to have
 | 
					// input writer *os.File not a generic io.Writer. Make sure to have
 | 
				
			||||||
// the file opened for writes with syscall.O_DIRECT flag.
 | 
					// the file opened for writes with syscall.O_DIRECT flag.
 | 
				
			||||||
func CopyAligned(w io.Writer, r io.Reader, alignedBuf []byte, totalSize int64, file *os.File) (int64, error) {
 | 
					func CopyAligned(w io.Writer, r io.Reader, alignedBuf []byte, totalSize int64, file *os.File) (int64, error) {
 | 
				
			||||||
 | 
						if totalSize == 0 {
 | 
				
			||||||
 | 
							return 0, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Writes remaining bytes in the buffer.
 | 
						// Writes remaining bytes in the buffer.
 | 
				
			||||||
	writeUnaligned := func(w io.Writer, buf []byte) (remainingWritten int64, err error) {
 | 
						writeUnaligned := func(w io.Writer, buf []byte) (remainingWritten int64, err error) {
 | 
				
			||||||
		// Disable O_DIRECT on fd's on unaligned buffer
 | 
							// Disable O_DIRECT on fd's on unaligned buffer
 | 
				
			||||||
| 
						 | 
					@ -364,17 +369,19 @@ func CopyAligned(w io.Writer, r io.Reader, alignedBuf []byte, totalSize int64, f
 | 
				
			||||||
	var written int64
 | 
						var written int64
 | 
				
			||||||
	for {
 | 
						for {
 | 
				
			||||||
		buf := alignedBuf
 | 
							buf := alignedBuf
 | 
				
			||||||
		if totalSize != -1 {
 | 
							if totalSize > 0 {
 | 
				
			||||||
			remaining := totalSize - written
 | 
								remaining := totalSize - written
 | 
				
			||||||
			if remaining < int64(len(buf)) {
 | 
								if remaining < int64(len(buf)) {
 | 
				
			||||||
				buf = buf[:remaining]
 | 
									buf = buf[:remaining]
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		nr, err := io.ReadFull(r, buf)
 | 
							nr, err := io.ReadFull(r, buf)
 | 
				
			||||||
		eof := err == io.EOF || err == io.ErrUnexpectedEOF
 | 
							eof := errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
 | 
				
			||||||
		if err != nil && !eof {
 | 
							if err != nil && !eof {
 | 
				
			||||||
			return written, err
 | 
								return written, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		buf = buf[:nr]
 | 
							buf = buf[:nr]
 | 
				
			||||||
		var nw int64
 | 
							var nw int64
 | 
				
			||||||
		if len(buf)%DirectioAlignSize == 0 {
 | 
							if len(buf)%DirectioAlignSize == 0 {
 | 
				
			||||||
| 
						 | 
					@ -386,22 +393,30 @@ func CopyAligned(w io.Writer, r io.Reader, alignedBuf []byte, totalSize int64, f
 | 
				
			||||||
			// buf is not aligned, hence use writeUnaligned()
 | 
								// buf is not aligned, hence use writeUnaligned()
 | 
				
			||||||
			nw, err = writeUnaligned(w, buf)
 | 
								nw, err = writeUnaligned(w, buf)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if nw > 0 {
 | 
							if nw > 0 {
 | 
				
			||||||
			written += nw
 | 
								written += nw
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return written, err
 | 
								return written, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if nw != int64(len(buf)) {
 | 
							if nw != int64(len(buf)) {
 | 
				
			||||||
			return written, io.ErrShortWrite
 | 
								return written, io.ErrShortWrite
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if totalSize != -1 {
 | 
							if totalSize > 0 && written == totalSize {
 | 
				
			||||||
			if written == totalSize {
 | 
								// we have written the entire stream, return right here.
 | 
				
			||||||
				return written, nil
 | 
								return written, nil
 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if eof {
 | 
							if eof {
 | 
				
			||||||
 | 
								// We reached EOF prematurely but we did not write everything
 | 
				
			||||||
 | 
								// that we promised that we would write.
 | 
				
			||||||
 | 
								if totalSize > 0 && written != totalSize {
 | 
				
			||||||
 | 
									return written, io.ErrUnexpectedEOF
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
			return written, nil
 | 
								return written, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
// Copyright (c) 2015-2021 MinIO, Inc.
 | 
					// Copyright (c) 2015-2023 MinIO, Inc.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// This file is part of MinIO Object Storage stack
 | 
					// This file is part of MinIO Object Storage stack
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
| 
						 | 
					@ -20,8 +20,10 @@ package ioutil
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -205,3 +207,36 @@ func TestSameFile(t *testing.T) {
 | 
				
			||||||
		t.Fatal("Expected the files not to be same")
 | 
							t.Fatal("Expected the files not to be same")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCopyAligned(t *testing.T) {
 | 
				
			||||||
 | 
						f, err := os.CreateTemp("", "")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("Error creating tmp file: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer f.Close()
 | 
				
			||||||
 | 
						defer os.Remove(f.Name())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						r := strings.NewReader("hello world")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						bufp := ODirectPoolSmall.Get().(*[]byte)
 | 
				
			||||||
 | 
						defer ODirectPoolSmall.Put(bufp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						written, err := CopyAligned(f, io.LimitReader(r, 5), *bufp, r.Size(), f)
 | 
				
			||||||
 | 
						if !errors.Is(err, io.ErrUnexpectedEOF) {
 | 
				
			||||||
 | 
							t.Errorf("Expected io.ErrUnexpectedEOF, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if written != 5 {
 | 
				
			||||||
 | 
							t.Errorf("Expected written to be '5', but got %v", written)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						f.Seek(0, io.SeekStart)
 | 
				
			||||||
 | 
						r.Seek(0, io.SeekStart)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						written, err = CopyAligned(f, r, *bufp, r.Size(), f)
 | 
				
			||||||
 | 
						if !errors.Is(err, nil) {
 | 
				
			||||||
 | 
							t.Errorf("Expected nil, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if written != r.Size() {
 | 
				
			||||||
 | 
							t.Errorf("Expected written to be '%v', but got %v", r.Size(), written)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue