mirror of https://github.com/pkg/sftp.git
Merge pull request #571 from powellnorma/win-root
server.go: "/" for windows
This commit is contained in:
commit
088878ba50
|
@ -20,10 +20,13 @@ func main() {
|
||||||
var (
|
var (
|
||||||
readOnly bool
|
readOnly bool
|
||||||
debugStderr bool
|
debugStderr bool
|
||||||
|
winRoot bool
|
||||||
)
|
)
|
||||||
|
|
||||||
flag.BoolVar(&readOnly, "R", false, "read-only server")
|
flag.BoolVar(&readOnly, "R", false, "read-only server")
|
||||||
flag.BoolVar(&debugStderr, "e", false, "debug to stderr")
|
flag.BoolVar(&debugStderr, "e", false, "debug to stderr")
|
||||||
|
flag.BoolVar(&winRoot, "wr", false, "windows root")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
debugStream := io.Discard
|
debugStream := io.Discard
|
||||||
|
@ -128,6 +131,11 @@ func main() {
|
||||||
fmt.Fprintf(debugStream, "Read write server\n")
|
fmt.Fprintf(debugStream, "Read write server\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if winRoot {
|
||||||
|
serverOptions = append(serverOptions, sftp.WindowsRootEnumeratesDrives())
|
||||||
|
fmt.Fprintf(debugStream, "Windows root enabled\n")
|
||||||
|
}
|
||||||
|
|
||||||
server, err := sftp.NewServer(
|
server, err := sftp.NewServer(
|
||||||
channel,
|
channel,
|
||||||
serverOptions...,
|
serverOptions...,
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -6,4 +6,5 @@ require (
|
||||||
github.com/kr/fs v0.1.0
|
github.com/kr/fs v0.1.0
|
||||||
github.com/stretchr/testify v1.8.0
|
github.com/stretchr/testify v1.8.0
|
||||||
golang.org/x/crypto v0.31.0
|
golang.org/x/crypto v0.31.0
|
||||||
|
golang.org/x/sys v0.28.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
36
server.go
36
server.go
|
@ -7,6 +7,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/fs"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -21,6 +22,18 @@ const (
|
||||||
SftpServerWorkerCount = 8
|
SftpServerWorkerCount = 8
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type file interface {
|
||||||
|
Stat() (os.FileInfo, error)
|
||||||
|
ReadAt(b []byte, off int64) (int, error)
|
||||||
|
WriteAt(b []byte, off int64) (int, error)
|
||||||
|
Readdir(int) ([]os.FileInfo, error)
|
||||||
|
Name() string
|
||||||
|
Truncate(int64) error
|
||||||
|
Chmod(mode fs.FileMode) error
|
||||||
|
Chown(uid, gid int) error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
// Server is an SSH File Transfer Protocol (sftp) server.
|
// Server is an SSH File Transfer Protocol (sftp) server.
|
||||||
// This is intended to provide the sftp subsystem to an ssh server daemon.
|
// This is intended to provide the sftp subsystem to an ssh server daemon.
|
||||||
// This implementation currently supports most of sftp server protocol version 3,
|
// This implementation currently supports most of sftp server protocol version 3,
|
||||||
|
@ -30,14 +43,15 @@ type Server struct {
|
||||||
debugStream io.Writer
|
debugStream io.Writer
|
||||||
readOnly bool
|
readOnly bool
|
||||||
pktMgr *packetManager
|
pktMgr *packetManager
|
||||||
openFiles map[string]*os.File
|
openFiles map[string]file
|
||||||
openFilesLock sync.RWMutex
|
openFilesLock sync.RWMutex
|
||||||
handleCount int
|
handleCount int
|
||||||
workDir string
|
workDir string
|
||||||
|
winRoot bool
|
||||||
maxTxPacket uint32
|
maxTxPacket uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (svr *Server) nextHandle(f *os.File) string {
|
func (svr *Server) nextHandle(f file) string {
|
||||||
svr.openFilesLock.Lock()
|
svr.openFilesLock.Lock()
|
||||||
defer svr.openFilesLock.Unlock()
|
defer svr.openFilesLock.Unlock()
|
||||||
svr.handleCount++
|
svr.handleCount++
|
||||||
|
@ -57,7 +71,7 @@ func (svr *Server) closeHandle(handle string) error {
|
||||||
return EBADF
|
return EBADF
|
||||||
}
|
}
|
||||||
|
|
||||||
func (svr *Server) getHandle(handle string) (*os.File, bool) {
|
func (svr *Server) getHandle(handle string) (file, bool) {
|
||||||
svr.openFilesLock.RLock()
|
svr.openFilesLock.RLock()
|
||||||
defer svr.openFilesLock.RUnlock()
|
defer svr.openFilesLock.RUnlock()
|
||||||
f, ok := svr.openFiles[handle]
|
f, ok := svr.openFiles[handle]
|
||||||
|
@ -86,7 +100,7 @@ func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error)
|
||||||
serverConn: svrConn,
|
serverConn: svrConn,
|
||||||
debugStream: ioutil.Discard,
|
debugStream: ioutil.Discard,
|
||||||
pktMgr: newPktMgr(svrConn),
|
pktMgr: newPktMgr(svrConn),
|
||||||
openFiles: make(map[string]*os.File),
|
openFiles: make(map[string]file),
|
||||||
maxTxPacket: defaultMaxTxPacket,
|
maxTxPacket: defaultMaxTxPacket,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,6 +132,14 @@ func ReadOnly() ServerOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WindowsRootEnumeratesDrives configures a Server to serve a virtual '/' for windows that lists all drives
|
||||||
|
func WindowsRootEnumeratesDrives() ServerOption {
|
||||||
|
return func(s *Server) error {
|
||||||
|
s.winRoot = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithAllocator enable the allocator.
|
// WithAllocator enable the allocator.
|
||||||
// After processing a packet we keep in memory the allocated slices
|
// After processing a packet we keep in memory the allocated slices
|
||||||
// and we reuse them for new packets.
|
// and we reuse them for new packets.
|
||||||
|
@ -215,7 +237,7 @@ func handlePacket(s *Server, p orderedRequest) error {
|
||||||
}
|
}
|
||||||
case *sshFxpLstatPacket:
|
case *sshFxpLstatPacket:
|
||||||
// stat the requested file
|
// stat the requested file
|
||||||
info, err := os.Lstat(s.toLocalPath(p.Path))
|
info, err := s.lstat(s.toLocalPath(p.Path))
|
||||||
rpkt = &sshFxpStatResponse{
|
rpkt = &sshFxpStatResponse{
|
||||||
ID: p.ID,
|
ID: p.ID,
|
||||||
info: info,
|
info: info,
|
||||||
|
@ -289,7 +311,7 @@ func handlePacket(s *Server, p orderedRequest) error {
|
||||||
case *sshFxpOpendirPacket:
|
case *sshFxpOpendirPacket:
|
||||||
lp := s.toLocalPath(p.Path)
|
lp := s.toLocalPath(p.Path)
|
||||||
|
|
||||||
if stat, err := os.Stat(lp); err != nil {
|
if stat, err := s.stat(lp); err != nil {
|
||||||
rpkt = statusFromError(p.ID, err)
|
rpkt = statusFromError(p.ID, err)
|
||||||
} else if !stat.IsDir() {
|
} else if !stat.IsDir() {
|
||||||
rpkt = statusFromError(p.ID, &os.PathError{
|
rpkt = statusFromError(p.ID, &os.PathError{
|
||||||
|
@ -493,7 +515,7 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket {
|
||||||
mode = fs.FileMode() & os.ModePerm
|
mode = fs.FileMode() & os.ModePerm
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, mode)
|
f, err := svr.openfile(svr.toLocalPath(p.Path), osFlags, mode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return statusFromError(p.ID, err)
|
return statusFromError(p.ID, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
//go:build !windows
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package sftp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) openfile(path string, flag int, mode fs.FileMode) (file, error) {
|
||||||
|
return os.OpenFile(path, flag, mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) lstat(name string) (os.FileInfo, error) {
|
||||||
|
return os.Lstat(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) stat(name string) (os.FileInfo, error) {
|
||||||
|
return os.Stat(name)
|
||||||
|
}
|
|
@ -1,8 +1,15 @@
|
||||||
package sftp
|
package sftp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) toLocalPath(p string) string {
|
func (s *Server) toLocalPath(p string) string {
|
||||||
|
@ -12,7 +19,11 @@ func (s *Server) toLocalPath(p string) string {
|
||||||
|
|
||||||
lp := filepath.FromSlash(p)
|
lp := filepath.FromSlash(p)
|
||||||
|
|
||||||
if path.IsAbs(p) {
|
if path.IsAbs(p) { // starts with '/'
|
||||||
|
if len(p) == 1 && s.winRoot {
|
||||||
|
return `\\.\` // for openfile
|
||||||
|
}
|
||||||
|
|
||||||
tmp := lp
|
tmp := lp
|
||||||
for len(tmp) > 0 && tmp[0] == '\\' {
|
for len(tmp) > 0 && tmp[0] == '\\' {
|
||||||
tmp = tmp[1:]
|
tmp = tmp[1:]
|
||||||
|
@ -33,7 +44,150 @@ func (s *Server) toLocalPath(p string) string {
|
||||||
// e.g. "/C:" to "C:\\"
|
// e.g. "/C:" to "C:\\"
|
||||||
return tmp
|
return tmp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.winRoot {
|
||||||
|
// Make it so that "/Windows" is not found, and "/c:/Windows" has to be used
|
||||||
|
return `\\.\` + tmp
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return lp
|
return lp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func bitsToDrives(bitmap uint32) []string {
|
||||||
|
var drive rune = 'a'
|
||||||
|
var drives []string
|
||||||
|
|
||||||
|
for bitmap != 0 && drive <= 'z' {
|
||||||
|
if bitmap&1 == 1 {
|
||||||
|
drives = append(drives, string(drive)+":")
|
||||||
|
}
|
||||||
|
drive++
|
||||||
|
bitmap >>= 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return drives
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDrives() ([]string, error) {
|
||||||
|
mask, err := windows.GetLogicalDrives()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("GetLogicalDrives: %w", err)
|
||||||
|
}
|
||||||
|
return bitsToDrives(mask), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type driveInfo struct {
|
||||||
|
fs.FileInfo
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *driveInfo) Name() string {
|
||||||
|
return i.name // since the Name() returned from a os.Stat("C:\\") is "\\"
|
||||||
|
}
|
||||||
|
|
||||||
|
type winRoot struct {
|
||||||
|
drives []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWinRoot() (*winRoot, error) {
|
||||||
|
drives, err := getDrives()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &winRoot{
|
||||||
|
drives: drives,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *winRoot) Readdir(n int) ([]os.FileInfo, error) {
|
||||||
|
drives := f.drives
|
||||||
|
if n > 0 && len(drives) > n {
|
||||||
|
drives = drives[:n]
|
||||||
|
}
|
||||||
|
f.drives = f.drives[len(drives):]
|
||||||
|
if len(drives) == 0 {
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
var infos []os.FileInfo
|
||||||
|
for _, drive := range drives {
|
||||||
|
fi, err := os.Stat(drive + `\`)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
di := &driveInfo{
|
||||||
|
FileInfo: fi,
|
||||||
|
name: drive,
|
||||||
|
}
|
||||||
|
infos = append(infos, di)
|
||||||
|
}
|
||||||
|
|
||||||
|
return infos, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *winRoot) Stat() (os.FileInfo, error) {
|
||||||
|
return rootFileInfo, nil
|
||||||
|
}
|
||||||
|
func (f *winRoot) ReadAt(b []byte, off int64) (int, error) {
|
||||||
|
return 0, os.ErrPermission
|
||||||
|
}
|
||||||
|
func (f *winRoot) WriteAt(b []byte, off int64) (int, error) {
|
||||||
|
return 0, os.ErrPermission
|
||||||
|
}
|
||||||
|
func (f *winRoot) Name() string {
|
||||||
|
return "/"
|
||||||
|
}
|
||||||
|
func (f *winRoot) Truncate(int64) error {
|
||||||
|
return os.ErrPermission
|
||||||
|
}
|
||||||
|
func (f *winRoot) Chmod(mode fs.FileMode) error {
|
||||||
|
return os.ErrPermission
|
||||||
|
}
|
||||||
|
func (f *winRoot) Chown(uid, gid int) error {
|
||||||
|
return os.ErrPermission
|
||||||
|
}
|
||||||
|
func (f *winRoot) Close() error {
|
||||||
|
f.drives = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) openfile(path string, flag int, mode fs.FileMode) (file, error) {
|
||||||
|
if path == `\\.\` && s.winRoot {
|
||||||
|
return newWinRoot()
|
||||||
|
}
|
||||||
|
return os.OpenFile(path, flag, mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
type winRootFileInfo struct {
|
||||||
|
name string
|
||||||
|
modTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *winRootFileInfo) Name() string { return w.name }
|
||||||
|
func (w *winRootFileInfo) Size() int64 { return 0 }
|
||||||
|
func (w *winRootFileInfo) Mode() fs.FileMode { return fs.ModeDir | 0555 } // read+execute for all
|
||||||
|
func (w *winRootFileInfo) ModTime() time.Time { return w.modTime }
|
||||||
|
func (w *winRootFileInfo) IsDir() bool { return true }
|
||||||
|
func (w *winRootFileInfo) Sys() interface{} { return nil }
|
||||||
|
|
||||||
|
// Create a new root FileInfo
|
||||||
|
var rootFileInfo = &winRootFileInfo{
|
||||||
|
name: "/",
|
||||||
|
modTime: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) lstat(name string) (os.FileInfo, error) {
|
||||||
|
if name == `\\.\` && s.winRoot {
|
||||||
|
return rootFileInfo, nil
|
||||||
|
}
|
||||||
|
return os.Lstat(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) stat(name string) (os.FileInfo, error) {
|
||||||
|
if name == `\\.\` && s.winRoot {
|
||||||
|
return rootFileInfo, nil
|
||||||
|
}
|
||||||
|
return os.Stat(name)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue