mirror of
https://github.com/fatedier/frp.git
synced 2025-07-27 07:35:07 +00:00
move dial functions into golib (#2767)
This commit is contained in:
@@ -17,18 +17,12 @@ package net
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
"golang.org/x/net/websocket"
|
||||
|
||||
gnet "github.com/fatedier/golib/net"
|
||||
kcp "github.com/fatedier/kcp-go"
|
||||
)
|
||||
|
||||
type ContextGetter interface {
|
||||
@@ -189,67 +183,3 @@ func (statsConn *StatsConn) Close() (err error) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ConnectServer(protocol string, addr string) (c net.Conn, err error) {
|
||||
switch protocol {
|
||||
case "tcp":
|
||||
return net.Dial("tcp", addr)
|
||||
case "kcp":
|
||||
return DialKCPServer(addr)
|
||||
case "websocket":
|
||||
return DialWebsocketServer(addr)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupport protocol: %s", protocol)
|
||||
}
|
||||
}
|
||||
|
||||
func DialKCPServer(addr string) (c net.Conn, err error) {
|
||||
kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3)
|
||||
if errRet != nil {
|
||||
err = errRet
|
||||
return
|
||||
}
|
||||
kcpConn.SetStreamMode(true)
|
||||
kcpConn.SetWriteDelay(true)
|
||||
kcpConn.SetNoDelay(1, 20, 2, 1)
|
||||
kcpConn.SetWindowSize(128, 512)
|
||||
kcpConn.SetMtu(1350)
|
||||
kcpConn.SetACKNoDelay(false)
|
||||
kcpConn.SetReadBuffer(4194304)
|
||||
kcpConn.SetWriteBuffer(4194304)
|
||||
c = kcpConn
|
||||
return
|
||||
}
|
||||
|
||||
func ConnectServerByProxy(proxyURL string, protocol string, addr string) (c net.Conn, err error) {
|
||||
switch protocol {
|
||||
case "tcp":
|
||||
return gnet.DialTcpByProxy(proxyURL, addr)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupport protocol: %s when connecting by proxy", protocol)
|
||||
}
|
||||
}
|
||||
|
||||
// addr: domain:port
|
||||
func DialWebsocketServer(addr string) (net.Conn, error) {
|
||||
addr = "ws://" + addr + FrpWebsocketPath
|
||||
uri, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
origin := "http://" + uri.Host
|
||||
cfg, err := websocket.NewConfig(addr, origin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.Dialer = &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := websocket.DialConfig(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
@@ -1,89 +1,44 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"context"
|
||||
"net"
|
||||
"net/url"
|
||||
|
||||
libdial "github.com/fatedier/golib/net/dial"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
type dialOptions struct {
|
||||
proxyURL string
|
||||
protocol string
|
||||
tlsConfig *tls.Config
|
||||
disableCustomTLSHeadByte bool
|
||||
}
|
||||
|
||||
type DialOption interface {
|
||||
apply(*dialOptions)
|
||||
}
|
||||
|
||||
type EmptyDialOption struct{}
|
||||
|
||||
func (EmptyDialOption) apply(*dialOptions) {}
|
||||
|
||||
type funcDialOption struct {
|
||||
f func(*dialOptions)
|
||||
}
|
||||
|
||||
func (fdo *funcDialOption) apply(do *dialOptions) {
|
||||
fdo.f(do)
|
||||
}
|
||||
|
||||
func newFuncDialOption(f func(*dialOptions)) *funcDialOption {
|
||||
return &funcDialOption{
|
||||
f: f,
|
||||
func DialHookCustomTLSHeadByte(enableTLS bool, disableCustomTLSHeadByte bool) libdial.AfterHookFunc {
|
||||
return func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) {
|
||||
if enableTLS && !disableCustomTLSHeadByte {
|
||||
_, err := c.Write([]byte{byte(FRPTLSHeadByte)})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
return ctx, c, nil
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultDialOptions() dialOptions {
|
||||
return dialOptions{
|
||||
protocol: "tcp",
|
||||
func DialHookWebsocket() libdial.AfterHookFunc {
|
||||
return func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) {
|
||||
addr = "ws://" + addr + FrpWebsocketPath
|
||||
uri, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
origin := "http://" + uri.Host
|
||||
cfg, err := websocket.NewConfig(addr, origin)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
conn, err := websocket.NewClient(cfg, c)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return ctx, conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithProxyURL(proxyURL string) DialOption {
|
||||
return newFuncDialOption(func(do *dialOptions) {
|
||||
do.proxyURL = proxyURL
|
||||
})
|
||||
}
|
||||
|
||||
func WithTLSConfig(tlsConfig *tls.Config) DialOption {
|
||||
return newFuncDialOption(func(do *dialOptions) {
|
||||
do.tlsConfig = tlsConfig
|
||||
})
|
||||
}
|
||||
|
||||
func WithDisableCustomTLSHeadByte(disableCustomTLSHeadByte bool) DialOption {
|
||||
return newFuncDialOption(func(do *dialOptions) {
|
||||
do.disableCustomTLSHeadByte = disableCustomTLSHeadByte
|
||||
})
|
||||
}
|
||||
|
||||
func WithProtocol(protocol string) DialOption {
|
||||
return newFuncDialOption(func(do *dialOptions) {
|
||||
do.protocol = protocol
|
||||
})
|
||||
}
|
||||
|
||||
func DialWithOptions(addr string, opts ...DialOption) (c net.Conn, err error) {
|
||||
op := DefaultDialOptions()
|
||||
|
||||
for _, opt := range opts {
|
||||
opt.apply(&op)
|
||||
}
|
||||
|
||||
if op.proxyURL == "" {
|
||||
c, err = ConnectServer(op.protocol, addr)
|
||||
} else {
|
||||
c, err = ConnectServerByProxy(op.proxyURL, op.protocol, addr)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if op.tlsConfig == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c = WrapTLSClientConn(c, op.tlsConfig, op.disableCustomTLSHeadByte)
|
||||
return
|
||||
}
|
||||
|
@@ -27,14 +27,6 @@ var (
|
||||
FRPTLSHeadByte = 0x17
|
||||
)
|
||||
|
||||
func WrapTLSClientConn(c net.Conn, tlsConfig *tls.Config, disableCustomTLSHeadByte bool) (out net.Conn) {
|
||||
if !disableCustomTLSHeadByte {
|
||||
c.Write([]byte{byte(FRPTLSHeadByte)})
|
||||
}
|
||||
out = tls.Client(c, tlsConfig)
|
||||
return
|
||||
}
|
||||
|
||||
func CheckAndEnableTLSServerConnWithTimeout(
|
||||
c net.Conn, tlsConfig *tls.Config, tlsOnly bool, timeout time.Duration,
|
||||
) (out net.Conn, isTLS bool, custom bool, err error) {
|
||||
|
Reference in New Issue
Block a user