Merge pull request #455 from pkg/cleanup-request-mutex-usage

Cleanup Request mutex usage
This commit is contained in:
Cassondra Foesch 2021-08-03 12:24:18 +00:00 committed by GitHub
commit 792ae58b7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 229 additions and 130 deletions

View File

@ -6,8 +6,6 @@ import (
"fmt"
"io"
"os"
"path"
"path/filepath"
"strings"
"sync"
"syscall"
@ -16,6 +14,113 @@ import (
// MaxFilelist is the max number of files to return in a readdir batch.
var MaxFilelist int64 = 100
// state encapsulates the reader/writer/readdir from handlers.
type state struct {
mu sync.RWMutex
writerAt io.WriterAt
readerAt io.ReaderAt
writerAtReaderAt WriterAtReaderAt
listerAt ListerAt
lsoffset int64
}
// copy returns a shallow copy the state.
// This is broken out to specific fields,
// because we have to copy around the mutex in state.
func (s *state) copy() state {
s.mu.RLock()
defer s.mu.RUnlock()
return state{
writerAt: s.writerAt,
readerAt: s.readerAt,
writerAtReaderAt: s.writerAtReaderAt,
listerAt: s.listerAt,
lsoffset: s.lsoffset,
}
}
func (s *state) setReaderAt(rd io.ReaderAt) {
s.mu.Lock()
defer s.mu.Unlock()
s.readerAt = rd
}
func (s *state) getReaderAt() io.ReaderAt {
s.mu.RLock()
defer s.mu.RUnlock()
return s.readerAt
}
func (s *state) setWriterAt(rd io.WriterAt) {
s.mu.Lock()
defer s.mu.Unlock()
s.writerAt = rd
}
func (s *state) getWriterAt() io.WriterAt {
s.mu.RLock()
defer s.mu.RUnlock()
return s.writerAt
}
func (s *state) setWriterAtReaderAt(rw WriterAtReaderAt) {
s.mu.Lock()
defer s.mu.Unlock()
s.writerAtReaderAt = rw
}
func (s *state) getWriterAtReaderAt() WriterAtReaderAt {
s.mu.RLock()
defer s.mu.RUnlock()
return s.writerAtReaderAt
}
func (s *state) getAllReaderWriters() (io.ReaderAt, io.WriterAt, WriterAtReaderAt) {
s.mu.RLock()
defer s.mu.RUnlock()
return s.readerAt, s.writerAt, s.writerAtReaderAt
}
// Returns current offset for file list
func (s *state) lsNext() int64 {
s.mu.RLock()
defer s.mu.RUnlock()
return s.lsoffset
}
// Increases next offset
func (s *state) lsInc(offset int64) {
s.mu.Lock()
defer s.mu.Unlock()
s.lsoffset += offset
}
// manage file read/write state
func (s *state) setListerAt(la ListerAt) {
s.mu.Lock()
defer s.mu.Unlock()
s.listerAt = la
}
func (s *state) getListerAt() ListerAt {
s.mu.RLock()
defer s.mu.RUnlock()
return s.listerAt
}
// Request contains the data and state for the incoming service request.
type Request struct {
// Get, Put, Setstat, Stat, Rename, Remove
@ -26,20 +131,40 @@ type Request struct {
Attrs []byte // convert to sub-struct
Target string // for renames and sym-links
handle string
// reader/writer/readdir from handlers
state state
state
// context lasts duration of request
ctx context.Context
cancelCtx context.CancelFunc
}
type state struct {
*sync.RWMutex
writerAt io.WriterAt
readerAt io.ReaderAt
writerReaderAt WriterAtReaderAt
listerAt ListerAt
lsoffset int64
// NewRequest creates a new Request object.
func NewRequest(method, path string) *Request {
return &Request{
Method: method,
Filepath: cleanPath(path),
}
}
// copy returns a shallow copy of existing request.
// This is broken out to specific fields,
// because we have to copy around the mutex in state.
func (r *Request) copy() *Request {
return &Request{
Method: r.Method,
Filepath: r.Filepath,
Flags: r.Flags,
Attrs: r.Attrs,
Target: r.Target,
handle: r.handle,
state: r.state.copy(),
ctx: r.ctx,
cancelCtx: r.cancelCtx,
}
}
// New Request initialized based on packet data
@ -66,21 +191,6 @@ func requestFromPacket(ctx context.Context, pkt hasPath) *Request {
return request
}
// NewRequest creates a new Request object.
func NewRequest(method, path string) *Request {
return &Request{Method: method, Filepath: cleanPath(path),
state: state{RWMutex: new(sync.RWMutex)}}
}
// shallow copy of existing request
func (r *Request) copy() *Request {
r.state.Lock()
defer r.state.Unlock()
r2 := new(Request)
*r2 = *r
return r2
}
// Context returns the request's context. To change the context,
// use WithContext.
//
@ -108,33 +218,6 @@ func (r *Request) WithContext(ctx context.Context) *Request {
return r2
}
// Returns current offset for file list
func (r *Request) lsNext() int64 {
r.state.RLock()
defer r.state.RUnlock()
return r.state.lsoffset
}
// Increases next offset
func (r *Request) lsInc(offset int64) {
r.state.Lock()
defer r.state.Unlock()
r.state.lsoffset = r.state.lsoffset + offset
}
// manage file read/write state
func (r *Request) setListerState(la ListerAt) {
r.state.Lock()
defer r.state.Unlock()
r.state.listerAt = la
}
func (r *Request) getLister() ListerAt {
r.state.RLock()
defer r.state.RUnlock()
return r.state.listerAt
}
// Close reader/writer if possible
func (r *Request) close() error {
defer func() {
@ -143,11 +226,7 @@ func (r *Request) close() error {
}
}()
r.state.RLock()
wr := r.state.writerAt
rd := r.state.readerAt
rw := r.state.writerReaderAt
r.state.RUnlock()
rd, wr, rw := r.getAllReaderWriters()
var err error
@ -164,7 +243,8 @@ func (r *Request) close() error {
if err2 := c.Close(); err == nil {
// update error if it is still nil
err = err2
r.state.writerReaderAt = nil
r.setWriterAtReaderAt(nil)
}
}
@ -184,11 +264,7 @@ func (r *Request) transferError(err error) {
return
}
r.state.RLock()
wr := r.state.writerAt
rd := r.state.readerAt
rw := r.state.writerReaderAt
r.state.RUnlock()
rd, wr, rw := r.getAllReaderWriters()
if t, ok := wr.(TransferError); ok {
t.TransferError(err)
@ -219,8 +295,7 @@ func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, o
case "Stat", "Lstat", "Readlink":
return filestat(handlers.FileList, r, pkt)
default:
return statusFromError(pkt.id(),
fmt.Errorf("unexpected method: %s", r.Method))
return statusFromError(pkt.id(), fmt.Errorf("unexpected method: %s", r.Method))
}
}
@ -239,8 +314,13 @@ func (r *Request) open(h Handlers, pkt requestPacket) responsePacket {
if err != nil {
return statusFromError(id, err)
}
r.state.writerReaderAt = rw
return &sshFxpHandlePacket{ID: id, Handle: r.handle}
r.setWriterAtReaderAt(rw)
return &sshFxpHandlePacket{
ID: id,
Handle: r.handle,
}
}
}
@ -249,18 +329,26 @@ func (r *Request) open(h Handlers, pkt requestPacket) responsePacket {
if err != nil {
return statusFromError(id, err)
}
r.state.writerAt = wr
r.setWriterAt(wr)
case flags.Read:
r.Method = "Get"
rd, err := h.FileGet.Fileread(r)
if err != nil {
return statusFromError(id, err)
}
r.state.readerAt = rd
r.setReaderAt(rd)
default:
return statusFromError(id, errors.New("bad file flags"))
}
return &sshFxpHandlePacket{ID: id, Handle: r.handle}
return &sshFxpHandlePacket{
ID: id,
Handle: r.handle,
}
}
func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket {
@ -269,25 +357,30 @@ func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket {
if err != nil {
return statusFromError(pkt.id(), wrapPathError(r.Filepath, err))
}
r.state.listerAt = la
return &sshFxpHandlePacket{ID: pkt.id(), Handle: r.handle}
r.setListerAt(la)
return &sshFxpHandlePacket{
ID: pkt.id(),
Handle: r.handle,
}
}
// wrap FileReader handler
func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
r.state.RLock()
reader := r.state.readerAt
r.state.RUnlock()
if reader == nil {
rd := r.getReaderAt()
if rd == nil {
return statusFromError(pkt.id(), errors.New("unexpected read packet"))
}
data, offset, _ := packetData(pkt, alloc, orderID)
n, err := reader.ReadAt(data, offset)
n, err := rd.ReadAt(data, offset)
// only return EOF error if no data left to read
if err != nil && (err != io.EOF || n == 0) {
return statusFromError(pkt.id(), err)
}
return &sshFxpDataPacket{
ID: pkt.id(),
Length: uint32(n),
@ -297,43 +390,46 @@ func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orde
// wrap FileWriter handler
func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
r.state.RLock()
writer := r.state.writerAt
r.state.RUnlock()
if writer == nil {
wr := r.getWriterAt()
if wr == nil {
return statusFromError(pkt.id(), errors.New("unexpected write packet"))
}
data, offset, _ := packetData(pkt, alloc, orderID)
_, err := writer.WriteAt(data, offset)
_, err := wr.WriteAt(data, offset)
return statusFromError(pkt.id(), err)
}
// wrap OpenFileWriter handler
func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
r.state.RLock()
writerReader := r.state.writerReaderAt
r.state.RUnlock()
if writerReader == nil {
rw := r.getWriterAtReaderAt()
if rw == nil {
return statusFromError(pkt.id(), errors.New("unexpected write and read packet"))
}
switch p := pkt.(type) {
case *sshFxpReadPacket:
data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset)
n, err := writerReader.ReadAt(data, offset)
n, err := rw.ReadAt(data, offset)
// only return EOF error if no data left to read
if err != nil && (err != io.EOF || n == 0) {
return statusFromError(pkt.id(), err)
}
return &sshFxpDataPacket{
ID: pkt.id(),
Length: uint32(n),
Data: data[:n],
}
case *sshFxpWritePacket:
data, offset := p.Data, int64(p.Offset)
_, err := writerReader.WriteAt(data, offset)
_, err := rw.WriteAt(data, offset)
return statusFromError(pkt.id(), err)
default:
return statusFromError(pkt.id(), errors.New("unexpected packet type for read or write"))
}
@ -358,7 +454,8 @@ func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket {
r.Attrs = p.Attrs.([]byte)
}
if r.Method == "PosixRename" {
switch r.Method {
case "PosixRename":
if posixRenamer, ok := h.(PosixRenameFileCmder); ok {
err := posixRenamer.PosixRename(r)
return statusFromError(pkt.id(), err)
@ -368,9 +465,8 @@ func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket {
r.Method = "Rename"
err := h.Filecmd(r)
return statusFromError(pkt.id(), err)
}
if r.Method == "StatVFS" {
case "StatVFS":
if statVFSCmdr, ok := h.(StatVFSFileCmder); ok {
stat, err := statVFSCmdr.StatVFS(r)
if err != nil {
@ -389,8 +485,7 @@ func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket {
// wrap FileLister handler
func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket {
var err error
lister := r.getLister()
lister := r.getListerAt()
if lister == nil {
return statusFromError(pkt.id(), errors.New("unexpected dir packet"))
}
@ -404,23 +499,25 @@ func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket {
switch r.Method {
case "List":
if err != nil && err != io.EOF {
if err != nil && (err != io.EOF || n == 0) {
return statusFromError(pkt.id(), err)
}
if err == io.EOF && n == 0 {
return statusFromError(pkt.id(), io.EOF)
}
dirname := filepath.ToSlash(path.Base(r.Filepath))
ret := &sshFxpNamePacket{ID: pkt.id()}
nameAttrs := make([]*sshFxpNameAttr, 0, len(finfo))
for _, fi := range finfo {
ret.NameAttrs = append(ret.NameAttrs, &sshFxpNameAttr{
nameAttrs = append(nameAttrs, &sshFxpNameAttr{
Name: fi.Name(),
LongName: runLs(dirname, fi),
LongName: runLs(fi),
Attrs: []interface{}{fi},
})
}
return ret
return &sshFxpNamePacket{
ID: pkt.id(),
NameAttrs: nameAttrs,
}
default:
err = fmt.Errorf("unexpected method: %s", r.Method)
return statusFromError(pkt.id(), err)
@ -455,8 +552,11 @@ func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket {
return statusFromError(pkt.id(), err)
}
if n == 0 {
err = &os.PathError{Op: strings.ToLower(r.Method), Path: r.Filepath,
Err: syscall.ENOENT}
err = &os.PathError{
Op: strings.ToLower(r.Method),
Path: r.Filepath,
Err: syscall.ENOENT,
}
return statusFromError(pkt.id(), err)
}
return &sshFxpStatResponse{
@ -468,8 +568,11 @@ func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket {
return statusFromError(pkt.id(), err)
}
if n == 0 {
err = &os.PathError{Op: "readlink", Path: r.Filepath,
Err: syscall.ENOENT}
err = &os.PathError{
Op: "readlink",
Path: r.Filepath,
Err: syscall.ENOENT,
}
return statusFromError(pkt.id(), err)
}
filename := finfo[0].Name()

View File

@ -1,15 +1,13 @@
package sftp
import (
"sync"
"github.com/stretchr/testify/assert"
"bytes"
"errors"
"io"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
type testHandler struct {
@ -75,7 +73,6 @@ func testRequest(method string) *Request {
Attrs: []byte("foo"),
Flags: flags,
Target: "foo",
state: state{RWMutex: new(sync.RWMutex)},
}
return request
}

View File

@ -461,7 +461,6 @@ func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket {
return statusFromError(p.ID, EBADF)
}
dirname := f.Name()
dirents, err := f.Readdir(128)
if err != nil {
return statusFromError(p.ID, err)
@ -471,7 +470,7 @@ func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket {
for _, dirent := range dirents {
ret.NameAttrs = append(ret.NameAttrs, &sshFxpNameAttr{
Name: dirent.Name(),
LongName: runLs(dirname, dirent),
LongName: runLs(dirent),
Attrs: []interface{}{dirent},
})
}

View File

@ -8,7 +8,7 @@ import (
"time"
)
func runLs(dirname string, dirent os.FileInfo) string {
func runLs(dirent os.FileInfo) string {
typeword := runLsTypeWord(dirent)
numLinks := 1
if dirent.IsDir() {

View File

@ -25,14 +25,14 @@ const (
func TestRunLsWithExamplesDirectory(t *testing.T) {
path := "examples"
item, _ := os.Stat(path)
result := runLs(path, item)
result := runLs(item)
runLsTestHelper(t, result, typeDirectory, path)
}
func TestRunLsWithLicensesFile(t *testing.T) {
path := "LICENSE"
item, _ := os.Stat(path)
result := runLs(path, item)
result := runLs(item)
runLsTestHelper(t, result, typeFile, path)
}
@ -79,61 +79,61 @@ func runLsTestHelper(t *testing.T, result, expectedType, path string) {
// permissions (len 10, "drwxr-xr-x")
got := result[0:10]
if ok, err := regexp.MatchString("^"+expectedType+"[rwx-]{9}$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): permission field mismatch, expected dir, got: %#v, err: %#v", path, got, err)
t.Errorf("runLs(*FileInfo): permission field mismatch, expected dir, got: %#v, err: %#v", got, err)
}
// space
got = result[10:11]
if ok, err := regexp.MatchString("^\\s$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): spacer 1 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err)
t.Errorf("runLs(*FileInfo): spacer 1 mismatch, expected whitespace, got: %#v, err: %#v", got, err)
}
// link count (len 3, number)
got = result[12:15]
if ok, err := regexp.MatchString("^\\s*[0-9]+$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): link count field mismatch, got: %#v, err: %#v", path, got, err)
t.Errorf("runLs(*FileInfo): link count field mismatch, got: %#v, err: %#v", got, err)
}
// spacer
got = result[15:16]
if ok, err := regexp.MatchString("^\\s$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): spacer 2 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err)
t.Errorf("runLs(*FileInfo): spacer 2 mismatch, expected whitespace, got: %#v, err: %#v", got, err)
}
// username / uid (len 8, number or string)
got = result[16:24]
if ok, err := regexp.MatchString("^[^\\s]{1,8}\\s*$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): username / uid mismatch, expected user, got: %#v, err: %#v", path, got, err)
t.Errorf("runLs(*FileInfo): username / uid mismatch, expected user, got: %#v, err: %#v", got, err)
}
// spacer
got = result[24:25]
if ok, err := regexp.MatchString("^\\s$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): spacer 3 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err)
t.Errorf("runLs(*FileInfo): spacer 3 mismatch, expected whitespace, got: %#v, err: %#v", got, err)
}
// groupname / gid (len 8, number or string)
got = result[25:33]
if ok, err := regexp.MatchString("^[^\\s]{1,8}\\s*$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): groupname / gid mismatch, expected group, got: %#v, err: %#v", path, got, err)
t.Errorf("runLs(*FileInfo): groupname / gid mismatch, expected group, got: %#v, err: %#v", got, err)
}
// spacer
got = result[33:34]
if ok, err := regexp.MatchString("^\\s$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): spacer 4 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err)
t.Errorf("runLs(*FileInfo): spacer 4 mismatch, expected whitespace, got: %#v, err: %#v", got, err)
}
// filesize (len 8)
got = result[34:42]
if ok, err := regexp.MatchString("^\\s*[0-9]+$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): filesize field mismatch, expected size in bytes, got: %#v, err: %#v", path, got, err)
t.Errorf("runLs(*FileInfo): filesize field mismatch, expected size in bytes, got: %#v, err: %#v", got, err)
}
// spacer
got = result[42:43]
if ok, err := regexp.MatchString("^\\s$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): spacer 5 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err)
t.Errorf("runLs(*FileInfo): spacer 5 mismatch, expected whitespace, got: %#v, err: %#v", got, err)
}
// mod time (len 12, e.g. Aug 9 19:46)
@ -146,19 +146,19 @@ func runLsTestHelper(t *testing.T, result, expectedType, path string) {
_, err = time.Parse(layout, got)
}
if err != nil {
t.Errorf("runLs(%#v, *FileInfo): mod time field mismatch, expected date layout %s, got: %#v, err: %#v", path, layout, got, err)
t.Errorf("runLs(*FileInfo): mod time field mismatch, expected date layout %s, got: %#v, err: %#v", layout, got, err)
}
// spacer
got = result[55:56]
if ok, err := regexp.MatchString("^\\s$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): spacer 6 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err)
t.Errorf("runLs(*FileInfo): spacer 6 mismatch, expected whitespace, got: %#v, err: %#v", got, err)
}
// filename
got = result[56:]
if ok, err := regexp.MatchString("^"+path+"$", got); !ok {
t.Errorf("runLs(%#v, *FileInfo): name field mismatch, expected examples, got: %#v, err: %#v", path, got, err)
t.Errorf("runLs(*FileInfo): name field mismatch, expected examples, got: %#v, err: %#v", got, err)
}
}

View File

@ -12,7 +12,7 @@ import (
// ls -l style output for a file, which is in the 'long output' section of a readdir response packet
// this is a very simple (lazy) implementation, just enough to look almost like openssh in a few basic cases
func runLs(dirname string, dirent os.FileInfo) string {
func runLs(dirent os.FileInfo) string {
// example from openssh sftp server:
// crw-rw-rw- 1 root wheel 0 Jul 31 20:52 ttyvd
// format: