make SFTP extensions configurable

This commit is contained in:
Nicola Murino 2019-11-11 18:04:56 +01:00
parent 515578a967
commit 429ee173fb
2 changed files with 81 additions and 4 deletions

34
sftp.go
View File

@ -85,10 +85,14 @@ const (
sshFxfExcl = 0x00000020
)
var sftpExtensions = []sshExtensionPair{
var (
// supportedSFTPExtensions defines the supported extensions
supportedSFTPExtensions = []sshExtensionPair{
{"hardlink@openssh.com", "1"},
{"posix-rename@openssh.com", "1"},
}
}
sftpExtensions = supportedSFTPExtensions
)
type fxp uint8
@ -227,3 +231,29 @@ func (s *StatusError) Error() string {
func (s *StatusError) FxCode() fxerr {
return fxerr(s.Code)
}
func getSupportedExtensionByName(extensionName string) (sshExtensionPair, error) {
for _, supportedExtension := range supportedSFTPExtensions {
if supportedExtension.Name == extensionName {
return supportedExtension, nil
}
}
return sshExtensionPair{}, fmt.Errorf("Unsupported extension: %v", extensionName)
}
// SetSFTPExtensions allows to customize the supported server extensions.
// See the variable supportedSFTPExtensions for supported extensions.
// This method accepts a slice of sshExtensionPair names for example 'hardlink@openssh.com'.
// If an invalid extension is given an error will be returned and nothing will be changed
func SetSFTPExtensions(extensions ...string) error {
tempExtensions := []sshExtensionPair{}
for _, extension := range extensions {
sftpExtension, err := getSupportedExtensionByName(extension)
if err != nil {
return err
}
tempExtensions = append(tempExtensions, sftpExtension)
}
sftpExtensions = tempExtensions
return nil
}

View File

@ -26,3 +26,50 @@ func TestErrFxCode(t *testing.T) {
assert.Equal(t, statusErr.FxCode(), tt.fx)
}
}
func TestSupportedExtensions(t *testing.T) {
for _, supportedExtension := range supportedSFTPExtensions {
_, err := getSupportedExtensionByName(supportedExtension.Name)
assert.NoError(t, err)
}
_, err := getSupportedExtensionByName("invalid@example.com")
assert.Error(t, err)
}
func TestExtensions(t *testing.T) {
var supportedExtensions []string
for _, supportedExtension := range supportedSFTPExtensions {
supportedExtensions = append(supportedExtensions, supportedExtension.Name)
}
testSFTPExtensions := []string{"hardlink@openssh.com"}
expectedSFTPExtensions := []sshExtensionPair{
{"hardlink@openssh.com", "1"},
}
err := SetSFTPExtensions(testSFTPExtensions...)
assert.NoError(t, err)
assert.Equal(t, expectedSFTPExtensions, sftpExtensions)
invalidSFTPExtensions := []string{"invalid@example.com"}
err = SetSFTPExtensions(invalidSFTPExtensions...)
assert.Error(t, err)
assert.Equal(t, expectedSFTPExtensions, sftpExtensions)
emptySFTPExtensions := []string{}
expectedSFTPExtensions = []sshExtensionPair{}
err = SetSFTPExtensions(emptySFTPExtensions...)
assert.NoError(t, err)
assert.Equal(t, expectedSFTPExtensions, sftpExtensions)
// if we only have an invalid extension nothing will be modified.
invalidSFTPExtensions = []string{
"hardlink@openssh.com",
"invalid@example.com",
}
err = SetSFTPExtensions(invalidSFTPExtensions...)
assert.Error(t, err)
assert.Equal(t, expectedSFTPExtensions, sftpExtensions)
err = SetSFTPExtensions(supportedExtensions...)
assert.Equal(t, supportedSFTPExtensions, sftpExtensions)
}