add proxy protocol support for UDP proxies (#4810)

This commit is contained in:
fatedier
2025-05-23 21:39:47 +08:00
parent 8eb525a648
commit 3fa76b72f3
9 changed files with 299 additions and 25 deletions

View File

@@ -24,6 +24,7 @@ import (
"github.com/fatedier/golib/pool"
"github.com/fatedier/frp/pkg/msg"
netpkg "github.com/fatedier/frp/pkg/util/net"
)
func NewUDPPacket(buf []byte, laddr, raddr *net.UDPAddr) *msg.UDPPacket {
@@ -69,7 +70,7 @@ func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UDPPacket, sendCh
}
}
func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<- msg.Message, bufSize int) {
func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<- msg.Message, bufSize int, proxyProtocolVersion string) {
var mu sync.RWMutex
udpConnMap := make(map[string]*net.UDPConn)
@@ -110,6 +111,7 @@ func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<-
if err != nil {
continue
}
mu.Lock()
udpConn, ok := udpConnMap[udpMsg.RemoteAddr.String()]
if !ok {
@@ -122,6 +124,18 @@ func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<-
}
mu.Unlock()
// Add proxy protocol header if configured
if proxyProtocolVersion != "" && udpMsg.RemoteAddr != nil {
ppBuf, err := netpkg.BuildProxyProtocolHeader(udpMsg.RemoteAddr, dstAddr, proxyProtocolVersion)
if err == nil {
// Prepend proxy protocol header to the UDP payload
finalBuf := make([]byte, len(ppBuf)+len(buf))
copy(finalBuf, ppBuf)
copy(finalBuf[len(ppBuf):], buf)
buf = finalBuf
}
}
_, err = udpConn.Write(buf)
if err != nil {
udpConn.Close()

View File

@@ -0,0 +1,45 @@
// Copyright 2025 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package net
import (
"bytes"
"fmt"
"net"
pp "github.com/pires/go-proxyproto"
)
func BuildProxyProtocolHeaderStruct(srcAddr, dstAddr net.Addr, version string) *pp.Header {
var versionByte byte
if version == "v1" {
versionByte = 1
} else {
versionByte = 2 // default to v2
}
return pp.HeaderProxyFromAddrs(versionByte, srcAddr, dstAddr)
}
func BuildProxyProtocolHeader(srcAddr, dstAddr net.Addr, version string) ([]byte, error) {
h := BuildProxyProtocolHeaderStruct(srcAddr, dstAddr, version)
// Convert header to bytes using a buffer
var buf bytes.Buffer
_, err := h.WriteTo(&buf)
if err != nil {
return nil, fmt.Errorf("failed to write proxy protocol header: %v", err)
}
return buf.Bytes(), nil
}

View File

@@ -0,0 +1,178 @@
package net
import (
"net"
"testing"
pp "github.com/pires/go-proxyproto"
"github.com/stretchr/testify/require"
)
func TestBuildProxyProtocolHeader(t *testing.T) {
require := require.New(t)
tests := []struct {
name string
srcAddr net.Addr
dstAddr net.Addr
version string
expectError bool
}{
{
name: "UDP IPv4 v2",
srcAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
dstAddr: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306},
version: "v2",
expectError: false,
},
{
name: "TCP IPv4 v1",
srcAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
dstAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 80},
version: "v1",
expectError: false,
},
{
name: "UDP IPv6 v2",
srcAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 12345},
dstAddr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 3306},
version: "v2",
expectError: false,
},
{
name: "TCP IPv6 v1",
srcAddr: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 12345},
dstAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 80},
version: "v1",
expectError: false,
},
{
name: "nil source address",
srcAddr: nil,
dstAddr: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306},
version: "v2",
expectError: false,
},
{
name: "nil destination address",
srcAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
dstAddr: nil,
version: "v2",
expectError: false,
},
{
name: "unsupported address type",
srcAddr: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
dstAddr: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306},
version: "v2",
expectError: false,
},
}
for _, tt := range tests {
header, err := BuildProxyProtocolHeader(tt.srcAddr, tt.dstAddr, tt.version)
if tt.expectError {
require.Error(err, "test case: %s", tt.name)
continue
}
require.NoError(err, "test case: %s", tt.name)
require.NotEmpty(header, "test case: %s", tt.name)
}
}
func TestBuildProxyProtocolHeaderStruct(t *testing.T) {
require := require.New(t)
tests := []struct {
name string
srcAddr net.Addr
dstAddr net.Addr
version string
expectedProtocol pp.AddressFamilyAndProtocol
expectedVersion byte
expectedCommand pp.ProtocolVersionAndCommand
expectedSourceAddr net.Addr
expectedDestAddr net.Addr
}{
{
name: "TCP IPv4 v2",
srcAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
dstAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 80},
version: "v2",
expectedProtocol: pp.TCPv4,
expectedVersion: 2,
expectedCommand: pp.PROXY,
expectedSourceAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
expectedDestAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 80},
},
{
name: "UDP IPv6 v1",
srcAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 12345},
dstAddr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 3306},
version: "v1",
expectedProtocol: pp.UDPv6,
expectedVersion: 1,
expectedCommand: pp.PROXY,
expectedSourceAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 12345},
expectedDestAddr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 3306},
},
{
name: "TCP IPv6 default version",
srcAddr: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 12345},
dstAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 80},
version: "",
expectedProtocol: pp.TCPv6,
expectedVersion: 2, // default to v2
expectedCommand: pp.PROXY,
expectedSourceAddr: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 12345},
expectedDestAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 80},
},
{
name: "nil source address",
srcAddr: nil,
dstAddr: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306},
version: "v2",
expectedProtocol: pp.UNSPEC,
expectedVersion: 2,
expectedCommand: pp.LOCAL,
expectedSourceAddr: nil, // go-proxyproto sets both to nil when srcAddr is nil
expectedDestAddr: nil,
},
{
name: "nil destination address",
srcAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
dstAddr: nil,
version: "v2",
expectedProtocol: pp.UNSPEC,
expectedVersion: 2,
expectedCommand: pp.LOCAL,
expectedSourceAddr: nil, // go-proxyproto sets both to nil when dstAddr is nil
expectedDestAddr: nil,
},
{
name: "unsupported address type",
srcAddr: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
dstAddr: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306},
version: "v2",
expectedProtocol: pp.UNSPEC,
expectedVersion: 2,
expectedCommand: pp.LOCAL,
expectedSourceAddr: nil, // go-proxyproto sets both to nil for unsupported types
expectedDestAddr: nil,
},
}
for _, tt := range tests {
header := BuildProxyProtocolHeaderStruct(tt.srcAddr, tt.dstAddr, tt.version)
require.NotNil(header, "test case: %s", tt.name)
require.Equal(tt.expectedCommand, header.Command, "test case: %s", tt.name)
require.Equal(tt.expectedSourceAddr, header.SourceAddr, "test case: %s", tt.name)
require.Equal(tt.expectedDestAddr, header.DestinationAddr, "test case: %s", tt.name)
require.Equal(tt.expectedProtocol, header.TransportProtocol, "test case: %s", tt.name)
require.Equal(tt.expectedVersion, header.Version, "test case: %s", tt.name)
}
}