use std slices package (#4008)

This commit is contained in:
fatedier
2024-02-20 12:01:41 +08:00
committed by GitHub
parent b6361fb143
commit 3e0c78233a
22 changed files with 75 additions and 84 deletions

View File

@@ -15,6 +15,8 @@
package nathole
import (
"cmp"
"slices"
"sync"
"time"
@@ -233,12 +235,12 @@ func (mhr *MakeHoleRecords) Recommand() (mode, index int) {
mhr.mu.Lock()
defer mhr.mu.Unlock()
maxScore := lo.MaxBy(mhr.scores, func(item, max *BehaviorScore) bool {
return item.Score > max.Score
})
if maxScore == nil {
if len(mhr.scores) == 0 {
return 0, 0
}
maxScore := slices.MaxFunc(mhr.scores, func(a, b *BehaviorScore) int {
return cmp.Compare(a.Score, b.Score)
})
maxScore.Score--
mhr.LastUpdateTime = time.Now()
return maxScore.Mode, maxScore.Index

View File

@@ -17,9 +17,8 @@ package nathole
import (
"fmt"
"net"
"slices"
"strconv"
"github.com/samber/lo"
)
const (
@@ -59,7 +58,7 @@ func ClassifyNATFeature(addresses []string, localIPs []string) (*NatFeature, err
if err != nil {
return nil, err
}
if lo.Contains(localIPs, ip) {
if slices.Contains(localIPs, ip) {
natFeature.PublicNetwork = true
}

View File

@@ -20,6 +20,7 @@ import (
"encoding/hex"
"fmt"
"net"
"slices"
"strconv"
"sync"
"time"
@@ -72,7 +73,7 @@ type Session struct {
func (s *Session) genAnalysisKey() {
hash := md5.New()
vIPs := lo.Uniq(parseIPs(s.visitorMsg.MappedAddrs))
vIPs := slices.Compact(parseIPs(s.visitorMsg.MappedAddrs))
if len(vIPs) > 0 {
hash.Write([]byte(vIPs[0]))
}
@@ -80,7 +81,7 @@ func (s *Session) genAnalysisKey() {
hash.Write([]byte(s.vNatFeature.Behavior))
hash.Write([]byte(strconv.FormatBool(s.vNatFeature.RegularPortsChange)))
cIPs := lo.Uniq(parseIPs(s.clientMsg.MappedAddrs))
cIPs := slices.Compact(parseIPs(s.clientMsg.MappedAddrs))
if len(cIPs) > 0 {
hash.Write([]byte(cIPs[0]))
}
@@ -156,7 +157,7 @@ func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName)))
return
}
if !lo.Contains(cfg.allowUsers, visitorUser) && !lo.Contains(cfg.allowUsers, "*") {
if !slices.Contains(cfg.allowUsers, visitorUser) && !slices.Contains(cfg.allowUsers, "*") {
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp visitor user [%s] not allowed for [%s]", visitorUser, m.ProxyName)))
return
}
@@ -327,8 +328,8 @@ func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleR
TransactionID: vm.TransactionID,
Sid: session.sid,
Protocol: protocol,
CandidateAddrs: lo.Uniq(cm.MappedAddrs),
AssistedAddrs: lo.Uniq(cm.AssistedAddrs),
CandidateAddrs: slices.Compact(cm.MappedAddrs),
AssistedAddrs: slices.Compact(cm.AssistedAddrs),
DetectBehavior: msg.NatHoleDetectBehavior{
Mode: mode,
Role: vBehavior.Role,
@@ -344,8 +345,8 @@ func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleR
TransactionID: cm.TransactionID,
Sid: session.sid,
Protocol: protocol,
CandidateAddrs: lo.Uniq(vm.MappedAddrs),
AssistedAddrs: lo.Uniq(vm.AssistedAddrs),
CandidateAddrs: slices.Compact(vm.MappedAddrs),
AssistedAddrs: slices.Compact(vm.AssistedAddrs),
DetectBehavior: msg.NatHoleDetectBehavior{
Mode: mode,
Role: cBehavior.Role,

View File

@@ -19,12 +19,12 @@ import (
"fmt"
"math/rand"
"net"
"slices"
"strconv"
"strings"
"time"
"github.com/fatedier/golib/pool"
"github.com/samber/lo"
"golang.org/x/net/ipv4"
"k8s.io/apimachinery/pkg/util/sets"
@@ -212,7 +212,7 @@ func MakeHole(ctx context.Context, listenConn *net.UDPConn, m *msg.NatHoleResp,
}
}
detectAddrs = lo.Uniq(detectAddrs)
detectAddrs = slices.Compact(detectAddrs)
for _, detectAddr := range detectAddrs {
for _, conn := range listenConns {
if err := sendSidMessage(ctx, conn, m.Sid, transactionID, detectAddr, key, m.DetectBehavior.TTL); err != nil {
@@ -377,7 +377,7 @@ func sendSidMessageToRangePorts(
sendFunc func(*net.UDPConn, string) error,
) {
xl := xlog.FromContextSafe(ctx)
for _, ip := range lo.Uniq(parseIPs(addrs)) {
for _, ip := range slices.Compact(parseIPs(addrs)) {
for _, portsRange := range ports {
for i := portsRange.From; i <= portsRange.To; i++ {
detectAddr := net.JoinHostPort(ip, strconv.Itoa(i))
@@ -419,7 +419,7 @@ func sendSidMessageToRandomPorts(
continue
}
for _, ip := range lo.Uniq(parseIPs(addrs)) {
for _, ip := range slices.Compact(parseIPs(addrs)) {
detectAddr := net.JoinHostPort(ip, strconv.Itoa(port))
if err := sendFunc(conn, detectAddr); err != nil {
xl.Trace("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err)