frp/pkg/vnet/controller.go
2025-04-09 02:07:29 +08:00

357 lines
8.9 KiB
Go

// Copyright 2025 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package vnet
import (
"encoding/base64"
"fmt"
"io"
"net"
"sync"
"github.com/songgao/water/waterutil"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/util/log"
)
const (
maxPacketSize = 1420 // Maximum TUN packet size
)
type Controller struct {
addr string
tun io.ReadWriteCloser
clientRouter *clientRouter // Route based on destination IP (client mode)
serverRouter *serverRouter // Route based on source IP (server mode)
}
func NewController(cfg v1.VirtualNetConfig) *Controller {
return &Controller{
addr: cfg.Address,
clientRouter: newClientRouter(),
serverRouter: newServerRouter(),
}
}
func (c *Controller) Init() error {
conn, _, _, err := createTun(c.addr)
if err != nil {
return err
}
c.tun = conn
return nil
}
func (c *Controller) Run() error {
conn := c.tun
for {
buf := make([]byte, maxPacketSize)
n, err := conn.Read(buf)
if err != nil {
log.Warnf("vnet read from tun error: %v", err)
return err
}
log.Tracef("vnet read from tun [%d]: %s", n, base64.StdEncoding.EncodeToString(buf[:n]))
var src, dst net.IP
switch {
case waterutil.IsIPv4(buf[:n]):
header, err := ipv4.ParseHeader(buf[:n])
if err != nil {
log.Warnf("parse ipv4 header error:", err)
continue
}
src = header.Src
dst = header.Dst
log.Tracef("%s >> %s %d/%-4d %-4x %d",
header.Src, header.Dst,
header.Len, header.TotalLen, header.ID, header.Flags)
case waterutil.IsIPv6(buf[:n]):
header, err := ipv6.ParseHeader(buf[:n])
if err != nil {
log.Warnf("parse ipv6 header error:", err)
continue
}
src = header.Src
dst = header.Dst
log.Tracef("%s >> %s %d %d",
header.Src, header.Dst,
header.PayloadLen, header.TrafficClass)
default:
log.Warnf("unknown packet, discarded(%d)", n)
continue
}
// 1. First try to route based on destination IP (client mode)
targetConn, err := c.clientRouter.findConn(dst)
if err == nil {
// Found matching destination route, sending data
if err := WriteMessage(targetConn, buf[:n]); err != nil {
log.Warnf("write to client target conn error: %v", err)
}
continue
}
// 2. If client routing fails, try routing based on source IP (server mode)
targetConn, err = c.serverRouter.findConnBySrc(dst)
if err == nil {
// Found matching source route, sending data
if err := WriteMessage(targetConn, buf[:n]); err != nil {
log.Warnf("write to server target conn error: %v", err)
}
continue
}
// 3. No matching route found
log.Tracef("no route found for packet from %s to %s", src, dst)
}
}
func (c *Controller) Stop() error {
return c.tun.Close()
}
// Client connection read loop
func (c *Controller) readLoopClient(conn io.ReadWriteCloser) {
for {
// Read message with framing
data, err := ReadMessage(conn)
if err != nil {
log.Warnf("client read error: %v", err)
return
}
if len(data) == 0 {
continue
}
switch {
case waterutil.IsIPv4(data):
header, err := ipv4.ParseHeader(data)
if err != nil {
log.Warnf("parse ipv4 header error: %v", err)
continue
}
log.Tracef("%s >> %s %d/%-4d %-4x %d",
header.Src, header.Dst,
header.Len, header.TotalLen, header.ID, header.Flags)
case waterutil.IsIPv6(data):
header, err := ipv6.ParseHeader(data)
if err != nil {
log.Warnf("parse ipv6 header error: %v", err)
continue
}
log.Tracef("%s >> %s %d %d",
header.Src, header.Dst,
header.PayloadLen, header.TrafficClass)
default:
log.Warnf("unknown packet, discarded(%d)", len(data))
continue
}
// Write to TUN device
log.Tracef("vnet write to tun (client) [%d]: %s", len(data), base64.StdEncoding.EncodeToString(data))
_, err = c.tun.Write(data)
if err != nil {
log.Warnf("client write tun error: %v", err)
}
}
}
// Server connection read loop
func (c *Controller) readLoopServer(conn io.ReadWriteCloser) {
for {
// Read packet with framing
data, err := ReadMessage(conn)
if err != nil {
log.Warnf("server read error: %v", err)
return
}
if len(data) == 0 {
continue
}
// Register source IP to connection mapping
if waterutil.IsIPv4(data) || waterutil.IsIPv6(data) {
var src net.IP
if waterutil.IsIPv4(data) {
header, err := ipv4.ParseHeader(data)
if err == nil {
src = header.Src
c.serverRouter.registerSrcIP(src, conn)
}
} else if waterutil.IsIPv6(data) {
header, err := ipv6.ParseHeader(data)
if err == nil {
src = header.Src
c.serverRouter.registerSrcIP(src, conn)
}
}
}
// Write to TUN
log.Tracef("vnet write to tun (server) [%d]: %s", len(data), base64.StdEncoding.EncodeToString(data))
_, err = c.tun.Write(data)
if err != nil {
log.Warnf("server write tun error: %v", err)
}
}
}
// RegisterClientRoute Register client route (based on destination IP CIDR)
func (c *Controller) RegisterClientRoute(name string, routes []net.IPNet, conn io.ReadWriteCloser) error {
if err := c.clientRouter.addRoute(name, routes, conn); err != nil {
return err
}
go c.readLoopClient(conn)
return nil
}
// RegisterServerConn Register server connection (dynamically associates with source IPs)
func (c *Controller) RegisterServerConn(name string, conn io.ReadWriteCloser) error {
if err := c.serverRouter.addConn(name, conn); err != nil {
return err
}
go c.readLoopServer(conn)
return nil
}
// UnregisterServerConn Remove server connection from routing table
func (c *Controller) UnregisterServerConn(name string) {
c.serverRouter.delConn(name)
}
// UnregisterClientRoute Remove client route from routing table
func (c *Controller) UnregisterClientRoute(name string) {
c.clientRouter.delRoute(name)
}
// ParseRoutes Convert route strings to IPNet objects
func ParseRoutes(routeStrings []string) ([]net.IPNet, error) {
routes := make([]net.IPNet, 0, len(routeStrings))
for _, r := range routeStrings {
_, ipNet, err := net.ParseCIDR(r)
if err != nil {
return nil, fmt.Errorf("parse route %s error: %v", r, err)
}
routes = append(routes, *ipNet)
}
return routes, nil
}
// Client router (based on destination IP routing)
type clientRouter struct {
routes map[string]*routeElement
mu sync.RWMutex
}
func newClientRouter() *clientRouter {
return &clientRouter{
routes: make(map[string]*routeElement),
}
}
func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWriteCloser) error {
r.mu.Lock()
defer r.mu.Unlock()
r.routes[name] = &routeElement{
name: name,
routes: routes,
conn: conn,
}
return nil
}
func (r *clientRouter) findConn(dst net.IP) (io.Writer, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, re := range r.routes {
for _, route := range re.routes {
if route.Contains(dst) {
return re.conn, nil
}
}
}
return nil, fmt.Errorf("no route found for destination %s", dst)
}
func (r *clientRouter) delRoute(name string) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.routes, name)
}
// Server router (based on source IP routing)
type serverRouter struct {
namedConns map[string]io.ReadWriteCloser // Name to connection mapping
srcIPConns map[string]io.Writer // Source IP string to connection mapping
mu sync.RWMutex
}
func newServerRouter() *serverRouter {
return &serverRouter{
namedConns: make(map[string]io.ReadWriteCloser),
srcIPConns: make(map[string]io.Writer),
}
}
func (r *serverRouter) addConn(name string, conn io.ReadWriteCloser) error {
r.mu.Lock()
original, ok := r.namedConns[name]
r.namedConns[name] = conn
r.mu.Unlock()
if ok {
// Close the original connection if it exists
_ = original.Close()
}
return nil
}
func (r *serverRouter) findConnBySrc(src net.IP) (io.Writer, error) {
r.mu.RLock()
defer r.mu.RUnlock()
conn, exists := r.srcIPConns[src.String()]
if !exists {
return nil, fmt.Errorf("no route found for source %s", src)
}
return conn, nil
}
func (r *serverRouter) registerSrcIP(src net.IP, conn io.Writer) {
r.mu.Lock()
defer r.mu.Unlock()
r.srcIPConns[src.String()] = conn
}
func (r *serverRouter) delConn(name string) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.namedConns, name)
// Note: We don't delete mappings from srcIPConns because we don't know which source IPs are associated with this connection
// This might cause dangling references, but they will be overwritten on new connections or restart
}
type routeElement struct {
name string
routes []net.IPNet
conn io.ReadWriteCloser
}