This commit is contained in:
Patrick Ohly 2025-05-30 12:56:03 -07:00 committed by GitHub
commit 3655e3cf9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 175 additions and 89 deletions

View File

@ -60,7 +60,8 @@ func NewTunnelingHandler(upgradeHandler http.Handler) *TunnelingHandler {
// case the upstream upgrade fails, we delegate communication to the passed // case the upstream upgrade fails, we delegate communication to the passed
// in "w" ResponseWriter. // in "w" ResponseWriter.
func (h *TunnelingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (h *TunnelingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
klog.V(4).Infoln("TunnelingHandler ServeHTTP") logger := klog.FromContext(req.Context())
logger.V(4).Info("TunnelingHandler ServeHTTP")
spdyProtocols := spdyProtocolsFromWebsocketProtocols(req) spdyProtocols := spdyProtocolsFromWebsocketProtocols(req)
if len(spdyProtocols) == 0 { if len(spdyProtocols) == 0 {
@ -75,10 +76,12 @@ func (h *TunnelingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// and the "conn" is hijacked and used in the subsequent upgradeHandler, or // and the "conn" is hijacked and used in the subsequent upgradeHandler, or
// the upgrade failed, and "w" is the delegate used for the non-upgrade response. // the upgrade failed, and "w" is the delegate used for the non-upgrade response.
writer := &tunnelingResponseWriter{ writer := &tunnelingResponseWriter{
logger: logger,
// "w" is used in the non-upgrade error cases called in the upgradeHandler. // "w" is used in the non-upgrade error cases called in the upgradeHandler.
w: w, w: w,
// "conn" is returned in the successful upgrade case when hijacked in the upgradeHandler. // "conn" is returned in the successful upgrade case when hijacked in the upgradeHandler.
conn: &headerInterceptingConn{ conn: &headerInterceptingConn{
logger: logger,
initializableConn: &tunnelingWebsocketUpgraderConn{ initializableConn: &tunnelingWebsocketUpgraderConn{
w: w, w: w,
req: req, req: req,
@ -86,7 +89,7 @@ func (h *TunnelingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}, },
} }
klog.V(4).Infoln("Tunnel spdy through websockets using the UpgradeAwareProxy") logger.V(4).Info("Tunnel spdy through websockets using the UpgradeAwareProxy")
h.upgradeHandler.ServeHTTP(writer, spdyRequest) h.upgradeHandler.ServeHTTP(writer, spdyRequest)
} }
@ -131,6 +134,7 @@ var _ http.Hijacker = &tunnelingResponseWriter{}
// Once Write or WriteHeader is called, Hijack returns an error. // Once Write or WriteHeader is called, Hijack returns an error.
// Once Hijack is called, Write, WriteHeader, and Hijack return errors. // Once Hijack is called, Write, WriteHeader, and Hijack return errors.
type tunnelingResponseWriter struct { type tunnelingResponseWriter struct {
logger klog.Logger
// w is used to delegate Header(), WriteHeader(), and Write() calls // w is used to delegate Header(), WriteHeader(), and Write() calls
w http.ResponseWriter w http.ResponseWriter
// conn is returned from Hijack() // conn is returned from Hijack()
@ -150,15 +154,15 @@ func (w *tunnelingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error)
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
if w.written { if w.written {
klog.Errorf("Hijack called after write") w.logger.Error(nil, "Hijack called after write")
return nil, nil, errors.New("connection has already been written to") return nil, nil, errors.New("connection has already been written to")
} }
if w.hijacked { if w.hijacked {
klog.Errorf("Hijack called after hijack") w.logger.Error(nil, "Hijack called after hijack")
return nil, nil, errors.New("connection has already been hijacked") return nil, nil, errors.New("connection has already been hijacked")
} }
w.hijacked = true w.hijacked = true
klog.V(6).Infof("Hijack returning websocket tunneling net.Conn") w.logger.V(6).Info("Hijack returning websocket tunneling net.Conn")
return w.conn, nil, nil return w.conn, nil, nil
} }
@ -172,7 +176,7 @@ func (w *tunnelingResponseWriter) Write(p []byte) (int, error) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
if w.hijacked { if w.hijacked {
klog.Errorf("Write called after hijack") w.logger.Error(nil, "Write called after hijack")
return 0, http.ErrHijacked return 0, http.ErrHijacked
} }
w.written = true w.written = true
@ -184,18 +188,18 @@ func (w *tunnelingResponseWriter) WriteHeader(statusCode int) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
if w.written { if w.written {
klog.Errorf("WriteHeader called after write") w.logger.Error(nil, "WriteHeader called after write")
return return
} }
if w.hijacked { if w.hijacked {
klog.Errorf("WriteHeader called after hijack") w.logger.Error(nil, "WriteHeader called after hijack")
return return
} }
w.written = true w.written = true
if statusCode == http.StatusSwitchingProtocols { if statusCode == http.StatusSwitchingProtocols {
// 101 upgrade responses must come via the hijacked connection, not WriteHeader // 101 upgrade responses must come via the hijacked connection, not WriteHeader
klog.Errorf("WriteHeader called with 101 upgrade") w.logger.Error(nil, "WriteHeader called with 101 upgrade")
http.Error(w.w, "unexpected upgrade", http.StatusInternalServerError) http.Error(w.w, "unexpected upgrade", http.StatusInternalServerError)
return return
} }
@ -208,6 +212,7 @@ func (w *tunnelingResponseWriter) WriteHeader(statusCode int) {
// HTTP response status/headers from the upstream SPDY connection, then use // HTTP response status/headers from the upstream SPDY connection, then use
// that to decide how to initialize the delegate connection for writes. // that to decide how to initialize the delegate connection for writes.
type headerInterceptingConn struct { type headerInterceptingConn struct {
logger klog.Logger
// initializableConn is delegated to for all net.Conn methods. // initializableConn is delegated to for all net.Conn methods.
// initializableConn.Write() is not called until response headers have been read // initializableConn.Write() is not called until response headers have been read
// and initializableConn#InitializeWrite() has been called with the result. // and initializableConn#InitializeWrite() has been called with the result.
@ -274,7 +279,7 @@ func (h *headerInterceptingConn) Write(b []byte) (int, error) {
} }
resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(headerBytes)), nil) resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(headerBytes)), nil)
if err != nil { if err != nil {
klog.Errorf("invalid headers: %v", err) h.logger.Error(err, "Invalid headers")
h.initializeErr = err h.initializeErr = err
return len(b), err return len(b), err
} }
@ -324,11 +329,12 @@ func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.R
return u.err return u.err
} }
logger := klog.FromContext(u.req.Context())
if backendResponse.StatusCode == http.StatusSwitchingProtocols { if backendResponse.StatusCode == http.StatusSwitchingProtocols {
connectionHeader := strings.ToLower(backendResponse.Header.Get(httpstream.HeaderConnection)) connectionHeader := strings.ToLower(backendResponse.Header.Get(httpstream.HeaderConnection))
upgradeHeader := strings.ToLower(backendResponse.Header.Get(httpstream.HeaderUpgrade)) upgradeHeader := strings.ToLower(backendResponse.Header.Get(httpstream.HeaderUpgrade))
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(spdy.HeaderSpdy31)) { if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(spdy.HeaderSpdy31)) {
klog.Errorf("unable to upgrade: missing upgrade headers in response: %#v", backendResponse.Header) logger.Error(nil, "Unable to upgrade: missing upgrade headers in response", "headers", backendResponse.Header)
u.err = fmt.Errorf("unable to upgrade: missing upgrade headers in response") u.err = fmt.Errorf("unable to upgrade: missing upgrade headers in response")
metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(http.StatusInternalServerError)) metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(http.StatusInternalServerError))
http.Error(u.w, u.err.Error(), http.StatusInternalServerError) http.Error(u.w, u.err.Error(), http.StatusInternalServerError)
@ -351,26 +357,26 @@ func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.R
} }
conn, err := upgrader.Upgrade(u.w, u.req, nil) conn, err := upgrader.Upgrade(u.w, u.req, nil)
if err != nil { if err != nil {
klog.Errorf("error upgrading websocket connection: %v", err) logger.Error(err, "Error upgrading websocket connection")
metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(http.StatusInternalServerError)) metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(http.StatusInternalServerError))
u.err = err u.err = err
return u.err return u.err
} }
klog.V(4).Infof("websocket connection created: %s", conn.Subprotocol()) logger.V(4).Info("Websocket connection created", "protocol", conn.Subprotocol())
metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(http.StatusSwitchingProtocols)) metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(http.StatusSwitchingProtocols))
u.conn = portforward.NewTunnelingConnection("server", conn) u.conn = portforward.NewTunnelingConnectionWithLogger(klog.LoggerWithName(logger, "server"), conn)
return nil return nil
} }
// anything other than an upgrade should pass through the backend response // anything other than an upgrade should pass through the backend response
klog.Errorf("SPDY upgrade failed: %s", backendResponse.Status) logger.Error(nil, "SPDY upgrade failed", "status", backendResponse.Status)
metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(backendResponse.StatusCode)) metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(backendResponse.StatusCode))
// try to hijack // try to hijack
conn, _, err = u.w.(http.Hijacker).Hijack() conn, _, err = u.w.(http.Hijacker).Hijack()
if err != nil { if err != nil {
klog.Errorf("Unable to hijack response: %v", err) logger.Error(err, "Unable to hijack response")
u.err = err u.err = err
return u.err return u.err
} }

