| 
									
										
										
										
											2016-03-14 11:18:24 +08:00
										 |  |  | // Copyright 2016 fatedier, fatedier@gmail.com
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // Licensed under the Apache License, Version 2.0 (the "License");
 | 
					
						
							|  |  |  | // you may not use this file except in compliance with the License.
 | 
					
						
							|  |  |  | // You may obtain a copy of the License at
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //     http://www.apache.org/licenses/LICENSE-2.0
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // Unless required by applicable law or agreed to in writing, software
 | 
					
						
							|  |  |  | // distributed under the License is distributed on an "AS IS" BASIS,
 | 
					
						
							|  |  |  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
					
						
							|  |  |  | // See the License for the specific language governing permissions and
 | 
					
						
							|  |  |  | // limitations under the License.
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-03-09 02:03:47 +08:00
										 |  |  | package net | 
					
						
							| 
									
										
										
										
											2016-02-18 16:56:55 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-03-09 02:03:47 +08:00
										 |  |  | import ( | 
					
						
							| 
									
										
										
										
											2019-10-12 20:13:12 +08:00
										 |  |  | 	"context" | 
					
						
							| 
									
										
										
										
											2017-06-09 01:33:57 +08:00
										 |  |  | 	"errors" | 
					
						
							| 
									
										
										
										
											2017-05-26 14:17:46 +08:00
										 |  |  | 	"io" | 
					
						
							| 
									
										
										
										
											2017-03-09 02:03:47 +08:00
										 |  |  | 	"net" | 
					
						
							| 
									
										
										
										
											2018-02-01 11:15:35 +08:00
										 |  |  | 	"sync/atomic" | 
					
						
							| 
									
										
										
										
											2017-05-26 14:17:46 +08:00
										 |  |  | 	"time" | 
					
						
							| 
									
										
										
										
											2016-02-18 16:56:55 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-06 10:51:48 +08:00
										 |  |  | 	"github.com/fatedier/golib/crypto" | 
					
						
							| 
									
										
										
										
											2023-02-02 20:20:17 +08:00
										 |  |  | 	quic "github.com/quic-go/quic-go" | 
					
						
							| 
									
										
										
										
											2022-12-12 11:04:10 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-23 13:49:14 +08:00
										 |  |  | 	"github.com/fatedier/frp/pkg/util/xlog" | 
					
						
							| 
									
										
										
										
											2016-07-17 21:42:21 +08:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-10-12 20:13:12 +08:00
										 |  |  | type ContextGetter interface { | 
					
						
							|  |  |  | 	Context() context.Context | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type ContextSetter interface { | 
					
						
							|  |  |  | 	WithContext(ctx context.Context) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func NewLogFromConn(conn net.Conn) *xlog.Logger { | 
					
						
							|  |  |  | 	if c, ok := conn.(ContextGetter); ok { | 
					
						
							|  |  |  | 		return xlog.FromContextSafe(c.Context()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return xlog.New() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func NewContextFromConn(conn net.Conn) context.Context { | 
					
						
							|  |  |  | 	if c, ok := conn.(ContextGetter); ok { | 
					
						
							|  |  |  | 		return c.Context() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return context.Background() | 
					
						
							| 
									
										
										
										
											2017-03-09 02:03:47 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-10-12 20:13:12 +08:00
										 |  |  | // ContextConn is the connection with context
 | 
					
						
							|  |  |  | type ContextConn struct { | 
					
						
							| 
									
										
										
										
											2017-05-17 17:47:20 +08:00
										 |  |  | 	net.Conn | 
					
						
							| 
									
										
										
										
											2019-10-12 20:13:12 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	ctx context.Context | 
					
						
							| 
									
										
										
										
											2017-05-17 17:47:20 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-24 17:48:37 +08:00
										 |  |  | func NewContextConn(ctx context.Context, c net.Conn) *ContextConn { | 
					
						
							| 
									
										
										
										
											2019-10-12 20:13:12 +08:00
										 |  |  | 	return &ContextConn{ | 
					
						
							|  |  |  | 		Conn: c, | 
					
						
							|  |  |  | 		ctx:  ctx, | 
					
						
							| 
									
										
										
										
											2017-05-17 17:47:20 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-10-12 20:13:12 +08:00
										 |  |  | func (c *ContextConn) WithContext(ctx context.Context) { | 
					
						
							|  |  |  | 	c.ctx = ctx | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (c *ContextConn) Context() context.Context { | 
					
						
							|  |  |  | 	return c.ctx | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-05-26 14:17:46 +08:00
										 |  |  | type WrapReadWriteCloserConn struct { | 
					
						
							|  |  |  | 	io.ReadWriteCloser | 
					
						
							| 
									
										
										
										
											2017-12-13 04:28:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	underConn net.Conn | 
					
						
							| 
									
										
										
										
											2024-03-28 16:47:27 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	remoteAddr net.Addr | 
					
						
							| 
									
										
										
										
											2017-05-26 14:17:46 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-28 16:47:27 +08:00
										 |  |  | func WrapReadWriteCloserToConn(rwc io.ReadWriteCloser, underConn net.Conn) *WrapReadWriteCloserConn { | 
					
						
							| 
									
										
										
										
											2017-06-09 01:33:57 +08:00
										 |  |  | 	return &WrapReadWriteCloserConn{ | 
					
						
							|  |  |  | 		ReadWriteCloser: rwc, | 
					
						
							| 
									
										
										
										
											2017-12-13 04:28:58 +08:00
										 |  |  | 		underConn:       underConn, | 
					
						
							| 
									
										
										
										
											2017-06-09 01:33:57 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-05-26 14:17:46 +08:00
										 |  |  | func (conn *WrapReadWriteCloserConn) LocalAddr() net.Addr { | 
					
						
							| 
									
										
										
										
											2017-12-13 04:28:58 +08:00
										 |  |  | 	if conn.underConn != nil { | 
					
						
							|  |  |  | 		return conn.underConn.LocalAddr() | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2017-05-30 14:37:51 +08:00
										 |  |  | 	return (*net.TCPAddr)(nil) | 
					
						
							| 
									
										
										
										
											2017-05-26 14:17:46 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-28 16:47:27 +08:00
										 |  |  | func (conn *WrapReadWriteCloserConn) SetRemoteAddr(addr net.Addr) { | 
					
						
							|  |  |  | 	conn.remoteAddr = addr | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-05-26 14:17:46 +08:00
										 |  |  | func (conn *WrapReadWriteCloserConn) RemoteAddr() net.Addr { | 
					
						
							| 
									
										
										
										
											2024-03-28 16:47:27 +08:00
										 |  |  | 	if conn.remoteAddr != nil { | 
					
						
							|  |  |  | 		return conn.remoteAddr | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2017-12-13 04:28:58 +08:00
										 |  |  | 	if conn.underConn != nil { | 
					
						
							|  |  |  | 		return conn.underConn.RemoteAddr() | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2017-05-30 14:37:51 +08:00
										 |  |  | 	return (*net.TCPAddr)(nil) | 
					
						
							| 
									
										
										
										
											2017-05-26 14:17:46 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (conn *WrapReadWriteCloserConn) SetDeadline(t time.Time) error { | 
					
						
							| 
									
										
										
										
											2017-12-13 04:28:58 +08:00
										 |  |  | 	if conn.underConn != nil { | 
					
						
							|  |  |  | 		return conn.underConn.SetDeadline(t) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2017-06-09 01:33:57 +08:00
										 |  |  | 	return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} | 
					
						
							| 
									
										
										
										
											2017-05-26 14:17:46 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (conn *WrapReadWriteCloserConn) SetReadDeadline(t time.Time) error { | 
					
						
							| 
									
										
										
										
											2017-12-13 04:28:58 +08:00
										 |  |  | 	if conn.underConn != nil { | 
					
						
							|  |  |  | 		return conn.underConn.SetReadDeadline(t) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2017-06-09 01:33:57 +08:00
										 |  |  | 	return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} | 
					
						
							| 
									
										
										
										
											2017-05-26 14:17:46 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (conn *WrapReadWriteCloserConn) SetWriteDeadline(t time.Time) error { | 
					
						
							| 
									
										
										
										
											2017-12-13 04:28:58 +08:00
										 |  |  | 	if conn.underConn != nil { | 
					
						
							|  |  |  | 		return conn.underConn.SetWriteDeadline(t) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2017-06-09 01:33:57 +08:00
										 |  |  | 	return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} | 
					
						
							| 
									
										
										
										
											2017-06-04 19:56:21 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-08-10 11:43:08 +08:00
										 |  |  | type CloseNotifyConn struct { | 
					
						
							|  |  |  | 	net.Conn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// 1 means closed
 | 
					
						
							|  |  |  | 	closeFlag int32 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	closeFn func() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // closeFn will be only called once
 | 
					
						
							| 
									
										
										
										
											2019-10-12 20:13:12 +08:00
										 |  |  | func WrapCloseNotifyConn(c net.Conn, closeFn func()) net.Conn { | 
					
						
							| 
									
										
										
										
											2018-08-10 11:43:08 +08:00
										 |  |  | 	return &CloseNotifyConn{ | 
					
						
							|  |  |  | 		Conn:    c, | 
					
						
							|  |  |  | 		closeFn: closeFn, | 
					
						
							| 
									
										
										
										
											2017-06-04 19:56:21 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-08-10 11:43:08 +08:00
										 |  |  | func (cc *CloseNotifyConn) Close() (err error) { | 
					
						
							|  |  |  | 	pflag := atomic.SwapInt32(&cc.closeFlag, 1) | 
					
						
							|  |  |  | 	if pflag == 0 { | 
					
						
							|  |  |  | 		err = cc.Close() | 
					
						
							|  |  |  | 		if cc.closeFn != nil { | 
					
						
							|  |  |  | 			cc.closeFn() | 
					
						
							| 
									
										
										
										
											2018-05-09 00:57:55 +08:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2017-06-04 19:56:21 +08:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2018-08-10 11:43:08 +08:00
										 |  |  | 	return | 
					
						
							| 
									
										
										
										
											2017-06-04 19:56:21 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2017-06-09 01:33:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-01-17 23:17:15 +08:00
										 |  |  | type StatsConn struct { | 
					
						
							| 
									
										
										
										
											2019-10-12 20:13:12 +08:00
										 |  |  | 	net.Conn | 
					
						
							| 
									
										
										
										
											2018-01-17 23:17:15 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-01 11:15:35 +08:00
										 |  |  | 	closed     int64 // 1 means closed
 | 
					
						
							| 
									
										
										
										
											2018-01-17 23:17:15 +08:00
										 |  |  | 	totalRead  int64 | 
					
						
							|  |  |  | 	totalWrite int64 | 
					
						
							|  |  |  | 	statsFunc  func(totalRead, totalWrite int64) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-10-12 20:13:12 +08:00
										 |  |  | func WrapStatsConn(conn net.Conn, statsFunc func(total, totalWrite int64)) *StatsConn { | 
					
						
							| 
									
										
										
										
											2018-01-17 23:17:15 +08:00
										 |  |  | 	return &StatsConn{ | 
					
						
							|  |  |  | 		Conn:      conn, | 
					
						
							|  |  |  | 		statsFunc: statsFunc, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (statsConn *StatsConn) Read(p []byte) (n int, err error) { | 
					
						
							|  |  |  | 	n, err = statsConn.Conn.Read(p) | 
					
						
							|  |  |  | 	statsConn.totalRead += int64(n) | 
					
						
							|  |  |  | 	return | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (statsConn *StatsConn) Write(p []byte) (n int, err error) { | 
					
						
							|  |  |  | 	n, err = statsConn.Conn.Write(p) | 
					
						
							|  |  |  | 	statsConn.totalWrite += int64(n) | 
					
						
							|  |  |  | 	return | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (statsConn *StatsConn) Close() (err error) { | 
					
						
							| 
									
										
										
										
											2018-02-01 11:15:35 +08:00
										 |  |  | 	old := atomic.SwapInt64(&statsConn.closed, 1) | 
					
						
							|  |  |  | 	if old != 1 { | 
					
						
							|  |  |  | 		err = statsConn.Conn.Close() | 
					
						
							|  |  |  | 		if statsConn.statsFunc != nil { | 
					
						
							|  |  |  | 			statsConn.statsFunc(statsConn.totalRead, statsConn.totalWrite) | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2018-01-17 23:17:15 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	return | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2022-12-12 11:04:10 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | type wrapQuicStream struct { | 
					
						
							|  |  |  | 	quic.Stream | 
					
						
							|  |  |  | 	c quic.Connection | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func QuicStreamToNetConn(s quic.Stream, c quic.Connection) net.Conn { | 
					
						
							|  |  |  | 	return &wrapQuicStream{ | 
					
						
							|  |  |  | 		Stream: s, | 
					
						
							|  |  |  | 		c:      c, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (conn *wrapQuicStream) LocalAddr() net.Addr { | 
					
						
							|  |  |  | 	if conn.c != nil { | 
					
						
							|  |  |  | 		return conn.c.LocalAddr() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return (*net.TCPAddr)(nil) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (conn *wrapQuicStream) RemoteAddr() net.Addr { | 
					
						
							|  |  |  | 	if conn.c != nil { | 
					
						
							|  |  |  | 		return conn.c.RemoteAddr() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return (*net.TCPAddr)(nil) | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2023-01-10 10:19:37 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | func (conn *wrapQuicStream) Close() error { | 
					
						
							| 
									
										
										
										
											2025-05-27 16:46:15 +08:00
										 |  |  | 	conn.CancelRead(0) | 
					
						
							| 
									
										
										
										
											2023-01-10 10:19:37 +08:00
										 |  |  | 	return conn.Stream.Close() | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2023-11-06 10:51:48 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) { | 
					
						
							|  |  |  | 	encReader := crypto.NewReader(rw, key) | 
					
						
							|  |  |  | 	encWriter, err := crypto.NewWriter(rw, key) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return struct { | 
					
						
							|  |  |  | 		io.Reader | 
					
						
							|  |  |  | 		io.Writer | 
					
						
							|  |  |  | 	}{ | 
					
						
							|  |  |  | 		Reader: encReader, | 
					
						
							|  |  |  | 		Writer: encWriter, | 
					
						
							|  |  |  | 	}, nil | 
					
						
							|  |  |  | } |