move dial functions into golib (#2767)

This commit is contained in:
fatedier
2022-01-20 20:03:07 +08:00
committed by GitHub
parent 293003fcdb
commit 70f4caac23
9 changed files with 114 additions and 195 deletions

View File

@@ -1,89 +1,44 @@
package net
import (
"crypto/tls"
"context"
"net"
"net/url"
libdial "github.com/fatedier/golib/net/dial"
"golang.org/x/net/websocket"
)
type dialOptions struct {
proxyURL string
protocol string
tlsConfig *tls.Config
disableCustomTLSHeadByte bool
}
type DialOption interface {
apply(*dialOptions)
}
type EmptyDialOption struct{}
func (EmptyDialOption) apply(*dialOptions) {}
type funcDialOption struct {
f func(*dialOptions)
}
func (fdo *funcDialOption) apply(do *dialOptions) {
fdo.f(do)
}
func newFuncDialOption(f func(*dialOptions)) *funcDialOption {
return &funcDialOption{
f: f,
func DialHookCustomTLSHeadByte(enableTLS bool, disableCustomTLSHeadByte bool) libdial.AfterHookFunc {
return func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) {
if enableTLS && !disableCustomTLSHeadByte {
_, err := c.Write([]byte{byte(FRPTLSHeadByte)})
if err != nil {
return nil, nil, err
}
}
return ctx, c, nil
}
}
func DefaultDialOptions() dialOptions {
return dialOptions{
protocol: "tcp",
func DialHookWebsocket() libdial.AfterHookFunc {
return func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) {
addr = "ws://" + addr + FrpWebsocketPath
uri, err := url.Parse(addr)
if err != nil {
return nil, nil, err
}
origin := "http://" + uri.Host
cfg, err := websocket.NewConfig(addr, origin)
if err != nil {
return nil, nil, err
}
conn, err := websocket.NewClient(cfg, c)
if err != nil {
return nil, nil, err
}
return ctx, conn, nil
}
}
func WithProxyURL(proxyURL string) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.proxyURL = proxyURL
})
}
func WithTLSConfig(tlsConfig *tls.Config) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.tlsConfig = tlsConfig
})
}
func WithDisableCustomTLSHeadByte(disableCustomTLSHeadByte bool) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.disableCustomTLSHeadByte = disableCustomTLSHeadByte
})
}
func WithProtocol(protocol string) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.protocol = protocol
})
}
func DialWithOptions(addr string, opts ...DialOption) (c net.Conn, err error) {
op := DefaultDialOptions()
for _, opt := range opts {
opt.apply(&op)
}
if op.proxyURL == "" {
c, err = ConnectServer(op.protocol, addr)
} else {
c, err = ConnectServerByProxy(op.proxyURL, op.protocol, addr)
}
if err != nil {
return nil, err
}
if op.tlsConfig == nil {
return
}
c = WrapTLSClientConn(c, op.tlsConfig, op.disableCustomTLSHeadByte)
return
}