mirror of https://github.com/pkg/sftp.git
make SFTP extensions configurable
This commit is contained in:
parent
515578a967
commit
429ee173fb
38
sftp.go
38
sftp.go
|
@ -85,10 +85,14 @@ const (
|
|||
sshFxfExcl = 0x00000020
|
||||
)
|
||||
|
||||
var sftpExtensions = []sshExtensionPair{
|
||||
{"hardlink@openssh.com", "1"},
|
||||
{"posix-rename@openssh.com", "1"},
|
||||
}
|
||||
var (
|
||||
// supportedSFTPExtensions defines the supported extensions
|
||||
supportedSFTPExtensions = []sshExtensionPair{
|
||||
{"hardlink@openssh.com", "1"},
|
||||
{"posix-rename@openssh.com", "1"},
|
||||
}
|
||||
sftpExtensions = supportedSFTPExtensions
|
||||
)
|
||||
|
||||
type fxp uint8
|
||||
|
||||
|
@ -227,3 +231,29 @@ func (s *StatusError) Error() string {
|
|||
func (s *StatusError) FxCode() fxerr {
|
||||
return fxerr(s.Code)
|
||||
}
|
||||
|
||||
func getSupportedExtensionByName(extensionName string) (sshExtensionPair, error) {
|
||||
for _, supportedExtension := range supportedSFTPExtensions {
|
||||
if supportedExtension.Name == extensionName {
|
||||
return supportedExtension, nil
|
||||
}
|
||||
}
|
||||
return sshExtensionPair{}, fmt.Errorf("Unsupported extension: %v", extensionName)
|
||||
}
|
||||
|
||||
// SetSFTPExtensions allows to customize the supported server extensions.
|
||||
// See the variable supportedSFTPExtensions for supported extensions.
|
||||
// This method accepts a slice of sshExtensionPair names for example 'hardlink@openssh.com'.
|
||||
// If an invalid extension is given an error will be returned and nothing will be changed
|
||||
func SetSFTPExtensions(extensions ...string) error {
|
||||
tempExtensions := []sshExtensionPair{}
|
||||
for _, extension := range extensions {
|
||||
sftpExtension, err := getSupportedExtensionByName(extension)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tempExtensions = append(tempExtensions, sftpExtension)
|
||||
}
|
||||
sftpExtensions = tempExtensions
|
||||
return nil
|
||||
}
|
||||
|
|
47
sftp_test.go
47
sftp_test.go
|
@ -26,3 +26,50 @@ func TestErrFxCode(t *testing.T) {
|
|||
assert.Equal(t, statusErr.FxCode(), tt.fx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedExtensions(t *testing.T) {
|
||||
for _, supportedExtension := range supportedSFTPExtensions {
|
||||
_, err := getSupportedExtensionByName(supportedExtension.Name)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
_, err := getSupportedExtensionByName("invalid@example.com")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestExtensions(t *testing.T) {
|
||||
var supportedExtensions []string
|
||||
for _, supportedExtension := range supportedSFTPExtensions {
|
||||
supportedExtensions = append(supportedExtensions, supportedExtension.Name)
|
||||
}
|
||||
|
||||
testSFTPExtensions := []string{"hardlink@openssh.com"}
|
||||
expectedSFTPExtensions := []sshExtensionPair{
|
||||
{"hardlink@openssh.com", "1"},
|
||||
}
|
||||
err := SetSFTPExtensions(testSFTPExtensions...)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedSFTPExtensions, sftpExtensions)
|
||||
|
||||
invalidSFTPExtensions := []string{"invalid@example.com"}
|
||||
err = SetSFTPExtensions(invalidSFTPExtensions...)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, expectedSFTPExtensions, sftpExtensions)
|
||||
|
||||
emptySFTPExtensions := []string{}
|
||||
expectedSFTPExtensions = []sshExtensionPair{}
|
||||
err = SetSFTPExtensions(emptySFTPExtensions...)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedSFTPExtensions, sftpExtensions)
|
||||
|
||||
// if we only have an invalid extension nothing will be modified.
|
||||
invalidSFTPExtensions = []string{
|
||||
"hardlink@openssh.com",
|
||||
"invalid@example.com",
|
||||
}
|
||||
err = SetSFTPExtensions(invalidSFTPExtensions...)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, expectedSFTPExtensions, sftpExtensions)
|
||||
|
||||
err = SetSFTPExtensions(supportedExtensions...)
|
||||
assert.Equal(t, supportedSFTPExtensions, sftpExtensions)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue