| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | /* | 
					
						
							|  |  |  | Copyright 2015 The Kubernetes Authors All rights reserved. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | You may obtain a copy of the License at | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     http://www.apache.org/licenses/LICENSE-2.0
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | See the License for the specific language governing permissions and | 
					
						
							|  |  |  | limitations under the License. | 
					
						
							|  |  |  | */ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-02-04 08:06:02 +08:00
										 |  |  | package ssh | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"bytes" | 
					
						
							| 
									
										
										
										
											2015-05-29 02:45:08 +08:00
										 |  |  | 	"crypto/rand" | 
					
						
							|  |  |  | 	"crypto/rsa" | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | 	"crypto/tls" | 
					
						
							| 
									
										
										
										
											2015-05-29 02:45:08 +08:00
										 |  |  | 	"crypto/x509" | 
					
						
							|  |  |  | 	"encoding/pem" | 
					
						
							| 
									
										
										
										
											2015-06-09 06:45:20 +08:00
										 |  |  | 	"errors" | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	"fmt" | 
					
						
							|  |  |  | 	"io" | 
					
						
							|  |  |  | 	"io/ioutil" | 
					
						
							| 
									
										
										
										
											2015-05-29 02:45:08 +08:00
										 |  |  | 	mathrand "math/rand" | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	"net" | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | 	"net/http" | 
					
						
							|  |  |  | 	"net/url" | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	"os" | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | 	"sync" | 
					
						
							| 
									
										
										
										
											2015-06-03 00:52:35 +08:00
										 |  |  | 	"time" | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/golang/glog" | 
					
						
							| 
									
										
										
										
											2015-06-16 08:13:11 +08:00
										 |  |  | 	"github.com/prometheus/client_golang/prometheus" | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	"golang.org/x/crypto/ssh" | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	utilnet "k8s.io/kubernetes/pkg/util/net" | 
					
						
							| 
									
										
										
										
											2016-01-15 15:32:10 +08:00
										 |  |  | 	"k8s.io/kubernetes/pkg/util/runtime" | 
					
						
							| 
									
										
										
										
											2016-02-04 08:06:02 +08:00
										 |  |  | 	"k8s.io/kubernetes/pkg/util/wait" | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-06-16 08:13:11 +08:00
										 |  |  | var ( | 
					
						
							|  |  |  | 	tunnelOpenCounter = prometheus.NewCounter( | 
					
						
							|  |  |  | 		prometheus.CounterOpts{ | 
					
						
							|  |  |  | 			Name: "ssh_tunnel_open_count", | 
					
						
							|  |  |  | 			Help: "Counter of ssh tunnel total open attempts", | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	) | 
					
						
							|  |  |  | 	tunnelOpenFailCounter = prometheus.NewCounter( | 
					
						
							|  |  |  | 		prometheus.CounterOpts{ | 
					
						
							|  |  |  | 			Name: "ssh_tunnel_open_fail_count", | 
					
						
							|  |  |  | 			Help: "Counter of ssh tunnel failed open attempts", | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	) | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func init() { | 
					
						
							|  |  |  | 	prometheus.MustRegister(tunnelOpenCounter) | 
					
						
							|  |  |  | 	prometheus.MustRegister(tunnelOpenFailCounter) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | // TODO: Unit tests for this code, we can spin up a test SSH server with instructions here:
 | 
					
						
							|  |  |  | // https://godoc.org/golang.org/x/crypto/ssh#ServerConn
 | 
					
						
							|  |  |  | type SSHTunnel struct { | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | 	Config  *ssh.ClientConfig | 
					
						
							|  |  |  | 	Host    string | 
					
						
							|  |  |  | 	SSHPort string | 
					
						
							|  |  |  | 	running bool | 
					
						
							|  |  |  | 	sock    net.Listener | 
					
						
							|  |  |  | 	client  *ssh.Client | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (s *SSHTunnel) copyBytes(out io.Writer, in io.Reader) { | 
					
						
							|  |  |  | 	if _, err := io.Copy(out, in); err != nil { | 
					
						
							|  |  |  | 		glog.Errorf("Error in SSH tunnel: %v", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | func NewSSHTunnel(user, keyfile, host string) (*SSHTunnel, error) { | 
					
						
							| 
									
										
										
										
											2015-06-09 06:45:20 +08:00
										 |  |  | 	signer, err := MakePrivateKeySignerFromFile(keyfile) | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-06-09 06:45:20 +08:00
										 |  |  | 	return makeSSHTunnel(user, signer, host) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-06-19 07:31:54 +08:00
										 |  |  | func NewSSHTunnelFromBytes(user string, privateKey []byte, host string) (*SSHTunnel, error) { | 
					
						
							|  |  |  | 	signer, err := MakePrivateKeySignerFromBytes(privateKey) | 
					
						
							| 
									
										
										
										
											2015-06-09 06:45:20 +08:00
										 |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return makeSSHTunnel(user, signer, host) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func makeSSHTunnel(user string, signer ssh.Signer, host string) (*SSHTunnel, error) { | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	config := ssh.ClientConfig{ | 
					
						
							|  |  |  | 		User: user, | 
					
						
							|  |  |  | 		Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return &SSHTunnel{ | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | 		Config:  &config, | 
					
						
							|  |  |  | 		Host:    host, | 
					
						
							|  |  |  | 		SSHPort: "22", | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	}, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (s *SSHTunnel) Open() error { | 
					
						
							|  |  |  | 	var err error | 
					
						
							| 
									
										
										
										
											2016-04-05 06:51:49 +08:00
										 |  |  | 	s.client, err = realTimeoutDialer.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config) | 
					
						
							| 
									
										
										
										
											2015-06-16 08:13:11 +08:00
										 |  |  | 	tunnelOpenCounter.Inc() | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2015-06-16 08:13:11 +08:00
										 |  |  | 		tunnelOpenFailCounter.Inc() | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | 	return err | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (s *SSHTunnel) Dial(network, address string) (net.Conn, error) { | 
					
						
							| 
									
										
										
										
											2015-06-09 06:45:20 +08:00
										 |  |  | 	if s.client == nil { | 
					
						
							|  |  |  | 		return nil, errors.New("tunnel is not opened.") | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | 	return s.client.Dial(network, address) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (s *SSHTunnel) tunnel(conn net.Conn, remoteHost, remotePort string) error { | 
					
						
							| 
									
										
										
										
											2015-06-16 03:38:14 +08:00
										 |  |  | 	if s.client == nil { | 
					
						
							|  |  |  | 		return errors.New("tunnel is not opened.") | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | 	tunnel, err := s.client.Dial("tcp", net.JoinHostPort(remoteHost, remotePort)) | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	go s.copyBytes(tunnel, conn) | 
					
						
							|  |  |  | 	go s.copyBytes(conn, tunnel) | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | func (s *SSHTunnel) Close() error { | 
					
						
							| 
									
										
										
										
											2015-06-16 03:38:14 +08:00
										 |  |  | 	if s.client == nil { | 
					
						
							|  |  |  | 		return errors.New("Cannot close tunnel. Tunnel was not opened.") | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	if err := s.client.Close(); err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-06-16 19:12:25 +08:00
										 |  |  | // Interface to allow mocking of ssh.Dial, for testing SSH
 | 
					
						
							|  |  |  | type sshDialer interface { | 
					
						
							|  |  |  | 	Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Real implementation of sshDialer
 | 
					
						
							|  |  |  | type realSSHDialer struct{} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-06-20 04:46:56 +08:00
										 |  |  | var _ sshDialer = &realSSHDialer{} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-06-16 19:12:25 +08:00
										 |  |  | func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { | 
					
						
							|  |  |  | 	return ssh.Dial(network, addr, config) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-04-05 06:51:49 +08:00
										 |  |  | // timeoutDialer wraps an sshDialer with a timeout around Dial(). The golang
 | 
					
						
							|  |  |  | // ssh library can hang indefinitely inside the Dial() call (see issue #23835).
 | 
					
						
							|  |  |  | // Wrapping all Dial() calls with a conservative timeout provides safety against
 | 
					
						
							|  |  |  | // getting stuck on that.
 | 
					
						
							|  |  |  | type timeoutDialer struct { | 
					
						
							|  |  |  | 	dialer  sshDialer | 
					
						
							|  |  |  | 	timeout time.Duration | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // 150 seconds is longer than the underlying default TCP backoff delay (127
 | 
					
						
							|  |  |  | // seconds). This timeout is only intended to catch otherwise uncaught hangs.
 | 
					
						
							|  |  |  | const sshDialTimeout = 150 * time.Second | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | var realTimeoutDialer sshDialer = &timeoutDialer{&realSSHDialer{}, sshDialTimeout} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (d *timeoutDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { | 
					
						
							|  |  |  | 	var client *ssh.Client | 
					
						
							|  |  |  | 	errCh := make(chan error, 1) | 
					
						
							|  |  |  | 	go func() { | 
					
						
							|  |  |  | 		defer runtime.HandleCrash() | 
					
						
							|  |  |  | 		var err error | 
					
						
							|  |  |  | 		client, err = d.dialer.Dial(network, addr, config) | 
					
						
							|  |  |  | 		errCh <- err | 
					
						
							|  |  |  | 	}() | 
					
						
							|  |  |  | 	select { | 
					
						
							|  |  |  | 	case err := <-errCh: | 
					
						
							|  |  |  | 		return client, err | 
					
						
							|  |  |  | 	case <-time.After(d.timeout): | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("timed out dialing %s:%s", network, addr) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-06-17 15:13:26 +08:00
										 |  |  | // RunSSHCommand returns the stdout, stderr, and exit code from running cmd on
 | 
					
						
							| 
									
										
										
										
											2015-06-16 19:12:25 +08:00
										 |  |  | // host as specific user, along with any SSH-level error.
 | 
					
						
							|  |  |  | // If user=="", it will default (like SSH) to os.Getenv("USER")
 | 
					
						
							|  |  |  | func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) { | 
					
						
							| 
									
										
										
										
											2016-04-05 06:51:49 +08:00
										 |  |  | 	return runSSHCommand(realTimeoutDialer, cmd, user, host, signer, true) | 
					
						
							| 
									
										
										
										
											2015-06-16 19:12:25 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Internal implementation of runSSHCommand, for testing
 | 
					
						
							| 
									
										
										
										
											2016-02-04 08:06:02 +08:00
										 |  |  | func runSSHCommand(dialer sshDialer, cmd, user, host string, signer ssh.Signer, retry bool) (string, string, int, error) { | 
					
						
							| 
									
										
										
										
											2015-06-16 19:12:25 +08:00
										 |  |  | 	if user == "" { | 
					
						
							|  |  |  | 		user = os.Getenv("USER") | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	// Setup the config, dial the server, and open a session.
 | 
					
						
							|  |  |  | 	config := &ssh.ClientConfig{ | 
					
						
							| 
									
										
										
										
											2015-06-16 19:12:25 +08:00
										 |  |  | 		User: user, | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 		Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-06-16 19:12:25 +08:00
										 |  |  | 	client, err := dialer.Dial("tcp", host, config) | 
					
						
							| 
									
										
										
										
											2016-02-04 08:06:02 +08:00
										 |  |  | 	if err != nil && retry { | 
					
						
							|  |  |  | 		err = wait.Poll(5*time.Second, 20*time.Second, func() (bool, error) { | 
					
						
							|  |  |  | 			fmt.Printf("error dialing %s@%s: '%v', retrying\n", user, host, err) | 
					
						
							|  |  |  | 			if client, err = dialer.Dial("tcp", host, config); err != nil { | 
					
						
							|  |  |  | 				return false, nil | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			return true, nil | 
					
						
							|  |  |  | 		}) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2015-06-16 19:12:25 +08:00
										 |  |  | 		return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: '%v'", user, host, err) | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	session, err := client.NewSession() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2015-06-16 19:12:25 +08:00
										 |  |  | 		return "", "", 0, fmt.Errorf("error creating session to %s@%s: '%v'", user, host, err) | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	defer session.Close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Run the command.
 | 
					
						
							|  |  |  | 	code := 0 | 
					
						
							|  |  |  | 	var bout, berr bytes.Buffer | 
					
						
							|  |  |  | 	session.Stdout, session.Stderr = &bout, &berr | 
					
						
							|  |  |  | 	if err = session.Run(cmd); err != nil { | 
					
						
							|  |  |  | 		// Check whether the command failed to run or didn't complete.
 | 
					
						
							|  |  |  | 		if exiterr, ok := err.(*ssh.ExitError); ok { | 
					
						
							|  |  |  | 			// If we got an ExitError and the exit code is nonzero, we'll
 | 
					
						
							|  |  |  | 			// consider the SSH itself successful (just that the command run
 | 
					
						
							|  |  |  | 			// errored on the host).
 | 
					
						
							|  |  |  | 			if code = exiterr.ExitStatus(); code != 0 { | 
					
						
							|  |  |  | 				err = nil | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} else { | 
					
						
							|  |  |  | 			// Some other kind of error happened (e.g. an IOError); consider the
 | 
					
						
							|  |  |  | 			// SSH unsuccessful.
 | 
					
						
							| 
									
										
										
										
											2015-06-16 19:12:25 +08:00
										 |  |  | 			err = fmt.Errorf("failed running `%s` on %s@%s: '%v'", cmd, user, host, err) | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return bout.String(), berr.String(), code, err | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-06-09 06:45:20 +08:00
										 |  |  | func MakePrivateKeySignerFromFile(key string) (ssh.Signer, error) { | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	// Create an actual signer.
 | 
					
						
							| 
									
										
										
										
											2015-06-18 02:49:13 +08:00
										 |  |  | 	buffer, err := ioutil.ReadFile(key) | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("error reading SSH key %s: '%v'", key, err) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-06-09 06:45:20 +08:00
										 |  |  | 	return MakePrivateKeySignerFromBytes(buffer) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func MakePrivateKeySignerFromBytes(buffer []byte) (ssh.Signer, error) { | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	signer, err := ssh.ParsePrivateKey(buffer) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2015-06-09 06:45:20 +08:00
										 |  |  | 		return nil, fmt.Errorf("error parsing SSH key %s: '%v'", buffer, err) | 
					
						
							| 
									
										
										
										
											2015-05-28 07:32:43 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	return signer, nil | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-06-18 02:49:13 +08:00
										 |  |  | func ParsePublicKeyFromFile(keyFile string) (*rsa.PublicKey, error) { | 
					
						
							|  |  |  | 	buffer, err := ioutil.ReadFile(keyFile) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("error reading SSH key %s: '%v'", keyFile, err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	keyBlock, _ := pem.Decode(buffer) | 
					
						
							|  |  |  | 	key, err := x509.ParsePKIXPublicKey(keyBlock.Bytes) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("error parsing SSH key %s: '%v'", keyFile, err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	rsaKey, ok := key.(*rsa.PublicKey) | 
					
						
							|  |  |  | 	if !ok { | 
					
						
							|  |  |  | 		return nil, fmt.Errorf("SSH key could not be parsed as rsa public key") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return rsaKey, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | type tunnel interface { | 
					
						
							|  |  |  | 	Open() error | 
					
						
							|  |  |  | 	Close() error | 
					
						
							|  |  |  | 	Dial(network, address string) (net.Conn, error) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type sshTunnelEntry struct { | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | 	Address string | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | 	Tunnel  tunnel | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type sshTunnelCreator interface { | 
					
						
							|  |  |  | 	NewSSHTunnel(user, keyFile, healthCheckURL string) (tunnel, error) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type realTunnelCreator struct{} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (*realTunnelCreator) NewSSHTunnel(user, keyFile, healthCheckURL string) (tunnel, error) { | 
					
						
							|  |  |  | 	return NewSSHTunnel(user, keyFile, healthCheckURL) | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-06-17 01:36:38 +08:00
										 |  |  | type SSHTunnelList struct { | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | 	entries       []sshTunnelEntry | 
					
						
							|  |  |  | 	adding        map[string]bool | 
					
						
							|  |  |  | 	tunnelCreator sshTunnelCreator | 
					
						
							|  |  |  | 	tunnelsLock   sync.Mutex | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	user           string | 
					
						
							|  |  |  | 	keyfile        string | 
					
						
							|  |  |  | 	healthCheckURL *url.URL | 
					
						
							| 
									
										
										
										
											2015-06-17 01:36:38 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | func NewSSHTunnelList(user, keyfile string, healthCheckURL *url.URL, stopChan chan struct{}) *SSHTunnelList { | 
					
						
							|  |  |  | 	l := &SSHTunnelList{ | 
					
						
							|  |  |  | 		adding:         make(map[string]bool), | 
					
						
							|  |  |  | 		tunnelCreator:  &realTunnelCreator{}, | 
					
						
							|  |  |  | 		user:           user, | 
					
						
							|  |  |  | 		keyfile:        keyfile, | 
					
						
							|  |  |  | 		healthCheckURL: healthCheckURL, | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | 	healthCheckPoll := 1 * time.Minute | 
					
						
							| 
									
										
										
										
											2016-02-02 18:57:06 +08:00
										 |  |  | 	go wait.Until(func() { | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | 		l.tunnelsLock.Lock() | 
					
						
							|  |  |  | 		defer l.tunnelsLock.Unlock() | 
					
						
							|  |  |  | 		// Healthcheck each tunnel every minute
 | 
					
						
							|  |  |  | 		numTunnels := len(l.entries) | 
					
						
							|  |  |  | 		for i, entry := range l.entries { | 
					
						
							|  |  |  | 			// Stagger healthchecks evenly across duration of healthCheckPoll.
 | 
					
						
							|  |  |  | 			delay := healthCheckPoll * time.Duration(i) / time.Duration(numTunnels) | 
					
						
							|  |  |  | 			l.delayedHealthCheck(entry, delay) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	}, healthCheckPoll, stopChan) | 
					
						
							|  |  |  | 	return l | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | func (l *SSHTunnelList) delayedHealthCheck(e sshTunnelEntry, delay time.Duration) { | 
					
						
							|  |  |  | 	go func() { | 
					
						
							|  |  |  | 		defer runtime.HandleCrash() | 
					
						
							|  |  |  | 		time.Sleep(delay) | 
					
						
							|  |  |  | 		if err := l.healthCheck(e); err != nil { | 
					
						
							|  |  |  | 			glog.Errorf("Healthcheck failed for tunnel to %q: %v", e.Address, err) | 
					
						
							|  |  |  | 			glog.Infof("Attempting once to re-establish tunnel to %q", e.Address) | 
					
						
							|  |  |  | 			l.removeAndReAdd(e) | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | 	}() | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | func (l *SSHTunnelList) healthCheck(e sshTunnelEntry) error { | 
					
						
							|  |  |  | 	// GET the healthcheck path using the provided tunnel's dial function.
 | 
					
						
							|  |  |  | 	transport := utilnet.SetTransportDefaults(&http.Transport{ | 
					
						
							|  |  |  | 		Dial: e.Tunnel.Dial, | 
					
						
							|  |  |  | 		// TODO(cjcullen): Plumb real TLS options through.
 | 
					
						
							|  |  |  | 		TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | 	client := &http.Client{Transport: transport} | 
					
						
							|  |  |  | 	_, err := client.Get(l.healthCheckURL.String()) | 
					
						
							|  |  |  | 	return err | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (l *SSHTunnelList) removeAndReAdd(e sshTunnelEntry) { | 
					
						
							|  |  |  | 	// Find the entry to replace.
 | 
					
						
							|  |  |  | 	l.tunnelsLock.Lock() | 
					
						
							|  |  |  | 	defer l.tunnelsLock.Unlock() | 
					
						
							|  |  |  | 	for i, entry := range l.entries { | 
					
						
							|  |  |  | 		if entry.Tunnel == e.Tunnel { | 
					
						
							|  |  |  | 			l.entries = append(l.entries[:i], l.entries[i+1:]...) | 
					
						
							|  |  |  | 			l.adding[e.Address] = true | 
					
						
							|  |  |  | 			go l.createAndAddTunnel(e.Address) | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | func (l *SSHTunnelList) Dial(net, addr string) (net.Conn, error) { | 
					
						
							|  |  |  | 	start := time.Now() | 
					
						
							|  |  |  | 	id := mathrand.Int63() // So you can match begins/ends in the log.
 | 
					
						
							|  |  |  | 	glog.Infof("[%x: %v] Dialing...", id, addr) | 
					
						
							|  |  |  | 	defer func() { | 
					
						
							|  |  |  | 		glog.Infof("[%x: %v] Dialed in %v.", id, addr, time.Now().Sub(start)) | 
					
						
							|  |  |  | 	}() | 
					
						
							|  |  |  | 	tunnel, err := l.pickRandomTunnel() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | 	return tunnel.Dial(net, addr) | 
					
						
							| 
									
										
										
										
											2015-06-19 07:31:54 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | func (l *SSHTunnelList) pickRandomTunnel() (tunnel, error) { | 
					
						
							|  |  |  | 	l.tunnelsLock.Lock() | 
					
						
							|  |  |  | 	defer l.tunnelsLock.Unlock() | 
					
						
							| 
									
										
										
										
											2015-06-19 07:31:54 +08:00
										 |  |  | 	if len(l.entries) == 0 { | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | 		return nil, fmt.Errorf("No SSH tunnels currently open. Were the targets able to accept an ssh-key for user %q?", l.user) | 
					
						
							| 
									
										
										
										
											2015-06-19 07:31:54 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	n := mathrand.Intn(len(l.entries)) | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | 	return l.entries[n].Tunnel, nil | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | // Update reconciles the list's entries with the specified addresses. Existing
 | 
					
						
							|  |  |  | // tunnels that are not in addresses are removed from entries and closed in a
 | 
					
						
							|  |  |  | // background goroutine. New tunnels specified in addresses are opened in a
 | 
					
						
							|  |  |  | // background goroutine and then added to entries.
 | 
					
						
							|  |  |  | func (l *SSHTunnelList) Update(addrs []string) { | 
					
						
							|  |  |  | 	haveAddrsMap := make(map[string]bool) | 
					
						
							|  |  |  | 	wantAddrsMap := make(map[string]bool) | 
					
						
							|  |  |  | 	func() { | 
					
						
							|  |  |  | 		l.tunnelsLock.Lock() | 
					
						
							|  |  |  | 		defer l.tunnelsLock.Unlock() | 
					
						
							|  |  |  | 		// Build a map of what we currently have.
 | 
					
						
							|  |  |  | 		for i := range l.entries { | 
					
						
							|  |  |  | 			haveAddrsMap[l.entries[i].Address] = true | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | 		// Determine any necessary additions.
 | 
					
						
							|  |  |  | 		for i := range addrs { | 
					
						
							|  |  |  | 			// Add tunnel if it is not in l.entries or l.adding
 | 
					
						
							|  |  |  | 			if _, ok := haveAddrsMap[addrs[i]]; !ok { | 
					
						
							|  |  |  | 				if _, ok := l.adding[addrs[i]]; !ok { | 
					
						
							|  |  |  | 					l.adding[addrs[i]] = true | 
					
						
							|  |  |  | 					addr := addrs[i] | 
					
						
							|  |  |  | 					go func() { | 
					
						
							|  |  |  | 						defer runtime.HandleCrash() | 
					
						
							|  |  |  | 						// Actually adding tunnel to list will block until lock
 | 
					
						
							|  |  |  | 						// is released after deletions.
 | 
					
						
							|  |  |  | 						l.createAndAddTunnel(addr) | 
					
						
							|  |  |  | 					}() | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			wantAddrsMap[addrs[i]] = true | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		// Determine any necessary deletions.
 | 
					
						
							|  |  |  | 		var newEntries []sshTunnelEntry | 
					
						
							|  |  |  | 		for i := range l.entries { | 
					
						
							|  |  |  | 			if _, ok := wantAddrsMap[l.entries[i].Address]; !ok { | 
					
						
							|  |  |  | 				tunnelEntry := l.entries[i] | 
					
						
							|  |  |  | 				glog.Infof("Removing tunnel to deleted node at %q", tunnelEntry.Address) | 
					
						
							|  |  |  | 				go func() { | 
					
						
							|  |  |  | 					defer runtime.HandleCrash() | 
					
						
							|  |  |  | 					if err := tunnelEntry.Tunnel.Close(); err != nil { | 
					
						
							|  |  |  | 						glog.Errorf("Failed to close tunnel to %q: %v", tunnelEntry.Address, err) | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 				}() | 
					
						
							|  |  |  | 			} else { | 
					
						
							|  |  |  | 				newEntries = append(newEntries, l.entries[i]) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		l.entries = newEntries | 
					
						
							|  |  |  | 	}() | 
					
						
							| 
									
										
										
										
											2015-05-28 12:38:21 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2015-05-29 02:45:08 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-12-05 10:01:29 +08:00
										 |  |  | func (l *SSHTunnelList) createAndAddTunnel(addr string) { | 
					
						
							|  |  |  | 	glog.Infof("Trying to add tunnel to %q", addr) | 
					
						
							|  |  |  | 	tunnel, err := l.tunnelCreator.NewSSHTunnel(l.user, l.keyfile, addr) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		glog.Errorf("Failed to create tunnel for %q: %v", addr, err) | 
					
						
							|  |  |  | 		return | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if err := tunnel.Open(); err != nil { | 
					
						
							|  |  |  | 		glog.Errorf("Failed to open tunnel to %q: %v", addr, err) | 
					
						
							|  |  |  | 		l.tunnelsLock.Lock() | 
					
						
							|  |  |  | 		delete(l.adding, addr) | 
					
						
							|  |  |  | 		l.tunnelsLock.Unlock() | 
					
						
							|  |  |  | 		return | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	l.tunnelsLock.Lock() | 
					
						
							|  |  |  | 	l.entries = append(l.entries, sshTunnelEntry{addr, tunnel}) | 
					
						
							|  |  |  | 	delete(l.adding, addr) | 
					
						
							|  |  |  | 	l.tunnelsLock.Unlock() | 
					
						
							|  |  |  | 	glog.Infof("Successfully added tunnel for %q", addr) | 
					
						
							| 
									
										
										
										
											2015-06-17 01:36:38 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-29 02:45:08 +08:00
										 |  |  | func EncodePrivateKey(private *rsa.PrivateKey) []byte { | 
					
						
							|  |  |  | 	return pem.EncodeToMemory(&pem.Block{ | 
					
						
							|  |  |  | 		Bytes: x509.MarshalPKCS1PrivateKey(private), | 
					
						
							|  |  |  | 		Type:  "RSA PRIVATE KEY", | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func EncodePublicKey(public *rsa.PublicKey) ([]byte, error) { | 
					
						
							|  |  |  | 	publicBytes, err := x509.MarshalPKIXPublicKey(public) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return pem.EncodeToMemory(&pem.Block{ | 
					
						
							|  |  |  | 		Bytes: publicBytes, | 
					
						
							|  |  |  | 		Type:  "PUBLIC KEY", | 
					
						
							|  |  |  | 	}), nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func EncodeSSHKey(public *rsa.PublicKey) ([]byte, error) { | 
					
						
							|  |  |  | 	publicKey, err := ssh.NewPublicKey(public) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return ssh.MarshalAuthorizedKey(publicKey), nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func GenerateKey(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) { | 
					
						
							|  |  |  | 	private, err := rsa.GenerateKey(rand.Reader, bits) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return private, &private.PublicKey, nil | 
					
						
							|  |  |  | } |