diff --git a/client/admin_api.go b/client/admin_api.go index 708b2cbd..f161d588 100644 --- a/client/admin_api.go +++ b/client/admin_api.go @@ -165,9 +165,9 @@ func (svr *Service) apiStatus(w http.ResponseWriter, _ *http.Request) { res StatusResp = make(map[string][]ProxyStatusResp) ) - log.Infof("Http request [/api/status]") + log.Infof("http request [/api/status]") defer func() { - log.Infof("Http response [/api/status]") + log.Infof("http response [/api/status]") buf, _ = json.Marshal(&res) _, _ = w.Write(buf) }() @@ -198,9 +198,9 @@ func (svr *Service) apiStatus(w http.ResponseWriter, _ *http.Request) { func (svr *Service) apiGetConfig(w http.ResponseWriter, _ *http.Request) { res := GeneralResponse{Code: 200} - log.Infof("Http get request [/api/config]") + log.Infof("http get request [/api/config]") defer func() { - log.Infof("Http get response [/api/config], code [%d]", res.Code) + log.Infof("http get response [/api/config], code [%d]", res.Code) w.WriteHeader(res.Code) if len(res.Msg) > 0 { _, _ = w.Write([]byte(res.Msg)) @@ -228,9 +228,9 @@ func (svr *Service) apiGetConfig(w http.ResponseWriter, _ *http.Request) { func (svr *Service) apiPutConfig(w http.ResponseWriter, r *http.Request) { res := GeneralResponse{Code: 200} - log.Infof("Http put request [/api/config]") + log.Infof("http put request [/api/config]") defer func() { - log.Infof("Http put response [/api/config], code [%d]", res.Code) + log.Infof("http put response [/api/config], code [%d]", res.Code) w.WriteHeader(res.Code) if len(res.Msg) > 0 { _, _ = w.Write([]byte(res.Msg)) diff --git a/client/control.go b/client/control.go index 157b4aef..0dd70b8c 100644 --- a/client/control.go +++ b/client/control.go @@ -189,7 +189,7 @@ func (ctl *Control) handlePong(m msg.Message) { inMsg := m.(*msg.Pong) if inMsg.Error != "" { - xl.Errorf("Pong message contains error: %s", inMsg.Error) + xl.Errorf("pong message contains error: %s", inMsg.Error) ctl.closeSession() return } diff --git a/client/service.go b/client/service.go index 57eb4835..e163cac4 100644 --- a/client/service.go +++ b/client/service.go @@ -341,7 +341,7 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE ctl, err := NewControl(svr.ctx, sessionCtx) if err != nil { conn.Close() - xl.Errorf("NewControl error: %v", err) + xl.Errorf("new control error: %v", err) return false, err } ctl.SetInWorkConnCallback(svr.handleWorkConnCb) diff --git a/pkg/plugin/client/virtual_net.go b/pkg/plugin/client/virtual_net.go index e1b29fdc..53570035 100644 --- a/pkg/plugin/client/virtual_net.go +++ b/pkg/plugin/client/virtual_net.go @@ -18,9 +18,10 @@ package client import ( "context" + "io" + "sync" v1 "github.com/fatedier/frp/pkg/config/v1" - "github.com/fatedier/frp/pkg/util/xlog" ) func init() { @@ -30,6 +31,8 @@ func init() { type VirtualNetPlugin struct { pluginCtx PluginContext opts *v1.VirtualNetPluginOptions + mu sync.Mutex + conns map[io.ReadWriteCloser]struct{} } func NewVirtualNetPlugin(pluginCtx PluginContext, options v1.ClientPluginOptions) (Plugin, error) { @@ -43,19 +46,32 @@ func NewVirtualNetPlugin(pluginCtx PluginContext, options v1.ClientPluginOptions } func (p *VirtualNetPlugin) Handle(ctx context.Context, connInfo *ConnectionInfo) { - xl := xlog.FromContextSafe(ctx) - // Verify if virtual network controller is available if p.pluginCtx.VnetController == nil { return } - // Register the connection with the controller - routeName := p.pluginCtx.Name - err := p.pluginCtx.VnetController.RegisterServerConn(ctx, routeName, connInfo.Conn) - if err != nil { - xl.Errorf("virtual net failed to register server connection: %v", err) - return + // Add the connection before starting the read loop to avoid race condition + // where RemoveConn might be called before the connection is added. + p.mu.Lock() + if p.conns == nil { + p.conns = make(map[io.ReadWriteCloser]struct{}) + } + p.conns[connInfo.Conn] = struct{}{} + p.mu.Unlock() + + // Register the connection with the controller and pass the cleanup function + p.pluginCtx.VnetController.StartServerConnReadLoop(ctx, connInfo.Conn, func() { + p.RemoveConn(connInfo.Conn) + }) +} + +func (p *VirtualNetPlugin) RemoveConn(conn io.ReadWriteCloser) { + p.mu.Lock() + defer p.mu.Unlock() + // Check if the map exists, as Close might have set it to nil concurrently + if p.conns != nil { + delete(p.conns, conn) } } @@ -64,8 +80,13 @@ func (p *VirtualNetPlugin) Name() string { } func (p *VirtualNetPlugin) Close() error { - if p.pluginCtx.VnetController != nil { - p.pluginCtx.VnetController.UnregisterServerConn(p.pluginCtx.Name) + p.mu.Lock() + defer p.mu.Unlock() + + // Close any remaining connections + for conn := range p.conns { + _ = conn.Close() } + p.conns = nil return nil } diff --git a/pkg/plugin/visitor/virtual_net.go b/pkg/plugin/visitor/virtual_net.go index e452e14f..f660c0c8 100644 --- a/pkg/plugin/visitor/virtual_net.go +++ b/pkg/plugin/visitor/virtual_net.go @@ -60,7 +60,7 @@ func NewVirtualNetPlugin(pluginCtx PluginContext, options v1.VisitorPluginOption return nil, errors.New("destinationIP is required") } - // Parse DestinationIP as a single IP and create a host route + // Parse DestinationIP and create a host route. ip := net.ParseIP(opts.DestinationIP) if ip == nil { return nil, fmt.Errorf("invalid destination IP address [%s]", opts.DestinationIP) @@ -91,7 +91,7 @@ func (p *VirtualNetPlugin) Start() { if len(p.routes) > 0 { routeStr = p.routes[0].String() } - xl.Infof("Starting VirtualNetPlugin for visitor [%s], attempting to register routes for %s", p.pluginCtx.Name, routeStr) + xl.Infof("starting VirtualNetPlugin for visitor [%s], attempting to register routes for %s", p.pluginCtx.Name, routeStr) go p.run() } @@ -101,10 +101,8 @@ func (p *VirtualNetPlugin) run() { reconnectDelay := 10 * time.Second for { - // Create a signal channel for this connection attempt currentCloseSignal := make(chan struct{}) - // Store the signal channel under lock p.mu.Lock() p.closeSignal = currentCloseSignal p.mu.Unlock() @@ -112,7 +110,6 @@ func (p *VirtualNetPlugin) run() { select { case <-p.ctx.Done(): xl.Infof("VirtualNetPlugin run loop for visitor [%s] stopping (context cancelled before pipe creation).", p.pluginCtx.Name) - // Ensure controllerConn from previous loop is cleaned up if necessary p.cleanupControllerConn(xl) return default: @@ -120,65 +117,43 @@ func (p *VirtualNetPlugin) run() { controllerConn, pluginConn := net.Pipe() - // Store controllerConn under lock for cleanup purposes p.mu.Lock() p.controllerConn = controllerConn p.mu.Unlock() - // Wrap pluginConn using CloseNotifyConn pluginNotifyConn := netutil.WrapCloseNotifyConn(pluginConn, func() { - close(currentCloseSignal) // Signal the run loop + close(currentCloseSignal) // Signal the run loop on close. }) - xl.Infof("Attempting to register client route for visitor [%s]", p.pluginCtx.Name) - err := p.pluginCtx.VnetController.RegisterClientRoute(p.ctx, p.pluginCtx.Name, p.routes, controllerConn) - if err != nil { - xl.Errorf("Failed to register client route for visitor [%s]: %v. Retrying after %v", p.pluginCtx.Name, err, reconnectDelay) - p.cleanupPipePair(xl, controllerConn, pluginConn) // Close both ends on registration failure - - // Wait before retrying registration, unless context is cancelled - select { - case <-time.After(reconnectDelay): - continue // Retry the loop - case <-p.ctx.Done(): - xl.Infof("VirtualNetPlugin registration retry wait interrupted for visitor [%s]", p.pluginCtx.Name) - return // Exit loop if context is cancelled during wait - } - } - - xl.Infof("Successfully registered client route for visitor [%s]. Starting connection handler with CloseNotifyConn.", p.pluginCtx.Name) + xl.Infof("attempting to register client route for visitor [%s]", p.pluginCtx.Name) + p.pluginCtx.VnetController.RegisterClientRoute(p.ctx, p.pluginCtx.Name, p.routes, controllerConn) + xl.Infof("successfully registered client route for visitor [%s]. Starting connection handler with CloseNotifyConn.", p.pluginCtx.Name) // Pass the CloseNotifyConn to HandleConn. // HandleConn is responsible for calling Close() on pluginNotifyConn. p.pluginCtx.HandleConn(pluginNotifyConn) - // Wait for either the plugin context to be cancelled or the wrapper's Close() to be called via the signal channel. + // Wait for context cancellation or connection close. select { case <-p.ctx.Done(): xl.Infof("VirtualNetPlugin run loop stopping for visitor [%s] (context cancelled while waiting).", p.pluginCtx.Name) - // Context cancelled, ensure controller side is closed if HandleConn didn't close its side yet. p.cleanupControllerConn(xl) return case <-currentCloseSignal: - xl.Infof("Detected connection closed via CloseNotifyConn for visitor [%s].", p.pluginCtx.Name) - // HandleConn closed the plugin side (pluginNotifyConn). The closeFn was called, closing currentCloseSignal. - // We still need to close the controller side. + xl.Infof("detected connection closed via CloseNotifyConn for visitor [%s].", p.pluginCtx.Name) + // HandleConn closed the plugin side. Close the controller side. p.cleanupControllerConn(xl) - // Add a delay before attempting to reconnect, respecting context cancellation. - xl.Infof("Waiting %v before attempting reconnection for visitor [%s]...", reconnectDelay, p.pluginCtx.Name) + xl.Infof("waiting %v before attempting reconnection for visitor [%s]...", reconnectDelay, p.pluginCtx.Name) select { case <-time.After(reconnectDelay): - // Delay completed, loop will continue. case <-p.ctx.Done(): xl.Infof("VirtualNetPlugin reconnection delay interrupted for visitor [%s]", p.pluginCtx.Name) - return // Exit loop if context is cancelled during wait + return } - // Loop will continue to reconnect. } - // Loop will restart, context check at the beginning of the loop is sufficient. - xl.Infof("Re-establishing virtual connection for visitor [%s]...", p.pluginCtx.Name) + xl.Infof("re-establishing virtual connection for visitor [%s]...", p.pluginCtx.Name) } } @@ -187,46 +162,31 @@ func (p *VirtualNetPlugin) cleanupControllerConn(xl *xlog.Logger) { p.mu.Lock() defer p.mu.Unlock() if p.controllerConn != nil { - xl.Debugf("Cleaning up controllerConn for visitor [%s]", p.pluginCtx.Name) + xl.Debugf("cleaning up controllerConn for visitor [%s]", p.pluginCtx.Name) p.controllerConn.Close() p.controllerConn = nil } - // Also clear the closeSignal reference for the completed/cancelled connection attempt p.closeSignal = nil } -// cleanupPipePair closes both ends of a pipe, used typically when registration fails. -func (p *VirtualNetPlugin) cleanupPipePair(xl *xlog.Logger, controllerConn, pluginConn net.Conn) { - xl.Debugf("Cleaning up pipe pair for visitor [%s] after registration failure", p.pluginCtx.Name) - controllerConn.Close() - pluginConn.Close() - p.mu.Lock() - p.controllerConn = nil // Ensure field is nil if it was briefly set - p.closeSignal = nil // Ensure field is nil if it was briefly set - p.mu.Unlock() -} - // Close initiates the plugin shutdown. func (p *VirtualNetPlugin) Close() error { - xl := xlog.FromContextSafe(p.pluginCtx.Ctx) // Use base context for close logging - xl.Infof("Closing VirtualNetPlugin for visitor [%s]", p.pluginCtx.Name) + xl := xlog.FromContextSafe(p.pluginCtx.Ctx) + xl.Infof("closing VirtualNetPlugin for visitor [%s]", p.pluginCtx.Name) - // 1. Signal the run loop goroutine to stop via context cancellation. + // Signal the run loop goroutine to stop. p.cancel() - // 2. Unregister the route from the controller. - // This might implicitly cause the VnetController to close its end of the pipe (controllerConn). + // Unregister the route from the controller. if p.pluginCtx.VnetController != nil { p.pluginCtx.VnetController.UnregisterClientRoute(p.pluginCtx.Name) - xl.Infof("Unregistered client route for visitor [%s]", p.pluginCtx.Name) - } else { - xl.Warnf("VnetController is nil during close for visitor [%s], cannot unregister route", p.pluginCtx.Name) + xl.Infof("unregistered client route for visitor [%s]", p.pluginCtx.Name) } - // 3. Explicitly close the controller side of the pipe managed by this plugin. + // Explicitly close the controller side of the pipe. // This ensures the pipe is broken even if the run loop is stuck or HandleConn hasn't closed its end. p.cleanupControllerConn(xl) - xl.Infof("Finished cleaning up connections during close for visitor [%s]", p.pluginCtx.Name) + xl.Infof("finished cleaning up connections during close for visitor [%s]", p.pluginCtx.Name) return nil } diff --git a/pkg/util/vhost/http.go b/pkg/util/vhost/http.go index bc458a5c..46ebc3c7 100644 --- a/pkg/util/vhost/http.go +++ b/pkg/util/vhost/http.go @@ -162,7 +162,7 @@ func (rp *HTTPReverseProxy) UnRegister(routeCfg RouteConfig) { func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser string) *RouteConfig { vr, ok := rp.getVhost(domain, location, routeByHTTPUser) if ok { - log.Debugf("get new HTTP request host [%s] path [%s] httpuser [%s]", domain, location, routeByHTTPUser) + log.Debugf("get new http request host [%s] path [%s] httpuser [%s]", domain, location, routeByHTTPUser) return vr.payload.(*RouteConfig) } return nil diff --git a/pkg/util/vhost/vhost.go b/pkg/util/vhost/vhost.go index 75b472b6..66dd577d 100644 --- a/pkg/util/vhost/vhost.go +++ b/pkg/util/vhost/vhost.go @@ -275,7 +275,7 @@ func (l *Listener) Accept() (net.Conn, error) { xl := xlog.FromContextSafe(l.ctx) conn, ok := <-l.accept if !ok { - return nil, fmt.Errorf("Listener closed") + return nil, fmt.Errorf("listener closed") } // if rewriteHost func is exist diff --git a/pkg/vnet/controller.go b/pkg/vnet/controller.go index d43147a3..ca71a8c3 100644 --- a/pkg/vnet/controller.go +++ b/pkg/vnet/controller.go @@ -87,7 +87,7 @@ func (c *Controller) handlePacket(buf []byte) { case waterutil.IsIPv4(buf): header, err := ipv4.ParseHeader(buf) if err != nil { - log.Warnf("parse ipv4 header error:", err) + log.Warnf("parse ipv4 header error: %v", err) return } src = header.Src @@ -98,7 +98,7 @@ func (c *Controller) handlePacket(buf []byte) { case waterutil.IsIPv6(buf): header, err := ipv6.ParseHeader(buf) if err != nil { - log.Warnf("parse ipv6 header error:", err) + log.Warnf("parse ipv6 header error: %v", err) return } src = header.Src @@ -137,6 +137,12 @@ func (c *Controller) Stop() error { // Client connection read loop func (c *Controller) readLoopClient(ctx context.Context, conn io.ReadWriteCloser) { xl := xlog.FromContextSafe(ctx) + defer func() { + // Remove the route when read loop ends (connection closed) + c.clientRouter.removeConnRoute(conn) + conn.Close() + }() + for { data, err := ReadMessage(conn) if err != nil { @@ -181,8 +187,18 @@ func (c *Controller) readLoopClient(ctx context.Context, conn io.ReadWriteCloser } // Server connection read loop -func (c *Controller) readLoopServer(ctx context.Context, conn io.ReadWriteCloser) { +func (c *Controller) readLoopServer(ctx context.Context, conn io.ReadWriteCloser, onClose func()) { xl := xlog.FromContextSafe(ctx) + defer func() { + // Clean up all IP mappings associated with this connection when it closes + c.serverRouter.cleanupConnIPs(conn) + // Call the provided callback upon closure + if onClose != nil { + onClose() + } + conn.Close() + }() + for { data, err := ReadMessage(conn) if err != nil { @@ -220,27 +236,11 @@ func (c *Controller) readLoopServer(ctx context.Context, conn io.ReadWriteCloser } } -// RegisterClientRoute Register client route (based on destination IP CIDR) -func (c *Controller) RegisterClientRoute(ctx context.Context, name string, routes []net.IPNet, conn io.ReadWriteCloser) error { - if err := c.clientRouter.addRoute(name, routes, conn); err != nil { - return err - } +// RegisterClientRoute registers a client route (based on destination IP CIDR) +// and starts the read loop +func (c *Controller) RegisterClientRoute(ctx context.Context, name string, routes []net.IPNet, conn io.ReadWriteCloser) { + c.clientRouter.addRoute(name, routes, conn) go c.readLoopClient(ctx, conn) - return nil -} - -// RegisterServerConn Register server connection (dynamically associates with source IPs) -func (c *Controller) RegisterServerConn(ctx context.Context, name string, conn io.ReadWriteCloser) error { - if err := c.serverRouter.addConn(name, conn); err != nil { - return err - } - go c.readLoopServer(ctx, conn) - return nil -} - -// UnregisterServerConn Remove server connection from routing table -func (c *Controller) UnregisterServerConn(name string) { - c.serverRouter.delConn(name) } // UnregisterClientRoute Remove client route from routing table @@ -248,6 +248,12 @@ func (c *Controller) UnregisterClientRoute(name string) { c.clientRouter.delRoute(name) } +// StartServerConnReadLoop starts the read loop for a server connection +// (dynamically associates with source IPs) +func (c *Controller) StartServerConnReadLoop(ctx context.Context, conn io.ReadWriteCloser, onClose func()) { + go c.readLoopServer(ctx, conn, onClose) +} + // ParseRoutes Convert route strings to IPNet objects func ParseRoutes(routeStrings []string) ([]net.IPNet, error) { routes := make([]net.IPNet, 0, len(routeStrings)) @@ -273,7 +279,7 @@ func newClientRouter() *clientRouter { } } -func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWriteCloser) error { +func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWriteCloser) { r.mu.Lock() defer r.mu.Unlock() r.routes[name] = &routeElement{ @@ -281,7 +287,6 @@ func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWri routes: routes, conn: conn, } - return nil } func (r *clientRouter) findConn(dst net.IP) (io.Writer, error) { @@ -303,32 +308,29 @@ func (r *clientRouter) delRoute(name string) { delete(r.routes, name) } -// Server router (based on source IP routing) +func (r *clientRouter) removeConnRoute(conn io.Writer) { + r.mu.Lock() + defer r.mu.Unlock() + for name, re := range r.routes { + if re.conn == conn { + delete(r.routes, name) + return + } + } +} + +// Server router (based solely on source IP routing) type serverRouter struct { - namedConns map[string]io.ReadWriteCloser // Name to connection mapping - srcIPConns map[string]io.Writer // Source IP string to connection mapping + srcIPConns map[string]io.Writer // Source IP string to connection mapping mu sync.RWMutex } func newServerRouter() *serverRouter { return &serverRouter{ - namedConns: make(map[string]io.ReadWriteCloser), srcIPConns: make(map[string]io.Writer), } } -func (r *serverRouter) addConn(name string, conn io.ReadWriteCloser) error { - r.mu.Lock() - original, ok := r.namedConns[name] - r.namedConns[name] = conn - r.mu.Unlock() - if ok { - // Close the original connection if it exists - _ = original.Close() - } - return nil -} - func (r *serverRouter) findConnBySrc(src net.IP) (io.Writer, error) { r.mu.RLock() defer r.mu.RUnlock() @@ -340,17 +342,41 @@ func (r *serverRouter) findConnBySrc(src net.IP) (io.Writer, error) { } func (r *serverRouter) registerSrcIP(src net.IP, conn io.Writer) { + key := src.String() + + r.mu.RLock() + existingConn, ok := r.srcIPConns[key] + r.mu.RUnlock() + + // If the entry exists and the connection is the same, no need to do anything. + if ok && existingConn == conn { + return + } + + // Acquire write lock to update the map. r.mu.Lock() defer r.mu.Unlock() - r.srcIPConns[src.String()] = conn + + // Double-check after acquiring the write lock to handle potential race conditions. + existingConn, ok = r.srcIPConns[key] + if ok && existingConn == conn { + return + } + + r.srcIPConns[key] = conn } -func (r *serverRouter) delConn(name string) { +// cleanupConnIPs removes all IP mappings associated with the specified connection +func (r *serverRouter) cleanupConnIPs(conn io.Writer) { r.mu.Lock() defer r.mu.Unlock() - delete(r.namedConns, name) - // Note: We don't delete mappings from srcIPConns because we don't know which source IPs are associated with this connection - // This might cause dangling references, but they will be overwritten on new connections or restart + + // Find and delete all IP mappings pointing to this connection + for ip, mappedConn := range r.srcIPConns { + if mappedConn == conn { + delete(r.srcIPConns, ip) + } + } } type routeElement struct { diff --git a/pkg/vnet/message.go b/pkg/vnet/message.go index 002b090a..68ac7704 100644 --- a/pkg/vnet/message.go +++ b/pkg/vnet/message.go @@ -33,7 +33,7 @@ func ReadMessage(r io.Reader) ([]byte, error) { var length uint32 err := binary.Read(r, binary.LittleEndian, &length) if err != nil { - return nil, fmt.Errorf("read message length error: %v", err) + return nil, fmt.Errorf("read message length error: %w", err) } // Check length to prevent DoS @@ -48,7 +48,7 @@ func ReadMessage(r io.Reader) ([]byte, error) { data := make([]byte, length) _, err = io.ReadFull(r, data) if err != nil { - return nil, fmt.Errorf("read message data error: %v", err) + return nil, fmt.Errorf("read message data error: %w", err) } return data, nil @@ -68,13 +68,13 @@ func WriteMessage(w io.Writer, data []byte) error { // Write length err := binary.Write(w, binary.LittleEndian, length) if err != nil { - return fmt.Errorf("write message length error: %v", err) + return fmt.Errorf("write message length error: %w", err) } // Write message data _, err = w.Write(data) if err != nil { - return fmt.Errorf("write message data error: %v", err) + return fmt.Errorf("write message data error: %w", err) } return nil diff --git a/pkg/vnet/tun.go b/pkg/vnet/tun.go index bafc6392..d26314d0 100644 --- a/pkg/vnet/tun.go +++ b/pkg/vnet/tun.go @@ -23,7 +23,8 @@ import ( ) const ( - offset = 16 + offset = 16 + defaultPacketSize = 1420 ) type TunDevice interface { @@ -35,20 +36,45 @@ func OpenTun(ctx context.Context, addr string) (TunDevice, error) { if err != nil { return nil, err } - return &tunDeviceWrapper{dev: td}, nil + + mtu, err := td.MTU() + if err != nil { + mtu = defaultPacketSize + } + + bufferSize := max(mtu, defaultPacketSize) + batchSize := td.BatchSize() + + device := &tunDeviceWrapper{ + dev: td, + bufferSize: bufferSize, + readBuffers: make([][]byte, batchSize), + sizeBuffer: make([]int, batchSize), + } + + for i := range device.readBuffers { + device.readBuffers[i] = make([]byte, offset+bufferSize) + } + + return device, nil } type tunDeviceWrapper struct { - dev tun.Device + dev tun.Device + bufferSize int + readBuffers [][]byte + packetBuffers [][]byte + sizeBuffer []int } func (d *tunDeviceWrapper) Read(p []byte) (int, error) { - buf := pool.GetBuf(len(p) + offset) - defer pool.PutBuf(buf) + if len(d.packetBuffers) > 0 { + n := copy(p, d.packetBuffers[0]) + d.packetBuffers = d.packetBuffers[1:] + return n, nil + } - sz := make([]int, 1) - - n, err := d.dev.Read([][]byte{buf}, sz, offset) + n, err := d.dev.Read(d.readBuffers, d.sizeBuffer, offset) if err != nil { return 0, err } @@ -56,20 +82,26 @@ func (d *tunDeviceWrapper) Read(p []byte) (int, error) { return 0, io.EOF } - dataSize := sz[0] - if dataSize > len(p) { - dataSize = len(p) + for i := range n { + if d.sizeBuffer[i] <= 0 { + continue + } + d.packetBuffers = append(d.packetBuffers, d.readBuffers[i][offset:offset+d.sizeBuffer[i]]) } - copy(p, buf[offset:offset+dataSize]) + + dataSize := copy(p, d.packetBuffers[0]) + d.packetBuffers = d.packetBuffers[1:] + return dataSize, nil } func (d *tunDeviceWrapper) Write(p []byte) (int, error) { - buf := pool.GetBuf(len(p) + offset) + buf := pool.GetBuf(offset + d.bufferSize) defer pool.PutBuf(buf) - copy(buf[offset:], p) - return d.dev.Write([][]byte{buf}, offset) + n := copy(buf[offset:], p) + _, err := d.dev.Write([][]byte{buf[:offset+n]}, offset) + return n, err } func (d *tunDeviceWrapper) Close() error { diff --git a/pkg/vnet/tun_linux.go b/pkg/vnet/tun_linux.go index 7c9c684c..2e0cc56b 100644 --- a/pkg/vnet/tun_linux.go +++ b/pkg/vnet/tun_linux.go @@ -16,35 +16,44 @@ package vnet import ( "context" + "crypto/sha256" + "encoding/hex" "fmt" "net" + "strconv" + "strings" "github.com/vishvananda/netlink" "golang.zx2c4.com/wireguard/tun" ) const ( - defaultTunName = "utun" - defaultMTU = 1420 + baseTunName = "utun" + defaultMTU = 1420 ) func openTun(_ context.Context, addr string) (tun.Device, error) { - dev, err := tun.CreateTUN(defaultTunName, defaultMTU) + name, err := findNextTunName(baseTunName) + if err != nil { + name = getFallbackTunName(baseTunName, addr) + } + + tunDevice, err := tun.CreateTUN(name, defaultMTU) + if err != nil { + return nil, fmt.Errorf("failed to create TUN device '%s': %w", name, err) + } + + actualName, err := tunDevice.Name() if err != nil { return nil, err } - name, err := dev.Name() + ifn, err := net.InterfaceByName(actualName) if err != nil { return nil, err } - ifn, err := net.InterfaceByName(name) - if err != nil { - return nil, err - } - - link, err := netlink.LinkByName(name) + link, err := netlink.LinkByName(actualName) if err != nil { return nil, err } @@ -69,7 +78,34 @@ func openTun(_ context.Context, addr string) (tun.Device, error) { if err = addRoutes(ifn, cidr); err != nil { return nil, err } - return dev, nil + return tunDevice, nil +} + +func findNextTunName(basename string) (string, error) { + interfaces, err := net.Interfaces() + if err != nil { + return "", fmt.Errorf("failed to get network interfaces: %w", err) + } + maxSuffix := -1 + + for _, iface := range interfaces { + name := iface.Name + if strings.HasPrefix(name, basename) { + suffix := name[len(basename):] + if suffix == "" { + continue + } + + numSuffix, err := strconv.Atoi(suffix) + if err == nil && numSuffix > maxSuffix { + maxSuffix = numSuffix + } + } + } + + nextSuffix := maxSuffix + 1 + name := fmt.Sprintf("%s%d", basename, nextSuffix) + return name, nil } func addRoutes(ifn *net.Interface, cidr *net.IPNet) error { @@ -82,3 +118,14 @@ func addRoutes(ifn *net.Interface, cidr *net.IPNet) error { } return nil } + +// getFallbackTunName generates a deterministic fallback TUN device name +// based on the base name and the provided address string using a hash. +func getFallbackTunName(baseName, addr string) string { + hasher := sha256.New() + hasher.Write([]byte(addr)) + hashBytes := hasher.Sum(nil) + // Use first 4 bytes -> 8 hex chars for brevity, respecting IFNAMSIZ limit. + shortHash := hex.EncodeToString(hashBytes[:4]) + return fmt.Sprintf("%s%s", baseName, shortHash) +} diff --git a/server/control.go b/server/control.go index 0b6b3174..b70d8d12 100644 --- a/server/control.go +++ b/server/control.go @@ -224,7 +224,7 @@ func (ctl *Control) Close() error { func (ctl *Control) Replaced(newCtl *Control) { xl := ctl.xl - xl.Infof("Replaced by client [%s]", newCtl.runID) + xl.Infof("replaced by client [%s]", newCtl.runID) ctl.runID = "" ctl.conn.Close() } diff --git a/server/dashboard_api.go b/server/dashboard_api.go index a29433a6..54e5d9e9 100644 --- a/server/dashboard_api.go +++ b/server/dashboard_api.go @@ -97,14 +97,14 @@ func (svr *Service) healthz(w http.ResponseWriter, _ *http.Request) { func (svr *Service) apiServerInfo(w http.ResponseWriter, r *http.Request) { res := GeneralResponse{Code: 200} defer func() { - log.Infof("Http response [%s]: code [%d]", r.URL.Path, res.Code) + log.Infof("http response [%s]: code [%d]", r.URL.Path, res.Code) w.WriteHeader(res.Code) if len(res.Msg) > 0 { _, _ = w.Write([]byte(res.Msg)) } }() - log.Infof("Http request: [%s]", r.URL.Path) + log.Infof("http request: [%s]", r.URL.Path) serverStats := mem.StatsCollector.GetServer() svrResp := serverInfoResp{ Version: version.Full(), @@ -218,13 +218,13 @@ func (svr *Service) apiProxyByType(w http.ResponseWriter, r *http.Request) { proxyType := params["type"] defer func() { - log.Infof("Http response [%s]: code [%d]", r.URL.Path, res.Code) + log.Infof("http response [%s]: code [%d]", r.URL.Path, res.Code) w.WriteHeader(res.Code) if len(res.Msg) > 0 { _, _ = w.Write([]byte(res.Msg)) } }() - log.Infof("Http request: [%s]", r.URL.Path) + log.Infof("http request: [%s]", r.URL.Path) proxyInfoResp := GetProxyInfoResp{} proxyInfoResp.Proxies = svr.getProxyStatsByType(proxyType) @@ -290,13 +290,13 @@ func (svr *Service) apiProxyByTypeAndName(w http.ResponseWriter, r *http.Request name := params["name"] defer func() { - log.Infof("Http response [%s]: code [%d]", r.URL.Path, res.Code) + log.Infof("http response [%s]: code [%d]", r.URL.Path, res.Code) w.WriteHeader(res.Code) if len(res.Msg) > 0 { _, _ = w.Write([]byte(res.Msg)) } }() - log.Infof("Http request: [%s]", r.URL.Path) + log.Infof("http request: [%s]", r.URL.Path) var proxyStatsResp GetProxyStatsResp proxyStatsResp, res.Code, res.Msg = svr.getProxyStatsByTypeAndName(proxyType, name) @@ -358,13 +358,13 @@ func (svr *Service) apiProxyTraffic(w http.ResponseWriter, r *http.Request) { name := params["name"] defer func() { - log.Infof("Http response [%s]: code [%d]", r.URL.Path, res.Code) + log.Infof("http response [%s]: code [%d]", r.URL.Path, res.Code) w.WriteHeader(res.Code) if len(res.Msg) > 0 { _, _ = w.Write([]byte(res.Msg)) } }() - log.Infof("Http request: [%s]", r.URL.Path) + log.Infof("http request: [%s]", r.URL.Path) trafficResp := GetProxyTrafficResp{} trafficResp.Name = name @@ -386,9 +386,9 @@ func (svr *Service) apiProxyTraffic(w http.ResponseWriter, r *http.Request) { func (svr *Service) deleteProxies(w http.ResponseWriter, r *http.Request) { res := GeneralResponse{Code: 200} - log.Infof("Http request: [%s]", r.URL.Path) + log.Infof("http request: [%s]", r.URL.Path) defer func() { - log.Infof("Http response [%s]: code [%d]", r.URL.Path, res.Code) + log.Infof("http response [%s]: code [%d]", r.URL.Path, res.Code) w.WriteHeader(res.Code) if len(res.Msg) > 0 { _, _ = w.Write([]byte(res.Msg)) diff --git a/server/service.go b/server/service.go index d1dd68a1..b9abaa80 100644 --- a/server/service.go +++ b/server/service.go @@ -427,7 +427,7 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna _ = conn.SetReadDeadline(time.Now().Add(connReadTimeout)) if rawMsg, err = msg.ReadMsg(conn); err != nil { - log.Tracef("Failed to read message: %v", err) + log.Tracef("failed to read message: %v", err) conn.Close() return } @@ -475,7 +475,7 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna }) } default: - log.Warnf("Error message type for the new connection [%s]", conn.RemoteAddr().String()) + log.Warnf("error message type for the new connection [%s]", conn.RemoteAddr().String()) conn.Close() } } @@ -488,7 +488,7 @@ func (svr *Service) HandleListener(l net.Listener, internal bool) { for { c, err := l.Accept() if err != nil { - log.Warnf("Listener for incoming connections from client closed") + log.Warnf("listener for incoming connections from client closed") return } // inject xlog object into net.Conn context @@ -504,7 +504,7 @@ func (svr *Service) HandleListener(l net.Listener, internal bool) { var isTLS, custom bool c, isTLS, custom, err = netpkg.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, forceTLS, connReadTimeout) if err != nil { - log.Warnf("CheckAndEnableTLSServerConnWithTimeout error: %v", err) + log.Warnf("checkAndEnableTLSServerConnWithTimeout error: %v", err) originConn.Close() continue } @@ -520,7 +520,7 @@ func (svr *Service) HandleListener(l net.Listener, internal bool) { fmuxCfg.MaxStreamWindowSize = 6 * 1024 * 1024 session, err := fmux.Server(frpConn, fmuxCfg) if err != nil { - log.Warnf("Failed to create mux connection: %v", err) + log.Warnf("failed to create mux connection: %v", err) frpConn.Close() return } @@ -528,7 +528,7 @@ func (svr *Service) HandleListener(l net.Listener, internal bool) { for { stream, err := session.AcceptStream() if err != nil { - log.Debugf("Accept new mux stream error: %v", err) + log.Debugf("accept new mux stream error: %v", err) session.Close() return } @@ -546,7 +546,7 @@ func (svr *Service) HandleQUICListener(l *quic.Listener) { for { c, err := l.Accept(context.Background()) if err != nil { - log.Warnf("QUICListener for incoming connections from client closed") + log.Warnf("quic listener for incoming connections from client closed") return } // Start a new goroutine to handle connection. @@ -554,7 +554,7 @@ func (svr *Service) HandleQUICListener(l *quic.Listener) { for { stream, err := frpConn.AcceptStream(context.Background()) if err != nil { - log.Debugf("Accept new quic mux stream error: %v", err) + log.Debugf("accept new quic mux stream error: %v", err) _ = frpConn.CloseWithError(0, "") return } @@ -620,7 +620,7 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) xl := netpkg.NewLogFromConn(workConn) ctl, exist := svr.ctlManager.GetByID(newMsg.RunID) if !exist { - xl.Warnf("No client control found for run id [%s]", newMsg.RunID) + xl.Warnf("no client control found for run id [%s]", newMsg.RunID) return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID) } // server plugin hook diff --git a/test/e2e/e2e.go b/test/e2e/e2e.go index 994ac59d..ee00e103 100644 --- a/test/e2e/e2e.go +++ b/test/e2e/e2e.go @@ -38,7 +38,7 @@ func RunE2ETests(t *testing.T) { // Randomize specs as well as suites suiteConfig.RandomizeAllSpecs = true - log.Infof("Starting e2e run %q on Ginkgo node %d of total %d", + log.Infof("starting e2e run %q on Ginkgo node %d of total %d", framework.RunID, suiteConfig.ParallelProcess, suiteConfig.ParallelTotal) ginkgo.RunSpecs(t, "frp e2e suite", suiteConfig, reporterConfig) } diff --git a/test/e2e/framework/request.go b/test/e2e/framework/request.go index f56fc973..599ff11b 100644 --- a/test/e2e/framework/request.go +++ b/test/e2e/framework/request.go @@ -20,7 +20,7 @@ func ExpectResponseCode(code int) EnsureFunc { if resp.Code == code { return true } - flog.Warnf("Expect code %d, but got %d", code, resp.Code) + flog.Warnf("expect code %d, but got %d", code, resp.Code) return false } } @@ -111,14 +111,14 @@ func (e *RequestExpect) Ensure(fns ...EnsureFunc) { if len(fns) == 0 { if !bytes.Equal(e.expectResp, ret.Content) { - flog.Tracef("Response info: %+v", ret) + flog.Tracef("response info: %+v", ret) } ExpectEqualValuesWithOffset(1, string(ret.Content), string(e.expectResp), e.explain...) } else { for _, fn := range fns { ok := fn(ret) if !ok { - flog.Tracef("Response info: %+v", ret) + flog.Tracef("response info: %+v", ret) } ExpectTrueWithOffset(1, ok, e.explain...) } diff --git a/test/e2e/legacy/features/real_ip.go b/test/e2e/legacy/features/real_ip.go index f74c62d2..a79afb45 100644 --- a/test/e2e/legacy/features/real_ip.go +++ b/test/e2e/legacy/features/real_ip.go @@ -93,7 +93,7 @@ var _ = ginkgo.Describe("[Feature: Real IP]", func() { f.RunProcesses([]string{serverConf}, []string{clientConf}) framework.NewRequestExpect(f).Port(remotePort).Ensure(func(resp *request.Response) bool { - log.Tracef("ProxyProtocol get SourceAddr: %s", string(resp.Content)) + log.Tracef("proxy protocol get SourceAddr: %s", string(resp.Content)) addr, err := net.ResolveTCPAddr("tcp", string(resp.Content)) if err != nil { return false @@ -142,7 +142,7 @@ var _ = ginkgo.Describe("[Feature: Real IP]", func() { r.HTTP().HTTPHost("normal.example.com") }).Ensure(framework.ExpectResponseCode(404)) - log.Tracef("ProxyProtocol get SourceAddr: %s", srcAddrRecord) + log.Tracef("proxy protocol get SourceAddr: %s", srcAddrRecord) addr, err := net.ResolveTCPAddr("tcp", srcAddrRecord) framework.ExpectNoError(err, srcAddrRecord) framework.ExpectEqualValues("127.0.0.1", addr.IP.String()) diff --git a/test/e2e/v1/features/real_ip.go b/test/e2e/v1/features/real_ip.go index 94508f7d..216f531d 100644 --- a/test/e2e/v1/features/real_ip.go +++ b/test/e2e/v1/features/real_ip.go @@ -215,7 +215,7 @@ var _ = ginkgo.Describe("[Feature: Real IP]", func() { f.RunProcesses([]string{serverConf}, []string{clientConf}) framework.NewRequestExpect(f).Port(remotePort).Ensure(func(resp *request.Response) bool { - log.Tracef("ProxyProtocol get SourceAddr: %s", string(resp.Content)) + log.Tracef("proxy protocol get SourceAddr: %s", string(resp.Content)) addr, err := net.ResolveTCPAddr("tcp", string(resp.Content)) if err != nil { return false @@ -265,7 +265,7 @@ var _ = ginkgo.Describe("[Feature: Real IP]", func() { r.HTTP().HTTPHost("normal.example.com") }).Ensure(framework.ExpectResponseCode(404)) - log.Tracef("ProxyProtocol get SourceAddr: %s", srcAddrRecord) + log.Tracef("proxy protocol get SourceAddr: %s", srcAddrRecord) addr, err := net.ResolveTCPAddr("tcp", srcAddrRecord) framework.ExpectNoError(err, srcAddrRecord) framework.ExpectEqualValues("127.0.0.1", addr.IP.String())