View File

@ -43,7 +43,7 @@ func NewTranslatingHandler(delegate http.Handler, translator http.Handler, shoul
func (t *translatingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (t *translatingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if t.shouldTranslate(req) { if t.shouldTranslate(req) {
klog.V(4).Infof("request handled by translator proxy") klog.FromContext(req.Context()).V(4).Info("Request handled by translator proxy")
t.translator.ServeHTTP(w, req) t.translator.ServeHTTP(w, req)
return return
} }

View File

@ -209,7 +209,7 @@ type LeaderElector struct {
// before leader election loop is stopped by ctx or it has // before leader election loop is stopped by ctx or it has
// stopped holding the leader lease // stopped holding the leader lease
func (le *LeaderElector) Run(ctx context.Context) { func (le *LeaderElector) Run(ctx context.Context) {
defer runtime.HandleCrash() defer runtime.HandleCrashWithContext(ctx)
defer le.config.Callbacks.OnStoppedLeading() defer le.config.Callbacks.OnStoppedLeading()
if !le.acquire(ctx) { if !le.acquire(ctx) {
@ -254,7 +254,8 @@ func (le *LeaderElector) acquire(ctx context.Context) bool {
defer cancel() defer cancel()
succeeded := false succeeded := false
desc := le.config.Lock.Describe() desc := le.config.Lock.Describe()
klog.Infof("attempting to acquire leader lease %v...", desc) logger := klog.FromContext(ctx)
logger.Info("Attempting to acquire leader lease...", "lock", desc)
wait.JitterUntil(func() { wait.JitterUntil(func() {
if !le.config.Coordinated { if !le.config.Coordinated {
succeeded = le.tryAcquireOrRenew(ctx) succeeded = le.tryAcquireOrRenew(ctx)
@ -263,12 +264,12 @@ func (le *LeaderElector) acquire(ctx context.Context) bool {
} }
le.maybeReportTransition() le.maybeReportTransition()
if !succeeded { if !succeeded {
klog.V(4).Infof("failed to acquire lease %v", desc) logger.V(4).Info("Failed to acquire lease", "lock", desc)
return return
} }
le.config.Lock.RecordEvent("became leader") le.config.Lock.RecordEvent("became leader")
le.metrics.leaderOn(le.config.Name) le.metrics.leaderOn(le.config.Name)
klog.Infof("successfully acquired lease %v", desc) logger.Info("Successfully acquired lease", "lock", desc)
cancel() cancel()
}, le.config.RetryPeriod, JitterFactor, true, ctx.Done()) }, le.config.RetryPeriod, JitterFactor, true, ctx.Done())
return succeeded return succeeded
@ -279,6 +280,7 @@ func (le *LeaderElector) renew(ctx context.Context) {
defer le.config.Lock.RecordEvent("stopped leading") defer le.config.Lock.RecordEvent("stopped leading")
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
logger := klog.FromContext(ctx)
wait.Until(func() { wait.Until(func() {
err := wait.PollUntilContextTimeout(ctx, le.config.RetryPeriod, le.config.RenewDeadline, true, func(ctx context.Context) (done bool, err error) { err := wait.PollUntilContextTimeout(ctx, le.config.RetryPeriod, le.config.RenewDeadline, true, func(ctx context.Context) (done bool, err error) {
if !le.config.Coordinated { if !le.config.Coordinated {
@ -290,22 +292,22 @@ func (le *LeaderElector) renew(ctx context.Context) {
le.maybeReportTransition() le.maybeReportTransition()
desc := le.config.Lock.Describe() desc := le.config.Lock.Describe()
if err == nil { if err == nil {
klog.V(5).Infof("successfully renewed lease %v", desc) logger.V(5).Info("Successfully renewed lease", "lock", desc)
return return
} }
le.metrics.leaderOff(le.config.Name) le.metrics.leaderOff(le.config.Name)
klog.Infof("failed to renew lease %v: %v", desc, err) logger.Info("Failed to renew lease", "lock", desc, "err", err)
cancel() cancel()
}, le.config.RetryPeriod, ctx.Done()) }, le.config.RetryPeriod, ctx.Done())
// if we hold the lease, give it up // if we hold the lease, give it up
if le.config.ReleaseOnCancel { if le.config.ReleaseOnCancel {
le.release() le.release(logger)
} }
} }
// release attempts to release the leader lease if we have acquired it. // release attempts to release the leader lease if we have acquired it.
func (le *LeaderElector) release() bool { func (le *LeaderElector) release(logger klog.Logger) bool {
if !le.IsLeader() { if !le.IsLeader() {
return true return true
} }
@ -316,10 +318,11 @@ func (le *LeaderElector) release() bool {
RenewTime: now, RenewTime: now,
AcquireTime: now, AcquireTime: now,
} }
// This intentionally (?) ignores the context of the caller.
timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), le.config.RenewDeadline) timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), le.config.RenewDeadline)
defer timeoutCancel() defer timeoutCancel()
if err := le.config.Lock.Update(timeoutCtx, leaderElectionRecord); err != nil { if err := le.config.Lock.Update(timeoutCtx, leaderElectionRecord); err != nil {
klog.Errorf("Failed to release lock: %v", err) logger.Error(err, "Failed to release lease", "lock", le.config.Lock.Describe())
return false return false
} }
@ -331,6 +334,7 @@ func (le *LeaderElector) release() bool {
// lease if it has already been acquired. Returns true on success else returns // lease if it has already been acquired. Returns true on success else returns
// false. // false.
func (le *LeaderElector) tryCoordinatedRenew(ctx context.Context) bool { func (le *LeaderElector) tryCoordinatedRenew(ctx context.Context) bool {
logger := klog.FromContext(ctx)
now := metav1.NewTime(le.clock.Now()) now := metav1.NewTime(le.clock.Now())
leaderElectionRecord := rl.LeaderElectionRecord{ leaderElectionRecord := rl.LeaderElectionRecord{
HolderIdentity: le.config.Lock.Identity(), HolderIdentity: le.config.Lock.Identity(),
@ -343,10 +347,10 @@ func (le *LeaderElector) tryCoordinatedRenew(ctx context.Context) bool {
oldLeaderElectionRecord, oldLeaderElectionRawRecord, err := le.config.Lock.Get(ctx) oldLeaderElectionRecord, oldLeaderElectionRawRecord, err := le.config.Lock.Get(ctx)
if err != nil { if err != nil {
if !errors.IsNotFound(err) { if !errors.IsNotFound(err) {
klog.Errorf("error retrieving resource lock %v: %v", le.config.Lock.Describe(), err) logger.Error(err, "Error retrieving lease lock", "lock", le.config.Lock.Describe())
return false return false
} }
klog.Infof("lease lock not found: %v", le.config.Lock.Describe()) logger.Info("Lease lock not found", "lock", le.config.Lock.Describe(), "err", err)
return false return false
} }
@ -359,18 +363,18 @@ func (le *LeaderElector) tryCoordinatedRenew(ctx context.Context) bool {
hasExpired := le.observedTime.Add(time.Second * time.Duration(oldLeaderElectionRecord.LeaseDurationSeconds)).Before(now.Time) hasExpired := le.observedTime.Add(time.Second * time.Duration(oldLeaderElectionRecord.LeaseDurationSeconds)).Before(now.Time)
if hasExpired { if hasExpired {
klog.Infof("lock has expired: %v", le.config.Lock.Describe()) logger.Info("Lease has expired", "lock", le.config.Lock.Describe())
return false return false
} }
if !le.IsLeader() { if !le.IsLeader() {
klog.V(6).Infof("lock is held by %v and has not yet expired: %v", oldLeaderElectionRecord.HolderIdentity, le.config.Lock.Describe()) logger.V(6).Info("Lease is held and has not yet expired", "lock", le.config.Lock.Describe(), "holder", oldLeaderElectionRecord.HolderIdentity)
return false return false
} }
// 2b. If the lease has been marked as "end of term", don't renew it // 2b. If the lease has been marked as "end of term", don't renew it
if le.IsLeader() && oldLeaderElectionRecord.PreferredHolder != "" { if le.IsLeader() && oldLeaderElectionRecord.PreferredHolder != "" {
klog.V(4).Infof("lock is marked as 'end of term': %v", le.config.Lock.Describe()) logger.V(4).Info("Lease is marked as 'end of term'", "lock", le.config.Lock.Describe())
// TODO: Instead of letting lease expire, the holder may deleted it directly // TODO: Instead of letting lease expire, the holder may deleted it directly
// This will not be compatible with all controllers, so it needs to be opt-in behavior. // This will not be compatible with all controllers, so it needs to be opt-in behavior.
// We must ensure all code guarded by this lease has successfully completed // We must ensure all code guarded by this lease has successfully completed
@ -394,7 +398,7 @@ func (le *LeaderElector) tryCoordinatedRenew(ctx context.Context) bool {
// update the lock itself // update the lock itself
if err = le.config.Lock.Update(ctx, leaderElectionRecord); err != nil { if err = le.config.Lock.Update(ctx, leaderElectionRecord); err != nil {
klog.Errorf("Failed to update lock: %v", err) logger.Error(err, "Failed to update lock", "lock", le.config.Lock.Describe())
return false return false
} }
@ -406,6 +410,7 @@ func (le *LeaderElector) tryCoordinatedRenew(ctx context.Context) bool {
// else it tries to renew the lease if it has already been acquired. Returns true // else it tries to renew the lease if it has already been acquired. Returns true
// on success else returns false. // on success else returns false.
func (le *LeaderElector) tryAcquireOrRenew(ctx context.Context) bool { func (le *LeaderElector) tryAcquireOrRenew(ctx context.Context) bool {
logger := klog.FromContext(ctx)
now := metav1.NewTime(le.clock.Now()) now := metav1.NewTime(le.clock.Now())
leaderElectionRecord := rl.LeaderElectionRecord{ leaderElectionRecord := rl.LeaderElectionRecord{
HolderIdentity: le.config.Lock.Identity(), HolderIdentity: le.config.Lock.Identity(),
@ -426,18 +431,18 @@ func (le *LeaderElector) tryAcquireOrRenew(ctx context.Context) bool {
le.setObservedRecord(&leaderElectionRecord) le.setObservedRecord(&leaderElectionRecord)
return true return true
} }
klog.Errorf("Failed to update lock optimistically: %v, falling back to slow path", err) logger.Error(err, "Failed to update lease optimistically, falling back to slow path", "lock", le.config.Lock.Describe())
} }
// 2. obtain or create the ElectionRecord // 2. obtain or create the ElectionRecord
oldLeaderElectionRecord, oldLeaderElectionRawRecord, err := le.config.Lock.Get(ctx) oldLeaderElectionRecord, oldLeaderElectionRawRecord, err := le.config.Lock.Get(ctx)
if err != nil { if err != nil {
if !errors.IsNotFound(err) { if !errors.IsNotFound(err) {
klog.Errorf("error retrieving resource lock %v: %v", le.config.Lock.Describe(), err) logger.Error(err, "Error retrieving lease lock", "lock", le.config.Lock.Describe())
return false return false
} }
if err = le.config.Lock.Create(ctx, leaderElectionRecord); err != nil { if err = le.config.Lock.Create(ctx, leaderElectionRecord); err != nil {
klog.Errorf("error initially creating leader election record: %v", err) logger.Error(err, "Error initially creating lease lock", "lock", le.config.Lock.Describe())
return false return false
} }
@ -453,7 +458,7 @@ func (le *LeaderElector) tryAcquireOrRenew(ctx context.Context) bool {
le.observedRawRecord = oldLeaderElectionRawRecord le.observedRawRecord = oldLeaderElectionRawRecord
} }
if len(oldLeaderElectionRecord.HolderIdentity) > 0 && le.isLeaseValid(now.Time) && !le.IsLeader() { if len(oldLeaderElectionRecord.HolderIdentity) > 0 && le.isLeaseValid(now.Time) && !le.IsLeader() {
klog.V(4).Infof("lock is held by %v and has not yet expired", oldLeaderElectionRecord.HolderIdentity) logger.V(4).Info("Lease is held by and has not yet expired", "lock", le.config.Lock.Describe(), "holder", oldLeaderElectionRecord.HolderIdentity)
return false return false
} }
@ -469,7 +474,7 @@ func (le *LeaderElector) tryAcquireOrRenew(ctx context.Context) bool {
// update the lock itself // update the lock itself
if err = le.config.Lock.Update(ctx, leaderElectionRecord); err != nil { if err = le.config.Lock.Update(ctx, leaderElectionRecord); err != nil {
klog.Errorf("Failed to update lock: %v", err) logger.Error(err, "Failed to update lease", "lock", le.config.Lock.Describe())
return false return false
} }

View File

@ -37,6 +37,7 @@ import (
fakeclient "k8s.io/client-go/testing" fakeclient "k8s.io/client-go/testing"
rl "k8s.io/client-go/tools/leaderelection/resourcelock" rl "k8s.io/client-go/tools/leaderelection/resourcelock"
"k8s.io/client-go/tools/record" "k8s.io/client-go/tools/record"
"k8s.io/klog/v2/ktesting"
"k8s.io/utils/clock" "k8s.io/utils/clock"
) )
@ -265,6 +266,8 @@ func testTryAcquireOrRenew(t *testing.T, objectType string) {
for i := range tests { for i := range tests {
test := &tests[i] test := &tests[i]
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
// OnNewLeader is called async so we have to wait for it. // OnNewLeader is called async so we have to wait for it.
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@ -316,10 +319,10 @@ func testTryAcquireOrRenew(t *testing.T, objectType string) {
clock: clock, clock: clock,
metrics: globalMetricsFactory.newLeaderMetrics(), metrics: globalMetricsFactory.newLeaderMetrics(),
} }
if test.expectSuccess != le.tryAcquireOrRenew(context.Background()) { if test.expectSuccess != le.tryAcquireOrRenew(ctx) {
if test.retryAfter != 0 { if test.retryAfter != 0 {
time.Sleep(test.retryAfter) time.Sleep(test.retryAfter)
if test.expectSuccess != le.tryAcquireOrRenew(context.Background()) { if test.expectSuccess != le.tryAcquireOrRenew(ctx) {
t.Errorf("unexpected result of tryAcquireOrRenew: [succeeded=%v]", !test.expectSuccess) t.Errorf("unexpected result of tryAcquireOrRenew: [succeeded=%v]", !test.expectSuccess)
} }
} else { } else {
@ -411,6 +414,8 @@ func TestTryCoordinatedRenew(t *testing.T) {
for i := range tests { for i := range tests {
test := &tests[i] test := &tests[i]
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
// OnNewLeader is called async so we have to wait for it. // OnNewLeader is called async so we have to wait for it.
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@ -457,10 +462,10 @@ func TestTryCoordinatedRenew(t *testing.T) {
clock: clock, clock: clock,
metrics: globalMetricsFactory.newLeaderMetrics(), metrics: globalMetricsFactory.newLeaderMetrics(),
} }
if test.expectSuccess != le.tryCoordinatedRenew(context.Background()) { if test.expectSuccess != le.tryCoordinatedRenew(ctx) {
if test.retryAfter != 0 { if test.retryAfter != 0 {
time.Sleep(test.retryAfter) time.Sleep(test.retryAfter)
if test.expectSuccess != le.tryCoordinatedRenew(context.Background()) { if test.expectSuccess != le.tryCoordinatedRenew(ctx) {
t.Errorf("unexpected result of tryCoordinatedRenew: [succeeded=%v]", !test.expectSuccess) t.Errorf("unexpected result of tryCoordinatedRenew: [succeeded=%v]", !test.expectSuccess)
} }
} else { } else {
@ -583,6 +588,8 @@ func testReleaseLease(t *testing.T, objectType string) {
for i := range tests { for i := range tests {
test := &tests[i] test := &tests[i]
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
logger, ctx := ktesting.NewTestContext(t)
// OnNewLeader is called async so we have to wait for it. // OnNewLeader is called async so we have to wait for it.
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@ -634,7 +641,7 @@ func testReleaseLease(t *testing.T, objectType string) {
clock: clock.RealClock{}, clock: clock.RealClock{},
metrics: globalMetricsFactory.newLeaderMetrics(), metrics: globalMetricsFactory.newLeaderMetrics(),
} }
if !le.tryAcquireOrRenew(context.Background()) { if !le.tryAcquireOrRenew(ctx) {
t.Errorf("unexpected result of tryAcquireOrRenew: [succeeded=%v]", true) t.Errorf("unexpected result of tryAcquireOrRenew: [succeeded=%v]", true)
} }
@ -644,7 +651,7 @@ func testReleaseLease(t *testing.T, objectType string) {
wg.Wait() wg.Wait()
wg.Add(1) wg.Add(1)
if test.expectSuccess != le.release() { if test.expectSuccess != le.release(logger) {
t.Errorf("unexpected result of release: [succeeded=%v]", !test.expectSuccess) t.Errorf("unexpected result of release: [succeeded=%v]", !test.expectSuccess)
} }
@ -841,6 +848,8 @@ func testReleaseOnCancellation(t *testing.T, objectType string) {
for i := range tests { for i := range tests {
test := &tests[i] test := &tests[i]
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
wg.Add(1) wg.Add(1)
resetVars() resetVars()
@ -868,7 +877,7 @@ func testReleaseOnCancellation(t *testing.T, objectType string) {
t.Fatal("Failed to create leader elector: ", err) t.Fatal("Failed to create leader elector: ", err)
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(ctx)
go elector.Run(ctx) go elector.Run(ctx)
@ -1082,6 +1091,8 @@ func TestFastPathLeaderElection(t *testing.T) {
for i := range tests { for i := range tests {
test := &tests[i] test := &tests[i]
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
resetVars() resetVars()
recorder := record.NewFakeRecorder(100) recorder := record.NewFakeRecorder(100)
@ -1108,7 +1119,7 @@ func TestFastPathLeaderElection(t *testing.T) {
t.Fatal("Failed to create leader elector: ", err) t.Fatal("Failed to create leader elector: ", err)
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(ctx)
cancelFunc = cancel cancelFunc = cancel
elector.Run(ctx) elector.Run(ctx)

View File

@ -120,8 +120,12 @@ func NewCandidate(clientset kubernetes.Interface,
func (c *LeaseCandidate) Run(ctx context.Context) { func (c *LeaseCandidate) Run(ctx context.Context) {
defer c.queue.ShutDown() defer c.queue.ShutDown()
logger := klog.FromContext(ctx)
logger = klog.LoggerWithName(logger, "leasecandidate")
ctx = klog.NewContext(ctx, logger)
c.informerFactory.Start(ctx.Done()) c.informerFactory.Start(ctx.Done())
if !cache.WaitForNamedCacheSync("leasecandidateclient", ctx.Done(), c.hasSynced) { if !cache.WaitForNamedCacheSyncWithContext(ctx, c.hasSynced) {
return return
} }
@ -148,7 +152,7 @@ func (c *LeaseCandidate) processNextWorkItem(ctx context.Context) bool {
return true return true
} }
utilruntime.HandleError(err) utilruntime.HandleErrorWithContext(ctx, err, "Ensuring lease failed")
c.queue.AddRateLimited(key) c.queue.AddRateLimited(key)
return true return true
@ -161,20 +165,21 @@ func (c *LeaseCandidate) enqueueLease() {
// ensureLease creates the lease if it does not exist and renew it if it exists. Returns the lease and // ensureLease creates the lease if it does not exist and renew it if it exists. Returns the lease and
// a bool (true if this call created the lease), or any error that occurs. // a bool (true if this call created the lease), or any error that occurs.
func (c *LeaseCandidate) ensureLease(ctx context.Context) error { func (c *LeaseCandidate) ensureLease(ctx context.Context) error {
logger := klog.FromContext(ctx)
lease, err := c.leaseClient.Get(ctx, c.name, metav1.GetOptions{}) lease, err := c.leaseClient.Get(ctx, c.name, metav1.GetOptions{})
if apierrors.IsNotFound(err) { if apierrors.IsNotFound(err) {
klog.V(2).Infof("Creating lease candidate") logger.V(2).Info("Creating lease candidate")
// lease does not exist, create it. // lease does not exist, create it.
leaseToCreate := c.newLeaseCandidate() leaseToCreate := c.newLeaseCandidate()
if _, err := c.leaseClient.Create(ctx, leaseToCreate, metav1.CreateOptions{}); err != nil { if _, err := c.leaseClient.Create(ctx, leaseToCreate, metav1.CreateOptions{}); err != nil {
return err return err
} }
klog.V(2).Infof("Created lease candidate") logger.V(2).Info("Created lease candidate")
return nil return nil
} else if err != nil { } else if err != nil {
return err return err
} }
klog.V(2).Infof("lease candidate exists. Renewing.") logger.V(2).Info("Lease candidate exists. Renewing.")
clone := lease.DeepCopy() clone := lease.DeepCopy()
clone.Spec.RenewTime = &metav1.MicroTime{Time: c.clock.Now()} clone.Spec.RenewTime = &metav1.MicroTime{Time: c.clock.Now()}
_, err = c.leaseClient.Update(ctx, clone, metav1.UpdateOptions{}) _, err = c.leaseClient.Update(ctx, clone, metav1.UpdateOptions{})

View File

@ -26,6 +26,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/kubernetes/fake" "k8s.io/client-go/kubernetes/fake"
"k8s.io/klog/v2/ktesting"
) )
type testcase struct { type testcase struct {
@ -34,6 +35,7 @@ type testcase struct {
} }
func TestLeaseCandidateCreation(t *testing.T) { func TestLeaseCandidateCreation(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
tc := testcase{ tc := testcase{
candidateName: "foo", candidateName: "foo",
candidateNamespace: "default", candidateNamespace: "default",
@ -42,7 +44,7 @@ func TestLeaseCandidateCreation(t *testing.T) {
emulationVersion: "1.30.0", emulationVersion: "1.30.0",
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
client := fake.NewSimpleClientset() client := fake.NewSimpleClientset()
@ -67,6 +69,8 @@ func TestLeaseCandidateCreation(t *testing.T) {
} }
func TestLeaseCandidateAck(t *testing.T) { func TestLeaseCandidateAck(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
tc := testcase{ tc := testcase{
candidateName: "foo", candidateName: "foo",
candidateNamespace: "default", candidateNamespace: "default",
@ -75,7 +79,7 @@ func TestLeaseCandidateAck(t *testing.T) {
emulationVersion: "1.30.0", emulationVersion: "1.30.0",
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
client := fake.NewSimpleClientset() client := fake.NewSimpleClientset()

View File

@ -50,6 +50,7 @@ func NewFallbackDialer(primary, secondary httpstream.Dialer, shouldFallback func
func (f *FallbackDialer) Dial(protocols ...string) (httpstream.Connection, string, error) { func (f *FallbackDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
conn, version, err := f.primary.Dial(protocols...) conn, version, err := f.primary.Dial(protocols...)
if err != nil && f.shouldFallback(err) { if err != nil && f.shouldFallback(err) {
//nolint:logcheck // This code is only used by kubectl where contextual logging is not that useful.
klog.V(4).Infof("fallback to secondary dialer from primary dialer err: %v", err) klog.V(4).Infof("fallback to secondary dialer from primary dialer err: %v", err)
return f.secondary.Dial(protocols...) return f.secondary.Dial(protocols...)
} }

View File

@ -40,7 +40,7 @@ func TestFallbackDialer(t *testing.T) {
assert.Equal(t, primaryProtocol, negotiated, "primary negotiated protocol returned") assert.Equal(t, primaryProtocol, negotiated, "primary negotiated protocol returned")
require.NoError(t, err, "error from primary dialer should be nil") require.NoError(t, err, "error from primary dialer should be nil")
// If primary dialer error is upgrade error, then fallback returning secondary dial response. // If primary dialer error is upgrade error, then fallback returning secondary dial response.
primary = &fakeDialer{dialed: false, negotiatedProtocol: primaryProtocol, err: &httpstream.UpgradeFailureError{}} primary = &fakeDialer{dialed: false, negotiatedProtocol: primaryProtocol, err: &httpstream.UpgradeFailureError{Cause: fmt.Errorf("fake error")}}
secondary = &fakeDialer{dialed: false, negotiatedProtocol: secondaryProtocol} secondary = &fakeDialer{dialed: false, negotiatedProtocol: secondaryProtocol}
fallbackDialer = NewFallbackDialer(primary, secondary, httpstream.IsUpgradeFailure) fallbackDialer = NewFallbackDialer(primary, secondary, httpstream.IsUpgradeFailure)
_, negotiated, err = fallbackDialer.Dial(protocols...) _, negotiated, err = fallbackDialer.Dial(protocols...)

View File

@ -17,6 +17,7 @@ limitations under the License.
package portforward package portforward
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -30,6 +31,8 @@ import (
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/httpstream" "k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/runtime" "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/klog/v2"
netutils "k8s.io/utils/net" netutils "k8s.io/utils/net"
) )
@ -52,6 +55,7 @@ type PortForwarder struct {
ports []ForwardedPort ports []ForwardedPort
stopChan <-chan struct{} stopChan <-chan struct{}
logger klog.Logger
dialer httpstream.Dialer dialer httpstream.Dialer
streamConn httpstream.Connection streamConn httpstream.Connection
listeners []io.Closer listeners []io.Closer
@ -165,7 +169,14 @@ func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}, rea
} }
// NewOnAddresses creates a new PortForwarder with custom listen addresses. // NewOnAddresses creates a new PortForwarder with custom listen addresses.
//
// Contextual logging: NewOnAddressesWithContext should be used instead of NewOnAddresses in code which supports contextual logging.
func NewOnAddresses(dialer httpstream.Dialer, addresses []string, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) { func NewOnAddresses(dialer httpstream.Dialer, addresses []string, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) {
return NewOnAddressesWithContext(wait.ContextForChannel(stopChan), dialer, addresses, ports, readyChan, out, errOut)
}
// NewOnAddressesWithContext creates a new PortForwarder with custom listen addresses.
func NewOnAddressesWithContext(ctx context.Context, dialer httpstream.Dialer, addresses []string, ports []string, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) {
if len(addresses) == 0 { if len(addresses) == 0 {
return nil, errors.New("you must specify at least 1 address") return nil, errors.New("you must specify at least 1 address")
} }
@ -181,10 +192,11 @@ func NewOnAddresses(dialer httpstream.Dialer, addresses []string, ports []string
return nil, err return nil, err
} }
return &PortForwarder{ return &PortForwarder{
logger: klog.FromContext(ctx),
dialer: dialer, dialer: dialer,
addresses: parsedAddresses, addresses: parsedAddresses,
ports: parsedPorts, ports: parsedPorts,
stopChan: stopChan, stopChan: ctx.Done(),
Ready: readyChan, Ready: readyChan,
out: out, out: out,
errOut: errOut, errOut: errOut,
@ -319,7 +331,7 @@ func (pf *PortForwarder) waitForConnection(listener net.Listener, port Forwarded
if err != nil { if err != nil {
// TODO consider using something like https://github.com/hydrogen18/stoppableListener? // TODO consider using something like https://github.com/hydrogen18/stoppableListener?
if !strings.Contains(strings.ToLower(err.Error()), networkClosedError) { if !strings.Contains(strings.ToLower(err.Error()), networkClosedError) {
runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err)) runtime.HandleErrorWithLogger(pf.logger, err, "Error accepting connection", "localPort", port.Local)
} }
return return
} }
@ -354,21 +366,23 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(requestID)) headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(requestID))
errorStream, err := pf.streamConn.CreateStream(headers) errorStream, err := pf.streamConn.CreateStream(headers)
if err != nil { if err != nil {
runtime.HandleError(fmt.Errorf("error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err)) runtime.HandleErrorWithLogger(pf.logger, err, "Error creating error stream", "localPort", port.Local, "remotePort", port.Remote)
return return
} }
// we're not writing to this stream // we're not writing to this stream
errorStream.Close() errorStream.Close()
defer pf.streamConn.RemoveStreams(errorStream) defer pf.streamConn.RemoveStreams(errorStream)
errorChan := make(chan error) type readAllResult struct {
message []byte
err error
}
errorChan := make(chan readAllResult)
go func() { go func() {
message, err := io.ReadAll(errorStream) message, err := io.ReadAll(errorStream)
switch { errorChan <- readAllResult{
case err != nil: message: message,
errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err) err: err,
case len(message) > 0:
errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
} }
close(errorChan) close(errorChan)
}() }()
@ -377,7 +391,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
headers.Set(v1.StreamType, v1.StreamTypeData) headers.Set(v1.StreamType, v1.StreamTypeData)
dataStream, err := pf.streamConn.CreateStream(headers) dataStream, err := pf.streamConn.CreateStream(headers)
if err != nil { if err != nil {
runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err)) runtime.HandleErrorWithLogger(pf.logger, err, "Error creating forwarding stream", "localPort", port.Local, "remotePort", port.Remote)
return return
} }
defer pf.streamConn.RemoveStreams(dataStream) defer pf.streamConn.RemoveStreams(dataStream)
@ -388,7 +402,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
go func() { go func() {
// Copy from the remote side to the local port. // Copy from the remote side to the local port.
if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(strings.ToLower(err.Error()), networkClosedError) { if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(strings.ToLower(err.Error()), networkClosedError) {
runtime.HandleError(fmt.Errorf("error copying from remote stream to local connection: %v", err)) runtime.HandleErrorWithLogger(pf.logger, err, "Error copying from remote stream to local connection", "localPort", port.Local, "remotePort", port.Remote)
} }
// inform the select below that the remote copy is done // inform the select below that the remote copy is done
@ -401,7 +415,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
// Copy from the local port to the remote side. // Copy from the local port to the remote side.
if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(strings.ToLower(err.Error()), networkClosedError) { if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(strings.ToLower(err.Error()), networkClosedError) {
runtime.HandleError(fmt.Errorf("error copying from local connection to remote stream: %v", err)) runtime.HandleErrorWithLogger(pf.logger, err, "Error copying from local connection to remote stream", "localPort", port.Local, "remotePort", port.Remote)
// break out of the select below without waiting for the other copy to finish // break out of the select below without waiting for the other copy to finish
close(localError) close(localError)
} }
@ -418,10 +432,14 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
// the blocking data will affect errorStream and cause <-errorChan to block indefinitely. // the blocking data will affect errorStream and cause <-errorChan to block indefinitely.
_ = dataStream.Reset() _ = dataStream.Reset()
// always expect something on errorChan (it may be nil) // always expect something on errorChan (it may be empty)
err = <-errorChan errResult := <-errorChan
if err != nil { switch {
runtime.HandleError(err) case errResult.err != nil:
runtime.HandleErrorWithLogger(pf.logger, errResult.err, "Error reading from error stream", "localPort", port.Local, "remotePort", port.Remote)
pf.streamConn.Close()
case len(errResult.message) > 0:
runtime.HandleErrorWithLogger(pf.logger, errors.New(string(errResult.message)), "An error occurred forwarding", "localPort", port.Local, "remotePort", port.Remote)
pf.streamConn.Close() pf.streamConn.Close()
} }
} }
@ -431,7 +449,7 @@ func (pf *PortForwarder) Close() {
// stop all listeners // stop all listeners
for _, l := range pf.listeners { for _, l := range pf.listeners {
if err := l.Close(); err != nil { if err := l.Close(); err != nil {
runtime.HandleError(fmt.Errorf("error closing listener: %v", err)) runtime.HandleErrorWithLogger(pf.logger, err, "Error closing listener")
} }
} }
} }

View File

@ -34,7 +34,7 @@ var _ net.Conn = &TunnelingConnection{}
// TunnelingConnection implements the "httpstream.Connection" interface, wrapping // TunnelingConnection implements the "httpstream.Connection" interface, wrapping
// a websocket connection that tunnels SPDY. // a websocket connection that tunnels SPDY.
type TunnelingConnection struct { type TunnelingConnection struct {
name string logger klog.Logger
conn *gwebsocket.Conn conn *gwebsocket.Conn
inProgressMessage io.Reader inProgressMessage io.Reader
closeOnce sync.Once closeOnce sync.Once
@ -42,9 +42,20 @@ type TunnelingConnection struct {
// NewTunnelingConnection wraps the passed gorilla/websockets connection // NewTunnelingConnection wraps the passed gorilla/websockets connection
// with the TunnelingConnection struct (implementing net.Conn). // with the TunnelingConnection struct (implementing net.Conn).
// The name is added to all log entries with [klog.LoggerWithName].
//
// Contextual logging: NewTunnelingConnectionWithLogger should be used instead of NewTunnelingConnection in code which supports contextual logging.
func NewTunnelingConnection(name string, conn *gwebsocket.Conn) *TunnelingConnection { func NewTunnelingConnection(name string, conn *gwebsocket.Conn) *TunnelingConnection {
logger := klog.LoggerWithName(klog.Background(), name)
return NewTunnelingConnectionWithLogger(logger, conn)
}
// NewTunnelingConnectionWithLogger is a variant of NewTunnelingConnection where
// the caller is in control of logging. For example, [klog.LoggerWithName] can be used
// to add a common name for all log entries to identify the connection.
func NewTunnelingConnectionWithLogger(logger klog.Logger, conn *gwebsocket.Conn) *TunnelingConnection {
return &TunnelingConnection{ return &TunnelingConnection{
name: name, logger: logger,
conn: conn, conn: conn,
} }
} }
@ -52,19 +63,25 @@ func NewTunnelingConnection(name string, conn *gwebsocket.Conn) *TunnelingConnec
// Read implements "io.Reader" interface, reading from the stored connection // Read implements "io.Reader" interface, reading from the stored connection
// into the passed buffer "p". Returns the number of bytes read and an error. // into the passed buffer "p". Returns the number of bytes read and an error.
// Can keep track of the "inProgress" messsage from the tunneled connection. // Can keep track of the "inProgress" messsage from the tunneled connection.
func (c *TunnelingConnection) Read(p []byte) (int, error) { func (c *TunnelingConnection) Read(p []byte) (len int, err error) {
klog.V(7).Infof("%s: tunneling connection read...", c.name) c.logger.V(7).Info("Tunneling connection read...")
defer klog.V(7).Infof("%s: tunneling connection read...complete", c.name) defer func() {
if loggerV := c.logger.V(8); loggerV.Enabled() {
loggerV.Info("Tunneling connection read...complete", "length", len, "data", p[:len], "err", err)
} else {
c.logger.V(7).Info("Tunneling connection read...complete")
}
}()
for { for {
if c.inProgressMessage == nil { if c.inProgressMessage == nil {
klog.V(8).Infof("%s: tunneling connection read before NextReader()...", c.name) c.logger.V(8).Info("Tunneling connection read before NextReader()...")
messageType, nextReader, err := c.conn.NextReader() messageType, nextReader, err := c.conn.NextReader()
if err != nil { if err != nil {
closeError := &gwebsocket.CloseError{} closeError := &gwebsocket.CloseError{}
if errors.As(err, &closeError) && closeError.Code == gwebsocket.CloseNormalClosure { if errors.As(err, &closeError) && closeError.Code == gwebsocket.CloseNormalClosure {
return 0, io.EOF return 0, io.EOF
} }
klog.V(4).Infof("%s:tunneling connection NextReader() error: %v", c.name, err) c.logger.V(4).Info("Tunneling connection NextReader() failed", "err", err)
return 0, err return 0, err
} }
if messageType != gwebsocket.BinaryMessage { if messageType != gwebsocket.BinaryMessage {
@ -72,12 +89,11 @@ func (c *TunnelingConnection) Read(p []byte) (int, error) {
} }
c.inProgressMessage = nextReader c.inProgressMessage = nextReader
} }
klog.V(8).Infof("%s: tunneling connection read in progress message...", c.name) c.logger.V(8).Info("Tunneling connection read in progress...")
i, err := c.inProgressMessage.Read(p) i, err := c.inProgressMessage.Read(p)
if i == 0 && err == io.EOF { if i == 0 && err == io.EOF {
c.inProgressMessage = nil c.inProgressMessage = nil
} else { } else {
klog.V(8).Infof("%s: read %d bytes, error=%v, bytes=% X", c.name, i, err, p[:i])
return i, err return i, err
} }
} }
@ -87,8 +103,8 @@ func (c *TunnelingConnection) Read(p []byte) (int, error) {
// byte array "p" into the stored tunneled connection. Returns the number // byte array "p" into the stored tunneled connection. Returns the number
// of bytes written and an error. // of bytes written and an error.
func (c *TunnelingConnection) Write(p []byte) (n int, err error) { func (c *TunnelingConnection) Write(p []byte) (n int, err error) {
klog.V(7).Infof("%s: write: %d bytes, bytes=% X", c.name, len(p), p) c.logger.V(7).Info("Tunneling connection write", "length", len(p), "data", p)
defer klog.V(7).Infof("%s: tunneling connection write...complete", c.name) defer c.logger.V(7).Info("Tunneling connection write...complete")
w, err := c.conn.NextWriter(gwebsocket.BinaryMessage) w, err := c.conn.NextWriter(gwebsocket.BinaryMessage)
if err != nil { if err != nil {
return 0, err return 0, err
@ -111,7 +127,7 @@ func (c *TunnelingConnection) Write(p []byte) (n int, err error) {
func (c *TunnelingConnection) Close() error { func (c *TunnelingConnection) Close() error {
var err error var err error
c.closeOnce.Do(func() { c.closeOnce.Do(func() {
klog.V(7).Infof("%s: tunneling connection Close()...", c.name) c.logger.V(7).Info("Tunneling connection Close()...")
// Signal other endpoint that websocket connection is closing; ignore error. // Signal other endpoint that websocket connection is closing; ignore error.
normalCloseMsg := gwebsocket.FormatCloseMessage(gwebsocket.CloseNormalClosure, "") normalCloseMsg := gwebsocket.FormatCloseMessage(gwebsocket.CloseNormalClosure, "")
writeControlErr := c.conn.WriteControl(gwebsocket.CloseMessage, normalCloseMsg, time.Now().Add(time.Second)) writeControlErr := c.conn.WriteControl(gwebsocket.CloseMessage, normalCloseMsg, time.Now().Add(time.Second))

View File

@ -36,8 +36,13 @@ import (
"k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/rest" "k8s.io/client-go/rest"
"k8s.io/client-go/transport/websocket" "k8s.io/client-go/transport/websocket"
"k8s.io/klog/v2"
) )
func init() {
klog.InitFlags(nil)
}
func TestTunnelingConnection_ReadWriteClose(t *testing.T) { func TestTunnelingConnection_ReadWriteClose(t *testing.T) {
// Stream channel that will receive streams created on upstream SPDY server. // Stream channel that will receive streams created on upstream SPDY server.
streamChan := make(chan httpstream.Stream) streamChan := make(chan httpstream.Stream)
@ -60,7 +65,7 @@ func TestTunnelingConnection_ReadWriteClose(t *testing.T) {
t.Errorf("Not acceptable agreement Subprotocol: %v", conn.Subprotocol()) t.Errorf("Not acceptable agreement Subprotocol: %v", conn.Subprotocol())
return return
} }
tunnelingConn := NewTunnelingConnection("server", conn) tunnelingConn := NewTunnelingConnectionWithLogger(klog.LoggerWithName(klog.Background(), "server"), conn)
spdyConn, err := spdy.NewServerConnection(tunnelingConn, justQueueStream(streamChan)) spdyConn, err := spdy.NewServerConnection(tunnelingConn, justQueueStream(streamChan))
if err != nil { if err != nil {
t.Errorf("unexpected error %v", err) t.Errorf("unexpected error %v", err)
@ -73,6 +78,7 @@ func TestTunnelingConnection_ReadWriteClose(t *testing.T) {
// Dial the client tunneling connection to the tunneling server. // Dial the client tunneling connection to the tunneling server.
url, err := url.Parse(tunnelingServer.URL) url, err := url.Parse(tunnelingServer.URL)
require.NoError(t, err) require.NoError(t, err)
//nolint:logcheck // Intentionally uses the old API.
dialer, err := NewSPDYOverWebsocketDialer(url, &rest.Config{Host: url.Host}) dialer, err := NewSPDYOverWebsocketDialer(url, &rest.Config{Host: url.Host})
require.NoError(t, err) require.NoError(t, err)
spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name) spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name)
@ -205,6 +211,7 @@ func dialForTunnelingConnection(url *url.URL) (*TunnelingConnection, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
//nolint:logcheck // Intentionally uses the old API.
return NewTunnelingConnection("client", conn), nil return NewTunnelingConnection("client", conn), nil
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package portforward package portforward
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
@ -35,6 +36,7 @@ const PingPeriod = 10 * time.Second
// tunnelingDialer implements "httpstream.Dial" interface // tunnelingDialer implements "httpstream.Dial" interface
type tunnelingDialer struct { type tunnelingDialer struct {
logger klog.Logger
url *url.URL url *url.URL
transport http.RoundTripper transport http.RoundTripper
holder websocket.ConnectionHolder holder websocket.ConnectionHolder
@ -43,12 +45,22 @@ type tunnelingDialer struct {
// NewTunnelingDialer creates and returns the tunnelingDialer structure which implemements the "httpstream.Dialer" // NewTunnelingDialer creates and returns the tunnelingDialer structure which implemements the "httpstream.Dialer"
// interface. The dialer can upgrade a websocket request, creating a websocket connection. This function // interface. The dialer can upgrade a websocket request, creating a websocket connection. This function
// returns an error if one occurs. // returns an error if one occurs.
//
// Contextual logging: NewSPDYOverWebsocketDialerWithLogger should be used instead of NewSPDYOverWebsocketDialer in code which supports contextual logging.
func NewSPDYOverWebsocketDialer(url *url.URL, config *restclient.Config) (httpstream.Dialer, error) { func NewSPDYOverWebsocketDialer(url *url.URL, config *restclient.Config) (httpstream.Dialer, error) {
return NewSPDYOverWebsocketDialerWithLogger(klog.Background(), url, config)
}
// NewTunnelingDialer creates and returns the tunnelingDialer structure which implemements the "httpstream.Dialer"
// interface. The dialer can upgrade a websocket request, creating a websocket connection. This function
// returns an error if one occurs.
func NewSPDYOverWebsocketDialerWithLogger(logger klog.Logger, url *url.URL, config *restclient.Config) (httpstream.Dialer, error) {
transport, holder, err := websocket.RoundTripperFor(config) transport, holder, err := websocket.RoundTripperFor(config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &tunnelingDialer{ return &tunnelingDialer{
logger: logger,
url: url, url: url,
transport: transport, transport: transport,
holder: holder, holder: holder,
@ -59,9 +71,10 @@ func NewSPDYOverWebsocketDialer(url *url.URL, config *restclient.Config) (httpst
// containing a WebSockets connection (which implements "net.Conn"). Also // containing a WebSockets connection (which implements "net.Conn"). Also
// returns the protocol negotiated, or an error. // returns the protocol negotiated, or an error.
func (d *tunnelingDialer) Dial(protocols ...string) (httpstream.Connection, string, error) { func (d *tunnelingDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
// There is no passed context, so skip the context when creating request for now. // There is no passed context, so use the background context when creating request for now.
ctx := klog.NewContext(context.Background(), d.logger)
// Websockets requires "GET" method: RFC 6455 Sec. 4.1 (page 17). // Websockets requires "GET" method: RFC 6455 Sec. 4.1 (page 17).
req, err := http.NewRequest("GET", d.url.String(), nil) req, err := http.NewRequestWithContext(ctx, "GET", d.url.String(), nil)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
@ -72,7 +85,7 @@ func (d *tunnelingDialer) Dial(protocols ...string) (httpstream.Connection, stri
tunnelingProtocol := constants.WebsocketsSPDYTunnelingPrefix + protocol tunnelingProtocol := constants.WebsocketsSPDYTunnelingPrefix + protocol
tunnelingProtocols = append(tunnelingProtocols, tunnelingProtocol) tunnelingProtocols = append(tunnelingProtocols, tunnelingProtocol)
} }
klog.V(4).Infoln("Before WebSocket Upgrade Connection...") d.logger.V(4).Info("Before WebSocket Upgrade Connection...")
conn, err := websocket.Negotiate(d.transport, d.holder, req, tunnelingProtocols...) conn, err := websocket.Negotiate(d.transport, d.holder, req, tunnelingProtocols...)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
@ -82,10 +95,10 @@ func (d *tunnelingDialer) Dial(protocols ...string) (httpstream.Connection, stri
} }
protocol := conn.Subprotocol() protocol := conn.Subprotocol()
protocol = strings.TrimPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix) protocol = strings.TrimPrefix(protocol, constants.WebsocketsSPDYTunnelingPrefix)
klog.V(4).Infof("negotiated protocol: %s", protocol) d.logger.V(4).Info("Negotiation complete", "protocol", protocol)
// Wrap the websocket connection which implements "net.Conn". // Wrap the websocket connection which implements "net.Conn".
tConn := NewTunnelingConnection("client", conn) tConn := NewTunnelingConnectionWithLogger(klog.LoggerWithName(d.logger, "client"), conn)
// Create SPDY connection injecting the previously created tunneling connection. // Create SPDY connection injecting the previously created tunneling connection.
spdyConn, err := spdy.NewClientConnectionWithPings(tConn, PingPeriod) spdyConn, err := spdy.NewClientConnectionWithPings(tConn, PingPeriod)