vnet: fix issues (#4771)

This commit is contained in:
fatedier
2025-04-27 15:22:28 +08:00
parent 27f66baf54
commit 3c8d648ddc
18 changed files with 271 additions and 185 deletions

View File

@@ -87,7 +87,7 @@ func (c *Controller) handlePacket(buf []byte) {
case waterutil.IsIPv4(buf):
header, err := ipv4.ParseHeader(buf)
if err != nil {
log.Warnf("parse ipv4 header error:", err)
log.Warnf("parse ipv4 header error: %v", err)
return
}
src = header.Src
@@ -98,7 +98,7 @@ func (c *Controller) handlePacket(buf []byte) {
case waterutil.IsIPv6(buf):
header, err := ipv6.ParseHeader(buf)
if err != nil {
log.Warnf("parse ipv6 header error:", err)
log.Warnf("parse ipv6 header error: %v", err)
return
}
src = header.Src
@@ -137,6 +137,12 @@ func (c *Controller) Stop() error {
// Client connection read loop
func (c *Controller) readLoopClient(ctx context.Context, conn io.ReadWriteCloser) {
xl := xlog.FromContextSafe(ctx)
defer func() {
// Remove the route when read loop ends (connection closed)
c.clientRouter.removeConnRoute(conn)
conn.Close()
}()
for {
data, err := ReadMessage(conn)
if err != nil {
@@ -181,8 +187,18 @@ func (c *Controller) readLoopClient(ctx context.Context, conn io.ReadWriteCloser
}
// Server connection read loop
func (c *Controller) readLoopServer(ctx context.Context, conn io.ReadWriteCloser) {
func (c *Controller) readLoopServer(ctx context.Context, conn io.ReadWriteCloser, onClose func()) {
xl := xlog.FromContextSafe(ctx)
defer func() {
// Clean up all IP mappings associated with this connection when it closes
c.serverRouter.cleanupConnIPs(conn)
// Call the provided callback upon closure
if onClose != nil {
onClose()
}
conn.Close()
}()
for {
data, err := ReadMessage(conn)
if err != nil {
@@ -220,27 +236,11 @@ func (c *Controller) readLoopServer(ctx context.Context, conn io.ReadWriteCloser
}
}
// RegisterClientRoute Register client route (based on destination IP CIDR)
func (c *Controller) RegisterClientRoute(ctx context.Context, name string, routes []net.IPNet, conn io.ReadWriteCloser) error {
if err := c.clientRouter.addRoute(name, routes, conn); err != nil {
return err
}
// RegisterClientRoute registers a client route (based on destination IP CIDR)
// and starts the read loop
func (c *Controller) RegisterClientRoute(ctx context.Context, name string, routes []net.IPNet, conn io.ReadWriteCloser) {
c.clientRouter.addRoute(name, routes, conn)
go c.readLoopClient(ctx, conn)
return nil
}
// RegisterServerConn Register server connection (dynamically associates with source IPs)
func (c *Controller) RegisterServerConn(ctx context.Context, name string, conn io.ReadWriteCloser) error {
if err := c.serverRouter.addConn(name, conn); err != nil {
return err
}
go c.readLoopServer(ctx, conn)
return nil
}
// UnregisterServerConn Remove server connection from routing table
func (c *Controller) UnregisterServerConn(name string) {
c.serverRouter.delConn(name)
}
// UnregisterClientRoute Remove client route from routing table
@@ -248,6 +248,12 @@ func (c *Controller) UnregisterClientRoute(name string) {
c.clientRouter.delRoute(name)
}
// StartServerConnReadLoop starts the read loop for a server connection
// (dynamically associates with source IPs)
func (c *Controller) StartServerConnReadLoop(ctx context.Context, conn io.ReadWriteCloser, onClose func()) {
go c.readLoopServer(ctx, conn, onClose)
}
// ParseRoutes Convert route strings to IPNet objects
func ParseRoutes(routeStrings []string) ([]net.IPNet, error) {
routes := make([]net.IPNet, 0, len(routeStrings))
@@ -273,7 +279,7 @@ func newClientRouter() *clientRouter {
}
}
func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWriteCloser) error {
func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWriteCloser) {
r.mu.Lock()
defer r.mu.Unlock()
r.routes[name] = &routeElement{
@@ -281,7 +287,6 @@ func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWri
routes: routes,
conn: conn,
}
return nil
}
func (r *clientRouter) findConn(dst net.IP) (io.Writer, error) {
@@ -303,32 +308,29 @@ func (r *clientRouter) delRoute(name string) {
delete(r.routes, name)
}
// Server router (based on source IP routing)
func (r *clientRouter) removeConnRoute(conn io.Writer) {
r.mu.Lock()
defer r.mu.Unlock()
for name, re := range r.routes {
if re.conn == conn {
delete(r.routes, name)
return
}
}
}
// Server router (based solely on source IP routing)
type serverRouter struct {
namedConns map[string]io.ReadWriteCloser // Name to connection mapping
srcIPConns map[string]io.Writer // Source IP string to connection mapping
srcIPConns map[string]io.Writer // Source IP string to connection mapping
mu sync.RWMutex
}
func newServerRouter() *serverRouter {
return &serverRouter{
namedConns: make(map[string]io.ReadWriteCloser),
srcIPConns: make(map[string]io.Writer),
}
}
func (r *serverRouter) addConn(name string, conn io.ReadWriteCloser) error {
r.mu.Lock()
original, ok := r.namedConns[name]
r.namedConns[name] = conn
r.mu.Unlock()
if ok {
// Close the original connection if it exists
_ = original.Close()
}
return nil
}
func (r *serverRouter) findConnBySrc(src net.IP) (io.Writer, error) {
r.mu.RLock()
defer r.mu.RUnlock()
@@ -340,17 +342,41 @@ func (r *serverRouter) findConnBySrc(src net.IP) (io.Writer, error) {
}
func (r *serverRouter) registerSrcIP(src net.IP, conn io.Writer) {
key := src.String()
r.mu.RLock()
existingConn, ok := r.srcIPConns[key]
r.mu.RUnlock()
// If the entry exists and the connection is the same, no need to do anything.
if ok && existingConn == conn {
return
}
// Acquire write lock to update the map.
r.mu.Lock()
defer r.mu.Unlock()
r.srcIPConns[src.String()] = conn
// Double-check after acquiring the write lock to handle potential race conditions.
existingConn, ok = r.srcIPConns[key]
if ok && existingConn == conn {
return
}
r.srcIPConns[key] = conn
}
func (r *serverRouter) delConn(name string) {
// cleanupConnIPs removes all IP mappings associated with the specified connection
func (r *serverRouter) cleanupConnIPs(conn io.Writer) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.namedConns, name)
// Note: We don't delete mappings from srcIPConns because we don't know which source IPs are associated with this connection
// This might cause dangling references, but they will be overwritten on new connections or restart
// Find and delete all IP mappings pointing to this connection
for ip, mappedConn := range r.srcIPConns {
if mappedConn == conn {
delete(r.srcIPConns, ip)
}
}
}
type routeElement struct {

View File

@@ -33,7 +33,7 @@ func ReadMessage(r io.Reader) ([]byte, error) {
var length uint32
err := binary.Read(r, binary.LittleEndian, &length)
if err != nil {
return nil, fmt.Errorf("read message length error: %v", err)
return nil, fmt.Errorf("read message length error: %w", err)
}
// Check length to prevent DoS
@@ -48,7 +48,7 @@ func ReadMessage(r io.Reader) ([]byte, error) {
data := make([]byte, length)
_, err = io.ReadFull(r, data)
if err != nil {
return nil, fmt.Errorf("read message data error: %v", err)
return nil, fmt.Errorf("read message data error: %w", err)
}
return data, nil
@@ -68,13 +68,13 @@ func WriteMessage(w io.Writer, data []byte) error {
// Write length
err := binary.Write(w, binary.LittleEndian, length)
if err != nil {
return fmt.Errorf("write message length error: %v", err)
return fmt.Errorf("write message length error: %w", err)
}
// Write message data
_, err = w.Write(data)
if err != nil {
return fmt.Errorf("write message data error: %v", err)
return fmt.Errorf("write message data error: %w", err)
}
return nil

View File

@@ -23,7 +23,8 @@ import (
)
const (
offset = 16
offset = 16
defaultPacketSize = 1420
)
type TunDevice interface {
@@ -35,20 +36,45 @@ func OpenTun(ctx context.Context, addr string) (TunDevice, error) {
if err != nil {
return nil, err
}
return &tunDeviceWrapper{dev: td}, nil
mtu, err := td.MTU()
if err != nil {
mtu = defaultPacketSize
}
bufferSize := max(mtu, defaultPacketSize)
batchSize := td.BatchSize()
device := &tunDeviceWrapper{
dev: td,
bufferSize: bufferSize,
readBuffers: make([][]byte, batchSize),
sizeBuffer: make([]int, batchSize),
}
for i := range device.readBuffers {
device.readBuffers[i] = make([]byte, offset+bufferSize)
}
return device, nil
}
type tunDeviceWrapper struct {
dev tun.Device
dev tun.Device
bufferSize int
readBuffers [][]byte
packetBuffers [][]byte
sizeBuffer []int
}
func (d *tunDeviceWrapper) Read(p []byte) (int, error) {
buf := pool.GetBuf(len(p) + offset)
defer pool.PutBuf(buf)
if len(d.packetBuffers) > 0 {
n := copy(p, d.packetBuffers[0])
d.packetBuffers = d.packetBuffers[1:]
return n, nil
}
sz := make([]int, 1)
n, err := d.dev.Read([][]byte{buf}, sz, offset)
n, err := d.dev.Read(d.readBuffers, d.sizeBuffer, offset)
if err != nil {
return 0, err
}
@@ -56,20 +82,26 @@ func (d *tunDeviceWrapper) Read(p []byte) (int, error) {
return 0, io.EOF
}
dataSize := sz[0]
if dataSize > len(p) {
dataSize = len(p)
for i := range n {
if d.sizeBuffer[i] <= 0 {
continue
}
d.packetBuffers = append(d.packetBuffers, d.readBuffers[i][offset:offset+d.sizeBuffer[i]])
}
copy(p, buf[offset:offset+dataSize])
dataSize := copy(p, d.packetBuffers[0])
d.packetBuffers = d.packetBuffers[1:]
return dataSize, nil
}
func (d *tunDeviceWrapper) Write(p []byte) (int, error) {
buf := pool.GetBuf(len(p) + offset)
buf := pool.GetBuf(offset + d.bufferSize)
defer pool.PutBuf(buf)
copy(buf[offset:], p)
return d.dev.Write([][]byte{buf}, offset)
n := copy(buf[offset:], p)
_, err := d.dev.Write([][]byte{buf[:offset+n]}, offset)
return n, err
}
func (d *tunDeviceWrapper) Close() error {

View File

@@ -16,35 +16,44 @@ package vnet
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"strconv"
"strings"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/tun"
)
const (
defaultTunName = "utun"
defaultMTU = 1420
baseTunName = "utun"
defaultMTU = 1420
)
func openTun(_ context.Context, addr string) (tun.Device, error) {
dev, err := tun.CreateTUN(defaultTunName, defaultMTU)
name, err := findNextTunName(baseTunName)
if err != nil {
name = getFallbackTunName(baseTunName, addr)
}
tunDevice, err := tun.CreateTUN(name, defaultMTU)
if err != nil {
return nil, fmt.Errorf("failed to create TUN device '%s': %w", name, err)
}
actualName, err := tunDevice.Name()
if err != nil {
return nil, err
}
name, err := dev.Name()
ifn, err := net.InterfaceByName(actualName)
if err != nil {
return nil, err
}
ifn, err := net.InterfaceByName(name)
if err != nil {
return nil, err
}
link, err := netlink.LinkByName(name)
link, err := netlink.LinkByName(actualName)
if err != nil {
return nil, err
}
@@ -69,7 +78,34 @@ func openTun(_ context.Context, addr string) (tun.Device, error) {
if err = addRoutes(ifn, cidr); err != nil {
return nil, err
}
return dev, nil
return tunDevice, nil
}
func findNextTunName(basename string) (string, error) {
interfaces, err := net.Interfaces()
if err != nil {
return "", fmt.Errorf("failed to get network interfaces: %w", err)
}
maxSuffix := -1
for _, iface := range interfaces {
name := iface.Name
if strings.HasPrefix(name, basename) {
suffix := name[len(basename):]
if suffix == "" {
continue
}
numSuffix, err := strconv.Atoi(suffix)
if err == nil && numSuffix > maxSuffix {
maxSuffix = numSuffix
}
}
}
nextSuffix := maxSuffix + 1
name := fmt.Sprintf("%s%d", basename, nextSuffix)
return name, nil
}
func addRoutes(ifn *net.Interface, cidr *net.IPNet) error {
@@ -82,3 +118,14 @@ func addRoutes(ifn *net.Interface, cidr *net.IPNet) error {
}
return nil
}
// getFallbackTunName generates a deterministic fallback TUN device name
// based on the base name and the provided address string using a hash.
func getFallbackTunName(baseName, addr string) string {
hasher := sha256.New()
hasher.Write([]byte(addr))
hashBytes := hasher.Sum(nil)
// Use first 4 bytes -> 8 hex chars for brevity, respecting IFNAMSIZ limit.
shortHash := hex.EncodeToString(hashBytes[:4])
return fmt.Sprintf("%s%s", baseName, shortHash)
}