refactor the code related to xtcp (#3449)

This commit is contained in:
fatedier
2023-05-28 16:50:43 +08:00
committed by GitHub
parent 9f029e3248
commit c71efde303
44 changed files with 3305 additions and 1699 deletions

328
pkg/nathole/analysis.go Normal file
View File

@@ -0,0 +1,328 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package nathole
import (
"sync"
"time"
"github.com/samber/lo"
)
var (
// mode 0, both EasyNAT, PublicNetwork is always receiver
// sender | receiver, ttl 7
// receiver, ttl 7 | sender
// sender | receiver, ttl 4
// receiver, ttl 4 | sender
// sender | receiver
// receiver | sender
// sender, sendDelayMs 5000 | receiver
// sender, sendDelayMs 10000 | receiver
// receiver | sender, sendDelayMs 5000
// receiver | sender, sendDelayMs 10000
mode0Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 7}, RecommandBehavior{Role: DetectRoleSender}),
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 4}, RecommandBehavior{Role: DetectRoleSender}),
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver}, RecommandBehavior{Role: DetectRoleSender}),
lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 5000}, RecommandBehavior{Role: DetectRoleReceiver}),
lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 10000}, RecommandBehavior{Role: DetectRoleReceiver}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver}, RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 5000}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver}, RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 10000}),
}
// mode 1, HardNAT is sender, EasyNAT is receiver, port changes is regular
// sender | receiver, ttl 7, portsRangeNumber max 10
// sender, sendDelayMs 2000 | receiver, ttl 7, portsRangeNumber max 10
// sender | receiver, ttl 4, portsRangeNumber max 10
// sender, sendDelayMs 2000 | receiver, ttl 4, portsRangeNumber max 10
// sender | receiver, portsRangeNumber max 10
// sender, sendDelayMs 2000 | receiver, portsRangeNumber max 10
mode1Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 2000}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 2000}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 2000}, RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}),
}
// mode 2, HardNAT is receiver, EasyNAT is sender
// sender, portsRandomNumber 1000, sendDelayMs 2000 | receiver, listen 256 ports, ttl 7
// sender, portsRandomNumber 1000, sendDelayMs 2000 | receiver, listen 256 ports, ttl 4
// sender, portsRandomNumber 1000, sendDelayMs 2000 | receiver, listen 256 ports
mode2Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{
lo.T2(
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000},
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 7},
),
lo.T2(
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000},
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 4},
),
lo.T2(
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000},
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256},
),
}
// mode 3, For HardNAT & HardNAT, both changes in the ports are regular
// sender, portsRangeNumber 10 | receiver, ttl 7, portsRangeNumber 10
// sender, portsRangeNumber 10 | receiver, ttl 4, portsRangeNumber 10
// sender, portsRangeNumber 10 | receiver, portsRangeNumber 10
// receiver, ttl 7, portsRangeNumber 10 | sender, portsRangeNumber 10
// receiver, ttl 4, portsRangeNumber 10 | sender, portsRangeNumber 10
// receiver, portsRangeNumber 10 | sender, portsRangeNumber 10
mode3Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{
lo.T2(RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}),
}
// mode 4, Regular ports changes are usually the sender.
// sender, portsRandomNumber 1000, sendDelayMs: 2000 | receiver, listen 256 ports, ttl 7, portsRangeNumber 10
// sender, portsRandomNumber 1000, sendDelayMs: 2000 | receiver, listen 256 ports, ttl 4, portsRangeNumber 10
// sender, portsRandomNumber 1000, SendDelayMs: 2000 | receiver, listen 256 ports, portsRangeNumber 10
mode4Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{
lo.T2(
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000},
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 7, PortsRangeNumber: 10},
),
lo.T2(
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000},
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 4, PortsRangeNumber: 10},
),
lo.T2(
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 2000},
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, PortsRangeNumber: 10},
),
}
)
func getBehaviorByMode(mode int) []lo.Tuple2[RecommandBehavior, RecommandBehavior] {
switch mode {
case 0:
return mode0Behaviors
case 1:
return mode1Behaviors
case 2:
return mode2Behaviors
case 3:
return mode3Behaviors
case 4:
return mode4Behaviors
}
// default
return mode0Behaviors
}
func getBehaviorByModeAndIndex(mode int, index int) (RecommandBehavior, RecommandBehavior) {
behaviors := getBehaviorByMode(mode)
if index >= len(behaviors) {
return RecommandBehavior{}, RecommandBehavior{}
}
return behaviors[index].A, behaviors[index].B
}
func getBehaviorScoresByMode(mode int, defaultScore int) []*BehaviorScore {
return getBehaviorScoresByMode2(mode, defaultScore, defaultScore)
}
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*BehaviorScore {
behaviors := getBehaviorByMode(mode)
scores := make([]*BehaviorScore, 0, len(behaviors))
for i := 0; i < len(behaviors); i++ {
score := receiverScore
if behaviors[i].A.Role == DetectRoleSender {
score = senderScore
}
scores = append(scores, &BehaviorScore{Mode: mode, Index: i, Score: score})
}
return scores
}
type RecommandBehavior struct {
Role string
TTL int
SendDelayMs int
PortsRangeNumber int
PortsRandomNumber int
ListenRandomPorts int
}
type MakeHoleRecords struct {
mu sync.Mutex
scores []*BehaviorScore
LastUpdateTime time.Time
}
func NewMakeHoleRecords(c, v *NatFeature) *MakeHoleRecords {
scores := []*BehaviorScore{}
easyCount, hardCount, portsChangedRegularCount := ClassifyFeatureCount([]*NatFeature{c, v})
appendMode0 := func() {
switch {
case c.PublicNetwork:
scores = append(scores, getBehaviorScoresByMode2(DetectMode0, 0, 1)...)
case v.PublicNetwork:
scores = append(scores, getBehaviorScoresByMode2(DetectMode0, 1, 0)...)
default:
scores = append(scores, getBehaviorScoresByMode(DetectMode0, 0)...)
}
}
switch {
case easyCount == 2:
appendMode0()
case hardCount == 1 && portsChangedRegularCount == 1:
scores = append(scores, getBehaviorScoresByMode(DetectMode1, 0)...)
scores = append(scores, getBehaviorScoresByMode(DetectMode2, 0)...)
appendMode0()
case hardCount == 1 && portsChangedRegularCount == 0:
scores = append(scores, getBehaviorScoresByMode(DetectMode2, 0)...)
scores = append(scores, getBehaviorScoresByMode(DetectMode1, 0)...)
appendMode0()
case hardCount == 2 && portsChangedRegularCount == 2:
scores = append(scores, getBehaviorScoresByMode(DetectMode3, 0)...)
scores = append(scores, getBehaviorScoresByMode(DetectMode4, 0)...)
case hardCount == 2 && portsChangedRegularCount == 1:
scores = append(scores, getBehaviorScoresByMode(DetectMode4, 0)...)
default:
// hard to make hole, just trying it out.
scores = append(scores, getBehaviorScoresByMode(DetectMode0, 1)...)
scores = append(scores, getBehaviorScoresByMode(DetectMode1, 1)...)
scores = append(scores, getBehaviorScoresByMode(DetectMode3, 1)...)
}
return &MakeHoleRecords{scores: scores, LastUpdateTime: time.Now()}
}
func (mhr *MakeHoleRecords) ReportSuccess(mode int, index int) {
mhr.mu.Lock()
defer mhr.mu.Unlock()
mhr.LastUpdateTime = time.Now()
for i := range mhr.scores {
score := mhr.scores[i]
if score.Mode != mode || score.Index != index {
continue
}
score.Score += 2
score.Score = lo.Min([]int{score.Score, 10})
return
}
}
func (mhr *MakeHoleRecords) Recommand() (mode, index int) {
mhr.mu.Lock()
defer mhr.mu.Unlock()
maxScore := lo.MaxBy(mhr.scores, func(item, max *BehaviorScore) bool {
return item.Score > max.Score
})
if maxScore == nil {
return 0, 0
}
maxScore.Score--
mhr.LastUpdateTime = time.Now()
return maxScore.Mode, maxScore.Index
}
type BehaviorScore struct {
Mode int
Index int
// between -10 and 10
Score int
}
type Analyzer struct {
// key is client ip + visitor ip
records map[string]*MakeHoleRecords
dataReserveDuration time.Duration
mu sync.Mutex
}
func NewAnalyzer(dataReserveDuration time.Duration) *Analyzer {
return &Analyzer{
records: make(map[string]*MakeHoleRecords),
dataReserveDuration: dataReserveDuration,
}
}
func (a *Analyzer) GetRecommandBehaviors(key string, c, v *NatFeature) (mode, index int, _ RecommandBehavior, _ RecommandBehavior) {
a.mu.Lock()
records, ok := a.records[key]
if !ok {
records = NewMakeHoleRecords(c, v)
a.records[key] = records
}
a.mu.Unlock()
mode, index = records.Recommand()
cBehavior, vBehavior := getBehaviorByModeAndIndex(mode, index)
switch mode {
case DetectMode1:
// HardNAT is always the sender
if c.NatType == EasyNAT {
cBehavior, vBehavior = vBehavior, cBehavior
}
case DetectMode2:
// HardNAT is always the receiver
if c.NatType == HardNAT {
cBehavior, vBehavior = vBehavior, cBehavior
}
case DetectMode4:
// Regular ports changes is always the sender
if !c.RegularPortsChange {
cBehavior, vBehavior = vBehavior, cBehavior
}
}
return mode, index, cBehavior, vBehavior
}
func (a *Analyzer) ReportSuccess(key string, mode, index int) {
a.mu.Lock()
records, ok := a.records[key]
a.mu.Unlock()
if !ok {
return
}
records.ReportSuccess(mode, index)
}
func (a *Analyzer) Clean() (int, int) {
now := time.Now()
total := 0
count := 0
// cleanup 10w records may take 5ms
a.mu.Lock()
defer a.mu.Unlock()
total = len(a.records)
// clean up records that have not been used for a period of time.
for key, records := range a.records {
if now.Sub(records.LastUpdateTime) > a.dataReserveDuration {
delete(a.records, key)
count++
}
}
return count, total
}

View File

@@ -17,6 +17,9 @@ package nathole
import (
"fmt"
"net"
"strconv"
"github.com/samber/lo"
)
const (
@@ -29,46 +32,96 @@ const (
BehaviorBothChanged = "BehaviorBothChanged"
)
// ClassifyNATType classify NAT type by given addresses.
func ClassifyNATType(addresses []string) (string, string, error) {
type NatFeature struct {
NatType string
Behavior string
PortsDifference int
RegularPortsChange bool
PublicNetwork bool
}
func ClassifyNATFeature(addresses []string, localIPs []string) (*NatFeature, error) {
if len(addresses) <= 1 {
return "", "", fmt.Errorf("not enough addresses")
return nil, fmt.Errorf("not enough addresses")
}
natFeature := &NatFeature{}
ipChanged := false
portChanged := false
var baseIP, basePort string
var portMax, portMin int
for _, addr := range addresses {
ip, port, err := net.SplitHostPort(addr)
if err != nil {
return "", "", err
return nil, err
}
portNum, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
if lo.Contains(localIPs, ip) {
natFeature.PublicNetwork = true
}
if baseIP == "" {
baseIP = ip
basePort = port
portMax = portNum
portMin = portNum
continue
}
if portNum > portMax {
portMax = portNum
}
if portNum < portMin {
portMin = portNum
}
if baseIP != ip {
ipChanged = true
}
if basePort != port {
portChanged = true
}
}
if ipChanged && portChanged {
break
}
natFeature.PortsDifference = portMax - portMin
if natFeature.PortsDifference <= 10 && natFeature.PortsDifference >= 1 {
natFeature.RegularPortsChange = true
}
switch {
case ipChanged && portChanged:
return HardNAT, BehaviorBothChanged, nil
natFeature.NatType = HardNAT
natFeature.Behavior = BehaviorBothChanged
case ipChanged:
return HardNAT, BehaviorIPChanged, nil
natFeature.NatType = HardNAT
natFeature.Behavior = BehaviorIPChanged
case portChanged:
return HardNAT, BehaviorPortChanged, nil
natFeature.NatType = HardNAT
natFeature.Behavior = BehaviorPortChanged
default:
return EasyNAT, BehaviorNoChange, nil
natFeature.NatType = EasyNAT
natFeature.Behavior = BehaviorNoChange
}
return natFeature, nil
}
func ClassifyFeatureCount(features []*NatFeature) (int, int, int) {
easyCount := 0
hardCount := 0
// for HardNAT
portsChangedRegularCount := 0
for _, feature := range features {
if feature.NatType == EasyNAT {
easyCount++
continue
}
hardCount++
if feature.RegularPortsChange {
portsChangedRegularCount++
}
}
return easyCount, hardCount, portsChangedRegularCount
}

382
pkg/nathole/controller.go Normal file
View File

@@ -0,0 +1,382 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package nathole
import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"net"
"strconv"
"sync"
"time"
"github.com/fatedier/golib/errors"
"github.com/samber/lo"
"golang.org/x/sync/errgroup"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/log"
"github.com/fatedier/frp/pkg/util/util"
)
// NatHoleTimeout seconds.
var NatHoleTimeout int64 = 10
func NewTransactionID() string {
id, _ := util.RandID()
return fmt.Sprintf("%d%s", time.Now().Unix(), id)
}
type ClientCfg struct {
name string
sk string
sidCh chan string
}
type Session struct {
sid string
analysisKey string
recommandMode int
recommandIndex int
visitorMsg *msg.NatHoleVisitor
visitorTransporter transport.MessageTransporter
vResp *msg.NatHoleResp
vNatFeature *NatFeature
vBehavior RecommandBehavior
clientMsg *msg.NatHoleClient
clientTransporter transport.MessageTransporter
cResp *msg.NatHoleResp
cNatFeature *NatFeature
cBehavior RecommandBehavior
notifyCh chan struct{}
}
func (s *Session) genAnalysisKey() {
hash := md5.New()
vIPs := lo.Uniq(parseIPs(s.visitorMsg.MappedAddrs))
if len(vIPs) > 0 {
hash.Write([]byte(vIPs[0]))
}
hash.Write([]byte(s.vNatFeature.NatType))
hash.Write([]byte(s.vNatFeature.Behavior))
hash.Write([]byte(strconv.FormatBool(s.vNatFeature.RegularPortsChange)))
cIPs := lo.Uniq(parseIPs(s.clientMsg.MappedAddrs))
if len(cIPs) > 0 {
hash.Write([]byte(cIPs[0]))
}
hash.Write([]byte(s.cNatFeature.NatType))
hash.Write([]byte(s.cNatFeature.Behavior))
hash.Write([]byte(strconv.FormatBool(s.cNatFeature.RegularPortsChange)))
s.analysisKey = hex.EncodeToString(hash.Sum(nil))
}
type Controller struct {
clientCfgs map[string]*ClientCfg
sessions map[string]*Session
analyzer *Analyzer
mu sync.RWMutex
}
func NewController(analysisDataReserveDuration time.Duration) (*Controller, error) {
return &Controller{
clientCfgs: make(map[string]*ClientCfg),
sessions: make(map[string]*Session),
analyzer: NewAnalyzer(analysisDataReserveDuration),
}, nil
}
func (c *Controller) CleanWorker(ctx context.Context) {
ticker := time.NewTicker(time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C:
start := time.Now()
count, total := c.analyzer.Clean()
log.Trace("clean %d/%d nathole analysis data, cost %v", count, total, time.Since(start))
case <-ctx.Done():
return
}
}
}
func (c *Controller) ListenClient(name string, sk string) chan string {
cfg := &ClientCfg{
name: name,
sk: sk,
sidCh: make(chan string),
}
c.mu.Lock()
defer c.mu.Unlock()
c.clientCfgs[name] = cfg
return cfg.sidCh
}
func (c *Controller) CloseClient(name string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.clientCfgs, name)
}
func (c *Controller) GenSid() string {
t := time.Now().Unix()
id, _ := util.RandID()
return fmt.Sprintf("%d%s", t, id)
}
func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.MessageTransporter) {
if m.PreCheck {
_, ok := c.clientCfgs[m.ProxyName]
if !ok {
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName)))
} else {
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, ""))
}
return
}
sid := c.GenSid()
session := &Session{
sid: sid,
visitorMsg: m,
visitorTransporter: transporter,
notifyCh: make(chan struct{}, 1),
}
var (
clientCfg *ClientCfg
ok bool
)
err := func() error {
c.mu.Lock()
defer c.mu.Unlock()
clientCfg, ok = c.clientCfgs[m.ProxyName]
if !ok {
return fmt.Errorf("xtcp server for [%s] doesn't exist", m.ProxyName)
}
if m.SignKey != util.GetAuthKey(clientCfg.sk, m.Timestamp) {
return fmt.Errorf("xtcp connection of [%s] auth failed", m.ProxyName)
}
c.sessions[sid] = session
return nil
}()
if err != nil {
log.Warn("handle visitorMsg error: %v", err)
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, err.Error()))
return
}
log.Trace("handle visitor message, sid [%s]", sid)
defer func() {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.sessions, sid)
}()
if err := errors.PanicToError(func() {
clientCfg.sidCh <- sid
}); err != nil {
return
}
// wait for NatHoleClient message
select {
case <-session.notifyCh:
case <-time.After(time.Duration(NatHoleTimeout) * time.Second):
log.Debug("wait for NatHoleClient message timeout, sid [%s]", sid)
return
}
// Make hole-punching decisions based on the NAT information of the client and visitor.
vResp, cResp, err := c.analysis(session)
if err != nil {
log.Debug("sid [%s] analysis error: %v", err)
vResp = c.GenNatHoleResponse(session.visitorMsg.TransactionID, nil, err.Error())
cResp = c.GenNatHoleResponse(session.clientMsg.TransactionID, nil, err.Error())
}
session.cResp = cResp
session.vResp = vResp
// send response to visitor and client
var g errgroup.Group
g.Go(func() error {
// if it's sender, wait for a while to make sure the client has send the detect messages
if vResp.DetectBehavior.Role == "sender" {
time.Sleep(1 * time.Second)
}
_ = session.visitorTransporter.Send(vResp)
return nil
})
g.Go(func() error {
// if it's sender, wait for a while to make sure the client has send the detect messages
if cResp.DetectBehavior.Role == "sender" {
time.Sleep(1 * time.Second)
}
_ = session.clientTransporter.Send(cResp)
return nil
})
_ = g.Wait()
time.Sleep(time.Duration(cResp.DetectBehavior.ReadTimeoutMs+30000) * time.Millisecond)
}
func (c *Controller) HandleClient(m *msg.NatHoleClient, transporter transport.MessageTransporter) {
c.mu.RLock()
session, ok := c.sessions[m.Sid]
c.mu.RUnlock()
if !ok {
return
}
log.Trace("handle client message, sid [%s]", session.sid)
session.clientMsg = m
session.clientTransporter = transporter
select {
case session.notifyCh <- struct{}{}:
default:
}
}
func (c *Controller) HandleReport(m *msg.NatHoleReport) {
c.mu.RLock()
session, ok := c.sessions[m.Sid]
c.mu.RUnlock()
if !ok {
log.Trace("sid [%s] report make hole success: %v, but session not found", m.Sid, m.Success)
return
}
if m.Success {
c.analyzer.ReportSuccess(session.analysisKey, session.recommandMode, session.recommandIndex)
}
log.Info("sid [%s] report make hole success: %v, mode %v, index %v",
m.Sid, m.Success, session.recommandMode, session.recommandIndex)
}
func (c *Controller) GenNatHoleResponse(transactionID string, session *Session, errInfo string) *msg.NatHoleResp {
var sid string
if session != nil {
sid = session.sid
}
return &msg.NatHoleResp{
TransactionID: transactionID,
Sid: sid,
Error: errInfo,
}
}
// analysis analyzes the NAT type and behavior of the visitor and client, then makes hole-punching decisions.
// return the response to the visitor and client.
func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleResp, error) {
cm := session.clientMsg
vm := session.visitorMsg
cNatFeature, err := ClassifyNATFeature(cm.MappedAddrs, parseIPs(cm.AssistedAddrs))
if err != nil {
return nil, nil, fmt.Errorf("classify client nat feature error: %v", err)
}
vNatFeature, err := ClassifyNATFeature(vm.MappedAddrs, parseIPs(vm.AssistedAddrs))
if err != nil {
return nil, nil, fmt.Errorf("classify visitor nat feature error: %v", err)
}
session.cNatFeature = cNatFeature
session.vNatFeature = vNatFeature
session.genAnalysisKey()
mode, index, cBehavior, vBehavior := c.analyzer.GetRecommandBehaviors(session.analysisKey, cNatFeature, vNatFeature)
session.recommandMode = mode
session.recommandIndex = index
session.cBehavior = cBehavior
session.vBehavior = vBehavior
timeoutMs := lo.Max([]int{cBehavior.SendDelayMs, vBehavior.SendDelayMs}) + 5000
if cBehavior.ListenRandomPorts > 0 || vBehavior.ListenRandomPorts > 0 {
timeoutMs += 30000
}
protocol := vm.Protocol
vResp := &msg.NatHoleResp{
TransactionID: vm.TransactionID,
Sid: session.sid,
Protocol: protocol,
CandidateAddrs: lo.Uniq(cm.MappedAddrs),
AssistedAddrs: lo.Uniq(cm.AssistedAddrs),
DetectBehavior: msg.NatHoleDetectBehavior{
Mode: mode,
Role: vBehavior.Role,
TTL: vBehavior.TTL,
SendDelayMs: vBehavior.SendDelayMs,
ReadTimeoutMs: timeoutMs - vBehavior.SendDelayMs,
SendRandomPorts: vBehavior.PortsRandomNumber,
ListenRandomPorts: vBehavior.ListenRandomPorts,
CandidatePorts: getRangePorts(cm.MappedAddrs, cNatFeature.PortsDifference, vBehavior.PortsRangeNumber),
},
}
cResp := &msg.NatHoleResp{
TransactionID: cm.TransactionID,
Sid: session.sid,
Protocol: protocol,
CandidateAddrs: lo.Uniq(vm.MappedAddrs),
AssistedAddrs: lo.Uniq(vm.AssistedAddrs),
DetectBehavior: msg.NatHoleDetectBehavior{
Mode: mode,
Role: cBehavior.Role,
TTL: cBehavior.TTL,
SendDelayMs: cBehavior.SendDelayMs,
ReadTimeoutMs: timeoutMs - cBehavior.SendDelayMs,
SendRandomPorts: cBehavior.PortsRandomNumber,
ListenRandomPorts: cBehavior.ListenRandomPorts,
CandidatePorts: getRangePorts(vm.MappedAddrs, vNatFeature.PortsDifference, cBehavior.PortsRangeNumber),
},
}
log.Debug("sid [%s] visitor nat: %+v, candidateAddrs: %v; client nat: %+v, candidateAddrs: %v, protocol: %s",
session.sid, *vNatFeature, vm.MappedAddrs, *cNatFeature, cm.MappedAddrs, protocol)
log.Debug("sid [%s] visitor detect behavior: %+v", session.sid, vResp.DetectBehavior)
log.Debug("sid [%s] client detect behavior: %+v", session.sid, cResp.DetectBehavior)
return vResp, cResp, nil
}
func getRangePorts(addrs []string, difference, maxNumber int) []msg.PortsRange {
if maxNumber <= 0 {
return nil
}
addr, err := lo.Last(addrs)
if err != nil {
return nil
}
var ports []msg.PortsRange
_, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil
}
port, err := strconv.Atoi(portStr)
if err != nil {
return nil
}
ports = append(ports, msg.PortsRange{
From: lo.Max([]int{port - difference - 5, port - maxNumber, 1}),
To: lo.Min([]int{port + difference + 5, port + maxNumber, 65535}),
})
return ports
}

View File

@@ -20,8 +20,6 @@ import (
"time"
"github.com/pion/stun"
"github.com/fatedier/frp/pkg/msg"
)
var responseTimeout = 3 * time.Second
@@ -31,35 +29,27 @@ type Message struct {
Addr string
}
func Discover(serverAddress string, stunServers []string, key []byte) ([]string, error) {
// If the localAddr is empty, it will listen on a random port.
func Discover(stunServers []string, localAddr string) ([]string, net.Addr, error) {
// create a discoverConn and get response from messageChan
discoverConn, err := listen()
discoverConn, err := listen(localAddr)
if err != nil {
return nil, err
return nil, nil, err
}
defer discoverConn.Close()
go discoverConn.readLoop()
addresses := make([]string, 0, len(stunServers)+1)
if serverAddress != "" {
// get external address from frp server
externalAddr, err := discoverConn.discoverFromServer(serverAddress, key)
if err != nil {
return nil, err
}
addresses = append(addresses, externalAddr)
}
addresses := make([]string, 0, len(stunServers))
for _, addr := range stunServers {
// get external address from stun server
externalAddrs, err := discoverConn.discoverFromStunServer(addr)
if err != nil {
return nil, err
return nil, nil, err
}
addresses = append(addresses, externalAddrs...)
}
return addresses, nil
return addresses, discoverConn.localAddr, nil
}
type stunResponse struct {
@@ -74,8 +64,16 @@ type discoverConn struct {
messageChan chan *Message
}
func listen() (*discoverConn, error) {
conn, err := net.ListenUDP("udp4", nil)
func listen(localAddr string) (*discoverConn, error) {
var local *net.UDPAddr
if localAddr != "" {
addr, err := net.ResolveUDPAddr("udp4", localAddr)
if err != nil {
return nil, err
}
local = addr
}
conn, err := net.ListenUDP("udp4", local)
if err != nil {
return nil, err
}
@@ -159,43 +157,6 @@ func (c *discoverConn) doSTUNRequest(addr string) (*stunResponse, error) {
return resp, nil
}
func (c *discoverConn) discoverFromServer(serverAddress string, key []byte) (string, error) {
addr, err := net.ResolveUDPAddr("udp4", serverAddress)
if err != nil {
return "", err
}
m := &msg.NatHoleBinding{
TransactionID: NewTransactionID(),
}
buf, err := EncodeMessage(m, key)
if err != nil {
return "", err
}
if _, err := c.conn.WriteTo(buf, addr); err != nil {
return "", err
}
var respMsg msg.NatHoleBindingResp
select {
case rawMsg := <-c.messageChan:
if err := DecodeMessageInto(rawMsg.Body, key, &respMsg); err != nil {
return "", err
}
case <-time.After(responseTimeout):
return "", fmt.Errorf("wait response from frp server timeout")
}
if respMsg.TransactionID == "" {
return "", fmt.Errorf("error format: no transaction id found")
}
if respMsg.Error != "" {
return "", fmt.Errorf("get externalAddr from frp server error: %s", respMsg.Error)
}
return respMsg.Address, nil
}
func (c *discoverConn) discoverFromStunServer(addr string) ([]string, error) {
resp, err := c.doSTUNRequest(addr)
if err != nil {

View File

@@ -15,249 +15,426 @@
package nathole
import (
"bytes"
"context"
"fmt"
"math/rand"
"net"
"sync"
"strconv"
"strings"
"time"
"github.com/fatedier/golib/crypto"
"github.com/fatedier/golib/errors"
"github.com/fatedier/golib/pool"
"github.com/samber/lo"
"golang.org/x/net/ipv4"
"k8s.io/apimachinery/pkg/util/sets"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/util/log"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/xlog"
)
// NatHoleTimeout seconds.
var NatHoleTimeout int64 = 10
var (
// mode 0: simple detect mode, usually for both EasyNAT or HardNAT & EasyNAT(Public Network)
// a. receiver sends detect message with low TTL
// b. sender sends normal detect message to receiver
// c. receiver receives detect message and sends back a message to sender
//
// mode 1: For HardNAT & EasyNAT, send detect messages to multiple guessed ports.
// Usually applicable to scenarios where port changes are regular.
// Most of the steps are the same as mode 0, but EasyNAT is fixed as the receiver and will send detect messages
// with low TTL to multiple guessed ports of the sender.
//
// mode 2: For HardNAT & EasyNAT, ports changes are not regular.
// a. HardNAT machine will listen on multiple ports and send detect messages with low TTL to EasyNAT machine
// b. EasyNAT machine will send detect messages to random ports of HardNAT machine.
//
// mode 3: For HardNAT & HardNAT, both changes in the ports are regular.
// Most of the steps are the same as mode 1, but the sender also needs to send detect messages to multiple guessed
// ports of the receiver.
//
// mode 4: For HardNAT & HardNAT, one of the changes in the ports is regular.
// Regular port changes are usually on the sender side.
// a. Receiver listens on multiple ports and sends detect messages with low TTL to the sender's guessed range ports.
// b. Sender sends detect messages to random ports of the receiver.
SupportedModes = []int{DetectMode0, DetectMode1, DetectMode2, DetectMode3, DetectMode4}
SupportedRoles = []string{DetectRoleSender, DetectRoleReceiver}
func NewTransactionID() string {
id, _ := util.RandID()
return fmt.Sprintf("%d%s", time.Now().Unix(), id)
DetectMode0 = 0
DetectMode1 = 1
DetectMode2 = 2
DetectMode3 = 3
DetectMode4 = 4
DetectRoleSender = "sender"
DetectRoleReceiver = "receiver"
)
type PrepareResult struct {
Addrs []string
AssistedAddrs []string
ListenConn *net.UDPConn
NatType string
Behavior string
}
type SidRequest struct {
Sid string
NotifyCh chan struct{}
}
// PreCheck is used to check if the proxy is ready for penetration.
// Call this function before calling Prepare to avoid unnecessary preparation work.
func PreCheck(
ctx context.Context, transporter transport.MessageTransporter,
proxyName string, timeout time.Duration,
) error {
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
type Controller struct {
listener *net.UDPConn
clientCfgs map[string]*ClientCfg
sessions map[string]*Session
encryptionKey []byte
mu sync.RWMutex
}
func NewController(udpBindAddr string, encryptionKey []byte) (nc *Controller, err error) {
addr, err := net.ResolveUDPAddr("udp", udpBindAddr)
var natHoleRespMsg *msg.NatHoleResp
transactionID := NewTransactionID()
m, err := transporter.Do(timeoutCtx, &msg.NatHoleVisitor{
TransactionID: transactionID,
ProxyName: proxyName,
PreCheck: true,
}, transactionID, msg.TypeNameNatHoleResp)
if err != nil {
return nil, err
return fmt.Errorf("get natHoleRespMsg error: %v", err)
}
lconn, err := net.ListenUDP("udp", addr)
mm, ok := m.(*msg.NatHoleResp)
if !ok {
return fmt.Errorf("get natHoleRespMsg error: invalid message type")
}
natHoleRespMsg = mm
if natHoleRespMsg.Error != "" {
return fmt.Errorf("%s", natHoleRespMsg.Error)
}
return nil
}
// Prepare is used to do some preparation work before penetration.
func Prepare(stunServers []string) (*PrepareResult, error) {
// discover for Nat type
addrs, localAddr, err := Discover(stunServers, "")
if err != nil {
return nil, err
return nil, fmt.Errorf("discover error: %v", err)
}
nc = &Controller{
listener: lconn,
clientCfgs: make(map[string]*ClientCfg),
sessions: make(map[string]*Session),
encryptionKey: encryptionKey,
if len(addrs) < 2 {
return nil, fmt.Errorf("discover error: not enough addresses")
}
return nc, nil
localIPs, _ := ListLocalIPsForNatHole(10)
natFeature, err := ClassifyNATFeature(addrs, localIPs)
if err != nil {
return nil, fmt.Errorf("classify nat feature error: %v", err)
}
laddr, err := net.ResolveUDPAddr("udp4", localAddr.String())
if err != nil {
return nil, fmt.Errorf("resolve local udp addr error: %v", err)
}
listenConn, err := net.ListenUDP("udp4", laddr)
if err != nil {
return nil, fmt.Errorf("listen local udp addr error: %v", err)
}
assistedAddrs := make([]string, 0, len(localIPs))
for _, ip := range localIPs {
assistedAddrs = append(assistedAddrs, net.JoinHostPort(ip, strconv.Itoa(laddr.Port)))
}
return &PrepareResult{
Addrs: addrs,
AssistedAddrs: assistedAddrs,
ListenConn: listenConn,
NatType: natFeature.NatType,
Behavior: natFeature.Behavior,
}, nil
}
func (nc *Controller) ListenClient(name string, sk string) (sidCh chan *SidRequest) {
clientCfg := &ClientCfg{
Name: name,
Sk: sk,
SidCh: make(chan *SidRequest),
// ExchangeInfo is used to exchange information between client and visitor.
// 1. Send input message to server by msgTransporter.
// 2. Server will gather information from client and visitor and analyze it. Then send back a NatHoleResp message to them to tell them how to do next.
// 3. Receive NatHoleResp message from server.
func ExchangeInfo(
ctx context.Context, transporter transport.MessageTransporter,
laneKey string, m msg.Message, timeout time.Duration,
) (*msg.NatHoleResp, error) {
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
var natHoleRespMsg *msg.NatHoleResp
m, err := transporter.Do(timeoutCtx, m, laneKey, msg.TypeNameNatHoleResp)
if err != nil {
return nil, fmt.Errorf("get natHoleRespMsg error: %v", err)
}
nc.mu.Lock()
nc.clientCfgs[name] = clientCfg
nc.mu.Unlock()
return clientCfg.SidCh
mm, ok := m.(*msg.NatHoleResp)
if !ok {
return nil, fmt.Errorf("get natHoleRespMsg error: invalid message type")
}
natHoleRespMsg = mm
if natHoleRespMsg.Error != "" {
return nil, fmt.Errorf("natHoleRespMsg get error info: %s", natHoleRespMsg.Error)
}
if len(natHoleRespMsg.CandidateAddrs) == 0 {
return nil, fmt.Errorf("natHoleRespMsg get empty candidate addresses")
}
return natHoleRespMsg, nil
}
func (nc *Controller) CloseClient(name string) {
nc.mu.Lock()
defer nc.mu.Unlock()
delete(nc.clientCfgs, name)
// MakeHole is used to make a NAT hole between client and visitor.
func MakeHole(ctx context.Context, listenConn *net.UDPConn, m *msg.NatHoleResp, key []byte) (*net.UDPConn, *net.UDPAddr, error) {
xl := xlog.FromContextSafe(ctx)
transactionID := NewTransactionID()
sendToRangePortsFunc := func(conn *net.UDPConn, addr string) error {
return sendSidMessage(ctx, conn, m.Sid, transactionID, addr, key, m.DetectBehavior.TTL)
}
listenConns := []*net.UDPConn{listenConn}
var detectAddrs []string
if m.DetectBehavior.Role == DetectRoleSender {
// sender
if m.DetectBehavior.SendDelayMs > 0 {
time.Sleep(time.Duration(m.DetectBehavior.SendDelayMs) * time.Millisecond)
}
detectAddrs = m.AssistedAddrs
detectAddrs = append(detectAddrs, m.CandidateAddrs...)
} else {
// receiver
if len(m.DetectBehavior.CandidatePorts) == 0 {
detectAddrs = m.CandidateAddrs
}
if m.DetectBehavior.ListenRandomPorts > 0 {
for i := 0; i < m.DetectBehavior.ListenRandomPorts; i++ {
tmpConn, err := net.ListenUDP("udp4", nil)
if err != nil {
xl.Warn("listen random udp addr error: %v", err)
continue
}
listenConns = append(listenConns, tmpConn)
}
}
}
detectAddrs = lo.Uniq(detectAddrs)
for _, detectAddr := range detectAddrs {
for _, conn := range listenConns {
if err := sendSidMessage(ctx, conn, m.Sid, transactionID, detectAddr, key, m.DetectBehavior.TTL); err != nil {
xl.Trace("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err)
}
}
}
if len(m.DetectBehavior.CandidatePorts) > 0 {
for _, conn := range listenConns {
sendSidMessageToRangePorts(ctx, conn, m.CandidateAddrs, m.DetectBehavior.CandidatePorts, sendToRangePortsFunc)
}
}
if m.DetectBehavior.SendRandomPorts > 0 {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
for i := range listenConns {
go sendSidMessageToRandomPorts(ctx, listenConns[i], m.CandidateAddrs, m.DetectBehavior.SendRandomPorts, sendToRangePortsFunc)
}
}
timeout := 5 * time.Second
if m.DetectBehavior.ReadTimeoutMs > 0 {
timeout = time.Duration(m.DetectBehavior.ReadTimeoutMs) * time.Millisecond
}
if len(listenConns) == 1 {
raddr, err := waitDetectMessage(ctx, listenConns[0], m.Sid, key, timeout, m.DetectBehavior.Role)
if err != nil {
return nil, nil, fmt.Errorf("wait detect message error: %v", err)
}
return listenConns[0], raddr, nil
}
type result struct {
lConn *net.UDPConn
raddr *net.UDPAddr
}
resultCh := make(chan result)
for _, conn := range listenConns {
go func(lConn *net.UDPConn) {
addr, err := waitDetectMessage(ctx, lConn, m.Sid, key, timeout, m.DetectBehavior.Role)
if err != nil {
lConn.Close()
return
}
select {
case resultCh <- result{lConn: lConn, raddr: addr}:
default:
lConn.Close()
}
}(conn)
}
select {
case result := <-resultCh:
return result.lConn, result.raddr, nil
case <-time.After(timeout):
return nil, nil, fmt.Errorf("wait detect message timeout")
case <-ctx.Done():
return nil, nil, fmt.Errorf("wait detect message canceled")
}
}
func (nc *Controller) Run() {
func waitDetectMessage(
ctx context.Context, conn *net.UDPConn, sid string, key []byte,
timeout time.Duration, role string,
) (*net.UDPAddr, error) {
xl := xlog.FromContextSafe(ctx)
for {
buf := pool.GetBuf(1024)
n, raddr, err := nc.listener.ReadFromUDP(buf)
_ = conn.SetReadDeadline(time.Now().Add(timeout))
n, raddr, err := conn.ReadFromUDP(buf)
_ = conn.SetReadDeadline(time.Time{})
if err != nil {
log.Warn("nat hole listener read from udp error: %v", err)
return
return nil, err
}
plain, err := crypto.Decode(buf[:n], nc.encryptionKey)
if err != nil {
log.Warn("nathole listener decode from %s error: %v", raddr.String(), err)
continue
}
rawMsg, err := msg.ReadMsg(bytes.NewReader(plain))
if err != nil {
log.Warn("read nat hole message error: %v", err)
continue
}
switch m := rawMsg.(type) {
case *msg.NatHoleBinding:
go nc.HandleBinding(m, raddr)
case *msg.NatHoleVisitor:
go nc.HandleVisitor(m, raddr)
case *msg.NatHoleClient:
go nc.HandleClient(m, raddr)
default:
log.Trace("unknown nat hole message type")
xl.Debug("get udp message local %s, from %s", conn.LocalAddr(), raddr)
var m msg.NatHoleSid
if err := DecodeMessageInto(buf[:n], key, &m); err != nil {
xl.Warn("decode sid message error: %v", err)
continue
}
pool.PutBuf(buf)
}
}
func (nc *Controller) GenSid() string {
t := time.Now().Unix()
id, _ := util.RandID()
return fmt.Sprintf("%d%s", t, id)
}
func (nc *Controller) HandleBinding(m *msg.NatHoleBinding, raddr *net.UDPAddr) {
log.Trace("handle binding message from %s", raddr.String())
resp := &msg.NatHoleBindingResp{
TransactionID: m.TransactionID,
Address: raddr.String(),
}
plain, err := msg.Pack(resp)
if err != nil {
log.Error("pack nat hole binding response error: %v", err)
return
}
buf, err := crypto.Encode(plain, nc.encryptionKey)
if err != nil {
log.Error("encode nat hole binding response error: %v", err)
return
}
_, err = nc.listener.WriteToUDP(buf, raddr)
if err != nil {
log.Error("write nat hole binding response to %s error: %v", raddr.String(), err)
return
}
}
func (nc *Controller) HandleVisitor(m *msg.NatHoleVisitor, raddr *net.UDPAddr) {
sid := nc.GenSid()
session := &Session{
Sid: sid,
VisitorAddr: raddr,
NotifyCh: make(chan struct{}),
}
nc.mu.Lock()
clientCfg, ok := nc.clientCfgs[m.ProxyName]
if !ok {
nc.mu.Unlock()
errInfo := fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName)
log.Debug(errInfo)
_, _ = nc.listener.WriteToUDP(nc.GenNatHoleResponse(nil, errInfo), raddr)
return
}
if m.SignKey != util.GetAuthKey(clientCfg.Sk, m.Timestamp) {
nc.mu.Unlock()
errInfo := fmt.Sprintf("xtcp connection of [%s] auth failed", m.ProxyName)
log.Debug(errInfo)
_, _ = nc.listener.WriteToUDP(nc.GenNatHoleResponse(nil, errInfo), raddr)
return
}
nc.sessions[sid] = session
nc.mu.Unlock()
log.Trace("handle visitor message, sid [%s]", sid)
defer func() {
nc.mu.Lock()
delete(nc.sessions, sid)
nc.mu.Unlock()
}()
err := errors.PanicToError(func() {
clientCfg.SidCh <- &SidRequest{
Sid: sid,
NotifyCh: session.NotifyCh,
if m.Sid != sid {
xl.Warn("get sid message with wrong sid: %s, expect: %s", m.Sid, sid)
continue
}
})
if !m.Response {
// only wait for response messages if we are a sender
if role == DetectRoleSender {
continue
}
m.Response = true
buf2, err := EncodeMessage(&m, key)
if err != nil {
xl.Warn("encode sid message error: %v", err)
continue
}
_, _ = conn.WriteToUDP(buf2, raddr)
}
return raddr, nil
}
}
func sendSidMessage(
ctx context.Context, conn *net.UDPConn,
sid string, transactionID string, addr string, key []byte, ttl int,
) error {
xl := xlog.FromContextSafe(ctx)
ttlStr := ""
if ttl > 0 {
ttlStr = fmt.Sprintf(" with ttl %d", ttl)
}
xl.Trace("send sid message from %s to %s%s", conn.LocalAddr(), addr, ttlStr)
raddr, err := net.ResolveUDPAddr("udp4", addr)
if err != nil {
return
return err
}
// Wait client connections.
select {
case <-session.NotifyCh:
resp := nc.GenNatHoleResponse(session, "")
log.Trace("send nat hole response to visitor")
_, _ = nc.listener.WriteToUDP(resp, raddr)
case <-time.After(time.Duration(NatHoleTimeout) * time.Second):
return
if transactionID == "" {
transactionID = NewTransactionID()
}
}
func (nc *Controller) HandleClient(m *msg.NatHoleClient, raddr *net.UDPAddr) {
nc.mu.RLock()
session, ok := nc.sessions[m.Sid]
nc.mu.RUnlock()
if !ok {
return
m := &msg.NatHoleSid{
TransactionID: transactionID,
Sid: sid,
Response: false,
Nonce: strings.Repeat("0", rand.Intn(20)),
}
log.Trace("handle client message, sid [%s]", session.Sid)
session.ClientAddr = raddr
resp := nc.GenNatHoleResponse(session, "")
log.Trace("send nat hole response to client")
_, _ = nc.listener.WriteToUDP(resp, raddr)
}
func (nc *Controller) GenNatHoleResponse(session *Session, errInfo string) []byte {
var (
sid string
visitorAddr string
clientAddr string
)
if session != nil {
sid = session.Sid
visitorAddr = session.VisitorAddr.String()
clientAddr = session.ClientAddr.String()
}
m := &msg.NatHoleResp{
Sid: sid,
VisitorAddr: visitorAddr,
ClientAddr: clientAddr,
Error: errInfo,
}
b := bytes.NewBuffer(nil)
err := msg.WriteMsg(b, m)
buf, err := EncodeMessage(m, key)
if err != nil {
return []byte("")
return err
}
return b.Bytes()
if ttl > 0 {
uConn := ipv4.NewConn(conn)
original, err := uConn.TTL()
if err != nil {
xl.Trace("get ttl error %v", err)
return err
}
xl.Trace("original ttl %d", original)
err = uConn.SetTTL(ttl)
if err != nil {
xl.Trace("set ttl error %v", err)
} else {
defer func() {
_ = uConn.SetTTL(original)
}()
}
}
if _, err := conn.WriteToUDP(buf, raddr); err != nil {
return err
}
return nil
}
type Session struct {
Sid string
VisitorAddr *net.UDPAddr
ClientAddr *net.UDPAddr
NotifyCh chan struct{}
func sendSidMessageToRangePorts(
ctx context.Context, conn *net.UDPConn, addrs []string, ports []msg.PortsRange,
sendFunc func(*net.UDPConn, string) error,
) {
xl := xlog.FromContextSafe(ctx)
for _, ip := range lo.Uniq(parseIPs(addrs)) {
for _, portsRange := range ports {
for i := portsRange.From; i <= portsRange.To; i++ {
detectAddr := net.JoinHostPort(ip, strconv.Itoa(i))
if err := sendFunc(conn, detectAddr); err != nil {
xl.Trace("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err)
}
time.Sleep(5 * time.Millisecond)
}
}
}
}
type ClientCfg struct {
Name string
Sk string
SidCh chan *SidRequest
func sendSidMessageToRandomPorts(
ctx context.Context, conn *net.UDPConn, addrs []string, count int,
sendFunc func(*net.UDPConn, string) error,
) {
xl := xlog.FromContextSafe(ctx)
used := sets.New[int]()
getUnusedPort := func() int {
for i := 0; i < 10; i++ {
port := rand.Intn(65535-1024) + 1024
if !used.Has(port) {
used.Insert(port)
return port
}
}
return 0
}
for i := 0; i < count; i++ {
select {
case <-ctx.Done():
return
default:
}
port := getUnusedPort()
if port == 0 {
continue
}
for _, ip := range lo.Uniq(parseIPs(addrs)) {
detectAddr := net.JoinHostPort(ip, strconv.Itoa(port))
if err := sendFunc(conn, detectAddr); err != nil {
xl.Trace("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err)
}
time.Sleep(time.Millisecond * 15)
}
}
}
func parseIPs(addrs []string) []string {
var ips []string
for _, addr := range addrs {
if ip, _, err := net.SplitHostPort(addr); err == nil {
ips = append(ips, ip)
}
}
return ips
}

View File

@@ -16,6 +16,7 @@ package nathole
import (
"bytes"
"fmt"
"net"
"strconv"
@@ -63,3 +64,49 @@ func (s *ChangedAddress) GetFrom(m *stun.Message) error {
func (s *ChangedAddress) String() string {
return net.JoinHostPort(s.IP.String(), strconv.Itoa(s.Port))
}
func ListAllLocalIPs() ([]net.IP, error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
return nil, err
}
ips := make([]net.IP, 0, len(addrs))
for _, addr := range addrs {
ip, _, err := net.ParseCIDR(addr.String())
if err != nil {
continue
}
ips = append(ips, ip)
}
return ips, nil
}
func ListLocalIPsForNatHole(max int) ([]string, error) {
if max <= 0 {
return nil, fmt.Errorf("max must be greater than 0")
}
ips, err := ListAllLocalIPs()
if err != nil {
return nil, err
}
filtered := make([]string, 0, max)
for _, ip := range ips {
if len(filtered) >= max {
break
}
// ignore ipv6 address
if ip.To4() == nil {
continue
}
// ignore localhost IP
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
continue
}
filtered = append(filtered, ip.String())
}
return filtered, nil
}