mirror of
https://github.com/fatedier/frp.git
synced 2025-07-27 07:35:07 +00:00
refactor the code related to xtcp (#3449)
This commit is contained in:
@@ -661,6 +661,9 @@ func Test_LoadClientBasicConf(t *testing.T) {
|
||||
BindAddr: "127.0.0.1",
|
||||
BindPort: 9001,
|
||||
},
|
||||
Protocol: "quic",
|
||||
MaxRetriesAnHour: 8,
|
||||
MinRetryInterval: 90,
|
||||
},
|
||||
}
|
||||
|
||||
|
@@ -1078,7 +1078,6 @@ func (cfg *XTCPProxyConf) Compare(cmp ProxyConf) bool {
|
||||
cfg.Sk != cmpConf.Sk {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1092,7 +1091,6 @@ func (cfg *XTCPProxyConf) UnmarshalFromIni(prefix string, name string, section *
|
||||
if cfg.Role == "" {
|
||||
cfg.Role = "server"
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1120,7 +1118,6 @@ func (cfg *XTCPProxyConf) CheckForCli() (err error) {
|
||||
if cfg.Role != "server" {
|
||||
return fmt.Errorf("role should be 'server'")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@@ -196,35 +196,38 @@ type ServerCommonConf struct {
|
||||
// Enable golang pprof handlers in dashboard listener.
|
||||
// Dashboard port must be set first.
|
||||
PprofEnable bool `ini:"pprof_enable" json:"pprof_enable"`
|
||||
// NatHoleAnalysisDataReserveHours specifies the hours to reserve nat hole analysis data.
|
||||
NatHoleAnalysisDataReserveHours int64 `ini:"nat_hole_analysis_data_reserve_hours" json:"nat_hole_analysis_data_reserve_hours"`
|
||||
}
|
||||
|
||||
// GetDefaultServerConf returns a server configuration with reasonable
|
||||
// defaults.
|
||||
func GetDefaultServerConf() ServerCommonConf {
|
||||
return ServerCommonConf{
|
||||
ServerConfig: auth.GetDefaultServerConf(),
|
||||
BindAddr: "0.0.0.0",
|
||||
BindPort: 7000,
|
||||
QUICKeepalivePeriod: 10,
|
||||
QUICMaxIdleTimeout: 30,
|
||||
QUICMaxIncomingStreams: 100000,
|
||||
VhostHTTPTimeout: 60,
|
||||
DashboardAddr: "0.0.0.0",
|
||||
LogFile: "console",
|
||||
LogWay: "console",
|
||||
LogLevel: "info",
|
||||
LogMaxDays: 3,
|
||||
DetailedErrorsToClient: true,
|
||||
TCPMux: true,
|
||||
TCPMuxKeepaliveInterval: 60,
|
||||
TCPKeepAlive: 7200,
|
||||
AllowPorts: make(map[int]struct{}),
|
||||
MaxPoolCount: 5,
|
||||
MaxPortsPerClient: 0,
|
||||
HeartbeatTimeout: 90,
|
||||
UserConnTimeout: 10,
|
||||
HTTPPlugins: make(map[string]plugin.HTTPPluginOptions),
|
||||
UDPPacketSize: 1500,
|
||||
ServerConfig: auth.GetDefaultServerConf(),
|
||||
BindAddr: "0.0.0.0",
|
||||
BindPort: 7000,
|
||||
QUICKeepalivePeriod: 10,
|
||||
QUICMaxIdleTimeout: 30,
|
||||
QUICMaxIncomingStreams: 100000,
|
||||
VhostHTTPTimeout: 60,
|
||||
DashboardAddr: "0.0.0.0",
|
||||
LogFile: "console",
|
||||
LogWay: "console",
|
||||
LogLevel: "info",
|
||||
LogMaxDays: 3,
|
||||
DetailedErrorsToClient: true,
|
||||
TCPMux: true,
|
||||
TCPMuxKeepaliveInterval: 60,
|
||||
TCPKeepAlive: 7200,
|
||||
AllowPorts: make(map[int]struct{}),
|
||||
MaxPoolCount: 5,
|
||||
MaxPortsPerClient: 0,
|
||||
HeartbeatTimeout: 90,
|
||||
UserConnTimeout: 10,
|
||||
HTTPPlugins: make(map[string]plugin.HTTPPluginOptions),
|
||||
UDPPacketSize: 1500,
|
||||
NatHoleAnalysisDataReserveHours: 7 * 24,
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -134,18 +134,19 @@ func Test_LoadServerCommonConf(t *testing.T) {
|
||||
12: {},
|
||||
99: {},
|
||||
},
|
||||
AllowPortsStr: "10-12,99",
|
||||
MaxPoolCount: 59,
|
||||
MaxPortsPerClient: 9,
|
||||
TLSOnly: true,
|
||||
TLSCertFile: "server.crt",
|
||||
TLSKeyFile: "server.key",
|
||||
TLSTrustedCaFile: "ca.crt",
|
||||
SubDomainHost: "frps.com",
|
||||
TCPMux: true,
|
||||
TCPMuxKeepaliveInterval: 60,
|
||||
TCPKeepAlive: 7200,
|
||||
UDPPacketSize: 1509,
|
||||
AllowPortsStr: "10-12,99",
|
||||
MaxPoolCount: 59,
|
||||
MaxPortsPerClient: 9,
|
||||
TLSOnly: true,
|
||||
TLSCertFile: "server.crt",
|
||||
TLSKeyFile: "server.key",
|
||||
TLSTrustedCaFile: "ca.crt",
|
||||
SubDomainHost: "frps.com",
|
||||
TCPMux: true,
|
||||
TCPMuxKeepaliveInterval: 60,
|
||||
TCPKeepAlive: 7200,
|
||||
UDPPacketSize: 1509,
|
||||
NatHoleAnalysisDataReserveHours: 7 * 24,
|
||||
|
||||
HTTPPlugins: map[string]plugin.HTTPPluginOptions{
|
||||
"user-manager": {
|
||||
@@ -180,32 +181,33 @@ func Test_LoadServerCommonConf(t *testing.T) {
|
||||
AuthenticateNewWorkConns: false,
|
||||
},
|
||||
},
|
||||
BindAddr: "0.0.0.9",
|
||||
BindPort: 7009,
|
||||
BindUDPPort: 7008,
|
||||
QUICKeepalivePeriod: 10,
|
||||
QUICMaxIdleTimeout: 30,
|
||||
QUICMaxIncomingStreams: 100000,
|
||||
ProxyBindAddr: "0.0.0.9",
|
||||
VhostHTTPTimeout: 60,
|
||||
DashboardAddr: "0.0.0.0",
|
||||
DashboardUser: "",
|
||||
DashboardPwd: "",
|
||||
EnablePrometheus: false,
|
||||
LogFile: "console",
|
||||
LogWay: "console",
|
||||
LogLevel: "info",
|
||||
LogMaxDays: 3,
|
||||
DetailedErrorsToClient: true,
|
||||
TCPMux: true,
|
||||
TCPMuxKeepaliveInterval: 60,
|
||||
TCPKeepAlive: 7200,
|
||||
AllowPorts: make(map[int]struct{}),
|
||||
MaxPoolCount: 5,
|
||||
HeartbeatTimeout: 90,
|
||||
UserConnTimeout: 10,
|
||||
HTTPPlugins: make(map[string]plugin.HTTPPluginOptions),
|
||||
UDPPacketSize: 1500,
|
||||
BindAddr: "0.0.0.9",
|
||||
BindPort: 7009,
|
||||
BindUDPPort: 7008,
|
||||
QUICKeepalivePeriod: 10,
|
||||
QUICMaxIdleTimeout: 30,
|
||||
QUICMaxIncomingStreams: 100000,
|
||||
ProxyBindAddr: "0.0.0.9",
|
||||
VhostHTTPTimeout: 60,
|
||||
DashboardAddr: "0.0.0.0",
|
||||
DashboardUser: "",
|
||||
DashboardPwd: "",
|
||||
EnablePrometheus: false,
|
||||
LogFile: "console",
|
||||
LogWay: "console",
|
||||
LogLevel: "info",
|
||||
LogMaxDays: 3,
|
||||
DetailedErrorsToClient: true,
|
||||
TCPMux: true,
|
||||
TCPMuxKeepaliveInterval: 60,
|
||||
TCPKeepAlive: 7200,
|
||||
AllowPorts: make(map[int]struct{}),
|
||||
MaxPoolCount: 5,
|
||||
HeartbeatTimeout: 90,
|
||||
UserConnTimeout: 10,
|
||||
HTTPPlugins: make(map[string]plugin.HTTPPluginOptions),
|
||||
UDPPacketSize: 1500,
|
||||
NatHoleAnalysisDataReserveHours: 7 * 24,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@@ -18,6 +18,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"gopkg.in/ini.v1"
|
||||
|
||||
"github.com/fatedier/frp/pkg/consts"
|
||||
@@ -61,6 +62,11 @@ type STCPVisitorConf struct {
|
||||
|
||||
type XTCPVisitorConf struct {
|
||||
BaseVisitorConf `ini:",extends"`
|
||||
|
||||
Protocol string `ini:"protocol" json:"protocol,omitempty"`
|
||||
KeepTunnelOpen bool `ini:"keep_tunnel_open" json:"keep_tunnel_open,omitempty"`
|
||||
MaxRetriesAnHour int `ini:"max_retries_an_hour" json:"max_retries_an_hour,omitempty"`
|
||||
MinRetryInterval int `ini:"min_retry_interval" json:"min_retry_interval,omitempty"`
|
||||
}
|
||||
|
||||
// DefaultVisitorConf creates a empty VisitorConf object by visitorType.
|
||||
@@ -259,7 +265,12 @@ func (cfg *XTCPVisitorConf) Compare(cmp VisitorConf) bool {
|
||||
}
|
||||
|
||||
// Add custom login equal, if exists
|
||||
|
||||
if cfg.Protocol != cmpConf.Protocol ||
|
||||
cfg.KeepTunnelOpen != cmpConf.KeepTunnelOpen ||
|
||||
cfg.MaxRetriesAnHour != cmpConf.MaxRetriesAnHour ||
|
||||
cfg.MinRetryInterval != cmpConf.MinRetryInterval {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -270,7 +281,15 @@ func (cfg *XTCPVisitorConf) UnmarshalFromIni(prefix string, name string, section
|
||||
}
|
||||
|
||||
// Add custom logic unmarshal, if exists
|
||||
|
||||
if cfg.Protocol == "" {
|
||||
cfg.Protocol = "quic"
|
||||
}
|
||||
if cfg.MaxRetriesAnHour <= 0 {
|
||||
cfg.MaxRetriesAnHour = 8
|
||||
}
|
||||
if cfg.MinRetryInterval <= 0 {
|
||||
cfg.MinRetryInterval = 90
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -280,6 +299,8 @@ func (cfg *XTCPVisitorConf) Check() (err error) {
|
||||
}
|
||||
|
||||
// Add custom logic validate, if exists
|
||||
|
||||
if !lo.Contains([]string{"", "kcp", "quic"}, cfg.Protocol) {
|
||||
return fmt.Errorf("protocol should be 'kcp' or 'quic'")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@@ -87,6 +87,9 @@ func Test_Visitor_UnmarshalFromIni(t *testing.T) {
|
||||
BindAddr: "127.0.0.1",
|
||||
BindPort: 9001,
|
||||
},
|
||||
Protocol: "quic",
|
||||
MaxRetriesAnHour: 8,
|
||||
MinRetryInterval: 90,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@@ -60,25 +60,30 @@ func (m *serverMetrics) run() {
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(12 * time.Hour)
|
||||
log.Debug("start to clear useless proxy statistics data...")
|
||||
m.clearUselessInfo()
|
||||
log.Debug("finish to clear useless proxy statistics data")
|
||||
start := time.Now()
|
||||
count, total := m.clearUselessInfo()
|
||||
log.Debug("clear useless proxy statistics data count %d/%d, cost %v", count, total, time.Since(start))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *serverMetrics) clearUselessInfo() {
|
||||
func (m *serverMetrics) clearUselessInfo() (int, int) {
|
||||
count := 0
|
||||
total := 0
|
||||
// To check if there are proxies that closed than 7 days and drop them.
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
total = len(m.info.ProxyStatistics)
|
||||
for name, data := range m.info.ProxyStatistics {
|
||||
if !data.LastCloseTime.IsZero() &&
|
||||
data.LastStartTime.Before(data.LastCloseTime) &&
|
||||
time.Since(data.LastCloseTime) > time.Duration(7*24)*time.Hour {
|
||||
delete(m.info.ProxyStatistics, name)
|
||||
count++
|
||||
log.Trace("clear proxy [%s]'s statistics data, lastCloseTime: [%s]", name, data.LastCloseTime.String())
|
||||
}
|
||||
}
|
||||
return count, total
|
||||
}
|
||||
|
||||
func (m *serverMetrics) NewClient() {
|
||||
|
142
pkg/msg/msg.go
142
pkg/msg/msg.go
@@ -16,54 +16,53 @@ package msg
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
const (
|
||||
TypeLogin = 'o'
|
||||
TypeLoginResp = '1'
|
||||
TypeNewProxy = 'p'
|
||||
TypeNewProxyResp = '2'
|
||||
TypeCloseProxy = 'c'
|
||||
TypeNewWorkConn = 'w'
|
||||
TypeReqWorkConn = 'r'
|
||||
TypeStartWorkConn = 's'
|
||||
TypeNewVisitorConn = 'v'
|
||||
TypeNewVisitorConnResp = '3'
|
||||
TypePing = 'h'
|
||||
TypePong = '4'
|
||||
TypeUDPPacket = 'u'
|
||||
TypeNatHoleVisitor = 'i'
|
||||
TypeNatHoleClient = 'n'
|
||||
TypeNatHoleResp = 'm'
|
||||
TypeNatHoleClientDetectOK = 'd'
|
||||
TypeNatHoleSid = '5'
|
||||
TypeNatHoleBinding = 'b'
|
||||
TypeNatHoleBindingResp = '6'
|
||||
TypeLogin = 'o'
|
||||
TypeLoginResp = '1'
|
||||
TypeNewProxy = 'p'
|
||||
TypeNewProxyResp = '2'
|
||||
TypeCloseProxy = 'c'
|
||||
TypeNewWorkConn = 'w'
|
||||
TypeReqWorkConn = 'r'
|
||||
TypeStartWorkConn = 's'
|
||||
TypeNewVisitorConn = 'v'
|
||||
TypeNewVisitorConnResp = '3'
|
||||
TypePing = 'h'
|
||||
TypePong = '4'
|
||||
TypeUDPPacket = 'u'
|
||||
TypeNatHoleVisitor = 'i'
|
||||
TypeNatHoleClient = 'n'
|
||||
TypeNatHoleResp = 'm'
|
||||
TypeNatHoleSid = '5'
|
||||
TypeNatHoleReport = '6'
|
||||
)
|
||||
|
||||
var msgTypeMap = map[byte]interface{}{
|
||||
TypeLogin: Login{},
|
||||
TypeLoginResp: LoginResp{},
|
||||
TypeNewProxy: NewProxy{},
|
||||
TypeNewProxyResp: NewProxyResp{},
|
||||
TypeCloseProxy: CloseProxy{},
|
||||
TypeNewWorkConn: NewWorkConn{},
|
||||
TypeReqWorkConn: ReqWorkConn{},
|
||||
TypeStartWorkConn: StartWorkConn{},
|
||||
TypeNewVisitorConn: NewVisitorConn{},
|
||||
TypeNewVisitorConnResp: NewVisitorConnResp{},
|
||||
TypePing: Ping{},
|
||||
TypePong: Pong{},
|
||||
TypeUDPPacket: UDPPacket{},
|
||||
TypeNatHoleVisitor: NatHoleVisitor{},
|
||||
TypeNatHoleClient: NatHoleClient{},
|
||||
TypeNatHoleResp: NatHoleResp{},
|
||||
TypeNatHoleClientDetectOK: NatHoleClientDetectOK{},
|
||||
TypeNatHoleSid: NatHoleSid{},
|
||||
TypeNatHoleBinding: NatHoleBinding{},
|
||||
TypeNatHoleBindingResp: NatHoleBindingResp{},
|
||||
TypeLogin: Login{},
|
||||
TypeLoginResp: LoginResp{},
|
||||
TypeNewProxy: NewProxy{},
|
||||
TypeNewProxyResp: NewProxyResp{},
|
||||
TypeCloseProxy: CloseProxy{},
|
||||
TypeNewWorkConn: NewWorkConn{},
|
||||
TypeReqWorkConn: ReqWorkConn{},
|
||||
TypeStartWorkConn: StartWorkConn{},
|
||||
TypeNewVisitorConn: NewVisitorConn{},
|
||||
TypeNewVisitorConnResp: NewVisitorConnResp{},
|
||||
TypePing: Ping{},
|
||||
TypePong: Pong{},
|
||||
TypeUDPPacket: UDPPacket{},
|
||||
TypeNatHoleVisitor: NatHoleVisitor{},
|
||||
TypeNatHoleClient: NatHoleClient{},
|
||||
TypeNatHoleResp: NatHoleResp{},
|
||||
TypeNatHoleSid: NatHoleSid{},
|
||||
TypeNatHoleReport: NatHoleReport{},
|
||||
}
|
||||
|
||||
var TypeNameNatHoleResp = reflect.TypeOf(&NatHoleResp{}).Elem().Name()
|
||||
|
||||
// When frpc start, client send this message to login to server.
|
||||
type Login struct {
|
||||
Version string `json:"version,omitempty"`
|
||||
@@ -175,35 +174,58 @@ type UDPPacket struct {
|
||||
}
|
||||
|
||||
type NatHoleVisitor struct {
|
||||
ProxyName string `json:"proxy_name,omitempty"`
|
||||
SignKey string `json:"sign_key,omitempty"`
|
||||
Timestamp int64 `json:"timestamp,omitempty"`
|
||||
TransactionID string `json:"transaction_id,omitempty"`
|
||||
ProxyName string `json:"proxy_name,omitempty"`
|
||||
PreCheck bool `json:"pre_check,omitempty"`
|
||||
Protocol string `json:"protocol,omitempty"`
|
||||
SignKey string `json:"sign_key,omitempty"`
|
||||
Timestamp int64 `json:"timestamp,omitempty"`
|
||||
MappedAddrs []string `json:"mapped_addrs,omitempty"`
|
||||
AssistedAddrs []string `json:"assisted_addrs,omitempty"`
|
||||
}
|
||||
|
||||
type NatHoleClient struct {
|
||||
ProxyName string `json:"proxy_name,omitempty"`
|
||||
Sid string `json:"sid,omitempty"`
|
||||
TransactionID string `json:"transaction_id,omitempty"`
|
||||
ProxyName string `json:"proxy_name,omitempty"`
|
||||
Sid string `json:"sid,omitempty"`
|
||||
MappedAddrs []string `json:"mapped_addrs,omitempty"`
|
||||
AssistedAddrs []string `json:"assisted_addrs,omitempty"`
|
||||
}
|
||||
|
||||
type PortsRange struct {
|
||||
From int `json:"from,omitempty"`
|
||||
To int `json:"to,omitempty"`
|
||||
}
|
||||
|
||||
type NatHoleDetectBehavior struct {
|
||||
Role string `json:"role,omitempty"` // sender or receiver
|
||||
Mode int `json:"mode,omitempty"` // 0, 1, 2...
|
||||
TTL int `json:"ttl,omitempty"`
|
||||
SendDelayMs int `json:"send_delay_ms,omitempty"`
|
||||
ReadTimeoutMs int `json:"read_timeout,omitempty"`
|
||||
CandidatePorts []PortsRange `json:"candidate_ports,omitempty"`
|
||||
SendRandomPorts int `json:"send_random_ports,omitempty"`
|
||||
ListenRandomPorts int `json:"listen_random_ports,omitempty"`
|
||||
}
|
||||
|
||||
type NatHoleResp struct {
|
||||
Sid string `json:"sid,omitempty"`
|
||||
VisitorAddr string `json:"visitor_addr,omitempty"`
|
||||
ClientAddr string `json:"client_addr,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
TransactionID string `json:"transaction_id,omitempty"`
|
||||
Sid string `json:"sid,omitempty"`
|
||||
Protocol string `json:"protocol,omitempty"`
|
||||
CandidateAddrs []string `json:"candidate_addrs,omitempty"`
|
||||
AssistedAddrs []string `json:"assisted_addrs,omitempty"`
|
||||
DetectBehavior NatHoleDetectBehavior `json:"detect_behavior,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type NatHoleClientDetectOK struct{}
|
||||
|
||||
type NatHoleSid struct {
|
||||
Sid string `json:"sid,omitempty"`
|
||||
TransactionID string `json:"transaction_id,omitempty"`
|
||||
Sid string `json:"sid,omitempty"`
|
||||
Response bool `json:"response,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
}
|
||||
|
||||
type NatHoleBinding struct {
|
||||
TransactionID string `json:"transaction_id,omitempty"`
|
||||
}
|
||||
|
||||
type NatHoleBindingResp struct {
|
||||
TransactionID string `json:"transaction_id,omitempty"`
|
||||
Address string `json:"address,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
type NatHoleReport struct {
|
||||
Sid string `json:"sid,omitempty"`
|
||||
Success bool `json:"success,omitempty"`
|
||||
}
|
||||
|
328
pkg/nathole/analysis.go
Normal file
328
pkg/nathole/analysis.go
Normal file
@@ -0,0 +1,328 @@
|
||||
// Copyright 2023 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package nathole
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
var (
|
||||
// mode 0, both EasyNAT, PublicNetwork is always receiver
|
||||
// sender | receiver, ttl 7
|
||||
// receiver, ttl 7 | sender
|
||||
// sender | receiver, ttl 4
|
||||
// receiver, ttl 4 | sender
|
||||
// sender | receiver
|
||||
// receiver | sender
|
||||
// sender, sendDelayMs 5000 | receiver
|
||||
// sender, sendDelayMs 10000 | receiver
|
||||
// receiver | sender, sendDelayMs 5000
|
||||
// receiver | sender, sendDelayMs 10000
|
||||
mode0Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 7}, RecommandBehavior{Role: DetectRoleSender}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 4}, RecommandBehavior{Role: DetectRoleSender}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleReceiver}, RecommandBehavior{Role: DetectRoleSender}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 5000}, RecommandBehavior{Role: DetectRoleReceiver}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 10000}, RecommandBehavior{Role: DetectRoleReceiver}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleReceiver}, RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 5000}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleReceiver}, RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 10000}),
|
||||
}
|
||||
|
||||
// mode 1, HardNAT is sender, EasyNAT is receiver, port changes is regular
|
||||
// sender | receiver, ttl 7, portsRangeNumber max 10
|
||||
// sender, sendDelayMs 2000 | receiver, ttl 7, portsRangeNumber max 10
|
||||
// sender | receiver, ttl 4, portsRangeNumber max 10
|
||||
// sender, sendDelayMs 2000 | receiver, ttl 4, portsRangeNumber max 10
|
||||
// sender | receiver, portsRangeNumber max 10
|
||||
// sender, sendDelayMs 2000 | receiver, portsRangeNumber max 10
|
||||
mode1Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 2000}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 2000}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 2000}, RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}),
|
||||
}
|
||||
|
||||
// mode 2, HardNAT is receiver, EasyNAT is sender
|
||||
// sender, portsRandomNumber 1000, sendDelayMs 2000 | receiver, listen 256 ports, ttl 7
|
||||
// sender, portsRandomNumber 1000, sendDelayMs 2000 | receiver, listen 256 ports, ttl 4
|
||||
// sender, portsRandomNumber 1000, sendDelayMs 2000 | receiver, listen 256 ports
|
||||
mode2Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{
|
||||
lo.T2(
|
||||
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000},
|
||||
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 7},
|
||||
),
|
||||
lo.T2(
|
||||
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000},
|
||||
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 4},
|
||||
),
|
||||
lo.T2(
|
||||
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000},
|
||||
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256},
|
||||
),
|
||||
}
|
||||
|
||||
// mode 3, For HardNAT & HardNAT, both changes in the ports are regular
|
||||
// sender, portsRangeNumber 10 | receiver, ttl 7, portsRangeNumber 10
|
||||
// sender, portsRangeNumber 10 | receiver, ttl 4, portsRangeNumber 10
|
||||
// sender, portsRangeNumber 10 | receiver, portsRangeNumber 10
|
||||
// receiver, ttl 7, portsRangeNumber 10 | sender, portsRangeNumber 10
|
||||
// receiver, ttl 4, portsRangeNumber 10 | sender, portsRangeNumber 10
|
||||
// receiver, portsRangeNumber 10 | sender, portsRangeNumber 10
|
||||
mode3Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}),
|
||||
lo.T2(RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}),
|
||||
}
|
||||
|
||||
// mode 4, Regular ports changes are usually the sender.
|
||||
// sender, portsRandomNumber 1000, sendDelayMs: 2000 | receiver, listen 256 ports, ttl 7, portsRangeNumber 10
|
||||
// sender, portsRandomNumber 1000, sendDelayMs: 2000 | receiver, listen 256 ports, ttl 4, portsRangeNumber 10
|
||||
// sender, portsRandomNumber 1000, SendDelayMs: 2000 | receiver, listen 256 ports, portsRangeNumber 10
|
||||
mode4Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{
|
||||
lo.T2(
|
||||
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000},
|
||||
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 7, PortsRangeNumber: 10},
|
||||
),
|
||||
lo.T2(
|
||||
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000},
|
||||
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 4, PortsRangeNumber: 10},
|
||||
),
|
||||
lo.T2(
|
||||
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000},
|
||||
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, PortsRangeNumber: 10},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
func getBehaviorByMode(mode int) []lo.Tuple2[RecommandBehavior, RecommandBehavior] {
|
||||
switch mode {
|
||||
case 0:
|
||||
return mode0Behaviors
|
||||
case 1:
|
||||
return mode1Behaviors
|
||||
case 2:
|
||||
return mode2Behaviors
|
||||
case 3:
|
||||
return mode3Behaviors
|
||||
case 4:
|
||||
return mode4Behaviors
|
||||
}
|
||||
// default
|
||||
return mode0Behaviors
|
||||
}
|
||||
|
||||
func getBehaviorByModeAndIndex(mode int, index int) (RecommandBehavior, RecommandBehavior) {
|
||||
behaviors := getBehaviorByMode(mode)
|
||||
if index >= len(behaviors) {
|
||||
return RecommandBehavior{}, RecommandBehavior{}
|
||||
}
|
||||
return behaviors[index].A, behaviors[index].B
|
||||
}
|
||||
|
||||
func getBehaviorScoresByMode(mode int, defaultScore int) []*BehaviorScore {
|
||||
return getBehaviorScoresByMode2(mode, defaultScore, defaultScore)
|
||||
}
|
||||
|
||||
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*BehaviorScore {
|
||||
behaviors := getBehaviorByMode(mode)
|
||||
scores := make([]*BehaviorScore, 0, len(behaviors))
|
||||
for i := 0; i < len(behaviors); i++ {
|
||||
score := receiverScore
|
||||
if behaviors[i].A.Role == DetectRoleSender {
|
||||
score = senderScore
|
||||
}
|
||||
scores = append(scores, &BehaviorScore{Mode: mode, Index: i, Score: score})
|
||||
}
|
||||
return scores
|
||||
}
|
||||
|
||||
type RecommandBehavior struct {
|
||||
Role string
|
||||
TTL int
|
||||
SendDelayMs int
|
||||
PortsRangeNumber int
|
||||
PortsRandomNumber int
|
||||
ListenRandomPorts int
|
||||
}
|
||||
|
||||
type MakeHoleRecords struct {
|
||||
mu sync.Mutex
|
||||
scores []*BehaviorScore
|
||||
LastUpdateTime time.Time
|
||||
}
|
||||
|
||||
func NewMakeHoleRecords(c, v *NatFeature) *MakeHoleRecords {
|
||||
scores := []*BehaviorScore{}
|
||||
easyCount, hardCount, portsChangedRegularCount := ClassifyFeatureCount([]*NatFeature{c, v})
|
||||
appendMode0 := func() {
|
||||
switch {
|
||||
case c.PublicNetwork:
|
||||
scores = append(scores, getBehaviorScoresByMode2(DetectMode0, 0, 1)...)
|
||||
case v.PublicNetwork:
|
||||
scores = append(scores, getBehaviorScoresByMode2(DetectMode0, 1, 0)...)
|
||||
default:
|
||||
scores = append(scores, getBehaviorScoresByMode(DetectMode0, 0)...)
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case easyCount == 2:
|
||||
appendMode0()
|
||||
case hardCount == 1 && portsChangedRegularCount == 1:
|
||||
scores = append(scores, getBehaviorScoresByMode(DetectMode1, 0)...)
|
||||
scores = append(scores, getBehaviorScoresByMode(DetectMode2, 0)...)
|
||||
appendMode0()
|
||||
case hardCount == 1 && portsChangedRegularCount == 0:
|
||||
scores = append(scores, getBehaviorScoresByMode(DetectMode2, 0)...)
|
||||
scores = append(scores, getBehaviorScoresByMode(DetectMode1, 0)...)
|
||||
appendMode0()
|
||||
case hardCount == 2 && portsChangedRegularCount == 2:
|
||||
scores = append(scores, getBehaviorScoresByMode(DetectMode3, 0)...)
|
||||
scores = append(scores, getBehaviorScoresByMode(DetectMode4, 0)...)
|
||||
case hardCount == 2 && portsChangedRegularCount == 1:
|
||||
scores = append(scores, getBehaviorScoresByMode(DetectMode4, 0)...)
|
||||
default:
|
||||
// hard to make hole, just trying it out.
|
||||
scores = append(scores, getBehaviorScoresByMode(DetectMode0, 1)...)
|
||||
scores = append(scores, getBehaviorScoresByMode(DetectMode1, 1)...)
|
||||
scores = append(scores, getBehaviorScoresByMode(DetectMode3, 1)...)
|
||||
}
|
||||
return &MakeHoleRecords{scores: scores, LastUpdateTime: time.Now()}
|
||||
}
|
||||
|
||||
func (mhr *MakeHoleRecords) ReportSuccess(mode int, index int) {
|
||||
mhr.mu.Lock()
|
||||
defer mhr.mu.Unlock()
|
||||
mhr.LastUpdateTime = time.Now()
|
||||
for i := range mhr.scores {
|
||||
score := mhr.scores[i]
|
||||
if score.Mode != mode || score.Index != index {
|
||||
continue
|
||||
}
|
||||
|
||||
score.Score += 2
|
||||
score.Score = lo.Min([]int{score.Score, 10})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (mhr *MakeHoleRecords) Recommand() (mode, index int) {
|
||||
mhr.mu.Lock()
|
||||
defer mhr.mu.Unlock()
|
||||
|
||||
maxScore := lo.MaxBy(mhr.scores, func(item, max *BehaviorScore) bool {
|
||||
return item.Score > max.Score
|
||||
})
|
||||
if maxScore == nil {
|
||||
return 0, 0
|
||||
}
|
||||
maxScore.Score--
|
||||
mhr.LastUpdateTime = time.Now()
|
||||
return maxScore.Mode, maxScore.Index
|
||||
}
|
||||
|
||||
type BehaviorScore struct {
|
||||
Mode int
|
||||
Index int
|
||||
// between -10 and 10
|
||||
Score int
|
||||
}
|
||||
|
||||
type Analyzer struct {
|
||||
// key is client ip + visitor ip
|
||||
records map[string]*MakeHoleRecords
|
||||
dataReserveDuration time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewAnalyzer(dataReserveDuration time.Duration) *Analyzer {
|
||||
return &Analyzer{
|
||||
records: make(map[string]*MakeHoleRecords),
|
||||
dataReserveDuration: dataReserveDuration,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Analyzer) GetRecommandBehaviors(key string, c, v *NatFeature) (mode, index int, _ RecommandBehavior, _ RecommandBehavior) {
|
||||
a.mu.Lock()
|
||||
records, ok := a.records[key]
|
||||
if !ok {
|
||||
records = NewMakeHoleRecords(c, v)
|
||||
a.records[key] = records
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
mode, index = records.Recommand()
|
||||
cBehavior, vBehavior := getBehaviorByModeAndIndex(mode, index)
|
||||
|
||||
switch mode {
|
||||
case DetectMode1:
|
||||
// HardNAT is always the sender
|
||||
if c.NatType == EasyNAT {
|
||||
cBehavior, vBehavior = vBehavior, cBehavior
|
||||
}
|
||||
case DetectMode2:
|
||||
// HardNAT is always the receiver
|
||||
if c.NatType == HardNAT {
|
||||
cBehavior, vBehavior = vBehavior, cBehavior
|
||||
}
|
||||
case DetectMode4:
|
||||
// Regular ports changes is always the sender
|
||||
if !c.RegularPortsChange {
|
||||
cBehavior, vBehavior = vBehavior, cBehavior
|
||||
}
|
||||
}
|
||||
return mode, index, cBehavior, vBehavior
|
||||
}
|
||||
|
||||
func (a *Analyzer) ReportSuccess(key string, mode, index int) {
|
||||
a.mu.Lock()
|
||||
records, ok := a.records[key]
|
||||
a.mu.Unlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
records.ReportSuccess(mode, index)
|
||||
}
|
||||
|
||||
func (a *Analyzer) Clean() (int, int) {
|
||||
now := time.Now()
|
||||
total := 0
|
||||
count := 0
|
||||
|
||||
// cleanup 10w records may take 5ms
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
total = len(a.records)
|
||||
// clean up records that have not been used for a period of time.
|
||||
for key, records := range a.records {
|
||||
if now.Sub(records.LastUpdateTime) > a.dataReserveDuration {
|
||||
delete(a.records, key)
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, total
|
||||
}
|
@@ -17,6 +17,9 @@ package nathole
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -29,46 +32,96 @@ const (
|
||||
BehaviorBothChanged = "BehaviorBothChanged"
|
||||
)
|
||||
|
||||
// ClassifyNATType classify NAT type by given addresses.
|
||||
func ClassifyNATType(addresses []string) (string, string, error) {
|
||||
type NatFeature struct {
|
||||
NatType string
|
||||
Behavior string
|
||||
PortsDifference int
|
||||
RegularPortsChange bool
|
||||
PublicNetwork bool
|
||||
}
|
||||
|
||||
func ClassifyNATFeature(addresses []string, localIPs []string) (*NatFeature, error) {
|
||||
if len(addresses) <= 1 {
|
||||
return "", "", fmt.Errorf("not enough addresses")
|
||||
return nil, fmt.Errorf("not enough addresses")
|
||||
}
|
||||
natFeature := &NatFeature{}
|
||||
ipChanged := false
|
||||
portChanged := false
|
||||
|
||||
var baseIP, basePort string
|
||||
var portMax, portMin int
|
||||
for _, addr := range addresses {
|
||||
ip, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
return nil, err
|
||||
}
|
||||
portNum, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lo.Contains(localIPs, ip) {
|
||||
natFeature.PublicNetwork = true
|
||||
}
|
||||
|
||||
if baseIP == "" {
|
||||
baseIP = ip
|
||||
basePort = port
|
||||
portMax = portNum
|
||||
portMin = portNum
|
||||
continue
|
||||
}
|
||||
|
||||
if portNum > portMax {
|
||||
portMax = portNum
|
||||
}
|
||||
if portNum < portMin {
|
||||
portMin = portNum
|
||||
}
|
||||
if baseIP != ip {
|
||||
ipChanged = true
|
||||
}
|
||||
if basePort != port {
|
||||
portChanged = true
|
||||
}
|
||||
}
|
||||
|
||||
if ipChanged && portChanged {
|
||||
break
|
||||
}
|
||||
natFeature.PortsDifference = portMax - portMin
|
||||
if natFeature.PortsDifference <= 10 && natFeature.PortsDifference >= 1 {
|
||||
natFeature.RegularPortsChange = true
|
||||
}
|
||||
|
||||
switch {
|
||||
case ipChanged && portChanged:
|
||||
return HardNAT, BehaviorBothChanged, nil
|
||||
natFeature.NatType = HardNAT
|
||||
natFeature.Behavior = BehaviorBothChanged
|
||||
case ipChanged:
|
||||
return HardNAT, BehaviorIPChanged, nil
|
||||
natFeature.NatType = HardNAT
|
||||
natFeature.Behavior = BehaviorIPChanged
|
||||
case portChanged:
|
||||
return HardNAT, BehaviorPortChanged, nil
|
||||
natFeature.NatType = HardNAT
|
||||
natFeature.Behavior = BehaviorPortChanged
|
||||
default:
|
||||
return EasyNAT, BehaviorNoChange, nil
|
||||
natFeature.NatType = EasyNAT
|
||||
natFeature.Behavior = BehaviorNoChange
|
||||
}
|
||||
return natFeature, nil
|
||||
}
|
||||
|
||||
func ClassifyFeatureCount(features []*NatFeature) (int, int, int) {
|
||||
easyCount := 0
|
||||
hardCount := 0
|
||||
// for HardNAT
|
||||
portsChangedRegularCount := 0
|
||||
for _, feature := range features {
|
||||
if feature.NatType == EasyNAT {
|
||||
easyCount++
|
||||
continue
|
||||
}
|
||||
|
||||
hardCount++
|
||||
if feature.RegularPortsChange {
|
||||
portsChangedRegularCount++
|
||||
}
|
||||
}
|
||||
return easyCount, hardCount, portsChangedRegularCount
|
||||
}
|
||||
|
382
pkg/nathole/controller.go
Normal file
382
pkg/nathole/controller.go
Normal file
@@ -0,0 +1,382 @@
|
||||
// Copyright 2023 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package nathole
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
"github.com/samber/lo"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
)
|
||||
|
||||
// NatHoleTimeout seconds.
|
||||
var NatHoleTimeout int64 = 10
|
||||
|
||||
func NewTransactionID() string {
|
||||
id, _ := util.RandID()
|
||||
return fmt.Sprintf("%d%s", time.Now().Unix(), id)
|
||||
}
|
||||
|
||||
type ClientCfg struct {
|
||||
name string
|
||||
sk string
|
||||
sidCh chan string
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
sid string
|
||||
analysisKey string
|
||||
recommandMode int
|
||||
recommandIndex int
|
||||
|
||||
visitorMsg *msg.NatHoleVisitor
|
||||
visitorTransporter transport.MessageTransporter
|
||||
vResp *msg.NatHoleResp
|
||||
vNatFeature *NatFeature
|
||||
vBehavior RecommandBehavior
|
||||
|
||||
clientMsg *msg.NatHoleClient
|
||||
clientTransporter transport.MessageTransporter
|
||||
cResp *msg.NatHoleResp
|
||||
cNatFeature *NatFeature
|
||||
cBehavior RecommandBehavior
|
||||
|
||||
notifyCh chan struct{}
|
||||
}
|
||||
|
||||
func (s *Session) genAnalysisKey() {
|
||||
hash := md5.New()
|
||||
vIPs := lo.Uniq(parseIPs(s.visitorMsg.MappedAddrs))
|
||||
if len(vIPs) > 0 {
|
||||
hash.Write([]byte(vIPs[0]))
|
||||
}
|
||||
hash.Write([]byte(s.vNatFeature.NatType))
|
||||
hash.Write([]byte(s.vNatFeature.Behavior))
|
||||
hash.Write([]byte(strconv.FormatBool(s.vNatFeature.RegularPortsChange)))
|
||||
|
||||
cIPs := lo.Uniq(parseIPs(s.clientMsg.MappedAddrs))
|
||||
if len(cIPs) > 0 {
|
||||
hash.Write([]byte(cIPs[0]))
|
||||
}
|
||||
hash.Write([]byte(s.cNatFeature.NatType))
|
||||
hash.Write([]byte(s.cNatFeature.Behavior))
|
||||
hash.Write([]byte(strconv.FormatBool(s.cNatFeature.RegularPortsChange)))
|
||||
s.analysisKey = hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
type Controller struct {
|
||||
clientCfgs map[string]*ClientCfg
|
||||
sessions map[string]*Session
|
||||
analyzer *Analyzer
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewController(analysisDataReserveDuration time.Duration) (*Controller, error) {
|
||||
return &Controller{
|
||||
clientCfgs: make(map[string]*ClientCfg),
|
||||
sessions: make(map[string]*Session),
|
||||
analyzer: NewAnalyzer(analysisDataReserveDuration),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Controller) CleanWorker(ctx context.Context) {
|
||||
ticker := time.NewTicker(time.Hour)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
start := time.Now()
|
||||
count, total := c.analyzer.Clean()
|
||||
log.Trace("clean %d/%d nathole analysis data, cost %v", count, total, time.Since(start))
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) ListenClient(name string, sk string) chan string {
|
||||
cfg := &ClientCfg{
|
||||
name: name,
|
||||
sk: sk,
|
||||
sidCh: make(chan string),
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.clientCfgs[name] = cfg
|
||||
return cfg.sidCh
|
||||
}
|
||||
|
||||
func (c *Controller) CloseClient(name string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.clientCfgs, name)
|
||||
}
|
||||
|
||||
func (c *Controller) GenSid() string {
|
||||
t := time.Now().Unix()
|
||||
id, _ := util.RandID()
|
||||
return fmt.Sprintf("%d%s", t, id)
|
||||
}
|
||||
|
||||
func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.MessageTransporter) {
|
||||
if m.PreCheck {
|
||||
_, ok := c.clientCfgs[m.ProxyName]
|
||||
if !ok {
|
||||
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName)))
|
||||
} else {
|
||||
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, ""))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
sid := c.GenSid()
|
||||
session := &Session{
|
||||
sid: sid,
|
||||
visitorMsg: m,
|
||||
visitorTransporter: transporter,
|
||||
notifyCh: make(chan struct{}, 1),
|
||||
}
|
||||
var (
|
||||
clientCfg *ClientCfg
|
||||
ok bool
|
||||
)
|
||||
err := func() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
clientCfg, ok = c.clientCfgs[m.ProxyName]
|
||||
if !ok {
|
||||
return fmt.Errorf("xtcp server for [%s] doesn't exist", m.ProxyName)
|
||||
}
|
||||
if m.SignKey != util.GetAuthKey(clientCfg.sk, m.Timestamp) {
|
||||
return fmt.Errorf("xtcp connection of [%s] auth failed", m.ProxyName)
|
||||
}
|
||||
c.sessions[sid] = session
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
log.Warn("handle visitorMsg error: %v", err)
|
||||
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, err.Error()))
|
||||
return
|
||||
}
|
||||
log.Trace("handle visitor message, sid [%s]", sid)
|
||||
|
||||
defer func() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.sessions, sid)
|
||||
}()
|
||||
|
||||
if err := errors.PanicToError(func() {
|
||||
clientCfg.sidCh <- sid
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// wait for NatHoleClient message
|
||||
select {
|
||||
case <-session.notifyCh:
|
||||
case <-time.After(time.Duration(NatHoleTimeout) * time.Second):
|
||||
log.Debug("wait for NatHoleClient message timeout, sid [%s]", sid)
|
||||
return
|
||||
}
|
||||
|
||||
// Make hole-punching decisions based on the NAT information of the client and visitor.
|
||||
vResp, cResp, err := c.analysis(session)
|
||||
if err != nil {
|
||||
log.Debug("sid [%s] analysis error: %v", err)
|
||||
vResp = c.GenNatHoleResponse(session.visitorMsg.TransactionID, nil, err.Error())
|
||||
cResp = c.GenNatHoleResponse(session.clientMsg.TransactionID, nil, err.Error())
|
||||
}
|
||||
session.cResp = cResp
|
||||
session.vResp = vResp
|
||||
|
||||
// send response to visitor and client
|
||||
var g errgroup.Group
|
||||
g.Go(func() error {
|
||||
// if it's sender, wait for a while to make sure the client has send the detect messages
|
||||
if vResp.DetectBehavior.Role == "sender" {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
_ = session.visitorTransporter.Send(vResp)
|
||||
return nil
|
||||
})
|
||||
g.Go(func() error {
|
||||
// if it's sender, wait for a while to make sure the client has send the detect messages
|
||||
if cResp.DetectBehavior.Role == "sender" {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
_ = session.clientTransporter.Send(cResp)
|
||||
return nil
|
||||
})
|
||||
_ = g.Wait()
|
||||
|
||||
time.Sleep(time.Duration(cResp.DetectBehavior.ReadTimeoutMs+30000) * time.Millisecond)
|
||||
}
|
||||
|
||||
func (c *Controller) HandleClient(m *msg.NatHoleClient, transporter transport.MessageTransporter) {
|
||||
c.mu.RLock()
|
||||
session, ok := c.sessions[m.Sid]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Trace("handle client message, sid [%s]", session.sid)
|
||||
session.clientMsg = m
|
||||
session.clientTransporter = transporter
|
||||
select {
|
||||
case session.notifyCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) HandleReport(m *msg.NatHoleReport) {
|
||||
c.mu.RLock()
|
||||
session, ok := c.sessions[m.Sid]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
log.Trace("sid [%s] report make hole success: %v, but session not found", m.Sid, m.Success)
|
||||
return
|
||||
}
|
||||
if m.Success {
|
||||
c.analyzer.ReportSuccess(session.analysisKey, session.recommandMode, session.recommandIndex)
|
||||
}
|
||||
log.Info("sid [%s] report make hole success: %v, mode %v, index %v",
|
||||
m.Sid, m.Success, session.recommandMode, session.recommandIndex)
|
||||
}
|
||||
|
||||
func (c *Controller) GenNatHoleResponse(transactionID string, session *Session, errInfo string) *msg.NatHoleResp {
|
||||
var sid string
|
||||
if session != nil {
|
||||
sid = session.sid
|
||||
}
|
||||
return &msg.NatHoleResp{
|
||||
TransactionID: transactionID,
|
||||
Sid: sid,
|
||||
Error: errInfo,
|
||||
}
|
||||
}
|
||||
|
||||
// analysis analyzes the NAT type and behavior of the visitor and client, then makes hole-punching decisions.
|
||||
// return the response to the visitor and client.
|
||||
func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleResp, error) {
|
||||
cm := session.clientMsg
|
||||
vm := session.visitorMsg
|
||||
|
||||
cNatFeature, err := ClassifyNATFeature(cm.MappedAddrs, parseIPs(cm.AssistedAddrs))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("classify client nat feature error: %v", err)
|
||||
}
|
||||
|
||||
vNatFeature, err := ClassifyNATFeature(vm.MappedAddrs, parseIPs(vm.AssistedAddrs))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("classify visitor nat feature error: %v", err)
|
||||
}
|
||||
session.cNatFeature = cNatFeature
|
||||
session.vNatFeature = vNatFeature
|
||||
session.genAnalysisKey()
|
||||
|
||||
mode, index, cBehavior, vBehavior := c.analyzer.GetRecommandBehaviors(session.analysisKey, cNatFeature, vNatFeature)
|
||||
session.recommandMode = mode
|
||||
session.recommandIndex = index
|
||||
session.cBehavior = cBehavior
|
||||
session.vBehavior = vBehavior
|
||||
|
||||
timeoutMs := lo.Max([]int{cBehavior.SendDelayMs, vBehavior.SendDelayMs}) + 5000
|
||||
if cBehavior.ListenRandomPorts > 0 || vBehavior.ListenRandomPorts > 0 {
|
||||
timeoutMs += 30000
|
||||
}
|
||||
|
||||
protocol := vm.Protocol
|
||||
vResp := &msg.NatHoleResp{
|
||||
TransactionID: vm.TransactionID,
|
||||
Sid: session.sid,
|
||||
Protocol: protocol,
|
||||
CandidateAddrs: lo.Uniq(cm.MappedAddrs),
|
||||
AssistedAddrs: lo.Uniq(cm.AssistedAddrs),
|
||||
DetectBehavior: msg.NatHoleDetectBehavior{
|
||||
Mode: mode,
|
||||
Role: vBehavior.Role,
|
||||
TTL: vBehavior.TTL,
|
||||
SendDelayMs: vBehavior.SendDelayMs,
|
||||
ReadTimeoutMs: timeoutMs - vBehavior.SendDelayMs,
|
||||
SendRandomPorts: vBehavior.PortsRandomNumber,
|
||||
ListenRandomPorts: vBehavior.ListenRandomPorts,
|
||||
CandidatePorts: getRangePorts(cm.MappedAddrs, cNatFeature.PortsDifference, vBehavior.PortsRangeNumber),
|
||||
},
|
||||
}
|
||||
cResp := &msg.NatHoleResp{
|
||||
TransactionID: cm.TransactionID,
|
||||
Sid: session.sid,
|
||||
Protocol: protocol,
|
||||
CandidateAddrs: lo.Uniq(vm.MappedAddrs),
|
||||
AssistedAddrs: lo.Uniq(vm.AssistedAddrs),
|
||||
DetectBehavior: msg.NatHoleDetectBehavior{
|
||||
Mode: mode,
|
||||
Role: cBehavior.Role,
|
||||
TTL: cBehavior.TTL,
|
||||
SendDelayMs: cBehavior.SendDelayMs,
|
||||
ReadTimeoutMs: timeoutMs - cBehavior.SendDelayMs,
|
||||
SendRandomPorts: cBehavior.PortsRandomNumber,
|
||||
ListenRandomPorts: cBehavior.ListenRandomPorts,
|
||||
CandidatePorts: getRangePorts(vm.MappedAddrs, vNatFeature.PortsDifference, cBehavior.PortsRangeNumber),
|
||||
},
|
||||
}
|
||||
|
||||
log.Debug("sid [%s] visitor nat: %+v, candidateAddrs: %v; client nat: %+v, candidateAddrs: %v, protocol: %s",
|
||||
session.sid, *vNatFeature, vm.MappedAddrs, *cNatFeature, cm.MappedAddrs, protocol)
|
||||
log.Debug("sid [%s] visitor detect behavior: %+v", session.sid, vResp.DetectBehavior)
|
||||
log.Debug("sid [%s] client detect behavior: %+v", session.sid, cResp.DetectBehavior)
|
||||
return vResp, cResp, nil
|
||||
}
|
||||
|
||||
func getRangePorts(addrs []string, difference, maxNumber int) []msg.PortsRange {
|
||||
if maxNumber <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
addr, err := lo.Last(addrs)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var ports []msg.PortsRange
|
||||
_, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
ports = append(ports, msg.PortsRange{
|
||||
From: lo.Max([]int{port - difference - 5, port - maxNumber, 1}),
|
||||
To: lo.Min([]int{port + difference + 5, port + maxNumber, 65535}),
|
||||
})
|
||||
return ports
|
||||
}
|
@@ -20,8 +20,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pion/stun"
|
||||
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
)
|
||||
|
||||
var responseTimeout = 3 * time.Second
|
||||
@@ -31,35 +29,27 @@ type Message struct {
|
||||
Addr string
|
||||
}
|
||||
|
||||
func Discover(serverAddress string, stunServers []string, key []byte) ([]string, error) {
|
||||
// If the localAddr is empty, it will listen on a random port.
|
||||
func Discover(stunServers []string, localAddr string) ([]string, net.Addr, error) {
|
||||
// create a discoverConn and get response from messageChan
|
||||
discoverConn, err := listen()
|
||||
discoverConn, err := listen(localAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
defer discoverConn.Close()
|
||||
|
||||
go discoverConn.readLoop()
|
||||
|
||||
addresses := make([]string, 0, len(stunServers)+1)
|
||||
if serverAddress != "" {
|
||||
// get external address from frp server
|
||||
externalAddr, err := discoverConn.discoverFromServer(serverAddress, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addresses = append(addresses, externalAddr)
|
||||
}
|
||||
|
||||
addresses := make([]string, 0, len(stunServers))
|
||||
for _, addr := range stunServers {
|
||||
// get external address from stun server
|
||||
externalAddrs, err := discoverConn.discoverFromStunServer(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
addresses = append(addresses, externalAddrs...)
|
||||
}
|
||||
return addresses, nil
|
||||
return addresses, discoverConn.localAddr, nil
|
||||
}
|
||||
|
||||
type stunResponse struct {
|
||||
@@ -74,8 +64,16 @@ type discoverConn struct {
|
||||
messageChan chan *Message
|
||||
}
|
||||
|
||||
func listen() (*discoverConn, error) {
|
||||
conn, err := net.ListenUDP("udp4", nil)
|
||||
func listen(localAddr string) (*discoverConn, error) {
|
||||
var local *net.UDPAddr
|
||||
if localAddr != "" {
|
||||
addr, err := net.ResolveUDPAddr("udp4", localAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
local = addr
|
||||
}
|
||||
conn, err := net.ListenUDP("udp4", local)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -159,43 +157,6 @@ func (c *discoverConn) doSTUNRequest(addr string) (*stunResponse, error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *discoverConn) discoverFromServer(serverAddress string, key []byte) (string, error) {
|
||||
addr, err := net.ResolveUDPAddr("udp4", serverAddress)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
m := &msg.NatHoleBinding{
|
||||
TransactionID: NewTransactionID(),
|
||||
}
|
||||
|
||||
buf, err := EncodeMessage(m, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if _, err := c.conn.WriteTo(buf, addr); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var respMsg msg.NatHoleBindingResp
|
||||
select {
|
||||
case rawMsg := <-c.messageChan:
|
||||
if err := DecodeMessageInto(rawMsg.Body, key, &respMsg); err != nil {
|
||||
return "", err
|
||||
}
|
||||
case <-time.After(responseTimeout):
|
||||
return "", fmt.Errorf("wait response from frp server timeout")
|
||||
}
|
||||
|
||||
if respMsg.TransactionID == "" {
|
||||
return "", fmt.Errorf("error format: no transaction id found")
|
||||
}
|
||||
if respMsg.Error != "" {
|
||||
return "", fmt.Errorf("get externalAddr from frp server error: %s", respMsg.Error)
|
||||
}
|
||||
return respMsg.Address, nil
|
||||
}
|
||||
|
||||
func (c *discoverConn) discoverFromStunServer(addr string) ([]string, error) {
|
||||
resp, err := c.doSTUNRequest(addr)
|
||||
if err != nil {
|
||||
|
@@ -15,249 +15,426 @@
|
||||
package nathole
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/crypto"
|
||||
"github.com/fatedier/golib/errors"
|
||||
"github.com/fatedier/golib/pool"
|
||||
"github.com/samber/lo"
|
||||
"golang.org/x/net/ipv4"
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
)
|
||||
|
||||
// NatHoleTimeout seconds.
|
||||
var NatHoleTimeout int64 = 10
|
||||
var (
|
||||
// mode 0: simple detect mode, usually for both EasyNAT or HardNAT & EasyNAT(Public Network)
|
||||
// a. receiver sends detect message with low TTL
|
||||
// b. sender sends normal detect message to receiver
|
||||
// c. receiver receives detect message and sends back a message to sender
|
||||
//
|
||||
// mode 1: For HardNAT & EasyNAT, send detect messages to multiple guessed ports.
|
||||
// Usually applicable to scenarios where port changes are regular.
|
||||
// Most of the steps are the same as mode 0, but EasyNAT is fixed as the receiver and will send detect messages
|
||||
// with low TTL to multiple guessed ports of the sender.
|
||||
//
|
||||
// mode 2: For HardNAT & EasyNAT, ports changes are not regular.
|
||||
// a. HardNAT machine will listen on multiple ports and send detect messages with low TTL to EasyNAT machine
|
||||
// b. EasyNAT machine will send detect messages to random ports of HardNAT machine.
|
||||
//
|
||||
// mode 3: For HardNAT & HardNAT, both changes in the ports are regular.
|
||||
// Most of the steps are the same as mode 1, but the sender also needs to send detect messages to multiple guessed
|
||||
// ports of the receiver.
|
||||
//
|
||||
// mode 4: For HardNAT & HardNAT, one of the changes in the ports is regular.
|
||||
// Regular port changes are usually on the sender side.
|
||||
// a. Receiver listens on multiple ports and sends detect messages with low TTL to the sender's guessed range ports.
|
||||
// b. Sender sends detect messages to random ports of the receiver.
|
||||
SupportedModes = []int{DetectMode0, DetectMode1, DetectMode2, DetectMode3, DetectMode4}
|
||||
SupportedRoles = []string{DetectRoleSender, DetectRoleReceiver}
|
||||
|
||||
func NewTransactionID() string {
|
||||
id, _ := util.RandID()
|
||||
return fmt.Sprintf("%d%s", time.Now().Unix(), id)
|
||||
DetectMode0 = 0
|
||||
DetectMode1 = 1
|
||||
DetectMode2 = 2
|
||||
DetectMode3 = 3
|
||||
DetectMode4 = 4
|
||||
DetectRoleSender = "sender"
|
||||
DetectRoleReceiver = "receiver"
|
||||
)
|
||||
|
||||
type PrepareResult struct {
|
||||
Addrs []string
|
||||
AssistedAddrs []string
|
||||
ListenConn *net.UDPConn
|
||||
NatType string
|
||||
Behavior string
|
||||
}
|
||||
|
||||
type SidRequest struct {
|
||||
Sid string
|
||||
NotifyCh chan struct{}
|
||||
}
|
||||
// PreCheck is used to check if the proxy is ready for penetration.
|
||||
// Call this function before calling Prepare to avoid unnecessary preparation work.
|
||||
func PreCheck(
|
||||
ctx context.Context, transporter transport.MessageTransporter,
|
||||
proxyName string, timeout time.Duration,
|
||||
) error {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
type Controller struct {
|
||||
listener *net.UDPConn
|
||||
|
||||
clientCfgs map[string]*ClientCfg
|
||||
sessions map[string]*Session
|
||||
|
||||
encryptionKey []byte
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewController(udpBindAddr string, encryptionKey []byte) (nc *Controller, err error) {
|
||||
addr, err := net.ResolveUDPAddr("udp", udpBindAddr)
|
||||
var natHoleRespMsg *msg.NatHoleResp
|
||||
transactionID := NewTransactionID()
|
||||
m, err := transporter.Do(timeoutCtx, &msg.NatHoleVisitor{
|
||||
TransactionID: transactionID,
|
||||
ProxyName: proxyName,
|
||||
PreCheck: true,
|
||||
}, transactionID, msg.TypeNameNatHoleResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return fmt.Errorf("get natHoleRespMsg error: %v", err)
|
||||
}
|
||||
lconn, err := net.ListenUDP("udp", addr)
|
||||
mm, ok := m.(*msg.NatHoleResp)
|
||||
if !ok {
|
||||
return fmt.Errorf("get natHoleRespMsg error: invalid message type")
|
||||
}
|
||||
natHoleRespMsg = mm
|
||||
|
||||
if natHoleRespMsg.Error != "" {
|
||||
return fmt.Errorf("%s", natHoleRespMsg.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prepare is used to do some preparation work before penetration.
|
||||
func Prepare(stunServers []string) (*PrepareResult, error) {
|
||||
// discover for Nat type
|
||||
addrs, localAddr, err := Discover(stunServers, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("discover error: %v", err)
|
||||
}
|
||||
nc = &Controller{
|
||||
listener: lconn,
|
||||
clientCfgs: make(map[string]*ClientCfg),
|
||||
sessions: make(map[string]*Session),
|
||||
encryptionKey: encryptionKey,
|
||||
if len(addrs) < 2 {
|
||||
return nil, fmt.Errorf("discover error: not enough addresses")
|
||||
}
|
||||
return nc, nil
|
||||
|
||||
localIPs, _ := ListLocalIPsForNatHole(10)
|
||||
natFeature, err := ClassifyNATFeature(addrs, localIPs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("classify nat feature error: %v", err)
|
||||
}
|
||||
|
||||
laddr, err := net.ResolveUDPAddr("udp4", localAddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve local udp addr error: %v", err)
|
||||
}
|
||||
listenConn, err := net.ListenUDP("udp4", laddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen local udp addr error: %v", err)
|
||||
}
|
||||
|
||||
assistedAddrs := make([]string, 0, len(localIPs))
|
||||
for _, ip := range localIPs {
|
||||
assistedAddrs = append(assistedAddrs, net.JoinHostPort(ip, strconv.Itoa(laddr.Port)))
|
||||
}
|
||||
return &PrepareResult{
|
||||
Addrs: addrs,
|
||||
AssistedAddrs: assistedAddrs,
|
||||
ListenConn: listenConn,
|
||||
NatType: natFeature.NatType,
|
||||
Behavior: natFeature.Behavior,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (nc *Controller) ListenClient(name string, sk string) (sidCh chan *SidRequest) {
|
||||
clientCfg := &ClientCfg{
|
||||
Name: name,
|
||||
Sk: sk,
|
||||
SidCh: make(chan *SidRequest),
|
||||
// ExchangeInfo is used to exchange information between client and visitor.
|
||||
// 1. Send input message to server by msgTransporter.
|
||||
// 2. Server will gather information from client and visitor and analyze it. Then send back a NatHoleResp message to them to tell them how to do next.
|
||||
// 3. Receive NatHoleResp message from server.
|
||||
func ExchangeInfo(
|
||||
ctx context.Context, transporter transport.MessageTransporter,
|
||||
laneKey string, m msg.Message, timeout time.Duration,
|
||||
) (*msg.NatHoleResp, error) {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
var natHoleRespMsg *msg.NatHoleResp
|
||||
m, err := transporter.Do(timeoutCtx, m, laneKey, msg.TypeNameNatHoleResp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get natHoleRespMsg error: %v", err)
|
||||
}
|
||||
nc.mu.Lock()
|
||||
nc.clientCfgs[name] = clientCfg
|
||||
nc.mu.Unlock()
|
||||
return clientCfg.SidCh
|
||||
mm, ok := m.(*msg.NatHoleResp)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("get natHoleRespMsg error: invalid message type")
|
||||
}
|
||||
natHoleRespMsg = mm
|
||||
|
||||
if natHoleRespMsg.Error != "" {
|
||||
return nil, fmt.Errorf("natHoleRespMsg get error info: %s", natHoleRespMsg.Error)
|
||||
}
|
||||
if len(natHoleRespMsg.CandidateAddrs) == 0 {
|
||||
return nil, fmt.Errorf("natHoleRespMsg get empty candidate addresses")
|
||||
}
|
||||
return natHoleRespMsg, nil
|
||||
}
|
||||
|
||||
func (nc *Controller) CloseClient(name string) {
|
||||
nc.mu.Lock()
|
||||
defer nc.mu.Unlock()
|
||||
delete(nc.clientCfgs, name)
|
||||
// MakeHole is used to make a NAT hole between client and visitor.
|
||||
func MakeHole(ctx context.Context, listenConn *net.UDPConn, m *msg.NatHoleResp, key []byte) (*net.UDPConn, *net.UDPAddr, error) {
|
||||
xl := xlog.FromContextSafe(ctx)
|
||||
transactionID := NewTransactionID()
|
||||
sendToRangePortsFunc := func(conn *net.UDPConn, addr string) error {
|
||||
return sendSidMessage(ctx, conn, m.Sid, transactionID, addr, key, m.DetectBehavior.TTL)
|
||||
}
|
||||
|
||||
listenConns := []*net.UDPConn{listenConn}
|
||||
var detectAddrs []string
|
||||
if m.DetectBehavior.Role == DetectRoleSender {
|
||||
// sender
|
||||
if m.DetectBehavior.SendDelayMs > 0 {
|
||||
time.Sleep(time.Duration(m.DetectBehavior.SendDelayMs) * time.Millisecond)
|
||||
}
|
||||
detectAddrs = m.AssistedAddrs
|
||||
detectAddrs = append(detectAddrs, m.CandidateAddrs...)
|
||||
} else {
|
||||
// receiver
|
||||
if len(m.DetectBehavior.CandidatePorts) == 0 {
|
||||
detectAddrs = m.CandidateAddrs
|
||||
}
|
||||
|
||||
if m.DetectBehavior.ListenRandomPorts > 0 {
|
||||
for i := 0; i < m.DetectBehavior.ListenRandomPorts; i++ {
|
||||
tmpConn, err := net.ListenUDP("udp4", nil)
|
||||
if err != nil {
|
||||
xl.Warn("listen random udp addr error: %v", err)
|
||||
continue
|
||||
}
|
||||
listenConns = append(listenConns, tmpConn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
detectAddrs = lo.Uniq(detectAddrs)
|
||||
for _, detectAddr := range detectAddrs {
|
||||
for _, conn := range listenConns {
|
||||
if err := sendSidMessage(ctx, conn, m.Sid, transactionID, detectAddr, key, m.DetectBehavior.TTL); err != nil {
|
||||
xl.Trace("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(m.DetectBehavior.CandidatePorts) > 0 {
|
||||
for _, conn := range listenConns {
|
||||
sendSidMessageToRangePorts(ctx, conn, m.CandidateAddrs, m.DetectBehavior.CandidatePorts, sendToRangePortsFunc)
|
||||
}
|
||||
}
|
||||
if m.DetectBehavior.SendRandomPorts > 0 {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
for i := range listenConns {
|
||||
go sendSidMessageToRandomPorts(ctx, listenConns[i], m.CandidateAddrs, m.DetectBehavior.SendRandomPorts, sendToRangePortsFunc)
|
||||
}
|
||||
}
|
||||
|
||||
timeout := 5 * time.Second
|
||||
if m.DetectBehavior.ReadTimeoutMs > 0 {
|
||||
timeout = time.Duration(m.DetectBehavior.ReadTimeoutMs) * time.Millisecond
|
||||
}
|
||||
|
||||
if len(listenConns) == 1 {
|
||||
raddr, err := waitDetectMessage(ctx, listenConns[0], m.Sid, key, timeout, m.DetectBehavior.Role)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("wait detect message error: %v", err)
|
||||
}
|
||||
return listenConns[0], raddr, nil
|
||||
}
|
||||
|
||||
type result struct {
|
||||
lConn *net.UDPConn
|
||||
raddr *net.UDPAddr
|
||||
}
|
||||
resultCh := make(chan result)
|
||||
for _, conn := range listenConns {
|
||||
go func(lConn *net.UDPConn) {
|
||||
addr, err := waitDetectMessage(ctx, lConn, m.Sid, key, timeout, m.DetectBehavior.Role)
|
||||
if err != nil {
|
||||
lConn.Close()
|
||||
return
|
||||
}
|
||||
select {
|
||||
case resultCh <- result{lConn: lConn, raddr: addr}:
|
||||
default:
|
||||
lConn.Close()
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
return result.lConn, result.raddr, nil
|
||||
case <-time.After(timeout):
|
||||
return nil, nil, fmt.Errorf("wait detect message timeout")
|
||||
case <-ctx.Done():
|
||||
return nil, nil, fmt.Errorf("wait detect message canceled")
|
||||
}
|
||||
}
|
||||
|
||||
func (nc *Controller) Run() {
|
||||
func waitDetectMessage(
|
||||
ctx context.Context, conn *net.UDPConn, sid string, key []byte,
|
||||
timeout time.Duration, role string,
|
||||
) (*net.UDPAddr, error) {
|
||||
xl := xlog.FromContextSafe(ctx)
|
||||
for {
|
||||
buf := pool.GetBuf(1024)
|
||||
n, raddr, err := nc.listener.ReadFromUDP(buf)
|
||||
_ = conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
n, raddr, err := conn.ReadFromUDP(buf)
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
log.Warn("nat hole listener read from udp error: %v", err)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
plain, err := crypto.Decode(buf[:n], nc.encryptionKey)
|
||||
if err != nil {
|
||||
log.Warn("nathole listener decode from %s error: %v", raddr.String(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
rawMsg, err := msg.ReadMsg(bytes.NewReader(plain))
|
||||
if err != nil {
|
||||
log.Warn("read nat hole message error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
switch m := rawMsg.(type) {
|
||||
case *msg.NatHoleBinding:
|
||||
go nc.HandleBinding(m, raddr)
|
||||
case *msg.NatHoleVisitor:
|
||||
go nc.HandleVisitor(m, raddr)
|
||||
case *msg.NatHoleClient:
|
||||
go nc.HandleClient(m, raddr)
|
||||
default:
|
||||
log.Trace("unknown nat hole message type")
|
||||
xl.Debug("get udp message local %s, from %s", conn.LocalAddr(), raddr)
|
||||
var m msg.NatHoleSid
|
||||
if err := DecodeMessageInto(buf[:n], key, &m); err != nil {
|
||||
xl.Warn("decode sid message error: %v", err)
|
||||
continue
|
||||
}
|
||||
pool.PutBuf(buf)
|
||||
}
|
||||
}
|
||||
|
||||
func (nc *Controller) GenSid() string {
|
||||
t := time.Now().Unix()
|
||||
id, _ := util.RandID()
|
||||
return fmt.Sprintf("%d%s", t, id)
|
||||
}
|
||||
|
||||
func (nc *Controller) HandleBinding(m *msg.NatHoleBinding, raddr *net.UDPAddr) {
|
||||
log.Trace("handle binding message from %s", raddr.String())
|
||||
resp := &msg.NatHoleBindingResp{
|
||||
TransactionID: m.TransactionID,
|
||||
Address: raddr.String(),
|
||||
}
|
||||
plain, err := msg.Pack(resp)
|
||||
if err != nil {
|
||||
log.Error("pack nat hole binding response error: %v", err)
|
||||
return
|
||||
}
|
||||
buf, err := crypto.Encode(plain, nc.encryptionKey)
|
||||
if err != nil {
|
||||
log.Error("encode nat hole binding response error: %v", err)
|
||||
return
|
||||
}
|
||||
_, err = nc.listener.WriteToUDP(buf, raddr)
|
||||
if err != nil {
|
||||
log.Error("write nat hole binding response to %s error: %v", raddr.String(), err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (nc *Controller) HandleVisitor(m *msg.NatHoleVisitor, raddr *net.UDPAddr) {
|
||||
sid := nc.GenSid()
|
||||
session := &Session{
|
||||
Sid: sid,
|
||||
VisitorAddr: raddr,
|
||||
NotifyCh: make(chan struct{}),
|
||||
}
|
||||
nc.mu.Lock()
|
||||
clientCfg, ok := nc.clientCfgs[m.ProxyName]
|
||||
if !ok {
|
||||
nc.mu.Unlock()
|
||||
errInfo := fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName)
|
||||
log.Debug(errInfo)
|
||||
_, _ = nc.listener.WriteToUDP(nc.GenNatHoleResponse(nil, errInfo), raddr)
|
||||
return
|
||||
}
|
||||
if m.SignKey != util.GetAuthKey(clientCfg.Sk, m.Timestamp) {
|
||||
nc.mu.Unlock()
|
||||
errInfo := fmt.Sprintf("xtcp connection of [%s] auth failed", m.ProxyName)
|
||||
log.Debug(errInfo)
|
||||
_, _ = nc.listener.WriteToUDP(nc.GenNatHoleResponse(nil, errInfo), raddr)
|
||||
return
|
||||
}
|
||||
|
||||
nc.sessions[sid] = session
|
||||
nc.mu.Unlock()
|
||||
log.Trace("handle visitor message, sid [%s]", sid)
|
||||
|
||||
defer func() {
|
||||
nc.mu.Lock()
|
||||
delete(nc.sessions, sid)
|
||||
nc.mu.Unlock()
|
||||
}()
|
||||
|
||||
err := errors.PanicToError(func() {
|
||||
clientCfg.SidCh <- &SidRequest{
|
||||
Sid: sid,
|
||||
NotifyCh: session.NotifyCh,
|
||||
if m.Sid != sid {
|
||||
xl.Warn("get sid message with wrong sid: %s, expect: %s", m.Sid, sid)
|
||||
continue
|
||||
}
|
||||
})
|
||||
|
||||
if !m.Response {
|
||||
// only wait for response messages if we are a sender
|
||||
if role == DetectRoleSender {
|
||||
continue
|
||||
}
|
||||
|
||||
m.Response = true
|
||||
buf2, err := EncodeMessage(&m, key)
|
||||
if err != nil {
|
||||
xl.Warn("encode sid message error: %v", err)
|
||||
continue
|
||||
}
|
||||
_, _ = conn.WriteToUDP(buf2, raddr)
|
||||
}
|
||||
return raddr, nil
|
||||
}
|
||||
}
|
||||
|
||||
func sendSidMessage(
|
||||
ctx context.Context, conn *net.UDPConn,
|
||||
sid string, transactionID string, addr string, key []byte, ttl int,
|
||||
) error {
|
||||
xl := xlog.FromContextSafe(ctx)
|
||||
ttlStr := ""
|
||||
if ttl > 0 {
|
||||
ttlStr = fmt.Sprintf(" with ttl %d", ttl)
|
||||
}
|
||||
xl.Trace("send sid message from %s to %s%s", conn.LocalAddr(), addr, ttlStr)
|
||||
raddr, err := net.ResolveUDPAddr("udp4", addr)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait client connections.
|
||||
select {
|
||||
case <-session.NotifyCh:
|
||||
resp := nc.GenNatHoleResponse(session, "")
|
||||
log.Trace("send nat hole response to visitor")
|
||||
_, _ = nc.listener.WriteToUDP(resp, raddr)
|
||||
case <-time.After(time.Duration(NatHoleTimeout) * time.Second):
|
||||
return
|
||||
if transactionID == "" {
|
||||
transactionID = NewTransactionID()
|
||||
}
|
||||
}
|
||||
|
||||
func (nc *Controller) HandleClient(m *msg.NatHoleClient, raddr *net.UDPAddr) {
|
||||
nc.mu.RLock()
|
||||
session, ok := nc.sessions[m.Sid]
|
||||
nc.mu.RUnlock()
|
||||
if !ok {
|
||||
return
|
||||
m := &msg.NatHoleSid{
|
||||
TransactionID: transactionID,
|
||||
Sid: sid,
|
||||
Response: false,
|
||||
Nonce: strings.Repeat("0", rand.Intn(20)),
|
||||
}
|
||||
log.Trace("handle client message, sid [%s]", session.Sid)
|
||||
session.ClientAddr = raddr
|
||||
|
||||
resp := nc.GenNatHoleResponse(session, "")
|
||||
log.Trace("send nat hole response to client")
|
||||
_, _ = nc.listener.WriteToUDP(resp, raddr)
|
||||
}
|
||||
|
||||
func (nc *Controller) GenNatHoleResponse(session *Session, errInfo string) []byte {
|
||||
var (
|
||||
sid string
|
||||
visitorAddr string
|
||||
clientAddr string
|
||||
)
|
||||
if session != nil {
|
||||
sid = session.Sid
|
||||
visitorAddr = session.VisitorAddr.String()
|
||||
clientAddr = session.ClientAddr.String()
|
||||
}
|
||||
m := &msg.NatHoleResp{
|
||||
Sid: sid,
|
||||
VisitorAddr: visitorAddr,
|
||||
ClientAddr: clientAddr,
|
||||
Error: errInfo,
|
||||
}
|
||||
b := bytes.NewBuffer(nil)
|
||||
err := msg.WriteMsg(b, m)
|
||||
buf, err := EncodeMessage(m, key)
|
||||
if err != nil {
|
||||
return []byte("")
|
||||
return err
|
||||
}
|
||||
return b.Bytes()
|
||||
if ttl > 0 {
|
||||
uConn := ipv4.NewConn(conn)
|
||||
original, err := uConn.TTL()
|
||||
if err != nil {
|
||||
xl.Trace("get ttl error %v", err)
|
||||
return err
|
||||
}
|
||||
xl.Trace("original ttl %d", original)
|
||||
|
||||
err = uConn.SetTTL(ttl)
|
||||
if err != nil {
|
||||
xl.Trace("set ttl error %v", err)
|
||||
} else {
|
||||
defer func() {
|
||||
_ = uConn.SetTTL(original)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := conn.WriteToUDP(buf, raddr); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
Sid string
|
||||
VisitorAddr *net.UDPAddr
|
||||
ClientAddr *net.UDPAddr
|
||||
|
||||
NotifyCh chan struct{}
|
||||
func sendSidMessageToRangePorts(
|
||||
ctx context.Context, conn *net.UDPConn, addrs []string, ports []msg.PortsRange,
|
||||
sendFunc func(*net.UDPConn, string) error,
|
||||
) {
|
||||
xl := xlog.FromContextSafe(ctx)
|
||||
for _, ip := range lo.Uniq(parseIPs(addrs)) {
|
||||
for _, portsRange := range ports {
|
||||
for i := portsRange.From; i <= portsRange.To; i++ {
|
||||
detectAddr := net.JoinHostPort(ip, strconv.Itoa(i))
|
||||
if err := sendFunc(conn, detectAddr); err != nil {
|
||||
xl.Trace("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err)
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ClientCfg struct {
|
||||
Name string
|
||||
Sk string
|
||||
SidCh chan *SidRequest
|
||||
func sendSidMessageToRandomPorts(
|
||||
ctx context.Context, conn *net.UDPConn, addrs []string, count int,
|
||||
sendFunc func(*net.UDPConn, string) error,
|
||||
) {
|
||||
xl := xlog.FromContextSafe(ctx)
|
||||
used := sets.New[int]()
|
||||
getUnusedPort := func() int {
|
||||
for i := 0; i < 10; i++ {
|
||||
port := rand.Intn(65535-1024) + 1024
|
||||
if !used.Has(port) {
|
||||
used.Insert(port)
|
||||
return port
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
port := getUnusedPort()
|
||||
if port == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ip := range lo.Uniq(parseIPs(addrs)) {
|
||||
detectAddr := net.JoinHostPort(ip, strconv.Itoa(port))
|
||||
if err := sendFunc(conn, detectAddr); err != nil {
|
||||
xl.Trace("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err)
|
||||
}
|
||||
time.Sleep(time.Millisecond * 15)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseIPs(addrs []string) []string {
|
||||
var ips []string
|
||||
for _, addr := range addrs {
|
||||
if ip, _, err := net.SplitHostPort(addr); err == nil {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
@@ -16,6 +16,7 @@ package nathole
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
@@ -63,3 +64,49 @@ func (s *ChangedAddress) GetFrom(m *stun.Message) error {
|
||||
func (s *ChangedAddress) String() string {
|
||||
return net.JoinHostPort(s.IP.String(), strconv.Itoa(s.Port))
|
||||
}
|
||||
|
||||
func ListAllLocalIPs() ([]net.IP, error) {
|
||||
addrs, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ips := make([]net.IP, 0, len(addrs))
|
||||
for _, addr := range addrs {
|
||||
ip, _, err := net.ParseCIDR(addr.String())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func ListLocalIPsForNatHole(max int) ([]string, error) {
|
||||
if max <= 0 {
|
||||
return nil, fmt.Errorf("max must be greater than 0")
|
||||
}
|
||||
|
||||
ips, err := ListAllLocalIPs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filtered := make([]string, 0, max)
|
||||
for _, ip := range ips {
|
||||
if len(filtered) >= max {
|
||||
break
|
||||
}
|
||||
|
||||
// ignore ipv6 address
|
||||
if ip.To4() == nil {
|
||||
continue
|
||||
}
|
||||
// ignore localhost IP
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
continue
|
||||
}
|
||||
|
||||
filtered = append(filtered, ip.String())
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
119
pkg/transport/message.go
Normal file
119
pkg/transport/message.go
Normal file
@@ -0,0 +1,119 @@
|
||||
// Copyright 2023 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
)
|
||||
|
||||
type MessageTransporter interface {
|
||||
Send(msg.Message) error
|
||||
// Recv(ctx context.Context, laneKey string, msgType string) (Message, error)
|
||||
// Do will first send msg, then recv msg with the same laneKey and specified msgType.
|
||||
Do(ctx context.Context, req msg.Message, laneKey, recvMsgType string) (msg.Message, error)
|
||||
Dispatch(m msg.Message, laneKey string) bool
|
||||
DispatchWithType(m msg.Message, msgType, laneKey string) bool
|
||||
}
|
||||
|
||||
func NewMessageTransporter(sendCh chan msg.Message) MessageTransporter {
|
||||
return &transporterImpl{
|
||||
sendCh: sendCh,
|
||||
registry: make(map[string]map[string]chan msg.Message),
|
||||
}
|
||||
}
|
||||
|
||||
type transporterImpl struct {
|
||||
sendCh chan msg.Message
|
||||
|
||||
// First key is message type and second key is lane key.
|
||||
// Dispatch will dispatch message to releated channel by its message type
|
||||
// and lane key.
|
||||
registry map[string]map[string]chan msg.Message
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (impl *transporterImpl) Send(m msg.Message) error {
|
||||
return errors.PanicToError(func() {
|
||||
impl.sendCh <- m
|
||||
})
|
||||
}
|
||||
|
||||
func (impl *transporterImpl) Do(ctx context.Context, req msg.Message, laneKey, recvMsgType string) (msg.Message, error) {
|
||||
ch := make(chan msg.Message, 1)
|
||||
defer close(ch)
|
||||
unregisterFn := impl.registerMsgChan(ch, laneKey, recvMsgType)
|
||||
defer unregisterFn()
|
||||
|
||||
if err := impl.Send(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case resp := <-ch:
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *transporterImpl) DispatchWithType(m msg.Message, msgType, laneKey string) bool {
|
||||
var ch chan msg.Message
|
||||
impl.mu.RLock()
|
||||
byLaneKey, ok := impl.registry[msgType]
|
||||
if ok {
|
||||
ch = byLaneKey[laneKey]
|
||||
}
|
||||
impl.mu.RUnlock()
|
||||
|
||||
if ch == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := errors.PanicToError(func() {
|
||||
ch <- m
|
||||
}); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (impl *transporterImpl) Dispatch(m msg.Message, laneKey string) bool {
|
||||
msgType := reflect.TypeOf(m).Elem().Name()
|
||||
return impl.DispatchWithType(m, msgType, laneKey)
|
||||
}
|
||||
|
||||
func (impl *transporterImpl) registerMsgChan(recvCh chan msg.Message, laneKey string, msgType string) (unregister func()) {
|
||||
impl.mu.Lock()
|
||||
byLaneKey, ok := impl.registry[msgType]
|
||||
if !ok {
|
||||
byLaneKey = make(map[string]chan msg.Message)
|
||||
impl.registry[msgType] = byLaneKey
|
||||
}
|
||||
byLaneKey[laneKey] = recvCh
|
||||
impl.mu.Unlock()
|
||||
|
||||
unregister = func() {
|
||||
impl.mu.Lock()
|
||||
delete(byLaneKey, laneKey)
|
||||
impl.mu.Unlock()
|
||||
}
|
||||
return
|
||||
}
|
@@ -1,3 +1,17 @@
|
||||
// Copyright 2023 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
|
@@ -256,3 +256,11 @@ func (l *UDPListener) Close() error {
|
||||
func (l *UDPListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
// ConnectedUDPConn is a wrapper for net.UDPConn which converts WriteTo syscalls
|
||||
// to Write syscalls that are 4 times faster on some OS'es. This should only be
|
||||
// used for connections that were produced by a net.Dial* call.
|
||||
type ConnectedUDPConn struct{ *net.UDPConn }
|
||||
|
||||
// WriteTo redirects all writes to the Write syscall, which is 4 times faster.
|
||||
func (c *ConnectedUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { return c.Write(b) }
|
||||
|
@@ -1,25 +0,0 @@
|
||||
package util
|
||||
|
||||
func InSlice[T comparable](v T, s []T) bool {
|
||||
for _, vv := range s {
|
||||
if v == vv {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func InSliceAny[T any](v T, s []T, equalFn func(a, b T) bool) bool {
|
||||
for _, vv := range s {
|
||||
if equalFn(v, vv) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func InSliceAnyFunc[T any](equalFn func(a, b T) bool) func(v T, s []T) bool {
|
||||
return func(v T, s []T) bool {
|
||||
return InSliceAny(v, s, equalFn)
|
||||
}
|
||||
}
|
@@ -1,49 +0,0 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInSlice(t *testing.T) {
|
||||
require := require.New(t)
|
||||
require.True(InSlice(1, []int{1, 2, 3}))
|
||||
require.False(InSlice(0, []int{1, 2, 3}))
|
||||
require.True(InSlice("foo", []string{"foo", "bar"}))
|
||||
require.False(InSlice("not exist", []string{"foo", "bar"}))
|
||||
}
|
||||
|
||||
type testStructA struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
func TestInSliceAny(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
a := testStructA{Name: "foo", Age: 20}
|
||||
b := testStructA{Name: "foo", Age: 30}
|
||||
c := testStructA{Name: "bar", Age: 20}
|
||||
|
||||
equalFn := func(o, p testStructA) bool {
|
||||
return o.Name == p.Name
|
||||
}
|
||||
require.True(InSliceAny(a, []testStructA{b, c}, equalFn))
|
||||
require.False(InSliceAny(c, []testStructA{a, b}, equalFn))
|
||||
}
|
||||
|
||||
func TestInSliceAnyFunc(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
a := testStructA{Name: "foo", Age: 20}
|
||||
b := testStructA{Name: "foo", Age: 30}
|
||||
c := testStructA{Name: "bar", Age: 20}
|
||||
|
||||
equalFn := func(o, p testStructA) bool {
|
||||
return o.Name == p.Name
|
||||
}
|
||||
testStructAInSlice := InSliceAnyFunc(equalFn)
|
||||
require.True(testStructAInSlice(a, []testStructA{b, c}))
|
||||
require.False(testStructAInSlice(c, []testStructA{a, b}))
|
||||
}
|
@@ -28,19 +28,32 @@ import (
|
||||
|
||||
// RandID return a rand string used in frp.
|
||||
func RandID() (id string, err error) {
|
||||
return RandIDWithLen(8)
|
||||
return RandIDWithLen(16)
|
||||
}
|
||||
|
||||
// RandIDWithLen return a rand string with idLen length.
|
||||
func RandIDWithLen(idLen int) (id string, err error) {
|
||||
b := make([]byte, idLen)
|
||||
if idLen <= 0 {
|
||||
return "", nil
|
||||
}
|
||||
b := make([]byte, idLen/2+1)
|
||||
_, err = rand.Read(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
id = fmt.Sprintf("%x", b)
|
||||
return
|
||||
return id[:idLen], nil
|
||||
}
|
||||
|
||||
// RandIDWithRandLen return a rand string with length between [start, end).
|
||||
func RandIDWithRandLen(start, end int) (id string, err error) {
|
||||
if start >= end {
|
||||
err = fmt.Errorf("start should be less than end")
|
||||
return
|
||||
}
|
||||
idLen := mathrand.Intn(end-start) + start
|
||||
return RandIDWithLen(idLen)
|
||||
}
|
||||
|
||||
func GetAuthKey(token string, timestamp int64) (key string) {
|
||||
|
@@ -14,10 +14,51 @@ func TestRandId(t *testing.T) {
|
||||
assert.Equal(16, len(id))
|
||||
}
|
||||
|
||||
func TestRandIDWithRandLen(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
start int
|
||||
end int
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "start and end are equal",
|
||||
start: 5,
|
||||
end: 5,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "start is less than end",
|
||||
start: 5,
|
||||
end: 10,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "start is greater than end",
|
||||
start: 10,
|
||||
end: 5,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
id, err := RandIDWithRandLen(tt.start, tt.end)
|
||||
if tt.expectErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.GreaterOrEqual(len(id), tt.start)
|
||||
assert.Less(len(id), tt.end)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAuthKey(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
key := GetAuthKey("1234", 1488720000)
|
||||
t.Log(key)
|
||||
assert.Equal("6df41a43725f0c770fd56379e12acf8c", key)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user