mirror of https://github.com/pkg/sftp.git
New method Client.Extensions to list server extensions
This commit is contained in:
parent
fcaa492add
commit
265b8168fd
25
client.go
25
client.go
|
@ -155,6 +155,9 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
|
|||
inflight: make(map[uint32]chan<- result),
|
||||
closed: make(chan struct{}),
|
||||
},
|
||||
|
||||
ext: make(map[string]string),
|
||||
|
||||
maxPacket: 1 << 15,
|
||||
maxConcurrentRequests: 64,
|
||||
}
|
||||
|
@ -183,6 +186,8 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
|
|||
type Client struct {
|
||||
clientConn
|
||||
|
||||
ext map[string]string // Extensions (name -> data).
|
||||
|
||||
maxPacket int // max packet size read or written.
|
||||
maxConcurrentRequests int
|
||||
nextid uint32
|
||||
|
@ -223,14 +228,32 @@ func (c *Client) recvVersion() error {
|
|||
return &unexpectedPacketErr{sshFxpVersion, typ}
|
||||
}
|
||||
|
||||
version, _ := unmarshalUint32(data)
|
||||
version, data := unmarshalUint32(data)
|
||||
if version != sftpProtocolVersion {
|
||||
return &unexpectedVersionErr{sftpProtocolVersion, version}
|
||||
}
|
||||
|
||||
for len(data) > 0 {
|
||||
var ext extensionPair
|
||||
ext, data, err = unmarshalExtensionPair(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.ext[ext.Name] = ext.Data
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasExtension checks whether the server supports a named extension.
|
||||
//
|
||||
// The first return value is the extension data reported by the server
|
||||
// (typically a version number).
|
||||
func (c *Client) HasExtension(name string) (string, bool) {
|
||||
data, ok := c.ext[name]
|
||||
return data, ok
|
||||
}
|
||||
|
||||
// Walk returns a new Walker rooted at root.
|
||||
func (c *Client) Walk(root string) *fs.Walker {
|
||||
return fs.WalkFS(root, c)
|
||||
|
|
|
@ -3,6 +3,8 @@ package sftp
|
|||
import (
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClientStatVFS(t *testing.T) {
|
||||
|
@ -13,6 +15,9 @@ func TestClientStatVFS(t *testing.T) {
|
|||
defer cmd.Wait()
|
||||
defer sftp.Close()
|
||||
|
||||
_, ok := sftp.HasExtension("statvfs@openssh.com")
|
||||
require.True(t, ok, "server doesn't list statvfs extension")
|
||||
|
||||
vfs, err := sftp.StatVFS("/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
Loading…
Reference in New Issue