diff --git a/server.go b/server.go index c0665ed..a09741c 100644 --- a/server.go +++ b/server.go @@ -615,13 +615,15 @@ func statusFromError(id uint32, err error) *sshFxpStatusPacket { return ret } - switch e := err.(type) { - case fxerr: + if errors.Is(err, io.EOF) { + ret.StatusError.Code = sshFxEOF + return ret + } + + var e fxerr + if errors.As(err, &e) { ret.StatusError.Code = uint32(e) - default: - if e == io.EOF { - ret.StatusError.Code = sshFxEOF - } + return ret } return ret diff --git a/sftp_test.go b/sftp_test.go index 487b84d..18eed5e 100644 --- a/sftp_test.go +++ b/sftp_test.go @@ -2,6 +2,7 @@ package sftp import ( "errors" + "fmt" "io" "syscall" "testing" @@ -19,6 +20,8 @@ func TestErrFxCode(t *testing.T) { {err: syscall.ENOENT, fx: ErrSSHFxNoSuchFile}, {err: syscall.EPERM, fx: ErrSSHFxPermissionDenied}, {err: io.EOF, fx: ErrSSHFxEOF}, + {err: fmt.Errorf("wrapped permission denied error: %w", ErrSSHFxPermissionDenied), fx: ErrSSHFxPermissionDenied}, + {err: fmt.Errorf("wrapped op unsupported error: %w", ErrSSHFxOpUnsupported), fx: ErrSSHFxOpUnsupported}, } for _, tt := range table { statusErr := statusFromError(1, tt.err).StatusError