Merge pull request #571 from powellnorma/win-root

server.go: "/" for windows
This commit is contained in:
Cassondra Foesch 2025-01-03 17:17:45 +00:00 committed by GitHub
commit 088878ba50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 214 additions and 8 deletions

View File

@ -20,10 +20,13 @@ func main() {
var ( var (
readOnly bool readOnly bool
debugStderr bool debugStderr bool
winRoot bool
) )
flag.BoolVar(&readOnly, "R", false, "read-only server") flag.BoolVar(&readOnly, "R", false, "read-only server")
flag.BoolVar(&debugStderr, "e", false, "debug to stderr") flag.BoolVar(&debugStderr, "e", false, "debug to stderr")
flag.BoolVar(&winRoot, "wr", false, "windows root")
flag.Parse() flag.Parse()
debugStream := io.Discard debugStream := io.Discard
@ -128,6 +131,11 @@ func main() {
fmt.Fprintf(debugStream, "Read write server\n") fmt.Fprintf(debugStream, "Read write server\n")
} }
if winRoot {
serverOptions = append(serverOptions, sftp.WindowsRootEnumeratesDrives())
fmt.Fprintf(debugStream, "Windows root enabled\n")
}
server, err := sftp.NewServer( server, err := sftp.NewServer(
channel, channel,
serverOptions..., serverOptions...,

1
go.mod
View File

@ -6,4 +6,5 @@ require (
github.com/kr/fs v0.1.0 github.com/kr/fs v0.1.0
github.com/stretchr/testify v1.8.0 github.com/stretchr/testify v1.8.0
golang.org/x/crypto v0.31.0 golang.org/x/crypto v0.31.0
golang.org/x/sys v0.28.0 // indirect
) )

View File

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
@ -21,6 +22,18 @@ const (
SftpServerWorkerCount = 8 SftpServerWorkerCount = 8
) )
type file interface {
Stat() (os.FileInfo, error)
ReadAt(b []byte, off int64) (int, error)
WriteAt(b []byte, off int64) (int, error)
Readdir(int) ([]os.FileInfo, error)
Name() string
Truncate(int64) error
Chmod(mode fs.FileMode) error
Chown(uid, gid int) error
Close() error
}
// Server is an SSH File Transfer Protocol (sftp) server. // Server is an SSH File Transfer Protocol (sftp) server.
// This is intended to provide the sftp subsystem to an ssh server daemon. // This is intended to provide the sftp subsystem to an ssh server daemon.
// This implementation currently supports most of sftp server protocol version 3, // This implementation currently supports most of sftp server protocol version 3,
@ -30,14 +43,15 @@ type Server struct {
debugStream io.Writer debugStream io.Writer
readOnly bool readOnly bool
pktMgr *packetManager pktMgr *packetManager
openFiles map[string]*os.File openFiles map[string]file
openFilesLock sync.RWMutex openFilesLock sync.RWMutex
handleCount int handleCount int
workDir string workDir string
winRoot bool
maxTxPacket uint32 maxTxPacket uint32
} }
func (svr *Server) nextHandle(f *os.File) string { func (svr *Server) nextHandle(f file) string {
svr.openFilesLock.Lock() svr.openFilesLock.Lock()
defer svr.openFilesLock.Unlock() defer svr.openFilesLock.Unlock()
svr.handleCount++ svr.handleCount++
@ -57,7 +71,7 @@ func (svr *Server) closeHandle(handle string) error {
return EBADF return EBADF
} }
func (svr *Server) getHandle(handle string) (*os.File, bool) { func (svr *Server) getHandle(handle string) (file, bool) {
svr.openFilesLock.RLock() svr.openFilesLock.RLock()
defer svr.openFilesLock.RUnlock() defer svr.openFilesLock.RUnlock()
f, ok := svr.openFiles[handle] f, ok := svr.openFiles[handle]
@ -86,7 +100,7 @@ func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error)
serverConn: svrConn, serverConn: svrConn,
debugStream: ioutil.Discard, debugStream: ioutil.Discard,
pktMgr: newPktMgr(svrConn), pktMgr: newPktMgr(svrConn),
openFiles: make(map[string]*os.File), openFiles: make(map[string]file),
maxTxPacket: defaultMaxTxPacket, maxTxPacket: defaultMaxTxPacket,
} }
@ -118,6 +132,14 @@ func ReadOnly() ServerOption {
} }
} }
// WindowsRootEnumeratesDrives configures a Server to serve a virtual '/' for windows that lists all drives
func WindowsRootEnumeratesDrives() ServerOption {
return func(s *Server) error {
s.winRoot = true
return nil
}
}
// WithAllocator enable the allocator. // WithAllocator enable the allocator.
// After processing a packet we keep in memory the allocated slices // After processing a packet we keep in memory the allocated slices
// and we reuse them for new packets. // and we reuse them for new packets.
@ -215,7 +237,7 @@ func handlePacket(s *Server, p orderedRequest) error {
} }
case *sshFxpLstatPacket: case *sshFxpLstatPacket:
// stat the requested file // stat the requested file
info, err := os.Lstat(s.toLocalPath(p.Path)) info, err := s.lstat(s.toLocalPath(p.Path))
rpkt = &sshFxpStatResponse{ rpkt = &sshFxpStatResponse{
ID: p.ID, ID: p.ID,
info: info, info: info,
@ -289,7 +311,7 @@ func handlePacket(s *Server, p orderedRequest) error {
case *sshFxpOpendirPacket: case *sshFxpOpendirPacket:
lp := s.toLocalPath(p.Path) lp := s.toLocalPath(p.Path)
if stat, err := os.Stat(lp); err != nil { if stat, err := s.stat(lp); err != nil {
rpkt = statusFromError(p.ID, err) rpkt = statusFromError(p.ID, err)
} else if !stat.IsDir() { } else if !stat.IsDir() {
rpkt = statusFromError(p.ID, &os.PathError{ rpkt = statusFromError(p.ID, &os.PathError{
@ -493,7 +515,7 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket {
mode = fs.FileMode() & os.ModePerm mode = fs.FileMode() & os.ModePerm
} }
f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, mode) f, err := svr.openfile(svr.toLocalPath(p.Path), osFlags, mode)
if err != nil { if err != nil {
return statusFromError(p.ID, err) return statusFromError(p.ID, err)
} }

21
server_posix.go Normal file
View File

@ -0,0 +1,21 @@
//go:build !windows
// +build !windows
package sftp
import (
"io/fs"
"os"
)
func (s *Server) openfile(path string, flag int, mode fs.FileMode) (file, error) {
return os.OpenFile(path, flag, mode)
}
func (s *Server) lstat(name string) (os.FileInfo, error) {
return os.Lstat(name)
}
func (s *Server) stat(name string) (os.FileInfo, error) {
return os.Stat(name)
}

View File

@ -1,8 +1,15 @@
package sftp package sftp
import ( import (
"fmt"
"io"
"io/fs"
"os"
"path" "path"
"path/filepath" "path/filepath"
"time"
"golang.org/x/sys/windows"
) )
func (s *Server) toLocalPath(p string) string { func (s *Server) toLocalPath(p string) string {
@ -12,7 +19,11 @@ func (s *Server) toLocalPath(p string) string {
lp := filepath.FromSlash(p) lp := filepath.FromSlash(p)
if path.IsAbs(p) { if path.IsAbs(p) { // starts with '/'
if len(p) == 1 && s.winRoot {
return `\\.\` // for openfile
}
tmp := lp tmp := lp
for len(tmp) > 0 && tmp[0] == '\\' { for len(tmp) > 0 && tmp[0] == '\\' {
tmp = tmp[1:] tmp = tmp[1:]
@ -33,7 +44,150 @@ func (s *Server) toLocalPath(p string) string {
// e.g. "/C:" to "C:\\" // e.g. "/C:" to "C:\\"
return tmp return tmp
} }
if s.winRoot {
// Make it so that "/Windows" is not found, and "/c:/Windows" has to be used
return `\\.\` + tmp
}
} }
return lp return lp
} }
func bitsToDrives(bitmap uint32) []string {
var drive rune = 'a'
var drives []string
for bitmap != 0 && drive <= 'z' {
if bitmap&1 == 1 {
drives = append(drives, string(drive)+":")
}
drive++
bitmap >>= 1
}
return drives
}
func getDrives() ([]string, error) {
mask, err := windows.GetLogicalDrives()
if err != nil {
return nil, fmt.Errorf("GetLogicalDrives: %w", err)
}
return bitsToDrives(mask), nil
}
type driveInfo struct {
fs.FileInfo
name string
}
func (i *driveInfo) Name() string {
return i.name // since the Name() returned from a os.Stat("C:\\") is "\\"
}
type winRoot struct {
drives []string
}
func newWinRoot() (*winRoot, error) {
drives, err := getDrives()
if err != nil {
return nil, err
}
return &winRoot{
drives: drives,
}, nil
}
func (f *winRoot) Readdir(n int) ([]os.FileInfo, error) {
drives := f.drives
if n > 0 && len(drives) > n {
drives = drives[:n]
}
f.drives = f.drives[len(drives):]
if len(drives) == 0 {
return nil, io.EOF
}
var infos []os.FileInfo
for _, drive := range drives {
fi, err := os.Stat(drive + `\`)
if err != nil {
return nil, err
}
di := &driveInfo{
FileInfo: fi,
name: drive,
}
infos = append(infos, di)
}
return infos, nil
}
func (f *winRoot) Stat() (os.FileInfo, error) {
return rootFileInfo, nil
}
func (f *winRoot) ReadAt(b []byte, off int64) (int, error) {
return 0, os.ErrPermission
}
func (f *winRoot) WriteAt(b []byte, off int64) (int, error) {
return 0, os.ErrPermission
}
func (f *winRoot) Name() string {
return "/"
}
func (f *winRoot) Truncate(int64) error {
return os.ErrPermission
}
func (f *winRoot) Chmod(mode fs.FileMode) error {
return os.ErrPermission
}
func (f *winRoot) Chown(uid, gid int) error {
return os.ErrPermission
}
func (f *winRoot) Close() error {
f.drives = nil
return nil
}
func (s *Server) openfile(path string, flag int, mode fs.FileMode) (file, error) {
if path == `\\.\` && s.winRoot {
return newWinRoot()
}
return os.OpenFile(path, flag, mode)
}
type winRootFileInfo struct {
name string
modTime time.Time
}
func (w *winRootFileInfo) Name() string { return w.name }
func (w *winRootFileInfo) Size() int64 { return 0 }
func (w *winRootFileInfo) Mode() fs.FileMode { return fs.ModeDir | 0555 } // read+execute for all
func (w *winRootFileInfo) ModTime() time.Time { return w.modTime }
func (w *winRootFileInfo) IsDir() bool { return true }
func (w *winRootFileInfo) Sys() interface{} { return nil }
// Create a new root FileInfo
var rootFileInfo = &winRootFileInfo{
name: "/",
modTime: time.Now(),
}
func (s *Server) lstat(name string) (os.FileInfo, error) {
if name == `\\.\` && s.winRoot {
return rootFileInfo, nil
}
return os.Lstat(name)
}
func (s *Server) stat(name string) (os.FileInfo, error) {
if name == `\\.\` && s.winRoot {
return rootFileInfo, nil
}
return os.Stat(name)
}