Compare commits

..

No commits in common. "cceab7e1b199bd3cf6c806030dfb99482239b3a1" and "2225a1781fc2774517b5d164c98675f09ba156e3" have entirely different histories.

93 changed files with 2543 additions and 4021 deletions

View File

@ -19,9 +19,6 @@ fmt:
fmt-more: fmt-more:
gofumpt -l -w . gofumpt -l -w .
gci:
gci write -s standard -s default -s "prefix(github.com/fatedier/frp/)" ./
vet: vet:
go vet ./... go vet ./...

View File

@ -1,3 +1,4 @@
# frp # frp
[![Build Status](https://circleci.com/gh/fatedier/frp.svg?style=shield)](https://circleci.com/gh/fatedier/frp) [![Build Status](https://circleci.com/gh/fatedier/frp.svg?style=shield)](https://circleci.com/gh/fatedier/frp)
@ -11,7 +12,12 @@
<a href="https://workos.com/?utm_campaign=github_repo&utm_medium=referral&utm_content=frp&utm_source=github" target="_blank"> <a href="https://workos.com/?utm_campaign=github_repo&utm_medium=referral&utm_content=frp&utm_source=github" target="_blank">
<img width="350px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_workos.png"> <img width="350px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_workos.png">
</a> </a>
<a>&nbsp</a>
<a href="https://asocks.com/c/vDu6Dk" target="_blank">
<img width="350px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_asocks.jpg">
</a>
</p> </p>
<!--gold sponsors end--> <!--gold sponsors end-->
## What is frp? ## What is frp?
@ -343,15 +349,20 @@ Configure `frps` same as above.
Note that it may not work with all types of NAT devices. You might want to fallback to stcp if xtcp doesn't work. Note that it may not work with all types of NAT devices. You might want to fallback to stcp if xtcp doesn't work.
1. Start `frpc` on machine B, and expose the SSH port. Note that the `remote_port` field is removed: 1. In `frps.ini` configure a UDP port for xtcp:
```ini
# frps.ini
bind_udp_port = 7001
```
2. Start `frpc` on machine B, and expose the SSH port. Note that the `remote_port` field is removed:
```ini ```ini
# frpc.ini # frpc.ini
[common] [common]
server_addr = x.x.x.x server_addr = x.x.x.x
server_port = 7000 server_port = 7000
# set up a new stun server if the default one is not available.
# nat_hole_stun_server = xxx
[p2p_ssh] [p2p_ssh]
type = xtcp type = xtcp
@ -360,15 +371,13 @@ Note that it may not work with all types of NAT devices. You might want to fallb
local_port = 22 local_port = 22
``` ```
2. Start another `frpc` (typically on another machine C) with the configuration to connect to SSH using P2P mode: 3. Start another `frpc` (typically on another machine C) with the configuration to connect to SSH using P2P mode:
```ini ```ini
# frpc.ini # frpc.ini
[common] [common]
server_addr = x.x.x.x server_addr = x.x.x.x
server_port = 7000 server_port = 7000
# set up a new stun server if the default one is not available.
# nat_hole_stun_server = xxx
[p2p_ssh_visitor] [p2p_ssh_visitor]
type = xtcp type = xtcp
@ -377,11 +386,9 @@ Note that it may not work with all types of NAT devices. You might want to fallb
sk = abcdefg sk = abcdefg
bind_addr = 127.0.0.1 bind_addr = 127.0.0.1
bind_port = 6000 bind_port = 6000
# when automatic tunnel persistence is required, set it to true
keep_tunnel_open = false
``` ```
3. On machine C, connect to SSH on machine B, using this command: 4. On machine C, connect to SSH on machine B, using this command:
`ssh -oPort=6000 127.0.0.1` `ssh -oPort=6000 127.0.0.1`

View File

@ -1,6 +1,6 @@
# frp # frp
[![Build Status](https://circleci.com/gh/fatedier/frp.svg?style=shield)](https://circleci.com/gh/fatedier/frp) [![Build Status](https://travis-ci.org/fatedier/frp.svg?branch=master)](https://travis-ci.org/fatedier/frp)
[![GitHub release](https://img.shields.io/github/tag/fatedier/frp.svg?label=release)](https://github.com/fatedier/frp/releases) [![GitHub release](https://img.shields.io/github/tag/fatedier/frp.svg?label=release)](https://github.com/fatedier/frp/releases)
[README](README.md) | [中文文档](README_zh.md) [README](README.md) | [中文文档](README_zh.md)
@ -13,6 +13,10 @@ frp 是一个专注于内网穿透的高性能的反向代理应用,支持 TCP
<a href="https://workos.com/?utm_campaign=github_repo&utm_medium=referral&utm_content=frp&utm_source=github" target="_blank"> <a href="https://workos.com/?utm_campaign=github_repo&utm_medium=referral&utm_content=frp&utm_source=github" target="_blank">
<img width="350px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_workos.png"> <img width="350px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_workos.png">
</a> </a>
<a>&nbsp</a>
<a href="https://asocks.com/c/vDu6Dk" target="_blank">
<img width="350px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_asocks.jpg">
</a>
</p> </p>
<!--gold sponsors end--> <!--gold sponsors end-->

View File

@ -1,19 +1,8 @@
## Notes
We have thoroughly refactored xtcp in this version to improve its penetration rate and stability.
In this version, different penetration strategies can be attempted by retrying connections multiple times. Once a hole is successfully punched, the strategy will be recorded in the server cache for future reuse. When new users connect, the successfully penetrated tunnel can be reused instead of punching a new hole.
**Due to a significant refactor of xtcp, this version is not compatible with previous versions of xtcp.**
**To use features related to xtcp, both frpc and frps need to be updated to the latest version.**
### New ### New
* The frpc has added the `nathole discover` command for testing the NAT type of the current network. * The `httpconnect` type in `tcpmux` now supports authentication through the parameters `http_user` and `http_pwd`.
* `XTCP` has been refactored, resulting in a significant improvement in the success rate of penetration.
* When verifying passwords, use `subtle.ConstantTimeCompare` and introduce a certain delay when the password is incorrect.
### Fix ### Improved
* Fix the problem of lagging when opening multiple table entries in the frps dashboard. * The web framework has been upgraded to vue3 + element-plus, and the dashboard has added some information display and supports dark mode.
* The e2e testing has been switched to ginkgo v2.

File diff suppressed because one or more lines are too long

View File

@ -4,7 +4,7 @@
<head> <head>
<meta charset="utf-8"> <meta charset="utf-8">
<title>frps dashboard</title> <title>frps dashboard</title>
<script type="module" crossorigin src="./index-ea3edf22.js"></script> <script type="module" crossorigin src="./index-93e38bbf.js"></script>
<link rel="stylesheet" href="./index-1e0c7400.css"> <link rel="stylesheet" href="./index-1e0c7400.css">
</head> </head>

View File

@ -23,7 +23,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/fatedier/frp/assets" "github.com/fatedier/frp/assets"
utilnet "github.com/fatedier/frp/pkg/util/net" frpNet "github.com/fatedier/frp/pkg/util/net"
) )
var ( var (
@ -48,7 +48,7 @@ func (svr *Service) RunAdminServer(address string) (err error) {
subRouter := router.NewRoute().Subrouter() subRouter := router.NewRoute().Subrouter()
user, passwd := svr.cfg.AdminUser, svr.cfg.AdminPwd user, passwd := svr.cfg.AdminUser, svr.cfg.AdminPwd
subRouter.Use(utilnet.NewHTTPAuthMiddleware(user, passwd).SetAuthFailDelay(200 * time.Millisecond).Middleware) subRouter.Use(frpNet.NewHTTPAuthMiddleware(user, passwd).Middleware)
// api, see admin_api.go // api, see admin_api.go
subRouter.HandleFunc("/api/reload", svr.apiReload).Methods("GET") subRouter.HandleFunc("/api/reload", svr.apiReload).Methods("GET")
@ -58,7 +58,7 @@ func (svr *Service) RunAdminServer(address string) (err error) {
// view // view
subRouter.Handle("/favicon.ico", http.FileServer(assets.FileSystem)).Methods("GET") subRouter.Handle("/favicon.ico", http.FileServer(assets.FileSystem)).Methods("GET")
subRouter.PathPrefix("/static/").Handler(utilnet.MakeHTTPGzipHandler(http.StripPrefix("/static/", http.FileServer(assets.FileSystem)))).Methods("GET") subRouter.PathPrefix("/static/").Handler(frpNet.MakeHTTPGzipHandler(http.StripPrefix("/static/", http.FileServer(assets.FileSystem)))).Methods("GET")
subRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { subRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/static/", http.StatusMovedPermanently) http.Redirect(w, r, "/static/", http.StatusMovedPermanently)
}) })

View File

@ -25,11 +25,10 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/samber/lo"
"github.com/fatedier/frp/client/proxy" "github.com/fatedier/frp/client/proxy"
"github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/util/log" "github.com/fatedier/frp/pkg/util/log"
"github.com/fatedier/frp/pkg/util/util"
) )
type GeneralResponse struct { type GeneralResponse struct {
@ -91,7 +90,7 @@ func NewProxyStatusResp(status *proxy.WorkingStatus, serverAddr string) ProxySta
Status: status.Phase, Status: status.Phase,
Err: status.Err, Err: status.Err,
} }
baseCfg := status.Cfg.GetBaseConfig() baseCfg := status.Cfg.GetBaseInfo()
if baseCfg.LocalPort != 0 { if baseCfg.LocalPort != 0 {
psr.LocalAddr = net.JoinHostPort(baseCfg.LocalIP, strconv.Itoa(baseCfg.LocalPort)) psr.LocalAddr = net.JoinHostPort(baseCfg.LocalIP, strconv.Itoa(baseCfg.LocalPort))
} }
@ -99,7 +98,7 @@ func NewProxyStatusResp(status *proxy.WorkingStatus, serverAddr string) ProxySta
if status.Err == "" { if status.Err == "" {
psr.RemoteAddr = status.RemoteAddr psr.RemoteAddr = status.RemoteAddr
if lo.Contains([]string{"tcp", "udp"}, status.Type) { if util.InSlice(status.Type, []string{"tcp", "udp"}) {
psr.RemoteAddr = serverAddr + psr.RemoteAddr psr.RemoteAddr = serverAddr + psr.RemoteAddr
} }
} }

View File

@ -25,21 +25,14 @@ import (
"github.com/fatedier/golib/crypto" "github.com/fatedier/golib/crypto"
"github.com/fatedier/frp/client/proxy" "github.com/fatedier/frp/client/proxy"
"github.com/fatedier/frp/client/visitor"
"github.com/fatedier/frp/pkg/auth" "github.com/fatedier/frp/pkg/auth"
"github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
type Control struct { type Control struct {
// service context // uniq id got from frps, attach it in loginMsg
ctx context.Context
xl *xlog.Logger
// Unique ID obtained from frps.
// It should be attached to the login message when reconnecting.
runID string runID string
// manage all proxies // manage all proxies
@ -47,7 +40,7 @@ type Control struct {
pm *proxy.Manager pm *proxy.Manager
// manage all visitors // manage all visitors
vm *visitor.Manager vm *VisitorManager
// control connection // control connection
conn net.Conn conn net.Conn
@ -75,10 +68,16 @@ type Control struct {
writerShutdown *shutdown.Shutdown writerShutdown *shutdown.Shutdown
msgHandlerShutdown *shutdown.Shutdown msgHandlerShutdown *shutdown.Shutdown
// The UDP port that the server is listening on
serverUDPPort int
xl *xlog.Logger
// service context
ctx context.Context
// sets authentication based on selected method // sets authentication based on selected method
authSetter auth.Setter authSetter auth.Setter
msgTransporter transport.MessageTransporter
} }
func NewControl( func NewControl(
@ -86,12 +85,11 @@ func NewControl(
clientCfg config.ClientCommonConf, clientCfg config.ClientCommonConf,
pxyCfgs map[string]config.ProxyConf, pxyCfgs map[string]config.ProxyConf,
visitorCfgs map[string]config.VisitorConf, visitorCfgs map[string]config.VisitorConf,
serverUDPPort int,
authSetter auth.Setter, authSetter auth.Setter,
) *Control { ) *Control {
// new xlog instance // new xlog instance
ctl := &Control{ ctl := &Control{
ctx: ctx,
xl: xlog.FromContextSafe(ctx),
runID: runID, runID: runID,
conn: conn, conn: conn,
cm: cm, cm: cm,
@ -104,12 +102,14 @@ func NewControl(
readerShutdown: shutdown.New(), readerShutdown: shutdown.New(),
writerShutdown: shutdown.New(), writerShutdown: shutdown.New(),
msgHandlerShutdown: shutdown.New(), msgHandlerShutdown: shutdown.New(),
serverUDPPort: serverUDPPort,
xl: xlog.FromContextSafe(ctx),
ctx: ctx,
authSetter: authSetter, authSetter: authSetter,
} }
ctl.msgTransporter = transport.NewMessageTransporter(ctl.sendCh) ctl.pm = proxy.NewManager(ctl.ctx, ctl.sendCh, clientCfg, serverUDPPort)
ctl.pm = proxy.NewManager(ctl.ctx, clientCfg, ctl.msgTransporter)
ctl.vm = visitor.NewManager(ctl.ctx, ctl.runID, ctl.clientCfg, ctl.connectServer, ctl.msgTransporter) ctl.vm = NewVisitorManager(ctl.ctx, ctl)
ctl.vm.Reload(visitorCfgs) ctl.vm.Reload(visitorCfgs)
return ctl return ctl
} }
@ -173,16 +173,6 @@ func (ctl *Control) HandleNewProxyResp(inMsg *msg.NewProxyResp) {
} }
} }
func (ctl *Control) HandleNatHoleResp(inMsg *msg.NatHoleResp) {
xl := ctl.xl
// Dispatch the NatHoleResp message to the related proxy.
ok := ctl.msgTransporter.DispatchWithType(inMsg, msg.TypeNameNatHoleResp, inMsg.TransactionID)
if !ok {
xl.Trace("dispatch NatHoleResp message to related proxy error")
}
}
func (ctl *Control) Close() error { func (ctl *Control) Close() error {
return ctl.GracefulClose(0) return ctl.GracefulClose(0)
} }
@ -198,7 +188,7 @@ func (ctl *Control) GracefulClose(d time.Duration) error {
return nil return nil
} }
// ClosedDoneCh returns a channel that will be closed after all resources are released // ClosedDoneCh returns a channel which will be closed after all resources are released
func (ctl *Control) ClosedDoneCh() <-chan struct{} { func (ctl *Control) ClosedDoneCh() <-chan struct{} {
return ctl.closedDoneCh return ctl.closedDoneCh
} }
@ -260,7 +250,7 @@ func (ctl *Control) writer() {
} }
} }
// msgHandler handles all channel events and performs corresponding operations. // msgHandler handles all channel events and do corresponding operations.
func (ctl *Control) msgHandler() { func (ctl *Control) msgHandler() {
xl := ctl.xl xl := ctl.xl
defer func() { defer func() {
@ -317,8 +307,6 @@ func (ctl *Control) msgHandler() {
go ctl.HandleReqWorkConn(m) go ctl.HandleReqWorkConn(m)
case *msg.NewProxyResp: case *msg.NewProxyResp:
ctl.HandleNewProxyResp(m) ctl.HandleNewProxyResp(m)
case *msg.NatHoleResp:
ctl.HandleNatHoleResp(m)
case *msg.Pong: case *msg.Pong:
if m.Error != "" { if m.Error != "" {
xl.Error("Pong contains error: %s", m.Error) xl.Error("Pong contains error: %s", m.Error)

View File

@ -1,47 +0,0 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"reflect"
"github.com/fatedier/frp/pkg/config"
)
func init() {
pxyConfs := []config.ProxyConf{
&config.TCPProxyConf{},
&config.HTTPProxyConf{},
&config.HTTPSProxyConf{},
&config.STCPProxyConf{},
&config.TCPMuxProxyConf{},
}
for _, cfg := range pxyConfs {
RegisterProxyFactory(reflect.TypeOf(cfg), NewGeneralTCPProxy)
}
}
// GeneralTCPProxy is a general implementation of Proxy interface for TCP protocol.
// If the default GeneralTCPProxy cannot meet the requirements, you can customize
// the implementation of the Proxy interface.
type GeneralTCPProxy struct {
*BaseProxy
}
func NewGeneralTCPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy {
return &GeneralTCPProxy{
BaseProxy: baseProxy,
}
}

View File

@ -19,31 +19,28 @@ import (
"context" "context"
"io" "io"
"net" "net"
"reflect"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
libio "github.com/fatedier/golib/io" "github.com/fatedier/golib/errors"
frpIo "github.com/fatedier/golib/io"
libdial "github.com/fatedier/golib/net/dial" libdial "github.com/fatedier/golib/net/dial"
"github.com/fatedier/golib/pool"
fmux "github.com/hashicorp/yamux"
pp "github.com/pires/go-proxyproto" pp "github.com/pires/go-proxyproto"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
plugin "github.com/fatedier/frp/pkg/plugin/client" plugin "github.com/fatedier/frp/pkg/plugin/client"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/proto/udp"
"github.com/fatedier/frp/pkg/util/limit" "github.com/fatedier/frp/pkg/util/limit"
frpNet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
var proxyFactoryRegistry = map[reflect.Type]func(*BaseProxy, config.ProxyConf) Proxy{}
func RegisterProxyFactory(proxyConfType reflect.Type, factory func(*BaseProxy, config.ProxyConf) Proxy) {
proxyFactoryRegistry[proxyConfType] = factory
}
// Proxy defines how to handle work connections for different proxy type. // Proxy defines how to handle work connections for different proxy type.
type Proxy interface { type Proxy interface {
Run() error Run() error
@ -54,101 +51,715 @@ type Proxy interface {
Close() Close()
} }
func NewProxy( func NewProxy(ctx context.Context, pxyConf config.ProxyConf, clientCfg config.ClientCommonConf, serverUDPPort int) (pxy Proxy) {
ctx context.Context,
pxyConf config.ProxyConf,
clientCfg config.ClientCommonConf,
msgTransporter transport.MessageTransporter,
) (pxy Proxy) {
var limiter *rate.Limiter var limiter *rate.Limiter
limitBytes := pxyConf.GetBaseConfig().BandwidthLimit.Bytes() limitBytes := pxyConf.GetBaseInfo().BandwidthLimit.Bytes()
if limitBytes > 0 && pxyConf.GetBaseConfig().BandwidthLimitMode == config.BandwidthLimitModeClient { if limitBytes > 0 && pxyConf.GetBaseInfo().BandwidthLimitMode == config.BandwidthLimitModeClient {
limiter = rate.NewLimiter(rate.Limit(float64(limitBytes)), int(limitBytes)) limiter = rate.NewLimiter(rate.Limit(float64(limitBytes)), int(limitBytes))
} }
baseProxy := BaseProxy{ baseProxy := BaseProxy{
baseProxyConfig: pxyConf.GetBaseConfig(), clientCfg: clientCfg,
clientCfg: clientCfg, serverUDPPort: serverUDPPort,
limiter: limiter, limiter: limiter,
msgTransporter: msgTransporter, xl: xlog.FromContextSafe(ctx),
xl: xlog.FromContextSafe(ctx), ctx: ctx,
ctx: ctx,
} }
switch cfg := pxyConf.(type) {
factory := proxyFactoryRegistry[reflect.TypeOf(pxyConf)] case *config.TCPProxyConf:
if factory == nil { pxy = &TCPProxy{
return nil BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.TCPMuxProxyConf:
pxy = &TCPMuxProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.UDPProxyConf:
pxy = &UDPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.HTTPProxyConf:
pxy = &HTTPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.HTTPSProxyConf:
pxy = &HTTPSProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.STCPProxyConf:
pxy = &STCPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.XTCPProxyConf:
pxy = &XTCPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.SUDPProxyConf:
pxy = &SUDPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
closeCh: make(chan struct{}),
}
} }
return factory(&baseProxy, pxyConf) return
} }
type BaseProxy struct { type BaseProxy struct {
baseProxyConfig *config.BaseProxyConf closed bool
clientCfg config.ClientCommonConf clientCfg config.ClientCommonConf
msgTransporter transport.MessageTransporter serverUDPPort int
limiter *rate.Limiter limiter *rate.Limiter
// proxyPlugin is used to handle connections instead of dialing to local service.
// It's only validate for TCP protocol now.
proxyPlugin plugin.Plugin
mu sync.RWMutex mu sync.RWMutex
xl *xlog.Logger xl *xlog.Logger
ctx context.Context ctx context.Context
} }
func (pxy *BaseProxy) Run() error { // TCP
if pxy.baseProxyConfig.Plugin != "" { type TCPProxy struct {
p, err := plugin.Create(pxy.baseProxyConfig.Plugin, pxy.baseProxyConfig.PluginParams) *BaseProxy
if err != nil {
return err cfg *config.TCPProxyConf
} proxyPlugin plugin.Plugin
pxy.proxyPlugin = p
}
return nil
} }
func (pxy *BaseProxy) Close() { func (pxy *TCPProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil {
return
}
}
return
}
func (pxy *TCPProxy) Close() {
if pxy.proxyPlugin != nil { if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close() pxy.proxyPlugin.Close()
} }
} }
func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { func (pxy *TCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
pxy.HandleTCPWorkConnection(conn, m, []byte(pxy.clientCfg.Token)) HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter,
conn, []byte(pxy.clientCfg.Token), m)
}
// TCP Multiplexer
type TCPMuxProxy struct {
*BaseProxy
cfg *config.TCPMuxProxyConf
proxyPlugin plugin.Plugin
}
func (pxy *TCPMuxProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil {
return
}
}
return
}
func (pxy *TCPMuxProxy) Close() {
if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close()
}
}
func (pxy *TCPMuxProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter,
conn, []byte(pxy.clientCfg.Token), m)
}
// HTTP
type HTTPProxy struct {
*BaseProxy
cfg *config.HTTPProxyConf
proxyPlugin plugin.Plugin
}
func (pxy *HTTPProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil {
return
}
}
return
}
func (pxy *HTTPProxy) Close() {
if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close()
}
}
func (pxy *HTTPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter,
conn, []byte(pxy.clientCfg.Token), m)
}
// HTTPS
type HTTPSProxy struct {
*BaseProxy
cfg *config.HTTPSProxyConf
proxyPlugin plugin.Plugin
}
func (pxy *HTTPSProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil {
return
}
}
return
}
func (pxy *HTTPSProxy) Close() {
if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close()
}
}
func (pxy *HTTPSProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter,
conn, []byte(pxy.clientCfg.Token), m)
}
// STCP
type STCPProxy struct {
*BaseProxy
cfg *config.STCPProxyConf
proxyPlugin plugin.Plugin
}
func (pxy *STCPProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil {
return
}
}
return
}
func (pxy *STCPProxy) Close() {
if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close()
}
}
func (pxy *STCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter,
conn, []byte(pxy.clientCfg.Token), m)
}
// XTCP
type XTCPProxy struct {
*BaseProxy
cfg *config.XTCPProxyConf
proxyPlugin plugin.Plugin
}
func (pxy *XTCPProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil {
return
}
}
return
}
func (pxy *XTCPProxy) Close() {
if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close()
}
}
func (pxy *XTCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
xl := pxy.xl
defer conn.Close()
var natHoleSidMsg msg.NatHoleSid
err := msg.ReadMsgInto(conn, &natHoleSidMsg)
if err != nil {
xl.Error("xtcp read from workConn error: %v", err)
return
}
natHoleClientMsg := &msg.NatHoleClient{
ProxyName: pxy.cfg.ProxyName,
Sid: natHoleSidMsg.Sid,
}
serverAddr := pxy.clientCfg.NatHoleServerAddr
if serverAddr == "" {
serverAddr = pxy.clientCfg.ServerAddr
}
raddr, _ := net.ResolveUDPAddr("udp",
net.JoinHostPort(serverAddr, strconv.Itoa(pxy.serverUDPPort)))
clientConn, err := net.DialUDP("udp", nil, raddr)
if err != nil {
xl.Error("dial server udp addr error: %v", err)
return
}
defer clientConn.Close()
err = msg.WriteMsg(clientConn, natHoleClientMsg)
if err != nil {
xl.Error("send natHoleClientMsg to server error: %v", err)
return
}
// Wait for client address at most 5 seconds.
var natHoleRespMsg msg.NatHoleResp
_ = clientConn.SetReadDeadline(time.Now().Add(5 * time.Second))
buf := pool.GetBuf(1024)
n, err := clientConn.Read(buf)
if err != nil {
xl.Error("get natHoleRespMsg error: %v", err)
return
}
err = msg.ReadMsgInto(bytes.NewReader(buf[:n]), &natHoleRespMsg)
if err != nil {
xl.Error("get natHoleRespMsg error: %v", err)
return
}
_ = clientConn.SetReadDeadline(time.Time{})
_ = clientConn.Close()
if natHoleRespMsg.Error != "" {
xl.Error("natHoleRespMsg get error info: %s", natHoleRespMsg.Error)
return
}
xl.Trace("get natHoleRespMsg, sid [%s], client address [%s] visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr)
// Send detect message
host, portStr, err := net.SplitHostPort(natHoleRespMsg.VisitorAddr)
if err != nil {
xl.Error("get NatHoleResp visitor address [%s] error: %v", natHoleRespMsg.VisitorAddr, err)
}
laddr, _ := net.ResolveUDPAddr("udp", clientConn.LocalAddr().String())
port, err := strconv.ParseInt(portStr, 10, 64)
if err != nil {
xl.Error("get natHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr)
return
}
_ = pxy.sendDetectMsg(host, int(port), laddr, []byte(natHoleRespMsg.Sid))
xl.Trace("send all detect msg done")
if err := msg.WriteMsg(conn, &msg.NatHoleClientDetectOK{}); err != nil {
xl.Error("write message error: %v", err)
return
}
// Listen for clientConn's address and wait for visitor connection
lConn, err := net.ListenUDP("udp", laddr)
if err != nil {
xl.Error("listen on visitorConn's local address error: %v", err)
return
}
defer lConn.Close()
_ = lConn.SetReadDeadline(time.Now().Add(8 * time.Second))
sidBuf := pool.GetBuf(1024)
var uAddr *net.UDPAddr
n, uAddr, err = lConn.ReadFromUDP(sidBuf)
if err != nil {
xl.Warn("get sid from visitor error: %v", err)
return
}
_ = lConn.SetReadDeadline(time.Time{})
if string(sidBuf[:n]) != natHoleRespMsg.Sid {
xl.Warn("incorrect sid from visitor")
return
}
pool.PutBuf(sidBuf)
xl.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid)
if _, err := lConn.WriteToUDP(sidBuf[:n], uAddr); err != nil {
xl.Error("write uaddr error: %v", err)
return
}
kcpConn, err := frpNet.NewKCPConnFromUDP(lConn, false, uAddr.String())
if err != nil {
xl.Error("create kcp connection from udp connection error: %v", err)
return
}
fmuxCfg := fmux.DefaultConfig()
fmuxCfg.KeepAliveInterval = 5 * time.Second
fmuxCfg.LogOutput = io.Discard
sess, err := fmux.Server(kcpConn, fmuxCfg)
if err != nil {
xl.Error("create yamux server from kcp connection error: %v", err)
return
}
defer sess.Close()
muxConn, err := sess.Accept()
if err != nil {
xl.Error("accept for yamux connection error: %v", err)
return
}
HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter,
muxConn, []byte(pxy.cfg.Sk), m)
}
func (pxy *XTCPProxy) sendDetectMsg(addr string, port int, laddr *net.UDPAddr, content []byte) (err error) {
daddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(addr, strconv.Itoa(port)))
if err != nil {
return err
}
tConn, err := net.DialUDP("udp", laddr, daddr)
if err != nil {
return err
}
// uConn := ipv4.NewConn(tConn)
// uConn.SetTTL(3)
if _, err := tConn.Write(content); err != nil {
return err
}
return tConn.Close()
}
// UDP
type UDPProxy struct {
*BaseProxy
cfg *config.UDPProxyConf
localAddr *net.UDPAddr
readCh chan *msg.UDPPacket
// include msg.UDPPacket and msg.Ping
sendCh chan msg.Message
workConn net.Conn
}
func (pxy *UDPProxy) Run() (err error) {
pxy.localAddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(pxy.cfg.LocalIP, strconv.Itoa(pxy.cfg.LocalPort)))
if err != nil {
return
}
return
}
func (pxy *UDPProxy) Close() {
pxy.mu.Lock()
defer pxy.mu.Unlock()
if !pxy.closed {
pxy.closed = true
if pxy.workConn != nil {
pxy.workConn.Close()
}
if pxy.readCh != nil {
close(pxy.readCh)
}
if pxy.sendCh != nil {
close(pxy.sendCh)
}
}
}
func (pxy *UDPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
xl := pxy.xl
xl.Info("incoming a new work connection for udp proxy, %s", conn.RemoteAddr().String())
// close resources releated with old workConn
pxy.Close()
var rwc io.ReadWriteCloser = conn
var err error
if pxy.limiter != nil {
rwc = frpIo.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
return conn.Close()
})
}
if pxy.cfg.UseEncryption {
rwc, err = frpIo.WithEncryption(rwc, []byte(pxy.clientCfg.Token))
if err != nil {
conn.Close()
xl.Error("create encryption stream error: %v", err)
return
}
}
if pxy.cfg.UseCompression {
rwc = frpIo.WithCompression(rwc)
}
conn = frpNet.WrapReadWriteCloserToConn(rwc, conn)
pxy.mu.Lock()
pxy.workConn = conn
pxy.readCh = make(chan *msg.UDPPacket, 1024)
pxy.sendCh = make(chan msg.Message, 1024)
pxy.closed = false
pxy.mu.Unlock()
workConnReaderFn := func(conn net.Conn, readCh chan *msg.UDPPacket) {
for {
var udpMsg msg.UDPPacket
if errRet := msg.ReadMsgInto(conn, &udpMsg); errRet != nil {
xl.Warn("read from workConn for udp error: %v", errRet)
return
}
if errRet := errors.PanicToError(func() {
xl.Trace("get udp package from workConn: %s", udpMsg.Content)
readCh <- &udpMsg
}); errRet != nil {
xl.Info("reader goroutine for udp work connection closed: %v", errRet)
return
}
}
}
workConnSenderFn := func(conn net.Conn, sendCh chan msg.Message) {
defer func() {
xl.Info("writer goroutine for udp work connection closed")
}()
var errRet error
for rawMsg := range sendCh {
switch m := rawMsg.(type) {
case *msg.UDPPacket:
xl.Trace("send udp package to workConn: %s", m.Content)
case *msg.Ping:
xl.Trace("send ping message to udp workConn")
}
if errRet = msg.WriteMsg(conn, rawMsg); errRet != nil {
xl.Error("udp work write error: %v", errRet)
return
}
}
}
heartbeatFn := func(sendCh chan msg.Message) {
var errRet error
for {
time.Sleep(time.Duration(30) * time.Second)
if errRet = errors.PanicToError(func() {
sendCh <- &msg.Ping{}
}); errRet != nil {
xl.Trace("heartbeat goroutine for udp work connection closed")
break
}
}
}
go workConnSenderFn(pxy.workConn, pxy.sendCh)
go workConnReaderFn(pxy.workConn, pxy.readCh)
go heartbeatFn(pxy.sendCh)
udp.Forwarder(pxy.localAddr, pxy.readCh, pxy.sendCh, int(pxy.clientCfg.UDPPacketSize))
}
type SUDPProxy struct {
*BaseProxy
cfg *config.SUDPProxyConf
localAddr *net.UDPAddr
closeCh chan struct{}
}
func (pxy *SUDPProxy) Run() (err error) {
pxy.localAddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(pxy.cfg.LocalIP, strconv.Itoa(pxy.cfg.LocalPort)))
if err != nil {
return
}
return
}
func (pxy *SUDPProxy) Close() {
pxy.mu.Lock()
defer pxy.mu.Unlock()
select {
case <-pxy.closeCh:
return
default:
close(pxy.closeCh)
}
}
func (pxy *SUDPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
xl := pxy.xl
xl.Info("incoming a new work connection for sudp proxy, %s", conn.RemoteAddr().String())
var rwc io.ReadWriteCloser = conn
var err error
if pxy.limiter != nil {
rwc = frpIo.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
return conn.Close()
})
}
if pxy.cfg.UseEncryption {
rwc, err = frpIo.WithEncryption(rwc, []byte(pxy.clientCfg.Token))
if err != nil {
conn.Close()
xl.Error("create encryption stream error: %v", err)
return
}
}
if pxy.cfg.UseCompression {
rwc = frpIo.WithCompression(rwc)
}
conn = frpNet.WrapReadWriteCloserToConn(rwc, conn)
workConn := conn
readCh := make(chan *msg.UDPPacket, 1024)
sendCh := make(chan msg.Message, 1024)
isClose := false
mu := &sync.Mutex{}
closeFn := func() {
mu.Lock()
defer mu.Unlock()
if isClose {
return
}
isClose = true
if workConn != nil {
workConn.Close()
}
close(readCh)
close(sendCh)
}
// udp service <- frpc <- frps <- frpc visitor <- user
workConnReaderFn := func(conn net.Conn, readCh chan *msg.UDPPacket) {
defer closeFn()
for {
// first to check sudp proxy is closed or not
select {
case <-pxy.closeCh:
xl.Trace("frpc sudp proxy is closed")
return
default:
}
var udpMsg msg.UDPPacket
if errRet := msg.ReadMsgInto(conn, &udpMsg); errRet != nil {
xl.Warn("read from workConn for sudp error: %v", errRet)
return
}
if errRet := errors.PanicToError(func() {
readCh <- &udpMsg
}); errRet != nil {
xl.Warn("reader goroutine for sudp work connection closed: %v", errRet)
return
}
}
}
// udp service -> frpc -> frps -> frpc visitor -> user
workConnSenderFn := func(conn net.Conn, sendCh chan msg.Message) {
defer func() {
closeFn()
xl.Info("writer goroutine for sudp work connection closed")
}()
var errRet error
for rawMsg := range sendCh {
switch m := rawMsg.(type) {
case *msg.UDPPacket:
xl.Trace("frpc send udp package to frpc visitor, [udp local: %v, remote: %v], [tcp work conn local: %v, remote: %v]",
m.LocalAddr.String(), m.RemoteAddr.String(), conn.LocalAddr().String(), conn.RemoteAddr().String())
case *msg.Ping:
xl.Trace("frpc send ping message to frpc visitor")
}
if errRet = msg.WriteMsg(conn, rawMsg); errRet != nil {
xl.Error("sudp work write error: %v", errRet)
return
}
}
}
heartbeatFn := func(sendCh chan msg.Message) {
ticker := time.NewTicker(30 * time.Second)
defer func() {
ticker.Stop()
closeFn()
}()
var errRet error
for {
select {
case <-ticker.C:
if errRet = errors.PanicToError(func() {
sendCh <- &msg.Ping{}
}); errRet != nil {
xl.Warn("heartbeat goroutine for sudp work connection closed")
return
}
case <-pxy.closeCh:
xl.Trace("frpc sudp proxy is closed")
return
}
}
}
go workConnSenderFn(workConn, sendCh)
go workConnReaderFn(workConn, readCh)
go heartbeatFn(sendCh)
udp.Forwarder(pxy.localAddr, readCh, sendCh, int(pxy.clientCfg.UDPPacketSize))
} }
// Common handler for tcp work connections. // Common handler for tcp work connections.
func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWorkConn, encKey []byte) { func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf, proxyPlugin plugin.Plugin,
xl := pxy.xl baseInfo *config.BaseProxyConf, limiter *rate.Limiter, workConn net.Conn, encKey []byte, m *msg.StartWorkConn,
baseConfig := pxy.baseProxyConfig ) {
xl := xlog.FromContextSafe(ctx)
var ( var (
remote io.ReadWriteCloser remote io.ReadWriteCloser
err error err error
) )
remote = workConn remote = workConn
if pxy.limiter != nil { if limiter != nil {
remote = libio.WrapReadWriteCloser(limit.NewReader(workConn, pxy.limiter), limit.NewWriter(workConn, pxy.limiter), func() error { remote = frpIo.WrapReadWriteCloser(limit.NewReader(workConn, limiter), limit.NewWriter(workConn, limiter), func() error {
return workConn.Close() return workConn.Close()
}) })
} }
xl.Trace("handle tcp work connection, use_encryption: %t, use_compression: %t", xl.Trace("handle tcp work connection, use_encryption: %t, use_compression: %t",
baseConfig.UseEncryption, baseConfig.UseCompression) baseInfo.UseEncryption, baseInfo.UseCompression)
if baseConfig.UseEncryption { if baseInfo.UseEncryption {
remote, err = libio.WithEncryption(remote, encKey) remote, err = frpIo.WithEncryption(remote, encKey)
if err != nil { if err != nil {
workConn.Close() workConn.Close()
xl.Error("create encryption stream error: %v", err) xl.Error("create encryption stream error: %v", err)
return return
} }
} }
if baseConfig.UseCompression { if baseInfo.UseCompression {
remote = libio.WithCompression(remote) remote = frpIo.WithCompression(remote)
} }
// check if we need to send proxy protocol info // check if we need to send proxy protocol info
var extraInfo []byte var extraInfo []byte
if baseConfig.ProxyProtocolVersion != "" { if baseInfo.ProxyProtocolVersion != "" {
if m.SrcAddr != "" && m.SrcPort != 0 { if m.SrcAddr != "" && m.SrcPort != 0 {
if m.DstAddr == "" { if m.DstAddr == "" {
m.DstAddr = "127.0.0.1" m.DstAddr = "127.0.0.1"
@ -167,9 +778,9 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
h.TransportProtocol = pp.TCPv6 h.TransportProtocol = pp.TCPv6
} }
if baseConfig.ProxyProtocolVersion == "v1" { if baseInfo.ProxyProtocolVersion == "v1" {
h.Version = 1 h.Version = 1
} else if baseConfig.ProxyProtocolVersion == "v2" { } else if baseInfo.ProxyProtocolVersion == "v2" {
h.Version = 2 h.Version = 2
} }
@ -179,21 +790,21 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
} }
} }
if pxy.proxyPlugin != nil { if proxyPlugin != nil {
// if plugin is set, let plugin handle connection first // if plugin is set, let plugin handle connections first
xl.Debug("handle by plugin: %s", pxy.proxyPlugin.Name()) xl.Debug("handle by plugin: %s", proxyPlugin.Name())
pxy.proxyPlugin.Handle(remote, workConn, extraInfo) proxyPlugin.Handle(remote, workConn, extraInfo)
xl.Debug("handle by plugin finished") xl.Debug("handle by plugin finished")
return return
} }
localConn, err := libdial.Dial( localConn, err := libdial.Dial(
net.JoinHostPort(baseConfig.LocalIP, strconv.Itoa(baseConfig.LocalPort)), net.JoinHostPort(localInfo.LocalIP, strconv.Itoa(localInfo.LocalPort)),
libdial.WithTimeout(10*time.Second), libdial.WithTimeout(10*time.Second),
) )
if err != nil { if err != nil {
workConn.Close() workConn.Close()
xl.Error("connect to local service [%s:%d] error: %v", baseConfig.LocalIP, baseConfig.LocalPort, err) xl.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIP, localInfo.LocalPort, err)
return return
} }
@ -208,7 +819,7 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
} }
} }
_, _, errs := libio.Join(localConn, remote) _, _, errs := frpIo.Join(localConn, remote)
xl.Debug("join connections closed") xl.Debug("join connections closed")
if len(errs) > 0 { if len(errs) > 0 {
xl.Trace("join connections errors: %v", errs) xl.Trace("join connections errors: %v", errs)

View File

@ -1,56 +1,42 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy package proxy
import ( import (
"context" "context"
"fmt" "fmt"
"net" "net"
"reflect"
"sync" "sync"
"github.com/fatedier/golib/errors"
"github.com/fatedier/frp/client/event" "github.com/fatedier/frp/client/event"
"github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
type Manager struct { type Manager struct {
proxies map[string]*Wrapper sendCh chan (msg.Message)
msgTransporter transport.MessageTransporter proxies map[string]*Wrapper
closed bool closed bool
mu sync.RWMutex mu sync.RWMutex
clientCfg config.ClientCommonConf clientCfg config.ClientCommonConf
// The UDP port that the server is listening on
serverUDPPort int
ctx context.Context ctx context.Context
} }
func NewManager( func NewManager(ctx context.Context, msgSendCh chan (msg.Message), clientCfg config.ClientCommonConf, serverUDPPort int) *Manager {
ctx context.Context,
clientCfg config.ClientCommonConf,
msgTransporter transport.MessageTransporter,
) *Manager {
return &Manager{ return &Manager{
proxies: make(map[string]*Wrapper), sendCh: msgSendCh,
msgTransporter: msgTransporter, proxies: make(map[string]*Wrapper),
closed: false, closed: false,
clientCfg: clientCfg, clientCfg: clientCfg,
ctx: ctx, serverUDPPort: serverUDPPort,
ctx: ctx,
} }
} }
@ -100,7 +86,10 @@ func (pm *Manager) HandleEvent(payload interface{}) error {
return event.ErrPayloadType return event.ErrPayloadType
} }
return pm.msgTransporter.Send(m) err := errors.PanicToError(func() {
pm.sendCh <- m
})
return err
} }
func (pm *Manager) GetAllProxyStatus() []*WorkingStatus { func (pm *Manager) GetAllProxyStatus() []*WorkingStatus {
@ -122,24 +111,27 @@ func (pm *Manager) Reload(pxyCfgs map[string]config.ProxyConf) {
for name, pxy := range pm.proxies { for name, pxy := range pm.proxies {
del := false del := false
cfg, ok := pxyCfgs[name] cfg, ok := pxyCfgs[name]
if !ok || !reflect.DeepEqual(pxy.Cfg, cfg) { if !ok {
del = true
} else if !pxy.Cfg.Compare(cfg) {
del = true del = true
} }
if del { if del {
delPxyNames = append(delPxyNames, name) delPxyNames = append(delPxyNames, name)
delete(pm.proxies, name) delete(pm.proxies, name)
pxy.Stop() pxy.Stop()
} }
} }
if len(delPxyNames) > 0 { if len(delPxyNames) > 0 {
xl.Info("proxy removed: %s", delPxyNames) xl.Info("proxy removed: %v", delPxyNames)
} }
addPxyNames := make([]string, 0) addPxyNames := make([]string, 0)
for name, cfg := range pxyCfgs { for name, cfg := range pxyCfgs {
if _, ok := pm.proxies[name]; !ok { if _, ok := pm.proxies[name]; !ok {
pxy := NewWrapper(pm.ctx, cfg, pm.clientCfg, pm.HandleEvent, pm.msgTransporter) pxy := NewWrapper(pm.ctx, cfg, pm.clientCfg, pm.HandleEvent, pm.serverUDPPort)
pm.proxies[name] = pxy pm.proxies[name] = pxy
addPxyNames = append(addPxyNames, name) addPxyNames = append(addPxyNames, name)
@ -147,6 +139,6 @@ func (pm *Manager) Reload(pxyCfgs map[string]config.ProxyConf) {
} }
} }
if len(addPxyNames) > 0 { if len(addPxyNames) > 0 {
xl.Info("proxy added: %s", addPxyNames) xl.Info("proxy added: %v", addPxyNames)
} }
} }

View File

@ -1,17 +1,3 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy package proxy
import ( import (
@ -28,7 +14,6 @@ import (
"github.com/fatedier/frp/client/health" "github.com/fatedier/frp/client/health"
"github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
@ -71,8 +56,6 @@ type Wrapper struct {
// event handler // event handler
handler event.Handler handler event.Handler
msgTransporter transport.MessageTransporter
health uint32 health uint32
lastSendStartMsg time.Time lastSendStartMsg time.Time
lastStartErr time.Time lastStartErr time.Time
@ -84,14 +67,8 @@ type Wrapper struct {
ctx context.Context ctx context.Context
} }
func NewWrapper( func NewWrapper(ctx context.Context, cfg config.ProxyConf, clientCfg config.ClientCommonConf, eventHandler event.Handler, serverUDPPort int) *Wrapper {
ctx context.Context, baseInfo := cfg.GetBaseInfo()
cfg config.ProxyConf,
clientCfg config.ClientCommonConf,
eventHandler event.Handler,
msgTransporter transport.MessageTransporter,
) *Wrapper {
baseInfo := cfg.GetBaseConfig()
xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(baseInfo.ProxyName) xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(baseInfo.ProxyName)
pw := &Wrapper{ pw := &Wrapper{
WorkingStatus: WorkingStatus{ WorkingStatus: WorkingStatus{
@ -103,7 +80,6 @@ func NewWrapper(
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
healthNotifyCh: make(chan struct{}), healthNotifyCh: make(chan struct{}),
handler: eventHandler, handler: eventHandler,
msgTransporter: msgTransporter,
xl: xl, xl: xl,
ctx: xlog.NewContext(ctx, xl), ctx: xlog.NewContext(ctx, xl),
} }
@ -116,7 +92,7 @@ func NewWrapper(
xl.Trace("enable health check monitor") xl.Trace("enable health check monitor")
} }
pw.pxy = NewProxy(pw.ctx, pw.Cfg, clientCfg, pw.msgTransporter) pw.pxy = NewProxy(pw.ctx, pw.Cfg, clientCfg, serverUDPPort)
return pw return pw
} }

View File

@ -1,207 +0,0 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"io"
"net"
"reflect"
"strconv"
"sync"
"time"
"github.com/fatedier/golib/errors"
libio "github.com/fatedier/golib/io"
"github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp"
"github.com/fatedier/frp/pkg/util/limit"
utilnet "github.com/fatedier/frp/pkg/util/net"
)
func init() {
RegisterProxyFactory(reflect.TypeOf(&config.SUDPProxyConf{}), NewSUDPProxy)
}
type SUDPProxy struct {
*BaseProxy
cfg *config.SUDPProxyConf
localAddr *net.UDPAddr
closeCh chan struct{}
}
func NewSUDPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy {
unwrapped, ok := cfg.(*config.SUDPProxyConf)
if !ok {
return nil
}
return &SUDPProxy{
BaseProxy: baseProxy,
cfg: unwrapped,
closeCh: make(chan struct{}),
}
}
func (pxy *SUDPProxy) Run() (err error) {
pxy.localAddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(pxy.cfg.LocalIP, strconv.Itoa(pxy.cfg.LocalPort)))
if err != nil {
return
}
return
}
func (pxy *SUDPProxy) Close() {
pxy.mu.Lock()
defer pxy.mu.Unlock()
select {
case <-pxy.closeCh:
return
default:
close(pxy.closeCh)
}
}
func (pxy *SUDPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
xl := pxy.xl
xl.Info("incoming a new work connection for sudp proxy, %s", conn.RemoteAddr().String())
var rwc io.ReadWriteCloser = conn
var err error
if pxy.limiter != nil {
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
return conn.Close()
})
}
if pxy.cfg.UseEncryption {
rwc, err = libio.WithEncryption(rwc, []byte(pxy.clientCfg.Token))
if err != nil {
conn.Close()
xl.Error("create encryption stream error: %v", err)
return
}
}
if pxy.cfg.UseCompression {
rwc = libio.WithCompression(rwc)
}
conn = utilnet.WrapReadWriteCloserToConn(rwc, conn)
workConn := conn
readCh := make(chan *msg.UDPPacket, 1024)
sendCh := make(chan msg.Message, 1024)
isClose := false
mu := &sync.Mutex{}
closeFn := func() {
mu.Lock()
defer mu.Unlock()
if isClose {
return
}
isClose = true
if workConn != nil {
workConn.Close()
}
close(readCh)
close(sendCh)
}
// udp service <- frpc <- frps <- frpc visitor <- user
workConnReaderFn := func(conn net.Conn, readCh chan *msg.UDPPacket) {
defer closeFn()
for {
// first to check sudp proxy is closed or not
select {
case <-pxy.closeCh:
xl.Trace("frpc sudp proxy is closed")
return
default:
}
var udpMsg msg.UDPPacket
if errRet := msg.ReadMsgInto(conn, &udpMsg); errRet != nil {
xl.Warn("read from workConn for sudp error: %v", errRet)
return
}
if errRet := errors.PanicToError(func() {
readCh <- &udpMsg
}); errRet != nil {
xl.Warn("reader goroutine for sudp work connection closed: %v", errRet)
return
}
}
}
// udp service -> frpc -> frps -> frpc visitor -> user
workConnSenderFn := func(conn net.Conn, sendCh chan msg.Message) {
defer func() {
closeFn()
xl.Info("writer goroutine for sudp work connection closed")
}()
var errRet error
for rawMsg := range sendCh {
switch m := rawMsg.(type) {
case *msg.UDPPacket:
xl.Trace("frpc send udp package to frpc visitor, [udp local: %v, remote: %v], [tcp work conn local: %v, remote: %v]",
m.LocalAddr.String(), m.RemoteAddr.String(), conn.LocalAddr().String(), conn.RemoteAddr().String())
case *msg.Ping:
xl.Trace("frpc send ping message to frpc visitor")
}
if errRet = msg.WriteMsg(conn, rawMsg); errRet != nil {
xl.Error("sudp work write error: %v", errRet)
return
}
}
}
heartbeatFn := func(sendCh chan msg.Message) {
ticker := time.NewTicker(30 * time.Second)
defer func() {
ticker.Stop()
closeFn()
}()
var errRet error
for {
select {
case <-ticker.C:
if errRet = errors.PanicToError(func() {
sendCh <- &msg.Ping{}
}); errRet != nil {
xl.Warn("heartbeat goroutine for sudp work connection closed")
return
}
case <-pxy.closeCh:
xl.Trace("frpc sudp proxy is closed")
return
}
}
}
go workConnSenderFn(workConn, sendCh)
go workConnReaderFn(workConn, readCh)
go heartbeatFn(sendCh)
udp.Forwarder(pxy.localAddr, readCh, sendCh, int(pxy.clientCfg.UDPPacketSize))
}

View File

@ -1,173 +0,0 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"io"
"net"
"reflect"
"strconv"
"time"
"github.com/fatedier/golib/errors"
libio "github.com/fatedier/golib/io"
"github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp"
"github.com/fatedier/frp/pkg/util/limit"
utilnet "github.com/fatedier/frp/pkg/util/net"
)
func init() {
RegisterProxyFactory(reflect.TypeOf(&config.UDPProxyConf{}), NewUDPProxy)
}
type UDPProxy struct {
*BaseProxy
cfg *config.UDPProxyConf
localAddr *net.UDPAddr
readCh chan *msg.UDPPacket
// include msg.UDPPacket and msg.Ping
sendCh chan msg.Message
workConn net.Conn
closed bool
}
func NewUDPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy {
unwrapped, ok := cfg.(*config.UDPProxyConf)
if !ok {
return nil
}
return &UDPProxy{
BaseProxy: baseProxy,
cfg: unwrapped,
}
}
func (pxy *UDPProxy) Run() (err error) {
pxy.localAddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(pxy.cfg.LocalIP, strconv.Itoa(pxy.cfg.LocalPort)))
if err != nil {
return
}
return
}
func (pxy *UDPProxy) Close() {
pxy.mu.Lock()
defer pxy.mu.Unlock()
if !pxy.closed {
pxy.closed = true
if pxy.workConn != nil {
pxy.workConn.Close()
}
if pxy.readCh != nil {
close(pxy.readCh)
}
if pxy.sendCh != nil {
close(pxy.sendCh)
}
}
}
func (pxy *UDPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
xl := pxy.xl
xl.Info("incoming a new work connection for udp proxy, %s", conn.RemoteAddr().String())
// close resources releated with old workConn
pxy.Close()
var rwc io.ReadWriteCloser = conn
var err error
if pxy.limiter != nil {
rwc = libio.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error {
return conn.Close()
})
}
if pxy.cfg.UseEncryption {
rwc, err = libio.WithEncryption(rwc, []byte(pxy.clientCfg.Token))
if err != nil {
conn.Close()
xl.Error("create encryption stream error: %v", err)
return
}
}
if pxy.cfg.UseCompression {
rwc = libio.WithCompression(rwc)
}
conn = utilnet.WrapReadWriteCloserToConn(rwc, conn)
pxy.mu.Lock()
pxy.workConn = conn
pxy.readCh = make(chan *msg.UDPPacket, 1024)
pxy.sendCh = make(chan msg.Message, 1024)
pxy.closed = false
pxy.mu.Unlock()
workConnReaderFn := func(conn net.Conn, readCh chan *msg.UDPPacket) {
for {
var udpMsg msg.UDPPacket
if errRet := msg.ReadMsgInto(conn, &udpMsg); errRet != nil {
xl.Warn("read from workConn for udp error: %v", errRet)
return
}
if errRet := errors.PanicToError(func() {
xl.Trace("get udp package from workConn: %s", udpMsg.Content)
readCh <- &udpMsg
}); errRet != nil {
xl.Info("reader goroutine for udp work connection closed: %v", errRet)
return
}
}
}
workConnSenderFn := func(conn net.Conn, sendCh chan msg.Message) {
defer func() {
xl.Info("writer goroutine for udp work connection closed")
}()
var errRet error
for rawMsg := range sendCh {
switch m := rawMsg.(type) {
case *msg.UDPPacket:
xl.Trace("send udp package to workConn: %s", m.Content)
case *msg.Ping:
xl.Trace("send ping message to udp workConn")
}
if errRet = msg.WriteMsg(conn, rawMsg); errRet != nil {
xl.Error("udp work write error: %v", errRet)
return
}
}
}
heartbeatFn := func(sendCh chan msg.Message) {
var errRet error
for {
time.Sleep(time.Duration(30) * time.Second)
if errRet = errors.PanicToError(func() {
sendCh <- &msg.Ping{}
}); errRet != nil {
xl.Trace("heartbeat goroutine for udp work connection closed")
break
}
}
}
go workConnSenderFn(pxy.workConn, pxy.sendCh)
go workConnReaderFn(pxy.workConn, pxy.readCh)
go heartbeatFn(pxy.sendCh)
udp.Forwarder(pxy.localAddr, pxy.readCh, pxy.sendCh, int(pxy.clientCfg.UDPPacketSize))
}

View File

@ -1,195 +0,0 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"io"
"net"
"reflect"
"time"
fmux "github.com/hashicorp/yamux"
"github.com/quic-go/quic-go"
"github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/nathole"
"github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net"
)
func init() {
RegisterProxyFactory(reflect.TypeOf(&config.XTCPProxyConf{}), NewXTCPProxy)
}
type XTCPProxy struct {
*BaseProxy
cfg *config.XTCPProxyConf
}
func NewXTCPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy {
unwrapped, ok := cfg.(*config.XTCPProxyConf)
if !ok {
return nil
}
return &XTCPProxy{
BaseProxy: baseProxy,
cfg: unwrapped,
}
}
func (pxy *XTCPProxy) InWorkConn(conn net.Conn, startWorkConnMsg *msg.StartWorkConn) {
xl := pxy.xl
defer conn.Close()
var natHoleSidMsg msg.NatHoleSid
err := msg.ReadMsgInto(conn, &natHoleSidMsg)
if err != nil {
xl.Error("xtcp read from workConn error: %v", err)
return
}
prepareResult, err := nathole.Prepare([]string{pxy.clientCfg.NatHoleSTUNServer})
if err != nil {
xl.Warn("nathole prepare error: %v", err)
return
}
xl.Info("nathole prepare success, nat type: %s, behavior: %s, addresses: %v, assistedAddresses: %v",
prepareResult.NatType, prepareResult.Behavior, prepareResult.Addrs, prepareResult.AssistedAddrs)
defer prepareResult.ListenConn.Close()
// send NatHoleClient msg to server
transactionID := nathole.NewTransactionID()
natHoleClientMsg := &msg.NatHoleClient{
TransactionID: transactionID,
ProxyName: pxy.cfg.ProxyName,
Sid: natHoleSidMsg.Sid,
MappedAddrs: prepareResult.Addrs,
AssistedAddrs: prepareResult.AssistedAddrs,
}
natHoleRespMsg, err := nathole.ExchangeInfo(pxy.ctx, pxy.msgTransporter, transactionID, natHoleClientMsg, 5*time.Second)
if err != nil {
xl.Warn("nathole exchange info error: %v", err)
return
}
xl.Info("get natHoleRespMsg, sid [%s], protocol [%s], candidate address %v, assisted address %v, detectBehavior: %+v",
natHoleRespMsg.Sid, natHoleRespMsg.Protocol, natHoleRespMsg.CandidateAddrs,
natHoleRespMsg.AssistedAddrs, natHoleRespMsg.DetectBehavior)
listenConn := prepareResult.ListenConn
newListenConn, raddr, err := nathole.MakeHole(pxy.ctx, listenConn, natHoleRespMsg, []byte(pxy.cfg.Sk))
if err != nil {
listenConn.Close()
xl.Warn("make hole error: %v", err)
_ = pxy.msgTransporter.Send(&msg.NatHoleReport{
Sid: natHoleRespMsg.Sid,
Success: false,
})
return
}
listenConn = newListenConn
xl.Info("establishing nat hole connection successful, sid [%s], remoteAddr [%s]", natHoleRespMsg.Sid, raddr)
_ = pxy.msgTransporter.Send(&msg.NatHoleReport{
Sid: natHoleRespMsg.Sid,
Success: true,
})
if natHoleRespMsg.Protocol == "kcp" {
pxy.listenByKCP(listenConn, raddr, startWorkConnMsg)
return
}
// default is quic
pxy.listenByQUIC(listenConn, raddr, startWorkConnMsg)
}
func (pxy *XTCPProxy) listenByKCP(listenConn *net.UDPConn, raddr *net.UDPAddr, startWorkConnMsg *msg.StartWorkConn) {
xl := pxy.xl
listenConn.Close()
laddr, _ := net.ResolveUDPAddr("udp", listenConn.LocalAddr().String())
lConn, err := net.DialUDP("udp", laddr, raddr)
if err != nil {
xl.Warn("dial udp error: %v", err)
return
}
defer lConn.Close()
remote, err := utilnet.NewKCPConnFromUDP(lConn, true, raddr.String())
if err != nil {
xl.Warn("create kcp connection from udp connection error: %v", err)
return
}
fmuxCfg := fmux.DefaultConfig()
fmuxCfg.KeepAliveInterval = 10 * time.Second
fmuxCfg.MaxStreamWindowSize = 2 * 1024 * 1024
fmuxCfg.LogOutput = io.Discard
session, err := fmux.Server(remote, fmuxCfg)
if err != nil {
xl.Error("create mux session error: %v", err)
return
}
defer session.Close()
for {
muxConn, err := session.Accept()
if err != nil {
xl.Error("accept connection error: %v", err)
return
}
go pxy.HandleTCPWorkConnection(muxConn, startWorkConnMsg, []byte(pxy.cfg.Sk))
}
}
func (pxy *XTCPProxy) listenByQUIC(listenConn *net.UDPConn, _ *net.UDPAddr, startWorkConnMsg *msg.StartWorkConn) {
xl := pxy.xl
defer listenConn.Close()
tlsConfig, err := transport.NewServerTLSConfig("", "", "")
if err != nil {
xl.Warn("create tls config error: %v", err)
return
}
tlsConfig.NextProtos = []string{"frp"}
quicListener, err := quic.Listen(listenConn, tlsConfig,
&quic.Config{
MaxIdleTimeout: time.Duration(pxy.clientCfg.QUICMaxIdleTimeout) * time.Second,
MaxIncomingStreams: int64(pxy.clientCfg.QUICMaxIncomingStreams),
KeepAlivePeriod: time.Duration(pxy.clientCfg.QUICKeepalivePeriod) * time.Second,
},
)
if err != nil {
xl.Warn("dial quic error: %v", err)
return
}
// only accept one connection from raddr
c, err := quicListener.Accept(pxy.ctx)
if err != nil {
xl.Error("quic accept connection error: %v", err)
return
}
for {
stream, err := c.AcceptStream(pxy.ctx)
if err != nil {
xl.Debug("quic accept stream error: %v", err)
_ = c.CloseWithError(0, "")
return
}
go pxy.HandleTCPWorkConnection(utilnet.QuicStreamToNetConn(stream, c), startWorkConnMsg, []byte(pxy.cfg.Sk))
}
}

View File

@ -39,7 +39,7 @@ import (
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/log" "github.com/fatedier/frp/pkg/util/log"
utilnet "github.com/fatedier/frp/pkg/util/net" frpNet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/version" "github.com/fatedier/frp/pkg/util/version"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
@ -72,6 +72,9 @@ type Service struct {
// string if no configuration file was used. // string if no configuration file was used.
cfgFile string cfgFile string
// This is configured by the login response from frps
serverUDPPort int
exit uint32 // 0 means not exit exit uint32 // 0 means not exit
// service context // service context
@ -138,7 +141,7 @@ func (svr *Service) Run() error {
util.RandomSleep(10*time.Second, 0.9, 1.1) util.RandomSleep(10*time.Second, 0.9, 1.1)
} else { } else {
// login success // login success
ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter) ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.serverUDPPort, svr.authSetter)
ctl.Run() ctl.Run()
svr.ctlMu.Lock() svr.ctlMu.Lock()
svr.ctl = ctl svr.ctl = ctl
@ -220,7 +223,7 @@ func (svr *Service) keepControllerWorking() {
// reconnect success, init delayTime // reconnect success, init delayTime
delayTime = time.Second delayTime = time.Second
ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter) ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.serverUDPPort, svr.authSetter)
ctl.Run() ctl.Run()
svr.ctlMu.Lock() svr.ctlMu.Lock()
if svr.ctl != nil { if svr.ctl != nil {
@ -292,7 +295,8 @@ func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) {
xl.ResetPrefixes() xl.ResetPrefixes()
xl.AppendPrefix(svr.runID) xl.AppendPrefix(svr.runID)
xl.Info("login to server success, get run id [%s]", loginRespMsg.RunID) svr.serverUDPPort = loginRespMsg.ServerUDPPort
xl.Info("login to server success, get run id [%s], server udp port [%d]", loginRespMsg.RunID, loginRespMsg.ServerUDPPort)
return return
} }
@ -369,8 +373,7 @@ func (cm *ConnectionManager) OpenConnection() error {
} }
tlsConfig.NextProtos = []string{"frp"} tlsConfig.NextProtos = []string{"frp"}
conn, err := quic.DialAddrContext( conn, err := quic.DialAddr(
cm.ctx,
net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)), net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)),
tlsConfig, &quic.Config{ tlsConfig, &quic.Config{
MaxIdleTimeout: time.Duration(cm.cfg.QUICMaxIdleTimeout) * time.Second, MaxIdleTimeout: time.Duration(cm.cfg.QUICMaxIdleTimeout) * time.Second,
@ -410,7 +413,7 @@ func (cm *ConnectionManager) Connect() (net.Conn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return utilnet.QuicStreamToNetConn(stream, cm.quicConn), nil return frpNet.QuicStreamToNetConn(stream, cm.quicConn), nil
} else if cm.muxSession != nil { } else if cm.muxSession != nil {
stream, err := cm.muxSession.OpenStream() stream, err := cm.muxSession.OpenStream()
if err != nil { if err != nil {
@ -452,7 +455,7 @@ func (cm *ConnectionManager) realConnect() (net.Conn, error) {
protocol := cm.cfg.Protocol protocol := cm.cfg.Protocol
if protocol == "websocket" { if protocol == "websocket" {
protocol = "tcp" protocol = "tcp"
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket()})) dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: frpNet.DialHookWebsocket()}))
} }
if cm.cfg.ConnectServerLocalIP != "" { if cm.cfg.ConnectServerLocalIP != "" {
dialOptions = append(dialOptions, libdial.WithLocalAddr(cm.cfg.ConnectServerLocalIP)) dialOptions = append(dialOptions, libdial.WithLocalAddr(cm.cfg.ConnectServerLocalIP))
@ -465,11 +468,10 @@ func (cm *ConnectionManager) realConnect() (net.Conn, error) {
libdial.WithProxyAuth(auth), libdial.WithProxyAuth(auth),
libdial.WithTLSConfig(tlsConfig), libdial.WithTLSConfig(tlsConfig),
libdial.WithAfterHook(libdial.AfterHook{ libdial.WithAfterHook(libdial.AfterHook{
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, cm.cfg.DisableCustomTLSFirstByte), Hook: frpNet.DialHookCustomTLSHeadByte(tlsConfig != nil, cm.cfg.DisableCustomTLSFirstByte),
}), }),
) )
conn, err := libdial.DialContext( conn, err := libdial.Dial(
cm.ctx,
net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)), net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)),
dialOptions..., dialOptions...,
) )

575
client/visitor.go Normal file
View File

@ -0,0 +1,575 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"bytes"
"context"
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
"github.com/fatedier/golib/errors"
frpIo "github.com/fatedier/golib/io"
"github.com/fatedier/golib/pool"
fmux "github.com/hashicorp/yamux"
"github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp"
frpNet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog"
)
// Visitor is used for forward traffics from local port to remote service.
type Visitor interface {
Run() error
Close()
}
func NewVisitor(ctx context.Context, ctl *Control, cfg config.VisitorConf) (visitor Visitor) {
xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(cfg.GetBaseInfo().ProxyName)
baseVisitor := BaseVisitor{
ctl: ctl,
ctx: xlog.NewContext(ctx, xl),
}
switch cfg := cfg.(type) {
case *config.STCPVisitorConf:
visitor = &STCPVisitor{
BaseVisitor: &baseVisitor,
cfg: cfg,
}
case *config.XTCPVisitorConf:
visitor = &XTCPVisitor{
BaseVisitor: &baseVisitor,
cfg: cfg,
}
case *config.SUDPVisitorConf:
visitor = &SUDPVisitor{
BaseVisitor: &baseVisitor,
cfg: cfg,
checkCloseCh: make(chan struct{}),
}
}
return
}
type BaseVisitor struct {
ctl *Control
l net.Listener
mu sync.RWMutex
ctx context.Context
}
type STCPVisitor struct {
*BaseVisitor
cfg *config.STCPVisitorConf
}
func (sv *STCPVisitor) Run() (err error) {
sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
if err != nil {
return
}
go sv.worker()
return
}
func (sv *STCPVisitor) Close() {
sv.l.Close()
}
func (sv *STCPVisitor) worker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.l.Accept()
if err != nil {
xl.Warn("stcp local listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *STCPVisitor) handleConn(userConn net.Conn) {
xl := xlog.FromContextSafe(sv.ctx)
defer userConn.Close()
xl.Debug("get a new stcp user connection")
visitorConn, err := sv.ctl.connectServer()
if err != nil {
return
}
defer visitorConn.Close()
now := time.Now().Unix()
newVisitorConnMsg := &msg.NewVisitorConn{
ProxyName: sv.cfg.ServerName,
SignKey: util.GetAuthKey(sv.cfg.Sk, now),
Timestamp: now,
UseEncryption: sv.cfg.UseEncryption,
UseCompression: sv.cfg.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
if err != nil {
xl.Warn("send newVisitorConnMsg to server error: %v", err)
return
}
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
xl.Warn("get newVisitorConnRespMsg error: %v", err)
return
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
xl.Warn("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
return
}
var remote io.ReadWriteCloser
remote = visitorConn
if sv.cfg.UseEncryption {
remote, err = frpIo.WithEncryption(remote, []byte(sv.cfg.Sk))
if err != nil {
xl.Error("create encryption stream error: %v", err)
return
}
}
if sv.cfg.UseCompression {
remote = frpIo.WithCompression(remote)
}
frpIo.Join(userConn, remote)
}
type XTCPVisitor struct {
*BaseVisitor
cfg *config.XTCPVisitorConf
}
func (sv *XTCPVisitor) Run() (err error) {
sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
if err != nil {
return
}
go sv.worker()
return
}
func (sv *XTCPVisitor) Close() {
sv.l.Close()
}
func (sv *XTCPVisitor) worker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.l.Accept()
if err != nil {
xl.Warn("xtcp local listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
xl := xlog.FromContextSafe(sv.ctx)
defer userConn.Close()
xl.Debug("get a new xtcp user connection")
if sv.ctl.serverUDPPort == 0 {
xl.Error("xtcp is not supported by server")
return
}
serverAddr := sv.ctl.clientCfg.NatHoleServerAddr
if serverAddr == "" {
serverAddr = sv.ctl.clientCfg.ServerAddr
}
raddr, err := net.ResolveUDPAddr("udp",
net.JoinHostPort(serverAddr, strconv.Itoa(sv.ctl.serverUDPPort)))
if err != nil {
xl.Error("resolve server UDP addr error")
return
}
visitorConn, err := net.DialUDP("udp", nil, raddr)
if err != nil {
xl.Warn("dial server udp addr error: %v", err)
return
}
defer visitorConn.Close()
now := time.Now().Unix()
natHoleVisitorMsg := &msg.NatHoleVisitor{
ProxyName: sv.cfg.ServerName,
SignKey: util.GetAuthKey(sv.cfg.Sk, now),
Timestamp: now,
}
err = msg.WriteMsg(visitorConn, natHoleVisitorMsg)
if err != nil {
xl.Warn("send natHoleVisitorMsg to server error: %v", err)
return
}
// Wait for client address at most 10 seconds.
var natHoleRespMsg msg.NatHoleResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
buf := pool.GetBuf(1024)
n, err := visitorConn.Read(buf)
if err != nil {
xl.Warn("get natHoleRespMsg error: %v", err)
return
}
err = msg.ReadMsgInto(bytes.NewReader(buf[:n]), &natHoleRespMsg)
if err != nil {
xl.Warn("get natHoleRespMsg error: %v", err)
return
}
_ = visitorConn.SetReadDeadline(time.Time{})
pool.PutBuf(buf)
if natHoleRespMsg.Error != "" {
xl.Error("natHoleRespMsg get error info: %s", natHoleRespMsg.Error)
return
}
xl.Trace("get natHoleRespMsg, sid [%s], client address [%s], visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr)
// Close visitorConn, so we can use it's local address.
visitorConn.Close()
// send sid message to client
laddr, _ := net.ResolveUDPAddr("udp", visitorConn.LocalAddr().String())
daddr, err := net.ResolveUDPAddr("udp", natHoleRespMsg.ClientAddr)
if err != nil {
xl.Error("resolve client udp address error: %v", err)
return
}
lConn, err := net.DialUDP("udp", laddr, daddr)
if err != nil {
xl.Error("dial client udp address error: %v", err)
return
}
defer lConn.Close()
if _, err := lConn.Write([]byte(natHoleRespMsg.Sid)); err != nil {
xl.Error("write sid error: %v", err)
return
}
// read ack sid from client
sidBuf := pool.GetBuf(1024)
_ = lConn.SetReadDeadline(time.Now().Add(8 * time.Second))
n, err = lConn.Read(sidBuf)
if err != nil {
xl.Warn("get sid from client error: %v", err)
return
}
_ = lConn.SetReadDeadline(time.Time{})
if string(sidBuf[:n]) != natHoleRespMsg.Sid {
xl.Warn("incorrect sid from client")
return
}
pool.PutBuf(sidBuf)
xl.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid)
// wrap kcp connection
var remote io.ReadWriteCloser
remote, err = frpNet.NewKCPConnFromUDP(lConn, true, natHoleRespMsg.ClientAddr)
if err != nil {
xl.Error("create kcp connection from udp connection error: %v", err)
return
}
fmuxCfg := fmux.DefaultConfig()
fmuxCfg.KeepAliveInterval = 5 * time.Second
fmuxCfg.LogOutput = io.Discard
sess, err := fmux.Client(remote, fmuxCfg)
if err != nil {
xl.Error("create yamux session error: %v", err)
return
}
defer sess.Close()
muxConn, err := sess.Open()
if err != nil {
xl.Error("open yamux stream error: %v", err)
return
}
var muxConnRWCloser io.ReadWriteCloser = muxConn
if sv.cfg.UseEncryption {
muxConnRWCloser, err = frpIo.WithEncryption(muxConnRWCloser, []byte(sv.cfg.Sk))
if err != nil {
xl.Error("create encryption stream error: %v", err)
return
}
}
if sv.cfg.UseCompression {
muxConnRWCloser = frpIo.WithCompression(muxConnRWCloser)
}
_, _, errs := frpIo.Join(userConn, muxConnRWCloser)
xl.Debug("join connections closed")
if len(errs) > 0 {
xl.Trace("join connections errors: %v", errs)
}
}
type SUDPVisitor struct {
*BaseVisitor
checkCloseCh chan struct{}
// udpConn is the listener of udp packet
udpConn *net.UDPConn
readCh chan *msg.UDPPacket
sendCh chan *msg.UDPPacket
cfg *config.SUDPVisitorConf
}
// SUDP Run start listen a udp port
func (sv *SUDPVisitor) Run() (err error) {
xl := xlog.FromContextSafe(sv.ctx)
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
if err != nil {
return fmt.Errorf("sudp ResolveUDPAddr error: %v", err)
}
sv.udpConn, err = net.ListenUDP("udp", addr)
if err != nil {
return fmt.Errorf("listen udp port %s error: %v", addr.String(), err)
}
sv.sendCh = make(chan *msg.UDPPacket, 1024)
sv.readCh = make(chan *msg.UDPPacket, 1024)
xl.Info("sudp start to work, listen on %s", addr)
go sv.dispatcher()
go udp.ForwardUserConn(sv.udpConn, sv.readCh, sv.sendCh, int(sv.ctl.clientCfg.UDPPacketSize))
return
}
func (sv *SUDPVisitor) dispatcher() {
xl := xlog.FromContextSafe(sv.ctx)
var (
visitorConn net.Conn
err error
firstPacket *msg.UDPPacket
)
for {
select {
case firstPacket = <-sv.sendCh:
if firstPacket == nil {
xl.Info("frpc sudp visitor proxy is closed")
return
}
case <-sv.checkCloseCh:
xl.Info("frpc sudp visitor proxy is closed")
return
}
visitorConn, err = sv.getNewVisitorConn()
if err != nil {
xl.Warn("newVisitorConn to frps error: %v, try to reconnect", err)
continue
}
// visitorConn always be closed when worker done.
sv.worker(visitorConn, firstPacket)
select {
case <-sv.checkCloseCh:
return
default:
}
}
}
func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
xl := xlog.FromContextSafe(sv.ctx)
xl.Debug("starting sudp proxy worker")
wg := &sync.WaitGroup{}
wg.Add(2)
closeCh := make(chan struct{})
// udp service -> frpc -> frps -> frpc visitor -> user
workConnReaderFn := func(conn net.Conn) {
defer func() {
conn.Close()
close(closeCh)
wg.Done()
}()
for {
var (
rawMsg msg.Message
errRet error
)
// frpc will send heartbeat in workConn to frpc visitor for keeping alive
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
if rawMsg, errRet = msg.ReadMsg(conn); errRet != nil {
xl.Warn("read from workconn for user udp conn error: %v", errRet)
return
}
_ = conn.SetReadDeadline(time.Time{})
switch m := rawMsg.(type) {
case *msg.Ping:
xl.Debug("frpc visitor get ping message from frpc")
continue
case *msg.UDPPacket:
if errRet := errors.PanicToError(func() {
sv.readCh <- m
xl.Trace("frpc visitor get udp packet from workConn: %s", m.Content)
}); errRet != nil {
xl.Info("reader goroutine for udp work connection closed")
return
}
}
}
}
// udp service <- frpc <- frps <- frpc visitor <- user
workConnSenderFn := func(conn net.Conn) {
defer func() {
conn.Close()
wg.Done()
}()
var errRet error
if firstPacket != nil {
if errRet = msg.WriteMsg(conn, firstPacket); errRet != nil {
xl.Warn("sender goroutine for udp work connection closed: %v", errRet)
return
}
xl.Trace("send udp package to workConn: %s", firstPacket.Content)
}
for {
select {
case udpMsg, ok := <-sv.sendCh:
if !ok {
xl.Info("sender goroutine for udp work connection closed")
return
}
if errRet = msg.WriteMsg(conn, udpMsg); errRet != nil {
xl.Warn("sender goroutine for udp work connection closed: %v", errRet)
return
}
xl.Trace("send udp package to workConn: %s", udpMsg.Content)
case <-closeCh:
return
}
}
}
go workConnReaderFn(workConn)
go workConnSenderFn(workConn)
wg.Wait()
xl.Info("sudp worker is closed")
}
func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, error) {
xl := xlog.FromContextSafe(sv.ctx)
visitorConn, err := sv.ctl.connectServer()
if err != nil {
return nil, fmt.Errorf("frpc connect frps error: %v", err)
}
now := time.Now().Unix()
newVisitorConnMsg := &msg.NewVisitorConn{
ProxyName: sv.cfg.ServerName,
SignKey: util.GetAuthKey(sv.cfg.Sk, now),
Timestamp: now,
UseEncryption: sv.cfg.UseEncryption,
UseCompression: sv.cfg.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
if err != nil {
return nil, fmt.Errorf("frpc send newVisitorConnMsg to frps error: %v", err)
}
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
return nil, fmt.Errorf("frpc read newVisitorConnRespMsg error: %v", err)
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
}
var remote io.ReadWriteCloser
remote = visitorConn
if sv.cfg.UseEncryption {
remote, err = frpIo.WithEncryption(remote, []byte(sv.cfg.Sk))
if err != nil {
xl.Error("create encryption stream error: %v", err)
return nil, err
}
}
if sv.cfg.UseCompression {
remote = frpIo.WithCompression(remote)
}
return frpNet.WrapReadWriteCloserToConn(remote, visitorConn), nil
}
func (sv *SUDPVisitor) Close() {
sv.mu.Lock()
defer sv.mu.Unlock()
select {
case <-sv.checkCloseCh:
return
default:
close(sv.checkCloseCh)
}
if sv.udpConn != nil {
sv.udpConn.Close()
}
close(sv.readCh)
close(sv.sendCh)
}

View File

@ -1,133 +0,0 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package visitor
import (
"io"
"net"
"strconv"
"time"
libio "github.com/fatedier/golib/io"
"github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog"
)
type STCPVisitor struct {
*BaseVisitor
cfg *config.STCPVisitorConf
}
func (sv *STCPVisitor) Run() (err error) {
if sv.cfg.BindPort > 0 {
sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
if err != nil {
return
}
go sv.worker()
}
go sv.internalConnWorker()
return
}
func (sv *STCPVisitor) Close() {
sv.BaseVisitor.Close()
}
func (sv *STCPVisitor) worker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.l.Accept()
if err != nil {
xl.Warn("stcp local listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *STCPVisitor) internalConnWorker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.internalLn.Accept()
if err != nil {
xl.Warn("stcp internal listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *STCPVisitor) handleConn(userConn net.Conn) {
xl := xlog.FromContextSafe(sv.ctx)
defer userConn.Close()
xl.Debug("get a new stcp user connection")
visitorConn, err := sv.helper.ConnectServer()
if err != nil {
return
}
defer visitorConn.Close()
now := time.Now().Unix()
newVisitorConnMsg := &msg.NewVisitorConn{
RunID: sv.helper.RunID(),
ProxyName: sv.cfg.ServerName,
SignKey: util.GetAuthKey(sv.cfg.Sk, now),
Timestamp: now,
UseEncryption: sv.cfg.UseEncryption,
UseCompression: sv.cfg.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
if err != nil {
xl.Warn("send newVisitorConnMsg to server error: %v", err)
return
}
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
xl.Warn("get newVisitorConnRespMsg error: %v", err)
return
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
xl.Warn("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
return
}
var remote io.ReadWriteCloser
remote = visitorConn
if sv.cfg.UseEncryption {
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.Sk))
if err != nil {
xl.Error("create encryption stream error: %v", err)
return
}
}
if sv.cfg.UseCompression {
remote = libio.WithCompression(remote)
}
libio.Join(userConn, remote)
}

View File

@ -1,264 +0,0 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package visitor
import (
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
"github.com/fatedier/golib/errors"
libio "github.com/fatedier/golib/io"
"github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp"
utilnet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog"
)
type SUDPVisitor struct {
*BaseVisitor
checkCloseCh chan struct{}
// udpConn is the listener of udp packet
udpConn *net.UDPConn
readCh chan *msg.UDPPacket
sendCh chan *msg.UDPPacket
cfg *config.SUDPVisitorConf
}
// SUDP Run start listen a udp port
func (sv *SUDPVisitor) Run() (err error) {
xl := xlog.FromContextSafe(sv.ctx)
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
if err != nil {
return fmt.Errorf("sudp ResolveUDPAddr error: %v", err)
}
sv.udpConn, err = net.ListenUDP("udp", addr)
if err != nil {
return fmt.Errorf("listen udp port %s error: %v", addr.String(), err)
}
sv.sendCh = make(chan *msg.UDPPacket, 1024)
sv.readCh = make(chan *msg.UDPPacket, 1024)
xl.Info("sudp start to work, listen on %s", addr)
go sv.dispatcher()
go udp.ForwardUserConn(sv.udpConn, sv.readCh, sv.sendCh, int(sv.clientCfg.UDPPacketSize))
return
}
func (sv *SUDPVisitor) dispatcher() {
xl := xlog.FromContextSafe(sv.ctx)
var (
visitorConn net.Conn
err error
firstPacket *msg.UDPPacket
)
for {
select {
case firstPacket = <-sv.sendCh:
if firstPacket == nil {
xl.Info("frpc sudp visitor proxy is closed")
return
}
case <-sv.checkCloseCh:
xl.Info("frpc sudp visitor proxy is closed")
return
}
visitorConn, err = sv.getNewVisitorConn()
if err != nil {
xl.Warn("newVisitorConn to frps error: %v, try to reconnect", err)
continue
}
// visitorConn always be closed when worker done.
sv.worker(visitorConn, firstPacket)
select {
case <-sv.checkCloseCh:
return
default:
}
}
}
func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) {
xl := xlog.FromContextSafe(sv.ctx)
xl.Debug("starting sudp proxy worker")
wg := &sync.WaitGroup{}
wg.Add(2)
closeCh := make(chan struct{})
// udp service -> frpc -> frps -> frpc visitor -> user
workConnReaderFn := func(conn net.Conn) {
defer func() {
conn.Close()
close(closeCh)
wg.Done()
}()
for {
var (
rawMsg msg.Message
errRet error
)
// frpc will send heartbeat in workConn to frpc visitor for keeping alive
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
if rawMsg, errRet = msg.ReadMsg(conn); errRet != nil {
xl.Warn("read from workconn for user udp conn error: %v", errRet)
return
}
_ = conn.SetReadDeadline(time.Time{})
switch m := rawMsg.(type) {
case *msg.Ping:
xl.Debug("frpc visitor get ping message from frpc")
continue
case *msg.UDPPacket:
if errRet := errors.PanicToError(func() {
sv.readCh <- m
xl.Trace("frpc visitor get udp packet from workConn: %s", m.Content)
}); errRet != nil {
xl.Info("reader goroutine for udp work connection closed")
return
}
}
}
}
// udp service <- frpc <- frps <- frpc visitor <- user
workConnSenderFn := func(conn net.Conn) {
defer func() {
conn.Close()
wg.Done()
}()
var errRet error
if firstPacket != nil {
if errRet = msg.WriteMsg(conn, firstPacket); errRet != nil {
xl.Warn("sender goroutine for udp work connection closed: %v", errRet)
return
}
xl.Trace("send udp package to workConn: %s", firstPacket.Content)
}
for {
select {
case udpMsg, ok := <-sv.sendCh:
if !ok {
xl.Info("sender goroutine for udp work connection closed")
return
}
if errRet = msg.WriteMsg(conn, udpMsg); errRet != nil {
xl.Warn("sender goroutine for udp work connection closed: %v", errRet)
return
}
xl.Trace("send udp package to workConn: %s", udpMsg.Content)
case <-closeCh:
return
}
}
}
go workConnReaderFn(workConn)
go workConnSenderFn(workConn)
wg.Wait()
xl.Info("sudp worker is closed")
}
func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, error) {
xl := xlog.FromContextSafe(sv.ctx)
visitorConn, err := sv.helper.ConnectServer()
if err != nil {
return nil, fmt.Errorf("frpc connect frps error: %v", err)
}
now := time.Now().Unix()
newVisitorConnMsg := &msg.NewVisitorConn{
RunID: sv.helper.RunID(),
ProxyName: sv.cfg.ServerName,
SignKey: util.GetAuthKey(sv.cfg.Sk, now),
Timestamp: now,
UseEncryption: sv.cfg.UseEncryption,
UseCompression: sv.cfg.UseCompression,
}
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
if err != nil {
return nil, fmt.Errorf("frpc send newVisitorConnMsg to frps error: %v", err)
}
var newVisitorConnRespMsg msg.NewVisitorConnResp
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
if err != nil {
return nil, fmt.Errorf("frpc read newVisitorConnRespMsg error: %v", err)
}
_ = visitorConn.SetReadDeadline(time.Time{})
if newVisitorConnRespMsg.Error != "" {
return nil, fmt.Errorf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
}
var remote io.ReadWriteCloser
remote = visitorConn
if sv.cfg.UseEncryption {
remote, err = libio.WithEncryption(remote, []byte(sv.cfg.Sk))
if err != nil {
xl.Error("create encryption stream error: %v", err)
return nil, err
}
}
if sv.cfg.UseCompression {
remote = libio.WithCompression(remote)
}
return utilnet.WrapReadWriteCloserToConn(remote, visitorConn), nil
}
func (sv *SUDPVisitor) Close() {
sv.mu.Lock()
defer sv.mu.Unlock()
select {
case <-sv.checkCloseCh:
return
default:
close(sv.checkCloseCh)
}
sv.BaseVisitor.Close()
if sv.udpConn != nil {
sv.udpConn.Close()
}
close(sv.readCh)
close(sv.sendCh)
}

View File

@ -1,104 +0,0 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package visitor
import (
"context"
"net"
"sync"
"github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/xlog"
)
// Helper wrapps some functions for visitor to use.
type Helper interface {
// ConnectServer directly connects to the frp server.
ConnectServer() (net.Conn, error)
// TransferConn transfers the connection to another visitor.
TransferConn(string, net.Conn) error
// MsgTransporter returns the message transporter that is used to send and receive messages
// to the frp server through the controller.
MsgTransporter() transport.MessageTransporter
// RunID returns the run id of current controller.
RunID() string
}
// Visitor is used for forward traffics from local port tot remote service.
type Visitor interface {
Run() error
AcceptConn(conn net.Conn) error
Close()
}
func NewVisitor(
ctx context.Context,
cfg config.VisitorConf,
clientCfg config.ClientCommonConf,
helper Helper,
) (visitor Visitor) {
xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(cfg.GetBaseConfig().ProxyName)
baseVisitor := BaseVisitor{
clientCfg: clientCfg,
helper: helper,
ctx: xlog.NewContext(ctx, xl),
internalLn: utilnet.NewInternalListener(),
}
switch cfg := cfg.(type) {
case *config.STCPVisitorConf:
visitor = &STCPVisitor{
BaseVisitor: &baseVisitor,
cfg: cfg,
}
case *config.XTCPVisitorConf:
visitor = &XTCPVisitor{
BaseVisitor: &baseVisitor,
cfg: cfg,
startTunnelCh: make(chan struct{}),
}
case *config.SUDPVisitorConf:
visitor = &SUDPVisitor{
BaseVisitor: &baseVisitor,
cfg: cfg,
checkCloseCh: make(chan struct{}),
}
}
return
}
type BaseVisitor struct {
clientCfg config.ClientCommonConf
helper Helper
l net.Listener
internalLn *utilnet.InternalListener
mu sync.RWMutex
ctx context.Context
}
func (v *BaseVisitor) AcceptConn(conn net.Conn) error {
return v.internalLn.PutConn(conn)
}
func (v *BaseVisitor) Close() {
if v.l != nil {
v.l.Close()
}
if v.internalLn != nil {
v.internalLn.Close()
}
}

View File

@ -1,452 +0,0 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package visitor
import (
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
libio "github.com/fatedier/golib/io"
fmux "github.com/hashicorp/yamux"
quic "github.com/quic-go/quic-go"
"golang.org/x/time/rate"
"github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/nathole"
"github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog"
)
var ErrNoTunnelSession = errors.New("no tunnel session")
type XTCPVisitor struct {
*BaseVisitor
session TunnelSession
startTunnelCh chan struct{}
retryLimiter *rate.Limiter
cancel context.CancelFunc
cfg *config.XTCPVisitorConf
}
func (sv *XTCPVisitor) Run() (err error) {
sv.ctx, sv.cancel = context.WithCancel(sv.ctx)
if sv.cfg.Protocol == "kcp" {
sv.session = NewKCPTunnelSession()
} else {
sv.session = NewQUICTunnelSession(&sv.clientCfg)
}
if sv.cfg.BindPort > 0 {
sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
if err != nil {
return
}
go sv.worker()
}
go sv.internalConnWorker()
go sv.processTunnelStartEvents()
if sv.cfg.KeepTunnelOpen {
sv.retryLimiter = rate.NewLimiter(rate.Every(time.Hour/time.Duration(sv.cfg.MaxRetriesAnHour)), sv.cfg.MaxRetriesAnHour)
go sv.keepTunnelOpenWorker()
}
return
}
func (sv *XTCPVisitor) Close() {
sv.mu.Lock()
defer sv.mu.Unlock()
sv.BaseVisitor.Close()
if sv.cancel != nil {
sv.cancel()
}
if sv.session != nil {
sv.session.Close()
}
}
func (sv *XTCPVisitor) worker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.l.Accept()
if err != nil {
xl.Warn("xtcp local listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *XTCPVisitor) internalConnWorker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.internalLn.Accept()
if err != nil {
xl.Warn("xtcp internal listener closed")
return
}
go sv.handleConn(conn)
}
}
func (sv *XTCPVisitor) processTunnelStartEvents() {
for {
select {
case <-sv.ctx.Done():
return
case <-sv.startTunnelCh:
start := time.Now()
sv.makeNatHole()
duration := time.Since(start)
// avoid too frequently
if duration < 10*time.Second {
time.Sleep(10*time.Second - duration)
}
}
}
}
func (sv *XTCPVisitor) keepTunnelOpenWorker() {
xl := xlog.FromContextSafe(sv.ctx)
ticker := time.NewTicker(time.Duration(sv.cfg.MinRetryInterval) * time.Second)
defer ticker.Stop()
sv.startTunnelCh <- struct{}{}
for {
select {
case <-sv.ctx.Done():
return
case <-ticker.C:
xl.Debug("keepTunnelOpenWorker try to check tunnel...")
conn, err := sv.getTunnelConn()
if err != nil {
xl.Warn("keepTunnelOpenWorker get tunnel connection error: %v", err)
_ = sv.retryLimiter.Wait(sv.ctx)
continue
}
xl.Debug("keepTunnelOpenWorker check success")
if conn != nil {
conn.Close()
}
}
}
}
func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
xl := xlog.FromContextSafe(sv.ctx)
isConnTrasfered := false
defer func() {
if !isConnTrasfered {
userConn.Close()
}
}()
xl.Debug("get a new xtcp user connection")
// Open a tunnel connection to the server. If there is already a successful hole-punching connection,
// it will be reused. Otherwise, it will block and wait for a successful hole-punching connection until timeout.
ctx := context.Background()
if sv.cfg.FallbackTo != "" {
timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(sv.cfg.FallbackTimeoutMs)*time.Millisecond)
defer cancel()
ctx = timeoutCtx
}
tunnelConn, err := sv.openTunnel(ctx)
if err != nil {
xl.Error("open tunnel error: %v", err)
// no fallback, just return
if sv.cfg.FallbackTo == "" {
return
}
xl.Debug("try to transfer connection to visitor: %s", sv.cfg.FallbackTo)
if err := sv.helper.TransferConn(sv.cfg.FallbackTo, userConn); err != nil {
xl.Error("transfer connection to visitor %s error: %v", sv.cfg.FallbackTo, err)
return
}
isConnTrasfered = true
return
}
var muxConnRWCloser io.ReadWriteCloser = tunnelConn
if sv.cfg.UseEncryption {
muxConnRWCloser, err = libio.WithEncryption(muxConnRWCloser, []byte(sv.cfg.Sk))
if err != nil {
xl.Error("create encryption stream error: %v", err)
return
}
}
if sv.cfg.UseCompression {
muxConnRWCloser = libio.WithCompression(muxConnRWCloser)
}
_, _, errs := libio.Join(userConn, muxConnRWCloser)
xl.Debug("join connections closed")
if len(errs) > 0 {
xl.Trace("join connections errors: %v", errs)
}
}
// openTunnel will open a tunnel connection to the target server.
func (sv *XTCPVisitor) openTunnel(ctx context.Context) (conn net.Conn, err error) {
xl := xlog.FromContextSafe(sv.ctx)
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
timeoutC := time.After(20 * time.Second)
immediateTrigger := make(chan struct{}, 1)
defer close(immediateTrigger)
immediateTrigger <- struct{}{}
for {
select {
case <-sv.ctx.Done():
return nil, sv.ctx.Err()
case <-ctx.Done():
return nil, ctx.Err()
case <-immediateTrigger:
conn, err = sv.getTunnelConn()
case <-ticker.C:
conn, err = sv.getTunnelConn()
case <-timeoutC:
return nil, fmt.Errorf("open tunnel timeout")
}
if err != nil {
if err != ErrNoTunnelSession {
xl.Warn("get tunnel connection error: %v", err)
}
continue
}
return conn, nil
}
}
func (sv *XTCPVisitor) getTunnelConn() (net.Conn, error) {
conn, err := sv.session.OpenConn(sv.ctx)
if err == nil {
return conn, nil
}
sv.session.Close()
select {
case sv.startTunnelCh <- struct{}{}:
default:
}
return nil, err
}
// 0. PreCheck
// 1. Prepare
// 2. ExchangeInfo
// 3. MakeNATHole
// 4. Create a tunnel session using an underlying UDP connection.
func (sv *XTCPVisitor) makeNatHole() {
xl := xlog.FromContextSafe(sv.ctx)
if err := nathole.PreCheck(sv.ctx, sv.helper.MsgTransporter(), sv.cfg.ServerName, 5*time.Second); err != nil {
xl.Warn("nathole precheck error: %v", err)
return
}
prepareResult, err := nathole.Prepare([]string{sv.clientCfg.NatHoleSTUNServer})
if err != nil {
xl.Warn("nathole prepare error: %v", err)
return
}
xl.Info("nathole prepare success, nat type: %s, behavior: %s, addresses: %v, assistedAddresses: %v",
prepareResult.NatType, prepareResult.Behavior, prepareResult.Addrs, prepareResult.AssistedAddrs)
listenConn := prepareResult.ListenConn
// send NatHoleVisitor to server
now := time.Now().Unix()
transactionID := nathole.NewTransactionID()
natHoleVisitorMsg := &msg.NatHoleVisitor{
TransactionID: transactionID,
ProxyName: sv.cfg.ServerName,
Protocol: sv.cfg.Protocol,
SignKey: util.GetAuthKey(sv.cfg.Sk, now),
Timestamp: now,
MappedAddrs: prepareResult.Addrs,
AssistedAddrs: prepareResult.AssistedAddrs,
}
natHoleRespMsg, err := nathole.ExchangeInfo(sv.ctx, sv.helper.MsgTransporter(), transactionID, natHoleVisitorMsg, 5*time.Second)
if err != nil {
listenConn.Close()
xl.Warn("nathole exchange info error: %v", err)
return
}
xl.Info("get natHoleRespMsg, sid [%s], protocol [%s], candidate address %v, assisted address %v, detectBehavior: %+v",
natHoleRespMsg.Sid, natHoleRespMsg.Protocol, natHoleRespMsg.CandidateAddrs,
natHoleRespMsg.AssistedAddrs, natHoleRespMsg.DetectBehavior)
newListenConn, raddr, err := nathole.MakeHole(sv.ctx, listenConn, natHoleRespMsg, []byte(sv.cfg.Sk))
if err != nil {
listenConn.Close()
xl.Warn("make hole error: %v", err)
return
}
listenConn = newListenConn
xl.Info("establishing nat hole connection successful, sid [%s], remoteAddr [%s]", natHoleRespMsg.Sid, raddr)
if err := sv.session.Init(listenConn, raddr); err != nil {
listenConn.Close()
xl.Warn("init tunnel session error: %v", err)
return
}
}
type TunnelSession interface {
Init(listenConn *net.UDPConn, raddr *net.UDPAddr) error
OpenConn(context.Context) (net.Conn, error)
Close()
}
type KCPTunnelSession struct {
session *fmux.Session
lConn *net.UDPConn
mu sync.RWMutex
}
func NewKCPTunnelSession() TunnelSession {
return &KCPTunnelSession{}
}
func (ks *KCPTunnelSession) Init(listenConn *net.UDPConn, raddr *net.UDPAddr) error {
listenConn.Close()
laddr, _ := net.ResolveUDPAddr("udp", listenConn.LocalAddr().String())
lConn, err := net.DialUDP("udp", laddr, raddr)
if err != nil {
return fmt.Errorf("dial udp error: %v", err)
}
remote, err := utilnet.NewKCPConnFromUDP(lConn, true, raddr.String())
if err != nil {
return fmt.Errorf("create kcp connection from udp connection error: %v", err)
}
fmuxCfg := fmux.DefaultConfig()
fmuxCfg.KeepAliveInterval = 10 * time.Second
fmuxCfg.MaxStreamWindowSize = 2 * 1024 * 1024
fmuxCfg.LogOutput = io.Discard
session, err := fmux.Client(remote, fmuxCfg)
if err != nil {
remote.Close()
return fmt.Errorf("initial client session error: %v", err)
}
ks.mu.Lock()
ks.session = session
ks.lConn = lConn
ks.mu.Unlock()
return nil
}
func (ks *KCPTunnelSession) OpenConn(ctx context.Context) (net.Conn, error) {
ks.mu.RLock()
defer ks.mu.RUnlock()
session := ks.session
if session == nil {
return nil, ErrNoTunnelSession
}
return session.Open()
}
func (ks *KCPTunnelSession) Close() {
ks.mu.Lock()
defer ks.mu.Unlock()
if ks.session != nil {
_ = ks.session.Close()
ks.session = nil
}
if ks.lConn != nil {
_ = ks.lConn.Close()
ks.lConn = nil
}
}
type QUICTunnelSession struct {
session quic.Connection
listenConn *net.UDPConn
mu sync.RWMutex
clientCfg *config.ClientCommonConf
}
func NewQUICTunnelSession(clientCfg *config.ClientCommonConf) TunnelSession {
return &QUICTunnelSession{
clientCfg: clientCfg,
}
}
func (qs *QUICTunnelSession) Init(listenConn *net.UDPConn, raddr *net.UDPAddr) error {
tlsConfig, err := transport.NewClientTLSConfig("", "", "", raddr.String())
if err != nil {
return fmt.Errorf("create tls config error: %v", err)
}
tlsConfig.NextProtos = []string{"frp"}
quicConn, err := quic.Dial(listenConn, raddr, raddr.String(), tlsConfig,
&quic.Config{
MaxIdleTimeout: time.Duration(qs.clientCfg.QUICMaxIdleTimeout) * time.Second,
MaxIncomingStreams: int64(qs.clientCfg.QUICMaxIncomingStreams),
KeepAlivePeriod: time.Duration(qs.clientCfg.QUICKeepalivePeriod) * time.Second,
})
if err != nil {
return fmt.Errorf("dial quic error: %v", err)
}
qs.mu.Lock()
qs.session = quicConn
qs.listenConn = listenConn
qs.mu.Unlock()
return nil
}
func (qs *QUICTunnelSession) OpenConn(ctx context.Context) (net.Conn, error) {
qs.mu.RLock()
defer qs.mu.RUnlock()
session := qs.session
if session == nil {
return nil, ErrNoTunnelSession
}
stream, err := session.OpenStreamSync(ctx)
if err != nil {
return nil, err
}
return utilnet.QuicStreamToNetConn(stream, session), nil
}
func (qs *QUICTunnelSession) Close() {
qs.mu.Lock()
defer qs.mu.Unlock()
if qs.session != nil {
_ = qs.session.CloseWithError(0, "")
qs.session = nil
}
if qs.listenConn != nil {
_ = qs.listenConn.Close()
qs.listenConn = nil
}
}

View File

@ -12,60 +12,43 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package visitor package client
import ( import (
"context" "context"
"fmt"
"net"
"reflect"
"sync" "sync"
"time" "time"
"github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
type Manager struct { type VisitorManager struct {
clientCfg config.ClientCommonConf ctl *Control
cfgs map[string]config.VisitorConf
visitors map[string]Visitor cfgs map[string]config.VisitorConf
helper Helper visitors map[string]Visitor
checkInterval time.Duration checkInterval time.Duration
mu sync.RWMutex mu sync.Mutex
ctx context.Context ctx context.Context
stopCh chan struct{} stopCh chan struct{}
} }
func NewManager( func NewVisitorManager(ctx context.Context, ctl *Control) *VisitorManager {
ctx context.Context, return &VisitorManager{
runID string, ctl: ctl,
clientCfg config.ClientCommonConf,
connectServer func() (net.Conn, error),
msgTransporter transport.MessageTransporter,
) *Manager {
m := &Manager{
clientCfg: clientCfg,
cfgs: make(map[string]config.VisitorConf), cfgs: make(map[string]config.VisitorConf),
visitors: make(map[string]Visitor), visitors: make(map[string]Visitor),
checkInterval: 10 * time.Second, checkInterval: 10 * time.Second,
ctx: ctx, ctx: ctx,
stopCh: make(chan struct{}), stopCh: make(chan struct{}),
} }
m.helper = &visitorHelperImpl{
connectServerFn: connectServer,
msgTransporter: msgTransporter,
transferConnFn: m.TransferConn,
runID: runID,
}
return m
} }
func (vm *Manager) Run() { func (vm *VisitorManager) Run() {
xl := xlog.FromContextSafe(vm.ctx) xl := xlog.FromContextSafe(vm.ctx)
ticker := time.NewTicker(vm.checkInterval) ticker := time.NewTicker(vm.checkInterval)
@ -79,7 +62,7 @@ func (vm *Manager) Run() {
case <-ticker.C: case <-ticker.C:
vm.mu.Lock() vm.mu.Lock()
for _, cfg := range vm.cfgs { for _, cfg := range vm.cfgs {
name := cfg.GetBaseConfig().ProxyName name := cfg.GetBaseInfo().ProxyName
if _, exist := vm.visitors[name]; !exist { if _, exist := vm.visitors[name]; !exist {
xl.Info("try to start visitor [%s]", name) xl.Info("try to start visitor [%s]", name)
_ = vm.startVisitor(cfg) _ = vm.startVisitor(cfg)
@ -90,24 +73,11 @@ func (vm *Manager) Run() {
} }
} }
func (vm *Manager) Close() {
vm.mu.Lock()
defer vm.mu.Unlock()
for _, v := range vm.visitors {
v.Close()
}
select {
case <-vm.stopCh:
default:
close(vm.stopCh)
}
}
// Hold lock before calling this function. // Hold lock before calling this function.
func (vm *Manager) startVisitor(cfg config.VisitorConf) (err error) { func (vm *VisitorManager) startVisitor(cfg config.VisitorConf) (err error) {
xl := xlog.FromContextSafe(vm.ctx) xl := xlog.FromContextSafe(vm.ctx)
name := cfg.GetBaseConfig().ProxyName name := cfg.GetBaseInfo().ProxyName
visitor := NewVisitor(vm.ctx, cfg, vm.clientCfg, vm.helper) visitor := NewVisitor(vm.ctx, vm.ctl, cfg)
err = visitor.Run() err = visitor.Run()
if err != nil { if err != nil {
xl.Warn("start error: %v", err) xl.Warn("start error: %v", err)
@ -118,7 +88,7 @@ func (vm *Manager) startVisitor(cfg config.VisitorConf) (err error) {
return return
} }
func (vm *Manager) Reload(cfgs map[string]config.VisitorConf) { func (vm *VisitorManager) Reload(cfgs map[string]config.VisitorConf) {
xl := xlog.FromContextSafe(vm.ctx) xl := xlog.FromContextSafe(vm.ctx)
vm.mu.Lock() vm.mu.Lock()
defer vm.mu.Unlock() defer vm.mu.Unlock()
@ -127,7 +97,9 @@ func (vm *Manager) Reload(cfgs map[string]config.VisitorConf) {
for name, oldCfg := range vm.cfgs { for name, oldCfg := range vm.cfgs {
del := false del := false
cfg, ok := cfgs[name] cfg, ok := cfgs[name]
if !ok || !reflect.DeepEqual(oldCfg, cfg) { if !ok {
del = true
} else if !oldCfg.Compare(cfg) {
del = true del = true
} }
@ -157,36 +129,15 @@ func (vm *Manager) Reload(cfgs map[string]config.VisitorConf) {
} }
} }
// TransferConn transfers a connection to a visitor. func (vm *VisitorManager) Close() {
func (vm *Manager) TransferConn(name string, conn net.Conn) error { vm.mu.Lock()
vm.mu.RLock() defer vm.mu.Unlock()
defer vm.mu.RUnlock() for _, v := range vm.visitors {
v, ok := vm.visitors[name] v.Close()
if !ok { }
return fmt.Errorf("visitor [%s] not found", name) select {
case <-vm.stopCh:
default:
close(vm.stopCh)
} }
return v.AcceptConn(conn)
}
type visitorHelperImpl struct {
connectServerFn func() (net.Conn, error)
msgTransporter transport.MessageTransporter
transferConnFn func(name string, conn net.Conn) error
runID string
}
func (v *visitorHelperImpl) ConnectServer() (net.Conn, error) {
return v.connectServerFn()
}
func (v *visitorHelperImpl) TransferConn(name string, conn net.Conn) error {
return v.transferConnFn(name, conn)
}
func (v *visitorHelperImpl) MsgTransporter() transport.MessageTransporter {
return v.msgTransporter
}
func (v *visitorHelperImpl) RunID() string {
return v.runID
} }

View File

@ -79,7 +79,7 @@ var httpCmd = &cobra.Command{
} }
cfg.BandwidthLimitMode = bandwidthLimitMode cfg.BandwidthLimitMode = bandwidthLimitMode
err = cfg.ValidateForClient() err = cfg.CheckForCli()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)

View File

@ -71,7 +71,7 @@ var httpsCmd = &cobra.Command{
} }
cfg.BandwidthLimitMode = bandwidthLimitMode cfg.BandwidthLimitMode = bandwidthLimitMode
err = cfg.ValidateForClient() err = cfg.CheckForCli()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)

View File

@ -16,7 +16,9 @@ package sub
import ( import (
"fmt" "fmt"
"net"
"os" "os"
"strconv"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -26,7 +28,7 @@ import (
var ( var (
natHoleSTUNServer string natHoleSTUNServer string
natHoleLocalAddr string serverUDPPort int
) )
func init() { func init() {
@ -35,8 +37,8 @@ func init() {
rootCmd.AddCommand(natholeCmd) rootCmd.AddCommand(natholeCmd)
natholeCmd.AddCommand(natholeDiscoveryCmd) natholeCmd.AddCommand(natholeDiscoveryCmd)
natholeCmd.PersistentFlags().StringVarP(&natHoleSTUNServer, "nat_hole_stun_server", "", "", "STUN server address for nathole") natholeCmd.PersistentFlags().StringVarP(&natHoleSTUNServer, "nat_hole_stun_server", "", "stun.easyvoip.com:3478", "STUN server address for nathole")
natholeCmd.PersistentFlags().StringVarP(&natHoleLocalAddr, "nat_hole_local_addr", "l", "", "local address to connect STUN server") natholeCmd.PersistentFlags().IntVarP(&serverUDPPort, "server_udp_port", "", 0, "UDP port of frps for nathole")
} }
var natholeCmd = &cobra.Command{ var natholeCmd = &cobra.Command{
@ -46,45 +48,48 @@ var natholeCmd = &cobra.Command{
var natholeDiscoveryCmd = &cobra.Command{ var natholeDiscoveryCmd = &cobra.Command{
Use: "discover", Use: "discover",
Short: "Discover nathole information from stun server", Short: "Discover nathole information by frps and stun server",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
// ignore error here, because we can use command line pameters // ignore error here, because we can use command line pameters
cfg, _, _, err := config.ParseClientConfig(cfgFile) cfg, _, _, _ := config.ParseClientConfig(cfgFile)
if err != nil {
cfg = config.GetDefaultClientConf()
}
if natHoleSTUNServer != "" { if natHoleSTUNServer != "" {
cfg.NatHoleSTUNServer = natHoleSTUNServer cfg.NatHoleSTUNServer = natHoleSTUNServer
} }
if serverUDPPort != 0 {
cfg.ServerUDPPort = serverUDPPort
}
if err := validateForNatHoleDiscovery(cfg); err != nil { if err := validateForNatHoleDiscovery(cfg); err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
addrs, localAddr, err := nathole.Discover([]string{cfg.NatHoleSTUNServer}, natHoleLocalAddr) serverAddr := ""
if cfg.ServerUDPPort != 0 {
serverAddr = net.JoinHostPort(cfg.ServerAddr, strconv.Itoa(cfg.ServerUDPPort))
}
addresses, err := nathole.Discover(
serverAddr,
[]string{cfg.NatHoleSTUNServer},
[]byte(cfg.Token),
)
if err != nil { if err != nil {
fmt.Println("discover error:", err) fmt.Println("discover error:", err)
os.Exit(1) os.Exit(1)
} }
if len(addrs) < 2 { if len(addresses) < 2 {
fmt.Printf("discover error: can not get enough addresses, need 2, got: %v\n", addrs) fmt.Printf("discover error: can not get enough addresses, need 2, got: %v\n", addresses)
os.Exit(1) os.Exit(1)
} }
localIPs, _ := nathole.ListLocalIPsForNatHole(10) natType, behavior, err := nathole.ClassifyNATType(addresses)
natFeature, err := nathole.ClassifyNATFeature(addrs, localIPs)
if err != nil { if err != nil {
fmt.Println("classify nat feature error:", err) fmt.Println("classify nat type error:", err)
os.Exit(1) os.Exit(1)
} }
fmt.Println("STUN server:", cfg.NatHoleSTUNServer) fmt.Println("Your NAT type is:", natType)
fmt.Println("Your NAT type is:", natFeature.NatType) fmt.Println("Behavior is:", behavior)
fmt.Println("Behavior is:", natFeature.Behavior) fmt.Println("External address is:", addresses)
fmt.Println("External address is:", addrs)
fmt.Println("Local address is:", localAddr.String())
fmt.Println("Public Network:", natFeature.PublicNetwork)
return nil return nil
}, },
} }

View File

@ -53,7 +53,6 @@ var (
logFile string logFile string
logMaxDays int logMaxDays int
disableLogColor bool disableLogColor bool
dnsServer string
proxyName string proxyName string
localIP string localIP string
@ -95,7 +94,6 @@ func RegisterCommonFlags(cmd *cobra.Command) {
cmd.PersistentFlags().IntVarP(&logMaxDays, "log_max_days", "", 3, "log file reversed days") cmd.PersistentFlags().IntVarP(&logMaxDays, "log_max_days", "", 3, "log file reversed days")
cmd.PersistentFlags().BoolVarP(&disableLogColor, "disable_log_color", "", false, "disable log color in console") cmd.PersistentFlags().BoolVarP(&disableLogColor, "disable_log_color", "", false, "disable log color in console")
cmd.PersistentFlags().BoolVarP(&tlsEnable, "tls_enable", "", false, "enable frpc tls") cmd.PersistentFlags().BoolVarP(&tlsEnable, "tls_enable", "", false, "enable frpc tls")
cmd.PersistentFlags().StringVarP(&dnsServer, "dns_server", "", "", "specify dns server instead of using system default one")
} }
var rootCmd = &cobra.Command{ var rootCmd = &cobra.Command{
@ -110,40 +108,39 @@ var rootCmd = &cobra.Command{
// If cfgDir is not empty, run multiple frpc service for each config file in cfgDir. // If cfgDir is not empty, run multiple frpc service for each config file in cfgDir.
// Note that it's only designed for testing. It's not guaranteed to be stable. // Note that it's only designed for testing. It's not guaranteed to be stable.
if cfgDir != "" { if cfgDir != "" {
_ = runMultipleClients(cfgDir) var wg sync.WaitGroup
_ = filepath.WalkDir(cfgDir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return nil
}
if d.IsDir() {
return nil
}
wg.Add(1)
time.Sleep(time.Millisecond)
go func() {
defer wg.Done()
err := runClient(path)
if err != nil {
fmt.Printf("frpc service error for config file [%s]\n", path)
}
}()
return nil
})
wg.Wait()
return nil return nil
} }
// Do not show command usage here. // Do not show command usage here.
err := runClient(cfgFile) err := runClient(cfgFile)
if err != nil { if err != nil {
fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
return nil return nil
}, },
} }
func runMultipleClients(cfgDir string) error {
var wg sync.WaitGroup
err := filepath.WalkDir(cfgDir, func(path string, d fs.DirEntry, err error) error {
if err != nil || d.IsDir() {
return nil
}
wg.Add(1)
time.Sleep(time.Millisecond)
go func() {
defer wg.Done()
err := runClient(path)
if err != nil {
fmt.Printf("frpc service error for config file [%s]\n", path)
}
}()
return nil
})
wg.Wait()
return err
}
func Execute() { func Execute() {
if err := rootCmd.Execute(); err != nil { if err := rootCmd.Execute(); err != nil {
os.Exit(1) os.Exit(1)
@ -180,7 +177,6 @@ func parseClientCommonCfgFromCmd() (cfg config.ClientCommonConf, err error) {
cfg.LogFile = logFile cfg.LogFile = logFile
cfg.LogMaxDays = int64(logMaxDays) cfg.LogMaxDays = int64(logMaxDays)
cfg.DisableLogColor = disableLogColor cfg.DisableLogColor = disableLogColor
cfg.DNSServer = dnsServer
// Only token authentication is supported in cmd mode // Only token authentication is supported in cmd mode
cfg.ClientConfig = auth.GetDefaultClientConf() cfg.ClientConfig = auth.GetDefaultClientConf()
@ -198,7 +194,6 @@ func parseClientCommonCfgFromCmd() (cfg config.ClientCommonConf, err error) {
func runClient(cfgFilePath string) error { func runClient(cfgFilePath string) error {
cfg, pxyCfgs, visitorCfgs, err := config.ParseClientConfig(cfgFilePath) cfg, pxyCfgs, visitorCfgs, err := config.ParseClientConfig(cfgFilePath)
if err != nil { if err != nil {
fmt.Println(err)
return err return err
} }
return startService(cfg, pxyCfgs, visitorCfgs, cfgFilePath) return startService(cfg, pxyCfgs, visitorCfgs, cfgFilePath)
@ -214,8 +209,8 @@ func startService(
cfg.LogMaxDays, cfg.DisableLogColor) cfg.LogMaxDays, cfg.DisableLogColor)
if cfgFile != "" { if cfgFile != "" {
log.Info("start frpc service for config file [%s]", cfgFile) log.Trace("start frpc service for config file [%s]", cfgFile)
defer log.Info("frpc service for config file [%s] stopped", cfgFile) defer log.Trace("frpc service for config file [%s] stopped", cfgFile)
} }
svr, errRet := client.NewService(cfg, pxyCfgs, visitorCfgs, cfgFile) svr, errRet := client.NewService(cfg, pxyCfgs, visitorCfgs, cfgFile)
if errRet != nil { if errRet != nil {

View File

@ -78,7 +78,7 @@ var stcpCmd = &cobra.Command{
os.Exit(1) os.Exit(1)
} }
cfg.BandwidthLimitMode = bandwidthLimitMode cfg.BandwidthLimitMode = bandwidthLimitMode
err = cfg.ValidateForClient() err = cfg.CheckForCli()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
@ -95,7 +95,7 @@ var stcpCmd = &cobra.Command{
cfg.ServerName = serverName cfg.ServerName = serverName
cfg.BindAddr = bindAddr cfg.BindAddr = bindAddr
cfg.BindPort = bindPort cfg.BindPort = bindPort
err = cfg.Validate() err = cfg.Check()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)

View File

@ -78,7 +78,7 @@ var sudpCmd = &cobra.Command{
os.Exit(1) os.Exit(1)
} }
cfg.BandwidthLimitMode = bandwidthLimitMode cfg.BandwidthLimitMode = bandwidthLimitMode
err = cfg.ValidateForClient() err = cfg.CheckForCli()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
@ -95,7 +95,7 @@ var sudpCmd = &cobra.Command{
cfg.ServerName = serverName cfg.ServerName = serverName
cfg.BindAddr = bindAddr cfg.BindAddr = bindAddr
cfg.BindPort = bindPort cfg.BindPort = bindPort
err = cfg.Validate() err = cfg.Check()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)

View File

@ -68,7 +68,7 @@ var tcpCmd = &cobra.Command{
} }
cfg.BandwidthLimitMode = bandwidthLimitMode cfg.BandwidthLimitMode = bandwidthLimitMode
err = cfg.ValidateForClient() err = cfg.CheckForCli()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)

View File

@ -73,7 +73,7 @@ var tcpMuxCmd = &cobra.Command{
} }
cfg.BandwidthLimitMode = bandwidthLimitMode cfg.BandwidthLimitMode = bandwidthLimitMode
err = cfg.ValidateForClient() err = cfg.CheckForCli()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)

View File

@ -68,7 +68,7 @@ var udpCmd = &cobra.Command{
} }
cfg.BandwidthLimitMode = bandwidthLimitMode cfg.BandwidthLimitMode = bandwidthLimitMode
err = cfg.ValidateForClient() err = cfg.CheckForCli()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)

View File

@ -78,7 +78,7 @@ var xtcpCmd = &cobra.Command{
os.Exit(1) os.Exit(1)
} }
cfg.BandwidthLimitMode = bandwidthLimitMode cfg.BandwidthLimitMode = bandwidthLimitMode
err = cfg.ValidateForClient() err = cfg.CheckForCli()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
@ -95,7 +95,7 @@ var xtcpCmd = &cobra.Command{
cfg.ServerName = serverName cfg.ServerName = serverName
cfg.BindAddr = bindAddr cfg.BindAddr = bindAddr
cfg.BindPort = bindPort cfg.BindPort = bindPort
err = cfg.Validate() err = cfg.Check()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)

View File

@ -39,6 +39,7 @@ var (
bindAddr string bindAddr string
bindPort int bindPort int
bindUDPPort int
kcpBindPort int kcpBindPort int
proxyBindAddr string proxyBindAddr string
vhostHTTPPort int vhostHTTPPort int
@ -69,6 +70,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&bindAddr, "bind_addr", "", "0.0.0.0", "bind address") rootCmd.PersistentFlags().StringVarP(&bindAddr, "bind_addr", "", "0.0.0.0", "bind address")
rootCmd.PersistentFlags().IntVarP(&bindPort, "bind_port", "p", 7000, "bind port") rootCmd.PersistentFlags().IntVarP(&bindPort, "bind_port", "p", 7000, "bind port")
rootCmd.PersistentFlags().IntVarP(&bindUDPPort, "bind_udp_port", "", 0, "bind udp port")
rootCmd.PersistentFlags().IntVarP(&kcpBindPort, "kcp_bind_port", "", 0, "kcp bind udp port") rootCmd.PersistentFlags().IntVarP(&kcpBindPort, "kcp_bind_port", "", 0, "kcp bind udp port")
rootCmd.PersistentFlags().StringVarP(&proxyBindAddr, "proxy_bind_addr", "", "0.0.0.0", "proxy bind address") rootCmd.PersistentFlags().StringVarP(&proxyBindAddr, "proxy_bind_addr", "", "0.0.0.0", "proxy bind address")
rootCmd.PersistentFlags().IntVarP(&vhostHTTPPort, "vhost_http_port", "", 0, "vhost http port") rootCmd.PersistentFlags().IntVarP(&vhostHTTPPort, "vhost_http_port", "", 0, "vhost http port")
@ -157,6 +159,7 @@ func parseServerCommonCfgFromCmd() (cfg config.ServerCommonConf, err error) {
cfg.BindAddr = bindAddr cfg.BindAddr = bindAddr
cfg.BindPort = bindPort cfg.BindPort = bindPort
cfg.BindUDPPort = bindUDPPort
cfg.KCPBindPort = kcpBindPort cfg.KCPBindPort = kcpBindPort
cfg.ProxyBindAddr = proxyBindAddr cfg.ProxyBindAddr = proxyBindAddr
cfg.VhostHTTPPort = vhostHTTPPort cfg.VhostHTTPPort = vhostHTTPPort

View File

@ -6,6 +6,14 @@
server_addr = 0.0.0.0 server_addr = 0.0.0.0
server_port = 7000 server_port = 7000
# Specify another address of the server to connect for nat hole. By default, it's same with
# server_addr.
# nat_hole_server_addr = 0.0.0.0
# ServerUDPPort specifies the server port to help penetrate NAT hole. By default, this value is 0.
# This parameter is only used when executing "nathole discover" in the command line.
# server_udp_port = 0
# STUN server to help penetrate NAT hole. # STUN server to help penetrate NAT hole.
# nat_hole_stun_server = stun.easyvoip.com:3478 # nat_hole_stun_server = stun.easyvoip.com:3478
@ -326,9 +334,6 @@ local_ip = 127.0.0.1
local_port = 22 local_port = 22
use_encryption = false use_encryption = false
use_compression = false use_compression = false
# If not empty, only visitors from specified users can connect.
# Otherwise, visitors from same user can connect. '*' means allow all users.
allow_users = *
# user of frpc should be same in both stcp server and stcp visitor # user of frpc should be same in both stcp server and stcp visitor
[secret_tcp_visitor] [secret_tcp_visitor]
@ -340,8 +345,6 @@ server_name = secret_tcp
sk = abcdefg sk = abcdefg
# connect this address to visitor stcp server # connect this address to visitor stcp server
bind_addr = 127.0.0.1 bind_addr = 127.0.0.1
# bind_port can be less than 0, it means don't bind to the port and only receive connections redirected from
# other visitors. (This is not supported for SUDP now)
bind_port = 9000 bind_port = 9000
use_encryption = false use_encryption = false
use_compression = false use_compression = false
@ -353,30 +356,16 @@ local_ip = 127.0.0.1
local_port = 22 local_port = 22
use_encryption = false use_encryption = false
use_compression = false use_compression = false
# If not empty, only visitors from specified users can connect.
# Otherwise, visitors from same user can connect. '*' means allow all users.
allow_users = user1, user2
[p2p_tcp_visitor] [p2p_tcp_visitor]
role = visitor role = visitor
type = xtcp type = xtcp
# if the server user is not set, it defaults to the current user
server_user = user1
server_name = p2p_tcp server_name = p2p_tcp
sk = abcdefg sk = abcdefg
bind_addr = 127.0.0.1 bind_addr = 127.0.0.1
# bind_port can be less than 0, it means don't bind to the port and only receive connections redirected from
# other visitors. (This is not supported for SUDP now)
bind_port = 9001 bind_port = 9001
use_encryption = false use_encryption = false
use_compression = false use_compression = false
# when automatic tunnel persistence is required, set it to true
keep_tunnel_open = false
# effective when keep_tunnel_open is set to true, the number of attempts to punch through per hour
max_retries_an_hour = 8
min_retry_interval = 90
# fallback_to = stcp_visitor
# fallback_timeout_ms = 500
[tcpmuxhttpconnect] [tcpmuxhttpconnect]
type = tcpmux type = tcpmux

View File

@ -6,6 +6,9 @@
bind_addr = 0.0.0.0 bind_addr = 0.0.0.0
bind_port = 7000 bind_port = 7000
# udp port to help make udp hole to penetrate nat
bind_udp_port = 7001
# udp port used for kcp protocol, it can be same with 'bind_port'. # udp port used for kcp protocol, it can be same with 'bind_port'.
# if not set, kcp is disabled in frps. # if not set, kcp is disabled in frps.
kcp_bind_port = 7000 kcp_bind_port = 7000
@ -154,9 +157,6 @@ udp_packet_size = 1500
# Dashboard port must be set first # Dashboard port must be set first
pprof_enable = false pprof_enable = false
# Retention time for NAT hole punching strategy data.
nat_hole_analysis_data_reserve_hours = 168
[plugin.user-manager] [plugin.user-manager]
addr = 127.0.0.1:9000 addr = 127.0.0.1:9000
path = /handler path = /handler

9
go.mod
View File

@ -18,14 +18,12 @@ require (
github.com/pion/stun v0.4.0 github.com/pion/stun v0.4.0
github.com/pires/go-proxyproto v0.6.2 github.com/pires/go-proxyproto v0.6.2
github.com/prometheus/client_golang v1.13.0 github.com/prometheus/client_golang v1.13.0
github.com/quic-go/quic-go v0.34.0 github.com/quic-go/quic-go v0.32.0
github.com/rodaine/table v1.0.1 github.com/rodaine/table v1.0.1
github.com/samber/lo v1.38.1
github.com/spf13/cobra v1.1.3 github.com/spf13/cobra v1.1.3
github.com/stretchr/testify v1.8.1 github.com/stretchr/testify v1.8.1
golang.org/x/net v0.7.0 golang.org/x/net v0.7.0
golang.org/x/oauth2 v0.3.0 golang.org/x/oauth2 v0.3.0
golang.org/x/sync v0.1.0
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 golang.org/x/time v0.0.0-20220210224613-90d013bbcef8
gopkg.in/ini.v1 v1.67.0 gopkg.in/ini.v1 v1.67.0
k8s.io/apimachinery v0.26.1 k8s.io/apimachinery v0.26.1
@ -57,8 +55,9 @@ require (
github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.8.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect
github.com/quic-go/qtls-go1-19 v0.3.2 // indirect github.com/quic-go/qtls-go1-18 v0.2.0 // indirect
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect github.com/quic-go/qtls-go1-19 v0.2.0 // indirect
github.com/quic-go/qtls-go1-20 v0.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/templexxx/cpufeat v0.0.0-20180724012125-cef66df7f161 // indirect github.com/templexxx/cpufeat v0.0.0-20180724012125-cef66df7f161 // indirect
github.com/templexxx/xor v0.0.0-20191217153810-f85b25db303b // indirect github.com/templexxx/xor v0.0.0-20191217153810-f85b25db303b // indirect

18
go.sum
View File

@ -381,12 +381,14 @@ github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo= github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4= github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= github.com/quic-go/qtls-go1-18 v0.2.0 h1:5ViXqBZ90wpUcZS0ge79rf029yx0dYB0McyPJwqqj7U=
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= github.com/quic-go/qtls-go1-18 v0.2.0/go.mod h1:moGulGHK7o6O8lSPSZNoOwcLvJKJ85vVNc7oJFD65bc=
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= github.com/quic-go/qtls-go1-19 v0.2.0 h1:Cvn2WdhyViFUHoOqK52i51k4nDX8EwIh5VJiVM4nttk=
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= github.com/quic-go/qtls-go1-19 v0.2.0/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
github.com/quic-go/quic-go v0.34.0 h1:OvOJ9LFjTySgwOTYUZmNoq0FzVicP8YujpV0kB7m2lU= github.com/quic-go/qtls-go1-20 v0.1.0 h1:d1PK3ErFy9t7zxKsG3NXBJXZjp/kMLoIb3y/kV54oAI=
github.com/quic-go/quic-go v0.34.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/quic-go/qtls-go1-20 v0.1.0/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
github.com/quic-go/quic-go v0.32.0 h1:lY02md31s1JgPiiyfqJijpu/UX/Iun304FI3yUqX7tA=
github.com/quic-go/quic-go v0.32.0/go.mod h1:/fCsKANhQIeD5l76c2JFU+07gVE3KaA0FP+0zMWwfwo=
github.com/rodaine/table v1.0.1 h1:U/VwCnUxlVYxw8+NJiLIuCxA/xa6jL38MY3FYysVWWQ= github.com/rodaine/table v1.0.1 h1:U/VwCnUxlVYxw8+NJiLIuCxA/xa6jL38MY3FYysVWWQ=
github.com/rodaine/table v1.0.1/go.mod h1:UVEtfBsflpeEcD56nF4F5AocNFta0ZuolpSVdPtlmP4= github.com/rodaine/table v1.0.1/go.mod h1:UVEtfBsflpeEcD56nF4F5AocNFta0ZuolpSVdPtlmP4=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
@ -397,8 +399,6 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
@ -602,8 +602,6 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

View File

@ -17,4 +17,4 @@ if [ x${LOG_LEVEL} != x"" ]; then
logLevel=${LOG_LEVEL} logLevel=${LOG_LEVEL}
fi fi
ginkgo -nodes=8 --poll-progress-after=30s ${ROOT}/test/e2e -- -frpc-path=${ROOT}/bin/frpc -frps-path=${ROOT}/bin/frps -log-level=${logLevel} -debug=${debug} ginkgo -nodes=8 --poll-progress-after=20s ${ROOT}/test/e2e -- -frpc-path=${ROOT}/bin/frpc -frps-path=${ROOT}/bin/frps -log-level=${logLevel} -debug=${debug}

View File

@ -73,30 +73,30 @@ func (auth *TokenAuthSetterVerifier) SetNewWorkConn(newWorkConnMsg *msg.NewWorkC
return nil return nil
} }
func (auth *TokenAuthSetterVerifier) VerifyLogin(m *msg.Login) error { func (auth *TokenAuthSetterVerifier) VerifyLogin(loginMsg *msg.Login) error {
if !util.ConstantTimeEqString(util.GetAuthKey(auth.token, m.Timestamp), m.PrivilegeKey) { if util.GetAuthKey(auth.token, loginMsg.Timestamp) != loginMsg.PrivilegeKey {
return fmt.Errorf("token in login doesn't match token from configuration") return fmt.Errorf("token in login doesn't match token from configuration")
} }
return nil return nil
} }
func (auth *TokenAuthSetterVerifier) VerifyPing(m *msg.Ping) error { func (auth *TokenAuthSetterVerifier) VerifyPing(pingMsg *msg.Ping) error {
if !auth.AuthenticateHeartBeats { if !auth.AuthenticateHeartBeats {
return nil return nil
} }
if !util.ConstantTimeEqString(util.GetAuthKey(auth.token, m.Timestamp), m.PrivilegeKey) { if util.GetAuthKey(auth.token, pingMsg.Timestamp) != pingMsg.PrivilegeKey {
return fmt.Errorf("token in heartbeat doesn't match token from configuration") return fmt.Errorf("token in heartbeat doesn't match token from configuration")
} }
return nil return nil
} }
func (auth *TokenAuthSetterVerifier) VerifyNewWorkConn(m *msg.NewWorkConn) error { func (auth *TokenAuthSetterVerifier) VerifyNewWorkConn(newWorkConnMsg *msg.NewWorkConn) error {
if !auth.AuthenticateNewWorkConns { if !auth.AuthenticateNewWorkConns {
return nil return nil
} }
if !util.ConstantTimeEqString(util.GetAuthKey(auth.token, m.Timestamp), m.PrivilegeKey) { if util.GetAuthKey(auth.token, newWorkConnMsg.Timestamp) != newWorkConnMsg.PrivilegeKey {
return fmt.Errorf("token in NewWorkConn doesn't match token from configuration") return fmt.Errorf("token in NewWorkConn doesn't match token from configuration")
} }
return nil return nil

View File

@ -35,9 +35,15 @@ type ClientCommonConf struct {
// ServerAddr specifies the address of the server to connect to. By // ServerAddr specifies the address of the server to connect to. By
// default, this value is "0.0.0.0". // default, this value is "0.0.0.0".
ServerAddr string `ini:"server_addr" json:"server_addr"` ServerAddr string `ini:"server_addr" json:"server_addr"`
// Specify another address of the server to connect for nat hole. By default, it's same with
// ServerAddr.
NatHoleServerAddr string `ini:"nat_hole_server_addr" json:"nat_hole_server_addr"`
// ServerPort specifies the port to connect to the server on. By default, // ServerPort specifies the port to connect to the server on. By default,
// this value is 7000. // this value is 7000.
ServerPort int `ini:"server_port" json:"server_port"` ServerPort int `ini:"server_port" json:"server_port"`
// ServerUDPPort specifies the server port to help penetrate NAT hole. By default, this value is 0.
// This parameter is only used when executing "nathole discover" in the command line.
ServerUDPPort int `ini:"server_udp_port" json:"server_udp_port"`
// STUN server to help penetrate NAT hole. // STUN server to help penetrate NAT hole.
NatHoleSTUNServer string `ini:"nat_hole_stun_server" json:"nat_hole_stun_server"` NatHoleSTUNServer string `ini:"nat_hole_stun_server" json:"nat_hole_stun_server"`
// The maximum amount of time a dial to server will wait for a connect to complete. // The maximum amount of time a dial to server will wait for a connect to complete.

View File

@ -500,10 +500,8 @@ func Test_LoadClientBasicConf(t *testing.T) {
}, },
BandwidthLimitMode: BandwidthLimitModeClient, BandwidthLimitMode: BandwidthLimitModeClient,
}, },
RoleServerCommonConf: RoleServerCommonConf{ Role: "server",
Role: "server", Sk: "abcdefg",
Sk: "abcdefg",
},
}, },
testUser + ".p2p_tcp": &XTCPProxyConf{ testUser + ".p2p_tcp": &XTCPProxyConf{
BaseProxyConf: BaseProxyConf{ BaseProxyConf: BaseProxyConf{
@ -515,10 +513,8 @@ func Test_LoadClientBasicConf(t *testing.T) {
}, },
BandwidthLimitMode: BandwidthLimitModeClient, BandwidthLimitMode: BandwidthLimitModeClient,
}, },
RoleServerCommonConf: RoleServerCommonConf{ Role: "server",
Role: "server", Sk: "abcdefg",
Sk: "abcdefg",
},
}, },
testUser + ".tcpmuxhttpconnect": &TCPMuxProxyConf{ testUser + ".tcpmuxhttpconnect": &TCPMuxProxyConf{
BaseProxyConf: BaseProxyConf{ BaseProxyConf: BaseProxyConf{
@ -665,10 +661,6 @@ func Test_LoadClientBasicConf(t *testing.T) {
BindAddr: "127.0.0.1", BindAddr: "127.0.0.1",
BindPort: 9001, BindPort: 9001,
}, },
Protocol: "quic",
MaxRetriesAnHour: 8,
MinRetryInterval: 90,
FallbackTimeoutMs: 1000,
}, },
} }

View File

@ -51,23 +51,13 @@ func NewConfByType(proxyType string) ProxyConf {
} }
type ProxyConf interface { type ProxyConf interface {
// GetBaseConfig returns the BaseProxyConf for this config. GetBaseInfo() *BaseProxyConf
GetBaseConfig() *BaseProxyConf
// SetDefaultValues sets the default values for this config.
SetDefaultValues()
// UnmarshalFromMsg unmarshals a msg.NewProxy message into this config.
// This function will be called on the frps side.
UnmarshalFromMsg(*msg.NewProxy) UnmarshalFromMsg(*msg.NewProxy)
// UnmarshalFromIni unmarshals a ini.Section into this config. This function
// will be called on the frpc side.
UnmarshalFromIni(string, string, *ini.Section) error UnmarshalFromIni(string, string, *ini.Section) error
// MarshalToMsg marshals this config into a msg.NewProxy message. This
// function will be called on the frpc side.
MarshalToMsg(*msg.NewProxy) MarshalToMsg(*msg.NewProxy)
// ValidateForClient checks that the config is valid for the frpc side. CheckForCli() error
ValidateForClient() error CheckForSvr(ServerCommonConf) error
// ValidateForServer checks that the config is valid for the frps side. Compare(ProxyConf) bool
ValidateForServer(ServerCommonConf) error
} }
// LocalSvrConf configures what location the client will to, or what // LocalSvrConf configures what location the client will to, or what
@ -168,16 +158,6 @@ type DomainConf struct {
SubDomain string `ini:"subdomain" json:"subdomain"` SubDomain string `ini:"subdomain" json:"subdomain"`
} }
type RoleServerCommonConf struct {
Role string `ini:"role" json:"role"`
Sk string `ini:"sk" json:"sk"`
AllowUsers []string `ini:"allow_users" json:"allow_users"`
}
func (cfg *RoleServerCommonConf) setDefaultValues() {
cfg.Role = "server"
}
// HTTP // HTTP
type HTTPProxyConf struct { type HTTPProxyConf struct {
BaseProxyConf `ini:",extends"` BaseProxyConf `ini:",extends"`
@ -203,13 +183,6 @@ type TCPProxyConf struct {
RemotePort int `ini:"remote_port" json:"remote_port"` RemotePort int `ini:"remote_port" json:"remote_port"`
} }
// UDP
type UDPProxyConf struct {
BaseProxyConf `ini:",extends"`
RemotePort int `ini:"remote_port" json:"remote_port"`
}
// TCPMux // TCPMux
type TCPMuxProxyConf struct { type TCPMuxProxyConf struct {
BaseProxyConf `ini:",extends"` BaseProxyConf `ini:",extends"`
@ -223,30 +196,80 @@ type TCPMuxProxyConf struct {
// STCP // STCP
type STCPProxyConf struct { type STCPProxyConf struct {
BaseProxyConf `ini:",extends"` BaseProxyConf `ini:",extends"`
RoleServerCommonConf `ini:",extends"`
Role string `ini:"role" json:"role"`
Sk string `ini:"sk" json:"sk"`
} }
// XTCP // XTCP
type XTCPProxyConf struct { type XTCPProxyConf struct {
BaseProxyConf `ini:",extends"` BaseProxyConf `ini:",extends"`
RoleServerCommonConf `ini:",extends"`
Role string `ini:"role" json:"role"`
Sk string `ini:"sk" json:"sk"`
}
// UDP
type UDPProxyConf struct {
BaseProxyConf `ini:",extends"`
RemotePort int `ini:"remote_port" json:"remote_port"`
} }
// SUDP // SUDP
type SUDPProxyConf struct { type SUDPProxyConf struct {
BaseProxyConf `ini:",extends"` BaseProxyConf `ini:",extends"`
RoleServerCommonConf `ini:",extends"`
Role string `ini:"role" json:"role"`
Sk string `ini:"sk" json:"sk"`
} }
// Proxy Conf Loader // Proxy Conf Loader
// DefaultProxyConf creates a empty ProxyConf object by proxyType. // DefaultProxyConf creates a empty ProxyConf object by proxyType.
// If proxyType doesn't exist, return nil. // If proxyType doesn't exist, return nil.
func DefaultProxyConf(proxyType string) ProxyConf { func DefaultProxyConf(proxyType string) ProxyConf {
conf := NewConfByType(proxyType) var conf ProxyConf
if conf != nil { switch proxyType {
conf.SetDefaultValues() case consts.TCPProxy:
conf = &TCPProxyConf{
BaseProxyConf: defaultBaseProxyConf(proxyType),
}
case consts.TCPMuxProxy:
conf = &TCPMuxProxyConf{
BaseProxyConf: defaultBaseProxyConf(proxyType),
}
case consts.UDPProxy:
conf = &UDPProxyConf{
BaseProxyConf: defaultBaseProxyConf(proxyType),
}
case consts.HTTPProxy:
conf = &HTTPProxyConf{
BaseProxyConf: defaultBaseProxyConf(proxyType),
}
case consts.HTTPSProxy:
conf = &HTTPSProxyConf{
BaseProxyConf: defaultBaseProxyConf(proxyType),
}
case consts.STCPProxy:
conf = &STCPProxyConf{
BaseProxyConf: defaultBaseProxyConf(proxyType),
Role: "server",
}
case consts.XTCPProxy:
conf = &XTCPProxyConf{
BaseProxyConf: defaultBaseProxyConf(proxyType),
Role: "server",
}
case consts.SUDPProxy:
conf = &SUDPProxyConf{
BaseProxyConf: defaultBaseProxyConf(proxyType),
Role: "server",
}
default:
return nil
} }
return conf return conf
} }
@ -267,9 +290,10 @@ func NewProxyConfFromIni(prefix, name string, section *ini.Section) (ProxyConf,
return nil, err return nil, err
} }
if err := conf.ValidateForClient(); err != nil { if err := conf.CheckForCli(); err != nil {
return nil, err return nil, err
} }
return conf, nil return conf, nil
} }
@ -286,7 +310,7 @@ func NewProxyConfFromMsg(pMsg *msg.NewProxy, serverCfg ServerCommonConf) (ProxyC
conf.UnmarshalFromMsg(pMsg) conf.UnmarshalFromMsg(pMsg)
err := conf.ValidateForServer(serverCfg) err := conf.CheckForSvr(serverCfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -295,15 +319,42 @@ func NewProxyConfFromMsg(pMsg *msg.NewProxy, serverCfg ServerCommonConf) (ProxyC
} }
// Base // Base
func (cfg *BaseProxyConf) GetBaseConfig() *BaseProxyConf { func defaultBaseProxyConf(proxyType string) BaseProxyConf {
return BaseProxyConf{
ProxyType: proxyType,
LocalSvrConf: LocalSvrConf{
LocalIP: "127.0.0.1",
},
BandwidthLimitMode: BandwidthLimitModeClient,
}
}
func (cfg *BaseProxyConf) GetBaseInfo() *BaseProxyConf {
return cfg return cfg
} }
func (cfg *BaseProxyConf) SetDefaultValues() { func (cfg *BaseProxyConf) compare(cmp *BaseProxyConf) bool {
cfg.LocalSvrConf = LocalSvrConf{ if cfg.ProxyName != cmp.ProxyName ||
LocalIP: "127.0.0.1", cfg.ProxyType != cmp.ProxyType ||
cfg.UseEncryption != cmp.UseEncryption ||
cfg.UseCompression != cmp.UseCompression ||
cfg.Group != cmp.Group ||
cfg.GroupKey != cmp.GroupKey ||
cfg.ProxyProtocolVersion != cmp.ProxyProtocolVersion ||
!cfg.BandwidthLimit.Equal(&cmp.BandwidthLimit) ||
cfg.BandwidthLimitMode != cmp.BandwidthLimitMode ||
!reflect.DeepEqual(cfg.Metas, cmp.Metas) {
return false
} }
cfg.BandwidthLimitMode = BandwidthLimitModeClient
if !reflect.DeepEqual(cfg.LocalSvrConf, cmp.LocalSvrConf) {
return false
}
if !reflect.DeepEqual(cfg.HealthCheckConf, cmp.HealthCheckConf) {
return false
}
return true
} }
// BaseProxyConf apply custom logic changes. // BaseProxyConf apply custom logic changes.
@ -372,7 +423,7 @@ func (cfg *BaseProxyConf) unmarshalFromMsg(pMsg *msg.NewProxy) {
cfg.Metas = pMsg.Metas cfg.Metas = pMsg.Metas
} }
func (cfg *BaseProxyConf) validateForClient() (err error) { func (cfg *BaseProxyConf) checkForCli() (err error) {
if cfg.ProxyProtocolVersion != "" { if cfg.ProxyProtocolVersion != "" {
if cfg.ProxyProtocolVersion != "v1" && cfg.ProxyProtocolVersion != "v2" { if cfg.ProxyProtocolVersion != "v1" && cfg.ProxyProtocolVersion != "v2" {
return fmt.Errorf("no support proxy protocol version: %s", cfg.ProxyProtocolVersion) return fmt.Errorf("no support proxy protocol version: %s", cfg.ProxyProtocolVersion)
@ -383,16 +434,16 @@ func (cfg *BaseProxyConf) validateForClient() (err error) {
return fmt.Errorf("bandwidth_limit_mode should be client or server") return fmt.Errorf("bandwidth_limit_mode should be client or server")
} }
if err = cfg.LocalSvrConf.validateForClient(); err != nil { if err = cfg.LocalSvrConf.checkForCli(); err != nil {
return return
} }
if err = cfg.HealthCheckConf.validateForClient(); err != nil { if err = cfg.HealthCheckConf.checkForCli(); err != nil {
return return
} }
return nil return nil
} }
func (cfg *BaseProxyConf) validateForServer() (err error) { func (cfg *BaseProxyConf) checkForSvr() (err error) {
if cfg.BandwidthLimitMode != "client" && cfg.BandwidthLimitMode != "server" { if cfg.BandwidthLimitMode != "client" && cfg.BandwidthLimitMode != "server" {
return fmt.Errorf("bandwidth_limit_mode should be client or server") return fmt.Errorf("bandwidth_limit_mode should be client or server")
} }
@ -408,14 +459,14 @@ func (cfg *DomainConf) check() (err error) {
return return
} }
func (cfg *DomainConf) validateForClient() (err error) { func (cfg *DomainConf) checkForCli() (err error) {
if err = cfg.check(); err != nil { if err = cfg.check(); err != nil {
return return
} }
return return
} }
func (cfg *DomainConf) validateForServer(serverCfg ServerCommonConf) (err error) { func (cfg *DomainConf) checkForSvr(serverCfg ServerCommonConf) (err error) {
if err = cfg.check(); err != nil { if err = cfg.check(); err != nil {
return return
} }
@ -440,7 +491,7 @@ func (cfg *DomainConf) validateForServer(serverCfg ServerCommonConf) (err error)
} }
// LocalSvrConf // LocalSvrConf
func (cfg *LocalSvrConf) validateForClient() (err error) { func (cfg *LocalSvrConf) checkForCli() (err error) {
if cfg.Plugin == "" { if cfg.Plugin == "" {
if cfg.LocalIP == "" { if cfg.LocalIP == "" {
err = fmt.Errorf("local ip or plugin is required") err = fmt.Errorf("local ip or plugin is required")
@ -455,7 +506,7 @@ func (cfg *LocalSvrConf) validateForClient() (err error) {
} }
// HealthCheckConf // HealthCheckConf
func (cfg *HealthCheckConf) validateForClient() error { func (cfg *HealthCheckConf) checkForCli() error {
if cfg.HealthCheckType != "" && cfg.HealthCheckType != "tcp" && cfg.HealthCheckType != "http" { if cfg.HealthCheckType != "" && cfg.HealthCheckType != "tcp" && cfg.HealthCheckType != "http" {
return fmt.Errorf("unsupport health check type") return fmt.Errorf("unsupport health check type")
} }
@ -473,7 +524,7 @@ func preUnmarshalFromIni(cfg ProxyConf, prefix string, name string, section *ini
return err return err
} }
err = cfg.GetBaseConfig().decorate(prefix, name, section) err = cfg.GetBaseInfo().decorate(prefix, name, section)
if err != nil { if err != nil {
return err return err
} }
@ -482,6 +533,24 @@ func preUnmarshalFromIni(cfg ProxyConf, prefix string, name string, section *ini
} }
// TCP // TCP
func (cfg *TCPProxyConf) Compare(cmp ProxyConf) bool {
cmpConf, ok := cmp.(*TCPProxyConf)
if !ok {
return false
}
if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) {
return false
}
// Add custom logic equal if exists.
if cfg.RemotePort != cmpConf.RemotePort {
return false
}
return true
}
func (cfg *TCPProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { func (cfg *TCPProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) {
cfg.BaseProxyConf.unmarshalFromMsg(pMsg) cfg.BaseProxyConf.unmarshalFromMsg(pMsg)
@ -507,8 +576,8 @@ func (cfg *TCPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) {
pMsg.RemotePort = cfg.RemotePort pMsg.RemotePort = cfg.RemotePort
} }
func (cfg *TCPProxyConf) ValidateForClient() (err error) { func (cfg *TCPProxyConf) CheckForCli() (err error) {
if err = cfg.BaseProxyConf.validateForClient(); err != nil { if err = cfg.BaseProxyConf.checkForCli(); err != nil {
return return
} }
@ -517,14 +586,39 @@ func (cfg *TCPProxyConf) ValidateForClient() (err error) {
return return
} }
func (cfg *TCPProxyConf) ValidateForServer(serverCfg ServerCommonConf) error { func (cfg *TCPProxyConf) CheckForSvr(serverCfg ServerCommonConf) error {
if err := cfg.BaseProxyConf.validateForServer(); err != nil { if err := cfg.BaseProxyConf.checkForSvr(); err != nil {
return err return err
} }
return nil return nil
} }
// TCPMux // TCPMux
func (cfg *TCPMuxProxyConf) Compare(cmp ProxyConf) bool {
cmpConf, ok := cmp.(*TCPMuxProxyConf)
if !ok {
return false
}
if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) {
return false
}
// Add custom logic equal if exists.
if !reflect.DeepEqual(cfg.DomainConf, cmpConf.DomainConf) {
return false
}
if cfg.Multiplexer != cmpConf.Multiplexer ||
cfg.HTTPUser != cmpConf.HTTPUser ||
cfg.HTTPPwd != cmpConf.HTTPPwd ||
cfg.RouteByHTTPUser != cmpConf.RouteByHTTPUser {
return false
}
return true
}
func (cfg *TCPMuxProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error { func (cfg *TCPMuxProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error {
err := preUnmarshalFromIni(cfg, prefix, name, section) err := preUnmarshalFromIni(cfg, prefix, name, section)
if err != nil { if err != nil {
@ -560,13 +654,13 @@ func (cfg *TCPMuxProxyConf) MarshalToMsg(pMsg *msg.NewProxy) {
pMsg.RouteByHTTPUser = cfg.RouteByHTTPUser pMsg.RouteByHTTPUser = cfg.RouteByHTTPUser
} }
func (cfg *TCPMuxProxyConf) ValidateForClient() (err error) { func (cfg *TCPMuxProxyConf) CheckForCli() (err error) {
if err = cfg.BaseProxyConf.validateForClient(); err != nil { if err = cfg.BaseProxyConf.checkForCli(); err != nil {
return return
} }
// Add custom logic check if exists // Add custom logic check if exists
if err = cfg.DomainConf.validateForClient(); err != nil { if err = cfg.DomainConf.checkForCli(); err != nil {
return return
} }
@ -577,8 +671,8 @@ func (cfg *TCPMuxProxyConf) ValidateForClient() (err error) {
return return
} }
func (cfg *TCPMuxProxyConf) ValidateForServer(serverCfg ServerCommonConf) (err error) { func (cfg *TCPMuxProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) {
if err := cfg.BaseProxyConf.validateForServer(); err != nil { if err := cfg.BaseProxyConf.checkForSvr(); err != nil {
return err return err
} }
@ -590,7 +684,7 @@ func (cfg *TCPMuxProxyConf) ValidateForServer(serverCfg ServerCommonConf) (err e
return fmt.Errorf("proxy [%s] type [tcpmux] with multiplexer [httpconnect] requires tcpmux_httpconnect_port configuration", cfg.ProxyName) return fmt.Errorf("proxy [%s] type [tcpmux] with multiplexer [httpconnect] requires tcpmux_httpconnect_port configuration", cfg.ProxyName)
} }
if err = cfg.DomainConf.validateForServer(serverCfg); err != nil { if err = cfg.DomainConf.checkForSvr(serverCfg); err != nil {
err = fmt.Errorf("proxy [%s] domain conf check error: %v", cfg.ProxyName, err) err = fmt.Errorf("proxy [%s] domain conf check error: %v", cfg.ProxyName, err)
return return
} }
@ -599,6 +693,24 @@ func (cfg *TCPMuxProxyConf) ValidateForServer(serverCfg ServerCommonConf) (err e
} }
// UDP // UDP
func (cfg *UDPProxyConf) Compare(cmp ProxyConf) bool {
cmpConf, ok := cmp.(*UDPProxyConf)
if !ok {
return false
}
if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) {
return false
}
// Add custom logic equal if exists.
if cfg.RemotePort != cmpConf.RemotePort {
return false
}
return true
}
func (cfg *UDPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error { func (cfg *UDPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error {
err := preUnmarshalFromIni(cfg, prefix, name, section) err := preUnmarshalFromIni(cfg, prefix, name, section)
if err != nil { if err != nil {
@ -624,8 +736,8 @@ func (cfg *UDPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) {
pMsg.RemotePort = cfg.RemotePort pMsg.RemotePort = cfg.RemotePort
} }
func (cfg *UDPProxyConf) ValidateForClient() (err error) { func (cfg *UDPProxyConf) CheckForCli() (err error) {
if err = cfg.BaseProxyConf.validateForClient(); err != nil { if err = cfg.BaseProxyConf.checkForCli(); err != nil {
return return
} }
@ -634,14 +746,41 @@ func (cfg *UDPProxyConf) ValidateForClient() (err error) {
return return
} }
func (cfg *UDPProxyConf) ValidateForServer(serverCfg ServerCommonConf) error { func (cfg *UDPProxyConf) CheckForSvr(serverCfg ServerCommonConf) error {
if err := cfg.BaseProxyConf.validateForServer(); err != nil { if err := cfg.BaseProxyConf.checkForSvr(); err != nil {
return err return err
} }
return nil return nil
} }
// HTTP // HTTP
func (cfg *HTTPProxyConf) Compare(cmp ProxyConf) bool {
cmpConf, ok := cmp.(*HTTPProxyConf)
if !ok {
return false
}
if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) {
return false
}
// Add custom logic equal if exists.
if !reflect.DeepEqual(cfg.DomainConf, cmpConf.DomainConf) {
return false
}
if !reflect.DeepEqual(cfg.Locations, cmpConf.Locations) ||
cfg.HTTPUser != cmpConf.HTTPUser ||
cfg.HTTPPwd != cmpConf.HTTPPwd ||
cfg.HostHeaderRewrite != cmpConf.HostHeaderRewrite ||
cfg.RouteByHTTPUser != cmpConf.RouteByHTTPUser ||
!reflect.DeepEqual(cfg.Headers, cmpConf.Headers) {
return false
}
return true
}
func (cfg *HTTPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error { func (cfg *HTTPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error {
err := preUnmarshalFromIni(cfg, prefix, name, section) err := preUnmarshalFromIni(cfg, prefix, name, section)
if err != nil { if err != nil {
@ -650,6 +789,7 @@ func (cfg *HTTPProxyConf) UnmarshalFromIni(prefix string, name string, section *
// Add custom logic unmarshal if exists // Add custom logic unmarshal if exists
cfg.Headers = GetMapWithoutPrefix(section.KeysHash(), "header_") cfg.Headers = GetMapWithoutPrefix(section.KeysHash(), "header_")
return nil return nil
} }
@ -681,21 +821,21 @@ func (cfg *HTTPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) {
pMsg.RouteByHTTPUser = cfg.RouteByHTTPUser pMsg.RouteByHTTPUser = cfg.RouteByHTTPUser
} }
func (cfg *HTTPProxyConf) ValidateForClient() (err error) { func (cfg *HTTPProxyConf) CheckForCli() (err error) {
if err = cfg.BaseProxyConf.validateForClient(); err != nil { if err = cfg.BaseProxyConf.checkForCli(); err != nil {
return return
} }
// Add custom logic check if exists // Add custom logic check if exists
if err = cfg.DomainConf.validateForClient(); err != nil { if err = cfg.DomainConf.checkForCli(); err != nil {
return return
} }
return return
} }
func (cfg *HTTPProxyConf) ValidateForServer(serverCfg ServerCommonConf) (err error) { func (cfg *HTTPProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) {
if err := cfg.BaseProxyConf.validateForServer(); err != nil { if err := cfg.BaseProxyConf.checkForSvr(); err != nil {
return err return err
} }
@ -703,7 +843,7 @@ func (cfg *HTTPProxyConf) ValidateForServer(serverCfg ServerCommonConf) (err err
return fmt.Errorf("type [http] not support when vhost_http_port is not set") return fmt.Errorf("type [http] not support when vhost_http_port is not set")
} }
if err = cfg.DomainConf.validateForServer(serverCfg); err != nil { if err = cfg.DomainConf.checkForSvr(serverCfg); err != nil {
err = fmt.Errorf("proxy [%s] domain conf check error: %v", cfg.ProxyName, err) err = fmt.Errorf("proxy [%s] domain conf check error: %v", cfg.ProxyName, err)
return return
} }
@ -712,6 +852,24 @@ func (cfg *HTTPProxyConf) ValidateForServer(serverCfg ServerCommonConf) (err err
} }
// HTTPS // HTTPS
func (cfg *HTTPSProxyConf) Compare(cmp ProxyConf) bool {
cmpConf, ok := cmp.(*HTTPSProxyConf)
if !ok {
return false
}
if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) {
return false
}
// Add custom logic equal if exists.
if !reflect.DeepEqual(cfg.DomainConf, cmpConf.DomainConf) {
return false
}
return true
}
func (cfg *HTTPSProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error { func (cfg *HTTPSProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error {
err := preUnmarshalFromIni(cfg, prefix, name, section) err := preUnmarshalFromIni(cfg, prefix, name, section)
if err != nil { if err != nil {
@ -719,6 +877,7 @@ func (cfg *HTTPSProxyConf) UnmarshalFromIni(prefix string, name string, section
} }
// Add custom logic unmarshal if exists // Add custom logic unmarshal if exists
return nil return nil
} }
@ -738,20 +897,21 @@ func (cfg *HTTPSProxyConf) MarshalToMsg(pMsg *msg.NewProxy) {
pMsg.SubDomain = cfg.SubDomain pMsg.SubDomain = cfg.SubDomain
} }
func (cfg *HTTPSProxyConf) ValidateForClient() (err error) { func (cfg *HTTPSProxyConf) CheckForCli() (err error) {
if err = cfg.BaseProxyConf.validateForClient(); err != nil { if err = cfg.BaseProxyConf.checkForCli(); err != nil {
return return
} }
// Add custom logic check if exists // Add custom logic check if exists
if err = cfg.DomainConf.validateForClient(); err != nil { if err = cfg.DomainConf.checkForCli(); err != nil {
return return
} }
return return
} }
func (cfg *HTTPSProxyConf) ValidateForServer(serverCfg ServerCommonConf) (err error) { func (cfg *HTTPSProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) {
if err := cfg.BaseProxyConf.validateForServer(); err != nil { if err := cfg.BaseProxyConf.checkForSvr(); err != nil {
return err return err
} }
@ -759,7 +919,7 @@ func (cfg *HTTPSProxyConf) ValidateForServer(serverCfg ServerCommonConf) (err er
return fmt.Errorf("type [https] not support when vhost_https_port is not set") return fmt.Errorf("type [https] not support when vhost_https_port is not set")
} }
if err = cfg.DomainConf.validateForServer(serverCfg); err != nil { if err = cfg.DomainConf.checkForSvr(serverCfg); err != nil {
err = fmt.Errorf("proxy [%s] domain conf check error: %v", cfg.ProxyName, err) err = fmt.Errorf("proxy [%s] domain conf check error: %v", cfg.ProxyName, err)
return return
} }
@ -768,9 +928,23 @@ func (cfg *HTTPSProxyConf) ValidateForServer(serverCfg ServerCommonConf) (err er
} }
// SUDP // SUDP
func (cfg *SUDPProxyConf) SetDefaultValues() { func (cfg *SUDPProxyConf) Compare(cmp ProxyConf) bool {
cfg.BaseProxyConf.SetDefaultValues() cmpConf, ok := cmp.(*SUDPProxyConf)
cfg.RoleServerCommonConf.setDefaultValues() if !ok {
return false
}
if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) {
return false
}
// Add custom logic equal if exists.
if cfg.Role != cmpConf.Role ||
cfg.Sk != cmpConf.Sk {
return false
}
return true
} }
func (cfg *SUDPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error { func (cfg *SUDPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error {
@ -780,6 +954,7 @@ func (cfg *SUDPProxyConf) UnmarshalFromIni(prefix string, name string, section *
} }
// Add custom logic unmarshal if exists // Add custom logic unmarshal if exists
return nil return nil
} }
@ -798,8 +973,8 @@ func (cfg *SUDPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) {
pMsg.Sk = cfg.Sk pMsg.Sk = cfg.Sk
} }
func (cfg *SUDPProxyConf) ValidateForClient() (err error) { func (cfg *SUDPProxyConf) CheckForCli() (err error) {
if err := cfg.BaseProxyConf.validateForClient(); err != nil { if err := cfg.BaseProxyConf.checkForCli(); err != nil {
return err return err
} }
@ -811,17 +986,31 @@ func (cfg *SUDPProxyConf) ValidateForClient() (err error) {
return nil return nil
} }
func (cfg *SUDPProxyConf) ValidateForServer(serverCfg ServerCommonConf) error { func (cfg *SUDPProxyConf) CheckForSvr(serverCfg ServerCommonConf) error {
if err := cfg.BaseProxyConf.validateForServer(); err != nil { if err := cfg.BaseProxyConf.checkForSvr(); err != nil {
return err return err
} }
return nil return nil
} }
// STCP // STCP
func (cfg *STCPProxyConf) SetDefaultValues() { func (cfg *STCPProxyConf) Compare(cmp ProxyConf) bool {
cfg.BaseProxyConf.SetDefaultValues() cmpConf, ok := cmp.(*STCPProxyConf)
cfg.RoleServerCommonConf.setDefaultValues() if !ok {
return false
}
if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) {
return false
}
// Add custom logic equal if exists.
if cfg.Role != cmpConf.Role ||
cfg.Sk != cmpConf.Sk {
return false
}
return true
} }
func (cfg *STCPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error { func (cfg *STCPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error {
@ -834,6 +1023,7 @@ func (cfg *STCPProxyConf) UnmarshalFromIni(prefix string, name string, section *
if cfg.Role == "" { if cfg.Role == "" {
cfg.Role = "server" cfg.Role = "server"
} }
return nil return nil
} }
@ -852,8 +1042,8 @@ func (cfg *STCPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) {
pMsg.Sk = cfg.Sk pMsg.Sk = cfg.Sk
} }
func (cfg *STCPProxyConf) ValidateForClient() (err error) { func (cfg *STCPProxyConf) CheckForCli() (err error) {
if err = cfg.BaseProxyConf.validateForClient(); err != nil { if err = cfg.BaseProxyConf.checkForCli(); err != nil {
return return
} }
@ -865,17 +1055,31 @@ func (cfg *STCPProxyConf) ValidateForClient() (err error) {
return return
} }
func (cfg *STCPProxyConf) ValidateForServer(serverCfg ServerCommonConf) error { func (cfg *STCPProxyConf) CheckForSvr(serverCfg ServerCommonConf) error {
if err := cfg.BaseProxyConf.validateForServer(); err != nil { if err := cfg.BaseProxyConf.checkForSvr(); err != nil {
return err return err
} }
return nil return nil
} }
// XTCP // XTCP
func (cfg *XTCPProxyConf) SetDefaultValues() { func (cfg *XTCPProxyConf) Compare(cmp ProxyConf) bool {
cfg.BaseProxyConf.SetDefaultValues() cmpConf, ok := cmp.(*XTCPProxyConf)
cfg.RoleServerCommonConf.setDefaultValues() if !ok {
return false
}
if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) {
return false
}
// Add custom logic equal if exists.
if cfg.Role != cmpConf.Role ||
cfg.Sk != cmpConf.Sk {
return false
}
return true
} }
func (cfg *XTCPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error { func (cfg *XTCPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error {
@ -888,6 +1092,7 @@ func (cfg *XTCPProxyConf) UnmarshalFromIni(prefix string, name string, section *
if cfg.Role == "" { if cfg.Role == "" {
cfg.Role = "server" cfg.Role = "server"
} }
return nil return nil
} }
@ -906,8 +1111,8 @@ func (cfg *XTCPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) {
pMsg.Sk = cfg.Sk pMsg.Sk = cfg.Sk
} }
func (cfg *XTCPProxyConf) ValidateForClient() (err error) { func (cfg *XTCPProxyConf) CheckForCli() (err error) {
if err = cfg.BaseProxyConf.validateForClient(); err != nil { if err = cfg.BaseProxyConf.checkForCli(); err != nil {
return return
} }
@ -915,11 +1120,12 @@ func (cfg *XTCPProxyConf) ValidateForClient() (err error) {
if cfg.Role != "server" { if cfg.Role != "server" {
return fmt.Errorf("role should be 'server'") return fmt.Errorf("role should be 'server'")
} }
return return
} }
func (cfg *XTCPProxyConf) ValidateForServer(serverCfg ServerCommonConf) error { func (cfg *XTCPProxyConf) CheckForSvr(serverCfg ServerCommonConf) error {
if err := cfg.BaseProxyConf.validateForServer(); err != nil { if err := cfg.BaseProxyConf.checkForSvr(); err != nil {
return err return err
} }
return nil return nil

View File

@ -254,10 +254,8 @@ func Test_Proxy_UnmarshalFromIni(t *testing.T) {
}, },
BandwidthLimitMode: BandwidthLimitModeClient, BandwidthLimitMode: BandwidthLimitModeClient,
}, },
RoleServerCommonConf: RoleServerCommonConf{ Role: "server",
Role: "server", Sk: "abcdefg",
Sk: "abcdefg",
},
}, },
}, },
{ {
@ -281,10 +279,8 @@ func Test_Proxy_UnmarshalFromIni(t *testing.T) {
}, },
BandwidthLimitMode: BandwidthLimitModeClient, BandwidthLimitMode: BandwidthLimitModeClient,
}, },
RoleServerCommonConf: RoleServerCommonConf{ Role: "server",
Role: "server", Sk: "abcdefg",
Sk: "abcdefg",
},
}, },
}, },
{ {

View File

@ -38,6 +38,10 @@ type ServerCommonConf struct {
// BindPort specifies the port that the server listens on. By default, this // BindPort specifies the port that the server listens on. By default, this
// value is 7000. // value is 7000.
BindPort int `ini:"bind_port" json:"bind_port" validate:"gte=0,lte=65535"` BindPort int `ini:"bind_port" json:"bind_port" validate:"gte=0,lte=65535"`
// BindUDPPort specifies the UDP port that the server listens on. If this
// value is 0, the server will not listen for UDP connections. By default,
// this value is 0
BindUDPPort int `ini:"bind_udp_port" json:"bind_udp_port" validate:"gte=0,lte=65535"`
// KCPBindPort specifies the KCP port that the server listens on. If this // KCPBindPort specifies the KCP port that the server listens on. If this
// value is 0, the server will not listen for KCP connections. By default, // value is 0, the server will not listen for KCP connections. By default,
// this value is 0. // this value is 0.
@ -192,38 +196,35 @@ type ServerCommonConf struct {
// Enable golang pprof handlers in dashboard listener. // Enable golang pprof handlers in dashboard listener.
// Dashboard port must be set first. // Dashboard port must be set first.
PprofEnable bool `ini:"pprof_enable" json:"pprof_enable"` PprofEnable bool `ini:"pprof_enable" json:"pprof_enable"`
// NatHoleAnalysisDataReserveHours specifies the hours to reserve nat hole analysis data.
NatHoleAnalysisDataReserveHours int64 `ini:"nat_hole_analysis_data_reserve_hours" json:"nat_hole_analysis_data_reserve_hours"`
} }
// GetDefaultServerConf returns a server configuration with reasonable // GetDefaultServerConf returns a server configuration with reasonable
// defaults. // defaults.
func GetDefaultServerConf() ServerCommonConf { func GetDefaultServerConf() ServerCommonConf {
return ServerCommonConf{ return ServerCommonConf{
ServerConfig: auth.GetDefaultServerConf(), ServerConfig: auth.GetDefaultServerConf(),
BindAddr: "0.0.0.0", BindAddr: "0.0.0.0",
BindPort: 7000, BindPort: 7000,
QUICKeepalivePeriod: 10, QUICKeepalivePeriod: 10,
QUICMaxIdleTimeout: 30, QUICMaxIdleTimeout: 30,
QUICMaxIncomingStreams: 100000, QUICMaxIncomingStreams: 100000,
VhostHTTPTimeout: 60, VhostHTTPTimeout: 60,
DashboardAddr: "0.0.0.0", DashboardAddr: "0.0.0.0",
LogFile: "console", LogFile: "console",
LogWay: "console", LogWay: "console",
LogLevel: "info", LogLevel: "info",
LogMaxDays: 3, LogMaxDays: 3,
DetailedErrorsToClient: true, DetailedErrorsToClient: true,
TCPMux: true, TCPMux: true,
TCPMuxKeepaliveInterval: 60, TCPMuxKeepaliveInterval: 60,
TCPKeepAlive: 7200, TCPKeepAlive: 7200,
AllowPorts: make(map[int]struct{}), AllowPorts: make(map[int]struct{}),
MaxPoolCount: 5, MaxPoolCount: 5,
MaxPortsPerClient: 0, MaxPortsPerClient: 0,
HeartbeatTimeout: 90, HeartbeatTimeout: 90,
UserConnTimeout: 10, UserConnTimeout: 10,
HTTPPlugins: make(map[string]plugin.HTTPPluginOptions), HTTPPlugins: make(map[string]plugin.HTTPPluginOptions),
UDPPacketSize: 1500, UDPPacketSize: 1500,
NatHoleAnalysisDataReserveHours: 7 * 24,
} }
} }

View File

@ -36,6 +36,7 @@ func Test_LoadServerCommonConf(t *testing.T) {
[common] [common]
bind_addr = 0.0.0.9 bind_addr = 0.0.0.9
bind_port = 7009 bind_port = 7009
bind_udp_port = 7008
kcp_bind_port = 7007 kcp_bind_port = 7007
proxy_bind_addr = 127.0.0.9 proxy_bind_addr = 127.0.0.9
vhost_http_port = 89 vhost_http_port = 89
@ -103,6 +104,7 @@ func Test_LoadServerCommonConf(t *testing.T) {
}, },
BindAddr: "0.0.0.9", BindAddr: "0.0.0.9",
BindPort: 7009, BindPort: 7009,
BindUDPPort: 7008,
KCPBindPort: 7007, KCPBindPort: 7007,
QUICKeepalivePeriod: 10, QUICKeepalivePeriod: 10,
QUICMaxIdleTimeout: 30, QUICMaxIdleTimeout: 30,
@ -132,19 +134,18 @@ func Test_LoadServerCommonConf(t *testing.T) {
12: {}, 12: {},
99: {}, 99: {},
}, },
AllowPortsStr: "10-12,99", AllowPortsStr: "10-12,99",
MaxPoolCount: 59, MaxPoolCount: 59,
MaxPortsPerClient: 9, MaxPortsPerClient: 9,
TLSOnly: true, TLSOnly: true,
TLSCertFile: "server.crt", TLSCertFile: "server.crt",
TLSKeyFile: "server.key", TLSKeyFile: "server.key",
TLSTrustedCaFile: "ca.crt", TLSTrustedCaFile: "ca.crt",
SubDomainHost: "frps.com", SubDomainHost: "frps.com",
TCPMux: true, TCPMux: true,
TCPMuxKeepaliveInterval: 60, TCPMuxKeepaliveInterval: 60,
TCPKeepAlive: 7200, TCPKeepAlive: 7200,
UDPPacketSize: 1509, UDPPacketSize: 1509,
NatHoleAnalysisDataReserveHours: 7 * 24,
HTTPPlugins: map[string]plugin.HTTPPluginOptions{ HTTPPlugins: map[string]plugin.HTTPPluginOptions{
"user-manager": { "user-manager": {
@ -169,6 +170,7 @@ func Test_LoadServerCommonConf(t *testing.T) {
[common] [common]
bind_addr = 0.0.0.9 bind_addr = 0.0.0.9
bind_port = 7009 bind_port = 7009
bind_udp_port = 7008
`), `),
expected: ServerCommonConf{ expected: ServerCommonConf{
ServerConfig: auth.ServerConfig{ ServerConfig: auth.ServerConfig{
@ -178,32 +180,32 @@ func Test_LoadServerCommonConf(t *testing.T) {
AuthenticateNewWorkConns: false, AuthenticateNewWorkConns: false,
}, },
}, },
BindAddr: "0.0.0.9", BindAddr: "0.0.0.9",
BindPort: 7009, BindPort: 7009,
QUICKeepalivePeriod: 10, BindUDPPort: 7008,
QUICMaxIdleTimeout: 30, QUICKeepalivePeriod: 10,
QUICMaxIncomingStreams: 100000, QUICMaxIdleTimeout: 30,
ProxyBindAddr: "0.0.0.9", QUICMaxIncomingStreams: 100000,
VhostHTTPTimeout: 60, ProxyBindAddr: "0.0.0.9",
DashboardAddr: "0.0.0.0", VhostHTTPTimeout: 60,
DashboardUser: "", DashboardAddr: "0.0.0.0",
DashboardPwd: "", DashboardUser: "",
EnablePrometheus: false, DashboardPwd: "",
LogFile: "console", EnablePrometheus: false,
LogWay: "console", LogFile: "console",
LogLevel: "info", LogWay: "console",
LogMaxDays: 3, LogLevel: "info",
DetailedErrorsToClient: true, LogMaxDays: 3,
TCPMux: true, DetailedErrorsToClient: true,
TCPMuxKeepaliveInterval: 60, TCPMux: true,
TCPKeepAlive: 7200, TCPMuxKeepaliveInterval: 60,
AllowPorts: make(map[int]struct{}), TCPKeepAlive: 7200,
MaxPoolCount: 5, AllowPorts: make(map[int]struct{}),
HeartbeatTimeout: 90, MaxPoolCount: 5,
UserConnTimeout: 10, HeartbeatTimeout: 90,
HTTPPlugins: make(map[string]plugin.HTTPPluginOptions), UserConnTimeout: 10,
UDPPacketSize: 1500, HTTPPlugins: make(map[string]plugin.HTTPPluginOptions),
NatHoleAnalysisDataReserveHours: 7 * 24, UDPPacketSize: 1500,
}, },
}, },
} }

View File

@ -18,7 +18,6 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"github.com/samber/lo"
"gopkg.in/ini.v1" "gopkg.in/ini.v1"
"github.com/fatedier/frp/pkg/consts" "github.com/fatedier/frp/pkg/consts"
@ -34,12 +33,10 @@ var (
) )
type VisitorConf interface { type VisitorConf interface {
// GetBaseConfig returns the base config of visitor. GetBaseInfo() *BaseVisitorConf
GetBaseConfig() *BaseVisitorConf Compare(cmp VisitorConf) bool
// UnmarshalFromIni unmarshals config from ini.
UnmarshalFromIni(prefix string, name string, section *ini.Section) error UnmarshalFromIni(prefix string, name string, section *ini.Section) error
// Validate validates config. Check() error
Validate() error
} }
type BaseVisitorConf struct { type BaseVisitorConf struct {
@ -49,14 +46,9 @@ type BaseVisitorConf struct {
UseCompression bool `ini:"use_compression" json:"use_compression"` UseCompression bool `ini:"use_compression" json:"use_compression"`
Role string `ini:"role" json:"role"` Role string `ini:"role" json:"role"`
Sk string `ini:"sk" json:"sk"` Sk string `ini:"sk" json:"sk"`
// if the server user is not set, it defaults to the current user ServerName string `ini:"server_name" json:"server_name"`
ServerUser string `ini:"server_user" json:"server_user"` BindAddr string `ini:"bind_addr" json:"bind_addr"`
ServerName string `ini:"server_name" json:"server_name"` BindPort int `ini:"bind_port" json:"bind_port"`
BindAddr string `ini:"bind_addr" json:"bind_addr"`
// BindPort is the port that visitor listens on.
// It can be less than 0, it means don't bind to the port and only receive connections redirected from
// other visitors. (This is not supported for SUDP now)
BindPort int `ini:"bind_port" json:"bind_port"`
} }
type SUDPVisitorConf struct { type SUDPVisitorConf struct {
@ -69,13 +61,6 @@ type STCPVisitorConf struct {
type XTCPVisitorConf struct { type XTCPVisitorConf struct {
BaseVisitorConf `ini:",extends"` BaseVisitorConf `ini:",extends"`
Protocol string `ini:"protocol" json:"protocol,omitempty"`
KeepTunnelOpen bool `ini:"keep_tunnel_open" json:"keep_tunnel_open,omitempty"`
MaxRetriesAnHour int `ini:"max_retries_an_hour" json:"max_retries_an_hour,omitempty"`
MinRetryInterval int `ini:"min_retry_interval" json:"min_retry_interval,omitempty"`
FallbackTo string `ini:"fallback_to" json:"fallback_to,omitempty"`
FallbackTimeoutMs int `ini:"fallback_timeout_ms" json:"fallback_timeout_ms,omitempty"`
} }
// DefaultVisitorConf creates a empty VisitorConf object by visitorType. // DefaultVisitorConf creates a empty VisitorConf object by visitorType.
@ -85,6 +70,7 @@ func DefaultVisitorConf(visitorType string) VisitorConf {
if !ok { if !ok {
return nil return nil
} }
return reflect.New(v).Interface().(VisitorConf) return reflect.New(v).Interface().(VisitorConf)
} }
@ -106,7 +92,7 @@ func NewVisitorConfFromIni(prefix string, name string, section *ini.Section) (Vi
return nil, fmt.Errorf("visitor [%s] type [%s] error", name, visitorType) return nil, fmt.Errorf("visitor [%s] type [%s] error", name, visitorType)
} }
if err := conf.Validate(); err != nil { if err := conf.Check(); err != nil {
return nil, err return nil, err
} }
@ -114,11 +100,26 @@ func NewVisitorConfFromIni(prefix string, name string, section *ini.Section) (Vi
} }
// Base // Base
func (cfg *BaseVisitorConf) GetBaseConfig() *BaseVisitorConf { func (cfg *BaseVisitorConf) GetBaseInfo() *BaseVisitorConf {
return cfg return cfg
} }
func (cfg *BaseVisitorConf) validate() (err error) { func (cfg *BaseVisitorConf) compare(cmp *BaseVisitorConf) bool {
if cfg.ProxyName != cmp.ProxyName ||
cfg.ProxyType != cmp.ProxyType ||
cfg.UseEncryption != cmp.UseEncryption ||
cfg.UseCompression != cmp.UseCompression ||
cfg.Role != cmp.Role ||
cfg.Sk != cmp.Sk ||
cfg.ServerName != cmp.ServerName ||
cfg.BindAddr != cmp.BindAddr ||
cfg.BindPort != cmp.BindPort {
return false
}
return true
}
func (cfg *BaseVisitorConf) check() (err error) {
if cfg.Role != "visitor" { if cfg.Role != "visitor" {
err = fmt.Errorf("invalid role") err = fmt.Errorf("invalid role")
return return
@ -127,9 +128,7 @@ func (cfg *BaseVisitorConf) validate() (err error) {
err = fmt.Errorf("bind_addr shouldn't be empty") err = fmt.Errorf("bind_addr shouldn't be empty")
return return
} }
// BindPort can be less than 0, it means don't bind to the port and only receive connections redirected from if cfg.BindPort <= 0 {
// other visitors
if cfg.BindPort == 0 {
err = fmt.Errorf("bind_port is required") err = fmt.Errorf("bind_port is required")
return return
} }
@ -144,16 +143,13 @@ func (cfg *BaseVisitorConf) unmarshalFromIni(prefix string, name string, section
cfg.ProxyName = prefix + name cfg.ProxyName = prefix + name
// server_name // server_name
if cfg.ServerUser == "" { cfg.ServerName = prefix + cfg.ServerName
cfg.ServerName = prefix + cfg.ServerName
} else {
cfg.ServerName = cfg.ServerUser + "." + cfg.ServerName
}
// bind_addr // bind_addr
if cfg.BindAddr == "" { if cfg.BindAddr == "" {
cfg.BindAddr = "127.0.0.1" cfg.BindAddr = "127.0.0.1"
} }
return nil return nil
} }
@ -163,16 +159,32 @@ func preVisitorUnmarshalFromIni(cfg VisitorConf, prefix string, name string, sec
return err return err
} }
err = cfg.GetBaseConfig().unmarshalFromIni(prefix, name, section) err = cfg.GetBaseInfo().unmarshalFromIni(prefix, name, section)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
// SUDP // SUDP
var _ VisitorConf = &SUDPVisitorConf{} var _ VisitorConf = &SUDPVisitorConf{}
func (cfg *SUDPVisitorConf) Compare(cmp VisitorConf) bool {
cmpConf, ok := cmp.(*SUDPVisitorConf)
if !ok {
return false
}
if !cfg.BaseVisitorConf.compare(&cmpConf.BaseVisitorConf) {
return false
}
// Add custom login equal, if exists
return true
}
func (cfg *SUDPVisitorConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) (err error) { func (cfg *SUDPVisitorConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) (err error) {
err = preVisitorUnmarshalFromIni(cfg, prefix, name, section) err = preVisitorUnmarshalFromIni(cfg, prefix, name, section)
if err != nil { if err != nil {
@ -184,8 +196,8 @@ func (cfg *SUDPVisitorConf) UnmarshalFromIni(prefix string, name string, section
return return
} }
func (cfg *SUDPVisitorConf) Validate() (err error) { func (cfg *SUDPVisitorConf) Check() (err error) {
if err = cfg.BaseVisitorConf.validate(); err != nil { if err = cfg.BaseVisitorConf.check(); err != nil {
return return
} }
@ -197,6 +209,21 @@ func (cfg *SUDPVisitorConf) Validate() (err error) {
// STCP // STCP
var _ VisitorConf = &STCPVisitorConf{} var _ VisitorConf = &STCPVisitorConf{}
func (cfg *STCPVisitorConf) Compare(cmp VisitorConf) bool {
cmpConf, ok := cmp.(*STCPVisitorConf)
if !ok {
return false
}
if !cfg.BaseVisitorConf.compare(&cmpConf.BaseVisitorConf) {
return false
}
// Add custom login equal, if exists
return true
}
func (cfg *STCPVisitorConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) (err error) { func (cfg *STCPVisitorConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) (err error) {
err = preVisitorUnmarshalFromIni(cfg, prefix, name, section) err = preVisitorUnmarshalFromIni(cfg, prefix, name, section)
if err != nil { if err != nil {
@ -208,8 +235,8 @@ func (cfg *STCPVisitorConf) UnmarshalFromIni(prefix string, name string, section
return return
} }
func (cfg *STCPVisitorConf) Validate() (err error) { func (cfg *STCPVisitorConf) Check() (err error) {
if err = cfg.BaseVisitorConf.validate(); err != nil { if err = cfg.BaseVisitorConf.check(); err != nil {
return return
} }
@ -221,6 +248,21 @@ func (cfg *STCPVisitorConf) Validate() (err error) {
// XTCP // XTCP
var _ VisitorConf = &XTCPVisitorConf{} var _ VisitorConf = &XTCPVisitorConf{}
func (cfg *XTCPVisitorConf) Compare(cmp VisitorConf) bool {
cmpConf, ok := cmp.(*XTCPVisitorConf)
if !ok {
return false
}
if !cfg.BaseVisitorConf.compare(&cmpConf.BaseVisitorConf) {
return false
}
// Add custom login equal, if exists
return true
}
func (cfg *XTCPVisitorConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) (err error) { func (cfg *XTCPVisitorConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) (err error) {
err = preVisitorUnmarshalFromIni(cfg, prefix, name, section) err = preVisitorUnmarshalFromIni(cfg, prefix, name, section)
if err != nil { if err != nil {
@ -228,29 +270,16 @@ func (cfg *XTCPVisitorConf) UnmarshalFromIni(prefix string, name string, section
} }
// Add custom logic unmarshal, if exists // Add custom logic unmarshal, if exists
if cfg.Protocol == "" {
cfg.Protocol = "quic"
}
if cfg.MaxRetriesAnHour <= 0 {
cfg.MaxRetriesAnHour = 8
}
if cfg.MinRetryInterval <= 0 {
cfg.MinRetryInterval = 90
}
if cfg.FallbackTimeoutMs <= 0 {
cfg.FallbackTimeoutMs = 1000
}
return return
} }
func (cfg *XTCPVisitorConf) Validate() (err error) { func (cfg *XTCPVisitorConf) Check() (err error) {
if err = cfg.BaseVisitorConf.validate(); err != nil { if err = cfg.BaseVisitorConf.check(); err != nil {
return return
} }
// Add custom logic validate, if exists // Add custom logic validate, if exists
if !lo.Contains([]string{"", "kcp", "quic"}, cfg.Protocol) {
return fmt.Errorf("protocol should be 'kcp' or 'quic'")
}
return return
} }

View File

@ -87,10 +87,6 @@ func Test_Visitor_UnmarshalFromIni(t *testing.T) {
BindAddr: "127.0.0.1", BindAddr: "127.0.0.1",
BindPort: 9001, BindPort: 9001,
}, },
Protocol: "quic",
MaxRetriesAnHour: 8,
MinRetryInterval: 90,
FallbackTimeoutMs: 1000,
}, },
}, },
} }

View File

@ -60,30 +60,25 @@ func (m *serverMetrics) run() {
go func() { go func() {
for { for {
time.Sleep(12 * time.Hour) time.Sleep(12 * time.Hour)
start := time.Now() log.Debug("start to clear useless proxy statistics data...")
count, total := m.clearUselessInfo() m.clearUselessInfo()
log.Debug("clear useless proxy statistics data count %d/%d, cost %v", count, total, time.Since(start)) log.Debug("finish to clear useless proxy statistics data")
} }
}() }()
} }
func (m *serverMetrics) clearUselessInfo() (int, int) { func (m *serverMetrics) clearUselessInfo() {
count := 0
total := 0
// To check if there are proxies that closed than 7 days and drop them. // To check if there are proxies that closed than 7 days and drop them.
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
total = len(m.info.ProxyStatistics)
for name, data := range m.info.ProxyStatistics { for name, data := range m.info.ProxyStatistics {
if !data.LastCloseTime.IsZero() && if !data.LastCloseTime.IsZero() &&
data.LastStartTime.Before(data.LastCloseTime) && data.LastStartTime.Before(data.LastCloseTime) &&
time.Since(data.LastCloseTime) > time.Duration(7*24)*time.Hour { time.Since(data.LastCloseTime) > time.Duration(7*24)*time.Hour {
delete(m.info.ProxyStatistics, name) delete(m.info.ProxyStatistics, name)
count++
log.Trace("clear proxy [%s]'s statistics data, lastCloseTime: [%s]", name, data.LastCloseTime.String()) log.Trace("clear proxy [%s]'s statistics data, lastCloseTime: [%s]", name, data.LastCloseTime.String())
} }
} }
return count, total
} }
func (m *serverMetrics) NewClient() { func (m *serverMetrics) NewClient() {

View File

@ -16,53 +16,54 @@ package msg
import ( import (
"net" "net"
"reflect"
) )
const ( const (
TypeLogin = 'o' TypeLogin = 'o'
TypeLoginResp = '1' TypeLoginResp = '1'
TypeNewProxy = 'p' TypeNewProxy = 'p'
TypeNewProxyResp = '2' TypeNewProxyResp = '2'
TypeCloseProxy = 'c' TypeCloseProxy = 'c'
TypeNewWorkConn = 'w' TypeNewWorkConn = 'w'
TypeReqWorkConn = 'r' TypeReqWorkConn = 'r'
TypeStartWorkConn = 's' TypeStartWorkConn = 's'
TypeNewVisitorConn = 'v' TypeNewVisitorConn = 'v'
TypeNewVisitorConnResp = '3' TypeNewVisitorConnResp = '3'
TypePing = 'h' TypePing = 'h'
TypePong = '4' TypePong = '4'
TypeUDPPacket = 'u' TypeUDPPacket = 'u'
TypeNatHoleVisitor = 'i' TypeNatHoleVisitor = 'i'
TypeNatHoleClient = 'n' TypeNatHoleClient = 'n'
TypeNatHoleResp = 'm' TypeNatHoleResp = 'm'
TypeNatHoleSid = '5' TypeNatHoleClientDetectOK = 'd'
TypeNatHoleReport = '6' TypeNatHoleSid = '5'
TypeNatHoleBinding = 'b'
TypeNatHoleBindingResp = '6'
) )
var msgTypeMap = map[byte]interface{}{ var msgTypeMap = map[byte]interface{}{
TypeLogin: Login{}, TypeLogin: Login{},
TypeLoginResp: LoginResp{}, TypeLoginResp: LoginResp{},
TypeNewProxy: NewProxy{}, TypeNewProxy: NewProxy{},
TypeNewProxyResp: NewProxyResp{}, TypeNewProxyResp: NewProxyResp{},
TypeCloseProxy: CloseProxy{}, TypeCloseProxy: CloseProxy{},
TypeNewWorkConn: NewWorkConn{}, TypeNewWorkConn: NewWorkConn{},
TypeReqWorkConn: ReqWorkConn{}, TypeReqWorkConn: ReqWorkConn{},
TypeStartWorkConn: StartWorkConn{}, TypeStartWorkConn: StartWorkConn{},
TypeNewVisitorConn: NewVisitorConn{}, TypeNewVisitorConn: NewVisitorConn{},
TypeNewVisitorConnResp: NewVisitorConnResp{}, TypeNewVisitorConnResp: NewVisitorConnResp{},
TypePing: Ping{}, TypePing: Ping{},
TypePong: Pong{}, TypePong: Pong{},
TypeUDPPacket: UDPPacket{}, TypeUDPPacket: UDPPacket{},
TypeNatHoleVisitor: NatHoleVisitor{}, TypeNatHoleVisitor: NatHoleVisitor{},
TypeNatHoleClient: NatHoleClient{}, TypeNatHoleClient: NatHoleClient{},
TypeNatHoleResp: NatHoleResp{}, TypeNatHoleResp: NatHoleResp{},
TypeNatHoleSid: NatHoleSid{}, TypeNatHoleClientDetectOK: NatHoleClientDetectOK{},
TypeNatHoleReport: NatHoleReport{}, TypeNatHoleSid: NatHoleSid{},
TypeNatHoleBinding: NatHoleBinding{},
TypeNatHoleBindingResp: NatHoleBindingResp{},
} }
var TypeNameNatHoleResp = reflect.TypeOf(&NatHoleResp{}).Elem().Name()
// When frpc start, client send this message to login to server. // When frpc start, client send this message to login to server.
type Login struct { type Login struct {
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
@ -80,9 +81,10 @@ type Login struct {
} }
type LoginResp struct { type LoginResp struct {
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
RunID string `json:"run_id,omitempty"` RunID string `json:"run_id,omitempty"`
Error string `json:"error,omitempty"` ServerUDPPort int `json:"server_udp_port,omitempty"`
Error string `json:"error,omitempty"`
} }
// When frpc login success, send this message to frps for running a new proxy. // When frpc login success, send this message to frps for running a new proxy.
@ -145,7 +147,6 @@ type StartWorkConn struct {
} }
type NewVisitorConn struct { type NewVisitorConn struct {
RunID string `json:"run_id,omitempty"`
ProxyName string `json:"proxy_name,omitempty"` ProxyName string `json:"proxy_name,omitempty"`
SignKey string `json:"sign_key,omitempty"` SignKey string `json:"sign_key,omitempty"`
Timestamp int64 `json:"timestamp,omitempty"` Timestamp int64 `json:"timestamp,omitempty"`
@ -174,58 +175,35 @@ type UDPPacket struct {
} }
type NatHoleVisitor struct { type NatHoleVisitor struct {
TransactionID string `json:"transaction_id,omitempty"` ProxyName string `json:"proxy_name,omitempty"`
ProxyName string `json:"proxy_name,omitempty"` SignKey string `json:"sign_key,omitempty"`
PreCheck bool `json:"pre_check,omitempty"` Timestamp int64 `json:"timestamp,omitempty"`
Protocol string `json:"protocol,omitempty"`
SignKey string `json:"sign_key,omitempty"`
Timestamp int64 `json:"timestamp,omitempty"`
MappedAddrs []string `json:"mapped_addrs,omitempty"`
AssistedAddrs []string `json:"assisted_addrs,omitempty"`
} }
type NatHoleClient struct { type NatHoleClient struct {
TransactionID string `json:"transaction_id,omitempty"` ProxyName string `json:"proxy_name,omitempty"`
ProxyName string `json:"proxy_name,omitempty"` Sid string `json:"sid,omitempty"`
Sid string `json:"sid,omitempty"`
MappedAddrs []string `json:"mapped_addrs,omitempty"`
AssistedAddrs []string `json:"assisted_addrs,omitempty"`
}
type PortsRange struct {
From int `json:"from,omitempty"`
To int `json:"to,omitempty"`
}
type NatHoleDetectBehavior struct {
Role string `json:"role,omitempty"` // sender or receiver
Mode int `json:"mode,omitempty"` // 0, 1, 2...
TTL int `json:"ttl,omitempty"`
SendDelayMs int `json:"send_delay_ms,omitempty"`
ReadTimeoutMs int `json:"read_timeout,omitempty"`
CandidatePorts []PortsRange `json:"candidate_ports,omitempty"`
SendRandomPorts int `json:"send_random_ports,omitempty"`
ListenRandomPorts int `json:"listen_random_ports,omitempty"`
} }
type NatHoleResp struct { type NatHoleResp struct {
TransactionID string `json:"transaction_id,omitempty"` Sid string `json:"sid,omitempty"`
Sid string `json:"sid,omitempty"` VisitorAddr string `json:"visitor_addr,omitempty"`
Protocol string `json:"protocol,omitempty"` ClientAddr string `json:"client_addr,omitempty"`
CandidateAddrs []string `json:"candidate_addrs,omitempty"` Error string `json:"error,omitempty"`
AssistedAddrs []string `json:"assisted_addrs,omitempty"`
DetectBehavior NatHoleDetectBehavior `json:"detect_behavior,omitempty"`
Error string `json:"error,omitempty"`
} }
type NatHoleClientDetectOK struct{}
type NatHoleSid struct { type NatHoleSid struct {
TransactionID string `json:"transaction_id,omitempty"` Sid string `json:"sid,omitempty"`
Sid string `json:"sid,omitempty"`
Response bool `json:"response,omitempty"`
Nonce string `json:"nonce,omitempty"`
} }
type NatHoleReport struct { type NatHoleBinding struct {
Sid string `json:"sid,omitempty"` TransactionID string `json:"transaction_id,omitempty"`
Success bool `json:"success,omitempty"` }
type NatHoleBindingResp struct {
TransactionID string `json:"transaction_id,omitempty"`
Address string `json:"address,omitempty"`
Error string `json:"error,omitempty"`
} }

View File

@ -1,328 +0,0 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package nathole
import (
"sync"
"time"
"github.com/samber/lo"
)
var (
// mode 0, both EasyNAT, PublicNetwork is always receiver
// sender | receiver, ttl 7
// receiver, ttl 7 | sender
// sender | receiver, ttl 4
// receiver, ttl 4 | sender
// sender | receiver
// receiver | sender
// sender, sendDelayMs 5000 | receiver
// sender, sendDelayMs 10000 | receiver
// receiver | sender, sendDelayMs 5000
// receiver | sender, sendDelayMs 10000
mode0Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 7}, RecommandBehavior{Role: DetectRoleSender}),
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 4}, RecommandBehavior{Role: DetectRoleSender}),
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver}, RecommandBehavior{Role: DetectRoleSender}),
lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 5000}, RecommandBehavior{Role: DetectRoleReceiver}),
lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 10000}, RecommandBehavior{Role: DetectRoleReceiver}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver}, RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 5000}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver}, RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 10000}),
}
// mode 1, HardNAT is sender, EasyNAT is receiver, port changes is regular
// sender | receiver, ttl 7, portsRangeNumber max 10
// sender, sendDelayMs 2000 | receiver, ttl 7, portsRangeNumber max 10
// sender | receiver, ttl 4, portsRangeNumber max 10
// sender, sendDelayMs 2000 | receiver, ttl 4, portsRangeNumber max 10
// sender | receiver, portsRangeNumber max 10
// sender, sendDelayMs 2000 | receiver, portsRangeNumber max 10
mode1Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 2000}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 2000}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleSender}, RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleSender, SendDelayMs: 2000}, RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}),
}
// mode 2, HardNAT is receiver, EasyNAT is sender
// sender, portsRandomNumber 1000, sendDelayMs 3000 | receiver, listen 256 ports, ttl 7
// sender, portsRandomNumber 1000, sendDelayMs 3000 | receiver, listen 256 ports, ttl 4
// sender, portsRandomNumber 1000, sendDelayMs 3000 | receiver, listen 256 ports
mode2Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{
lo.T2(
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 3000},
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 7},
),
lo.T2(
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 3000},
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 4},
),
lo.T2(
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 3000},
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256},
),
}
// mode 3, For HardNAT & HardNAT, both changes in the ports are regular
// sender, portsRangeNumber 10 | receiver, ttl 7, portsRangeNumber 10
// sender, portsRangeNumber 10 | receiver, ttl 4, portsRangeNumber 10
// sender, portsRangeNumber 10 | receiver, portsRangeNumber 10
// receiver, ttl 7, portsRangeNumber 10 | sender, portsRangeNumber 10
// receiver, ttl 4, portsRangeNumber 10 | sender, portsRangeNumber 10
// receiver, portsRangeNumber 10 | sender, portsRangeNumber 10
mode3Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{
lo.T2(RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 7, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver, TTL: 4, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}),
lo.T2(RecommandBehavior{Role: DetectRoleReceiver, PortsRangeNumber: 10}, RecommandBehavior{Role: DetectRoleSender, PortsRangeNumber: 10}),
}
// mode 4, Regular ports changes are usually the sender.
// sender, portsRandomNumber 1000, sendDelayMs: 2000 | receiver, listen 256 ports, ttl 7, portsRangeNumber 2
// sender, portsRandomNumber 1000, sendDelayMs: 2000 | receiver, listen 256 ports, ttl 4, portsRangeNumber 2
// sender, portsRandomNumber 1000, SendDelayMs: 2000 | receiver, listen 256 ports, portsRangeNumber 2
mode4Behaviors = []lo.Tuple2[RecommandBehavior, RecommandBehavior]{
lo.T2(
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 3000},
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 7, PortsRangeNumber: 2},
),
lo.T2(
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 3000},
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, TTL: 4, PortsRangeNumber: 2},
),
lo.T2(
RecommandBehavior{Role: DetectRoleSender, PortsRandomNumber: 1000, SendDelayMs: 3000},
RecommandBehavior{Role: DetectRoleReceiver, ListenRandomPorts: 256, PortsRangeNumber: 2},
),
}
)
func getBehaviorByMode(mode int) []lo.Tuple2[RecommandBehavior, RecommandBehavior] {
switch mode {
case 0:
return mode0Behaviors
case 1:
return mode1Behaviors
case 2:
return mode2Behaviors
case 3:
return mode3Behaviors
case 4:
return mode4Behaviors
}
// default
return mode0Behaviors
}
func getBehaviorByModeAndIndex(mode int, index int) (RecommandBehavior, RecommandBehavior) {
behaviors := getBehaviorByMode(mode)
if index >= len(behaviors) {
return RecommandBehavior{}, RecommandBehavior{}
}
return behaviors[index].A, behaviors[index].B
}
func getBehaviorScoresByMode(mode int, defaultScore int) []*BehaviorScore {
return getBehaviorScoresByMode2(mode, defaultScore, defaultScore)
}
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*BehaviorScore {
behaviors := getBehaviorByMode(mode)
scores := make([]*BehaviorScore, 0, len(behaviors))
for i := 0; i < len(behaviors); i++ {
score := receiverScore
if behaviors[i].A.Role == DetectRoleSender {
score = senderScore
}
scores = append(scores, &BehaviorScore{Mode: mode, Index: i, Score: score})
}
return scores
}
type RecommandBehavior struct {
Role string
TTL int
SendDelayMs int
PortsRangeNumber int
PortsRandomNumber int
ListenRandomPorts int
}
type MakeHoleRecords struct {
mu sync.Mutex
scores []*BehaviorScore
LastUpdateTime time.Time
}
func NewMakeHoleRecords(c, v *NatFeature) *MakeHoleRecords {
scores := []*BehaviorScore{}
easyCount, hardCount, portsChangedRegularCount := ClassifyFeatureCount([]*NatFeature{c, v})
appendMode0 := func() {
switch {
case c.PublicNetwork:
scores = append(scores, getBehaviorScoresByMode2(DetectMode0, 0, 1)...)
case v.PublicNetwork:
scores = append(scores, getBehaviorScoresByMode2(DetectMode0, 1, 0)...)
default:
scores = append(scores, getBehaviorScoresByMode(DetectMode0, 0)...)
}
}
switch {
case easyCount == 2:
appendMode0()
case hardCount == 1 && portsChangedRegularCount == 1:
scores = append(scores, getBehaviorScoresByMode(DetectMode1, 0)...)
scores = append(scores, getBehaviorScoresByMode(DetectMode2, 0)...)
appendMode0()
case hardCount == 1 && portsChangedRegularCount == 0:
scores = append(scores, getBehaviorScoresByMode(DetectMode2, 0)...)
scores = append(scores, getBehaviorScoresByMode(DetectMode1, 0)...)
appendMode0()
case hardCount == 2 && portsChangedRegularCount == 2:
scores = append(scores, getBehaviorScoresByMode(DetectMode3, 0)...)
scores = append(scores, getBehaviorScoresByMode(DetectMode4, 0)...)
case hardCount == 2 && portsChangedRegularCount == 1:
scores = append(scores, getBehaviorScoresByMode(DetectMode4, 0)...)
default:
// hard to make hole, just trying it out.
scores = append(scores, getBehaviorScoresByMode(DetectMode0, 1)...)
scores = append(scores, getBehaviorScoresByMode(DetectMode1, 1)...)
scores = append(scores, getBehaviorScoresByMode(DetectMode3, 1)...)
}
return &MakeHoleRecords{scores: scores, LastUpdateTime: time.Now()}
}
func (mhr *MakeHoleRecords) ReportSuccess(mode int, index int) {
mhr.mu.Lock()
defer mhr.mu.Unlock()
mhr.LastUpdateTime = time.Now()
for i := range mhr.scores {
score := mhr.scores[i]
if score.Mode != mode || score.Index != index {
continue
}
score.Score += 2
score.Score = lo.Min([]int{score.Score, 10})
return
}
}
func (mhr *MakeHoleRecords) Recommand() (mode, index int) {
mhr.mu.Lock()
defer mhr.mu.Unlock()
maxScore := lo.MaxBy(mhr.scores, func(item, max *BehaviorScore) bool {
return item.Score > max.Score
})
if maxScore == nil {
return 0, 0
}
maxScore.Score--
mhr.LastUpdateTime = time.Now()
return maxScore.Mode, maxScore.Index
}
type BehaviorScore struct {
Mode int
Index int
// between -10 and 10
Score int
}
type Analyzer struct {
// key is client ip + visitor ip
records map[string]*MakeHoleRecords
dataReserveDuration time.Duration
mu sync.Mutex
}
func NewAnalyzer(dataReserveDuration time.Duration) *Analyzer {
return &Analyzer{
records: make(map[string]*MakeHoleRecords),
dataReserveDuration: dataReserveDuration,
}
}
func (a *Analyzer) GetRecommandBehaviors(key string, c, v *NatFeature) (mode, index int, _ RecommandBehavior, _ RecommandBehavior) {
a.mu.Lock()
records, ok := a.records[key]
if !ok {
records = NewMakeHoleRecords(c, v)
a.records[key] = records
}
a.mu.Unlock()
mode, index = records.Recommand()
cBehavior, vBehavior := getBehaviorByModeAndIndex(mode, index)
switch mode {
case DetectMode1:
// HardNAT is always the sender
if c.NatType == EasyNAT {
cBehavior, vBehavior = vBehavior, cBehavior
}
case DetectMode2:
// HardNAT is always the receiver
if c.NatType == HardNAT {
cBehavior, vBehavior = vBehavior, cBehavior
}
case DetectMode4:
// Regular ports changes is always the sender
if !c.RegularPortsChange {
cBehavior, vBehavior = vBehavior, cBehavior
}
}
return mode, index, cBehavior, vBehavior
}
func (a *Analyzer) ReportSuccess(key string, mode, index int) {
a.mu.Lock()
records, ok := a.records[key]
a.mu.Unlock()
if !ok {
return
}
records.ReportSuccess(mode, index)
}
func (a *Analyzer) Clean() (int, int) {
now := time.Now()
total := 0
count := 0
// cleanup 10w records may take 5ms
a.mu.Lock()
defer a.mu.Unlock()
total = len(a.records)
// clean up records that have not been used for a period of time.
for key, records := range a.records {
if now.Sub(records.LastUpdateTime) > a.dataReserveDuration {
delete(a.records, key)
count++
}
}
return count, total
}

View File

@ -17,9 +17,6 @@ package nathole
import ( import (
"fmt" "fmt"
"net" "net"
"strconv"
"github.com/samber/lo"
) )
const ( const (
@ -32,97 +29,46 @@ const (
BehaviorBothChanged = "BehaviorBothChanged" BehaviorBothChanged = "BehaviorBothChanged"
) )
type NatFeature struct { // ClassifyNATType classify NAT type by given addresses.
NatType string func ClassifyNATType(addresses []string) (string, string, error) {
Behavior string
PortsDifference int
RegularPortsChange bool
PublicNetwork bool
}
func ClassifyNATFeature(addresses []string, localIPs []string) (*NatFeature, error) {
if len(addresses) <= 1 { if len(addresses) <= 1 {
return nil, fmt.Errorf("not enough addresses") return "", "", fmt.Errorf("not enough addresses")
} }
natFeature := &NatFeature{}
ipChanged := false ipChanged := false
portChanged := false portChanged := false
var baseIP, basePort string var baseIP, basePort string
var portMax, portMin int
for _, addr := range addresses { for _, addr := range addresses {
ip, port, err := net.SplitHostPort(addr) ip, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, err return "", "", err
} }
portNum, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
if lo.Contains(localIPs, ip) {
natFeature.PublicNetwork = true
}
if baseIP == "" { if baseIP == "" {
baseIP = ip baseIP = ip
basePort = port basePort = port
portMax = portNum
portMin = portNum
continue continue
} }
if portNum > portMax {
portMax = portNum
}
if portNum < portMin {
portMin = portNum
}
if baseIP != ip { if baseIP != ip {
ipChanged = true ipChanged = true
} }
if basePort != port { if basePort != port {
portChanged = true portChanged = true
} }
if ipChanged && portChanged {
break
}
} }
switch { switch {
case ipChanged && portChanged: case ipChanged && portChanged:
natFeature.NatType = HardNAT return HardNAT, BehaviorBothChanged, nil
natFeature.Behavior = BehaviorBothChanged
case ipChanged: case ipChanged:
natFeature.NatType = HardNAT return HardNAT, BehaviorIPChanged, nil
natFeature.Behavior = BehaviorIPChanged
case portChanged: case portChanged:
natFeature.NatType = HardNAT return HardNAT, BehaviorPortChanged, nil
natFeature.Behavior = BehaviorPortChanged
default: default:
natFeature.NatType = EasyNAT return EasyNAT, BehaviorNoChange, nil
natFeature.Behavior = BehaviorNoChange
} }
if natFeature.Behavior == BehaviorPortChanged {
natFeature.PortsDifference = portMax - portMin
if natFeature.PortsDifference <= 5 && natFeature.PortsDifference >= 1 {
natFeature.RegularPortsChange = true
}
}
return natFeature, nil
}
func ClassifyFeatureCount(features []*NatFeature) (int, int, int) {
easyCount := 0
hardCount := 0
// for HardNAT
portsChangedRegularCount := 0
for _, feature := range features {
if feature.NatType == EasyNAT {
easyCount++
continue
}
hardCount++
if feature.RegularPortsChange {
portsChangedRegularCount++
}
}
return easyCount, hardCount, portsChangedRegularCount
} }

View File

@ -1,382 +0,0 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package nathole
import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"net"
"strconv"
"sync"
"time"
"github.com/fatedier/golib/errors"
"github.com/samber/lo"
"golang.org/x/sync/errgroup"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/log"
"github.com/fatedier/frp/pkg/util/util"
)
// NatHoleTimeout seconds.
var NatHoleTimeout int64 = 10
func NewTransactionID() string {
id, _ := util.RandID()
return fmt.Sprintf("%d%s", time.Now().Unix(), id)
}
type ClientCfg struct {
name string
sk string
sidCh chan string
}
type Session struct {
sid string
analysisKey string
recommandMode int
recommandIndex int
visitorMsg *msg.NatHoleVisitor
visitorTransporter transport.MessageTransporter
vResp *msg.NatHoleResp
vNatFeature *NatFeature
vBehavior RecommandBehavior
clientMsg *msg.NatHoleClient
clientTransporter transport.MessageTransporter
cResp *msg.NatHoleResp
cNatFeature *NatFeature
cBehavior RecommandBehavior
notifyCh chan struct{}
}
func (s *Session) genAnalysisKey() {
hash := md5.New()
vIPs := lo.Uniq(parseIPs(s.visitorMsg.MappedAddrs))
if len(vIPs) > 0 {
hash.Write([]byte(vIPs[0]))
}
hash.Write([]byte(s.vNatFeature.NatType))
hash.Write([]byte(s.vNatFeature.Behavior))
hash.Write([]byte(strconv.FormatBool(s.vNatFeature.RegularPortsChange)))
cIPs := lo.Uniq(parseIPs(s.clientMsg.MappedAddrs))
if len(cIPs) > 0 {
hash.Write([]byte(cIPs[0]))
}
hash.Write([]byte(s.cNatFeature.NatType))
hash.Write([]byte(s.cNatFeature.Behavior))
hash.Write([]byte(strconv.FormatBool(s.cNatFeature.RegularPortsChange)))
s.analysisKey = hex.EncodeToString(hash.Sum(nil))
}
type Controller struct {
clientCfgs map[string]*ClientCfg
sessions map[string]*Session
analyzer *Analyzer
mu sync.RWMutex
}
func NewController(analysisDataReserveDuration time.Duration) (*Controller, error) {
return &Controller{
clientCfgs: make(map[string]*ClientCfg),
sessions: make(map[string]*Session),
analyzer: NewAnalyzer(analysisDataReserveDuration),
}, nil
}
func (c *Controller) CleanWorker(ctx context.Context) {
ticker := time.NewTicker(time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C:
start := time.Now()
count, total := c.analyzer.Clean()
log.Trace("clean %d/%d nathole analysis data, cost %v", count, total, time.Since(start))
case <-ctx.Done():
return
}
}
}
func (c *Controller) ListenClient(name string, sk string) chan string {
cfg := &ClientCfg{
name: name,
sk: sk,
sidCh: make(chan string),
}
c.mu.Lock()
defer c.mu.Unlock()
c.clientCfgs[name] = cfg
return cfg.sidCh
}
func (c *Controller) CloseClient(name string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.clientCfgs, name)
}
func (c *Controller) GenSid() string {
t := time.Now().Unix()
id, _ := util.RandID()
return fmt.Sprintf("%d%s", t, id)
}
func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.MessageTransporter) {
if m.PreCheck {
_, ok := c.clientCfgs[m.ProxyName]
if !ok {
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName)))
} else {
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, ""))
}
return
}
sid := c.GenSid()
session := &Session{
sid: sid,
visitorMsg: m,
visitorTransporter: transporter,
notifyCh: make(chan struct{}, 1),
}
var (
clientCfg *ClientCfg
ok bool
)
err := func() error {
c.mu.Lock()
defer c.mu.Unlock()
clientCfg, ok = c.clientCfgs[m.ProxyName]
if !ok {
return fmt.Errorf("xtcp server for [%s] doesn't exist", m.ProxyName)
}
if !util.ConstantTimeEqString(m.SignKey, util.GetAuthKey(clientCfg.sk, m.Timestamp)) {
return fmt.Errorf("xtcp connection of [%s] auth failed", m.ProxyName)
}
c.sessions[sid] = session
return nil
}()
if err != nil {
log.Warn("handle visitorMsg error: %v", err)
_ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, err.Error()))
return
}
log.Trace("handle visitor message, sid [%s]", sid)
defer func() {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.sessions, sid)
}()
if err := errors.PanicToError(func() {
clientCfg.sidCh <- sid
}); err != nil {
return
}
// wait for NatHoleClient message
select {
case <-session.notifyCh:
case <-time.After(time.Duration(NatHoleTimeout) * time.Second):
log.Debug("wait for NatHoleClient message timeout, sid [%s]", sid)
return
}
// Make hole-punching decisions based on the NAT information of the client and visitor.
vResp, cResp, err := c.analysis(session)
if err != nil {
log.Debug("sid [%s] analysis error: %v", err)
vResp = c.GenNatHoleResponse(session.visitorMsg.TransactionID, nil, err.Error())
cResp = c.GenNatHoleResponse(session.clientMsg.TransactionID, nil, err.Error())
}
session.cResp = cResp
session.vResp = vResp
// send response to visitor and client
var g errgroup.Group
g.Go(func() error {
// if it's sender, wait for a while to make sure the client has send the detect messages
if vResp.DetectBehavior.Role == "sender" {
time.Sleep(1 * time.Second)
}
_ = session.visitorTransporter.Send(vResp)
return nil
})
g.Go(func() error {
// if it's sender, wait for a while to make sure the client has send the detect messages
if cResp.DetectBehavior.Role == "sender" {
time.Sleep(1 * time.Second)
}
_ = session.clientTransporter.Send(cResp)
return nil
})
_ = g.Wait()
time.Sleep(time.Duration(cResp.DetectBehavior.ReadTimeoutMs+30000) * time.Millisecond)
}
func (c *Controller) HandleClient(m *msg.NatHoleClient, transporter transport.MessageTransporter) {
c.mu.RLock()
session, ok := c.sessions[m.Sid]
c.mu.RUnlock()
if !ok {
return
}
log.Trace("handle client message, sid [%s]", session.sid)
session.clientMsg = m
session.clientTransporter = transporter
select {
case session.notifyCh <- struct{}{}:
default:
}
}
func (c *Controller) HandleReport(m *msg.NatHoleReport) {
c.mu.RLock()
session, ok := c.sessions[m.Sid]
c.mu.RUnlock()
if !ok {
log.Trace("sid [%s] report make hole success: %v, but session not found", m.Sid, m.Success)
return
}
if m.Success {
c.analyzer.ReportSuccess(session.analysisKey, session.recommandMode, session.recommandIndex)
}
log.Info("sid [%s] report make hole success: %v, mode %v, index %v",
m.Sid, m.Success, session.recommandMode, session.recommandIndex)
}
func (c *Controller) GenNatHoleResponse(transactionID string, session *Session, errInfo string) *msg.NatHoleResp {
var sid string
if session != nil {
sid = session.sid
}
return &msg.NatHoleResp{
TransactionID: transactionID,
Sid: sid,
Error: errInfo,
}
}
// analysis analyzes the NAT type and behavior of the visitor and client, then makes hole-punching decisions.
// return the response to the visitor and client.
func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleResp, error) {
cm := session.clientMsg
vm := session.visitorMsg
cNatFeature, err := ClassifyNATFeature(cm.MappedAddrs, parseIPs(cm.AssistedAddrs))
if err != nil {
return nil, nil, fmt.Errorf("classify client nat feature error: %v", err)
}
vNatFeature, err := ClassifyNATFeature(vm.MappedAddrs, parseIPs(vm.AssistedAddrs))
if err != nil {
return nil, nil, fmt.Errorf("classify visitor nat feature error: %v", err)
}
session.cNatFeature = cNatFeature
session.vNatFeature = vNatFeature
session.genAnalysisKey()
mode, index, cBehavior, vBehavior := c.analyzer.GetRecommandBehaviors(session.analysisKey, cNatFeature, vNatFeature)
session.recommandMode = mode
session.recommandIndex = index
session.cBehavior = cBehavior
session.vBehavior = vBehavior
timeoutMs := lo.Max([]int{cBehavior.SendDelayMs, vBehavior.SendDelayMs}) + 5000
if cBehavior.ListenRandomPorts > 0 || vBehavior.ListenRandomPorts > 0 {
timeoutMs += 30000
}
protocol := vm.Protocol
vResp := &msg.NatHoleResp{
TransactionID: vm.TransactionID,
Sid: session.sid,
Protocol: protocol,
CandidateAddrs: lo.Uniq(cm.MappedAddrs),
AssistedAddrs: lo.Uniq(cm.AssistedAddrs),
DetectBehavior: msg.NatHoleDetectBehavior{
Mode: mode,
Role: vBehavior.Role,
TTL: vBehavior.TTL,
SendDelayMs: vBehavior.SendDelayMs,
ReadTimeoutMs: timeoutMs - vBehavior.SendDelayMs,
SendRandomPorts: vBehavior.PortsRandomNumber,
ListenRandomPorts: vBehavior.ListenRandomPorts,
CandidatePorts: getRangePorts(cm.MappedAddrs, cNatFeature.PortsDifference, vBehavior.PortsRangeNumber),
},
}
cResp := &msg.NatHoleResp{
TransactionID: cm.TransactionID,
Sid: session.sid,
Protocol: protocol,
CandidateAddrs: lo.Uniq(vm.MappedAddrs),
AssistedAddrs: lo.Uniq(vm.AssistedAddrs),
DetectBehavior: msg.NatHoleDetectBehavior{
Mode: mode,
Role: cBehavior.Role,
TTL: cBehavior.TTL,
SendDelayMs: cBehavior.SendDelayMs,
ReadTimeoutMs: timeoutMs - cBehavior.SendDelayMs,
SendRandomPorts: cBehavior.PortsRandomNumber,
ListenRandomPorts: cBehavior.ListenRandomPorts,
CandidatePorts: getRangePorts(vm.MappedAddrs, vNatFeature.PortsDifference, cBehavior.PortsRangeNumber),
},
}
log.Debug("sid [%s] visitor nat: %+v, candidateAddrs: %v; client nat: %+v, candidateAddrs: %v, protocol: %s",
session.sid, *vNatFeature, vm.MappedAddrs, *cNatFeature, cm.MappedAddrs, protocol)
log.Debug("sid [%s] visitor detect behavior: %+v", session.sid, vResp.DetectBehavior)
log.Debug("sid [%s] client detect behavior: %+v", session.sid, cResp.DetectBehavior)
return vResp, cResp, nil
}
func getRangePorts(addrs []string, difference, maxNumber int) []msg.PortsRange {
if maxNumber <= 0 {
return nil
}
addr, err := lo.Last(addrs)
if err != nil {
return nil
}
var ports []msg.PortsRange
_, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil
}
port, err := strconv.Atoi(portStr)
if err != nil {
return nil
}
ports = append(ports, msg.PortsRange{
From: lo.Max([]int{port - difference - 5, port - maxNumber, 1}),
To: lo.Min([]int{port + difference + 5, port + maxNumber, 65535}),
})
return ports
}

View File

@ -20,6 +20,8 @@ import (
"time" "time"
"github.com/pion/stun" "github.com/pion/stun"
"github.com/fatedier/frp/pkg/msg"
) )
var responseTimeout = 3 * time.Second var responseTimeout = 3 * time.Second
@ -29,27 +31,35 @@ type Message struct {
Addr string Addr string
} }
// If the localAddr is empty, it will listen on a random port. func Discover(serverAddress string, stunServers []string, key []byte) ([]string, error) {
func Discover(stunServers []string, localAddr string) ([]string, net.Addr, error) {
// create a discoverConn and get response from messageChan // create a discoverConn and get response from messageChan
discoverConn, err := listen(localAddr) discoverConn, err := listen()
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
defer discoverConn.Close() defer discoverConn.Close()
go discoverConn.readLoop() go discoverConn.readLoop()
addresses := make([]string, 0, len(stunServers)) addresses := make([]string, 0, len(stunServers)+1)
if serverAddress != "" {
// get external address from frp server
externalAddr, err := discoverConn.discoverFromServer(serverAddress, key)
if err != nil {
return nil, err
}
addresses = append(addresses, externalAddr)
}
for _, addr := range stunServers { for _, addr := range stunServers {
// get external address from stun server // get external address from stun server
externalAddrs, err := discoverConn.discoverFromStunServer(addr) externalAddrs, err := discoverConn.discoverFromStunServer(addr)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
addresses = append(addresses, externalAddrs...) addresses = append(addresses, externalAddrs...)
} }
return addresses, discoverConn.localAddr, nil return addresses, nil
} }
type stunResponse struct { type stunResponse struct {
@ -64,16 +74,8 @@ type discoverConn struct {
messageChan chan *Message messageChan chan *Message
} }
func listen(localAddr string) (*discoverConn, error) { func listen() (*discoverConn, error) {
var local *net.UDPAddr conn, err := net.ListenUDP("udp4", nil)
if localAddr != "" {
addr, err := net.ResolveUDPAddr("udp4", localAddr)
if err != nil {
return nil, err
}
local = addr
}
conn, err := net.ListenUDP("udp4", local)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -157,6 +159,43 @@ func (c *discoverConn) doSTUNRequest(addr string) (*stunResponse, error) {
return resp, nil return resp, nil
} }
func (c *discoverConn) discoverFromServer(serverAddress string, key []byte) (string, error) {
addr, err := net.ResolveUDPAddr("udp4", serverAddress)
if err != nil {
return "", err
}
m := &msg.NatHoleBinding{
TransactionID: NewTransactionID(),
}
buf, err := EncodeMessage(m, key)
if err != nil {
return "", err
}
if _, err := c.conn.WriteTo(buf, addr); err != nil {
return "", err
}
var respMsg msg.NatHoleBindingResp
select {
case rawMsg := <-c.messageChan:
if err := DecodeMessageInto(rawMsg.Body, key, &respMsg); err != nil {
return "", err
}
case <-time.After(responseTimeout):
return "", fmt.Errorf("wait response from frp server timeout")
}
if respMsg.TransactionID == "" {
return "", fmt.Errorf("error format: no transaction id found")
}
if respMsg.Error != "" {
return "", fmt.Errorf("get externalAddr from frp server error: %s", respMsg.Error)
}
return respMsg.Address, nil
}
func (c *discoverConn) discoverFromStunServer(addr string) ([]string, error) { func (c *discoverConn) discoverFromStunServer(addr string) ([]string, error) {
resp, err := c.doSTUNRequest(addr) resp, err := c.doSTUNRequest(addr)
if err != nil { if err != nil {

View File

@ -15,426 +15,249 @@
package nathole package nathole
import ( import (
"context" "bytes"
"fmt" "fmt"
"math/rand"
"net" "net"
"strconv" "sync"
"strings"
"time" "time"
"github.com/fatedier/golib/crypto"
"github.com/fatedier/golib/errors"
"github.com/fatedier/golib/pool" "github.com/fatedier/golib/pool"
"github.com/samber/lo"
"golang.org/x/net/ipv4"
"k8s.io/apimachinery/pkg/util/sets"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/log"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/util"
) )
var ( // NatHoleTimeout seconds.
// mode 0: simple detect mode, usually for both EasyNAT or HardNAT & EasyNAT(Public Network) var NatHoleTimeout int64 = 10
// a. receiver sends detect message with low TTL
// b. sender sends normal detect message to receiver
// c. receiver receives detect message and sends back a message to sender
//
// mode 1: For HardNAT & EasyNAT, send detect messages to multiple guessed ports.
// Usually applicable to scenarios where port changes are regular.
// Most of the steps are the same as mode 0, but EasyNAT is fixed as the receiver and will send detect messages
// with low TTL to multiple guessed ports of the sender.
//
// mode 2: For HardNAT & EasyNAT, ports changes are not regular.
// a. HardNAT machine will listen on multiple ports and send detect messages with low TTL to EasyNAT machine
// b. EasyNAT machine will send detect messages to random ports of HardNAT machine.
//
// mode 3: For HardNAT & HardNAT, both changes in the ports are regular.
// Most of the steps are the same as mode 1, but the sender also needs to send detect messages to multiple guessed
// ports of the receiver.
//
// mode 4: For HardNAT & HardNAT, one of the changes in the ports is regular.
// Regular port changes are usually on the sender side.
// a. Receiver listens on multiple ports and sends detect messages with low TTL to the sender's guessed range ports.
// b. Sender sends detect messages to random ports of the receiver.
SupportedModes = []int{DetectMode0, DetectMode1, DetectMode2, DetectMode3, DetectMode4}
SupportedRoles = []string{DetectRoleSender, DetectRoleReceiver}
DetectMode0 = 0 func NewTransactionID() string {
DetectMode1 = 1 id, _ := util.RandID()
DetectMode2 = 2 return fmt.Sprintf("%d%s", time.Now().Unix(), id)
DetectMode3 = 3
DetectMode4 = 4
DetectRoleSender = "sender"
DetectRoleReceiver = "receiver"
)
type PrepareResult struct {
Addrs []string
AssistedAddrs []string
ListenConn *net.UDPConn
NatType string
Behavior string
} }
// PreCheck is used to check if the proxy is ready for penetration. type SidRequest struct {
// Call this function before calling Prepare to avoid unnecessary preparation work. Sid string
func PreCheck( NotifyCh chan struct{}
ctx context.Context, transporter transport.MessageTransporter,
proxyName string, timeout time.Duration,
) error {
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
var natHoleRespMsg *msg.NatHoleResp
transactionID := NewTransactionID()
m, err := transporter.Do(timeoutCtx, &msg.NatHoleVisitor{
TransactionID: transactionID,
ProxyName: proxyName,
PreCheck: true,
}, transactionID, msg.TypeNameNatHoleResp)
if err != nil {
return fmt.Errorf("get natHoleRespMsg error: %v", err)
}
mm, ok := m.(*msg.NatHoleResp)
if !ok {
return fmt.Errorf("get natHoleRespMsg error: invalid message type")
}
natHoleRespMsg = mm
if natHoleRespMsg.Error != "" {
return fmt.Errorf("%s", natHoleRespMsg.Error)
}
return nil
} }
// Prepare is used to do some preparation work before penetration. type Controller struct {
func Prepare(stunServers []string) (*PrepareResult, error) { listener *net.UDPConn
// discover for Nat type
addrs, localAddr, err := Discover(stunServers, "")
if err != nil {
return nil, fmt.Errorf("discover error: %v", err)
}
if len(addrs) < 2 {
return nil, fmt.Errorf("discover error: not enough addresses")
}
localIPs, _ := ListLocalIPsForNatHole(10) clientCfgs map[string]*ClientCfg
natFeature, err := ClassifyNATFeature(addrs, localIPs) sessions map[string]*Session
if err != nil {
return nil, fmt.Errorf("classify nat feature error: %v", err)
}
laddr, err := net.ResolveUDPAddr("udp4", localAddr.String()) encryptionKey []byte
if err != nil { mu sync.RWMutex
return nil, fmt.Errorf("resolve local udp addr error: %v", err)
}
listenConn, err := net.ListenUDP("udp4", laddr)
if err != nil {
return nil, fmt.Errorf("listen local udp addr error: %v", err)
}
assistedAddrs := make([]string, 0, len(localIPs))
for _, ip := range localIPs {
assistedAddrs = append(assistedAddrs, net.JoinHostPort(ip, strconv.Itoa(laddr.Port)))
}
return &PrepareResult{
Addrs: addrs,
AssistedAddrs: assistedAddrs,
ListenConn: listenConn,
NatType: natFeature.NatType,
Behavior: natFeature.Behavior,
}, nil
} }
// ExchangeInfo is used to exchange information between client and visitor. func NewController(udpBindAddr string, encryptionKey []byte) (nc *Controller, err error) {
// 1. Send input message to server by msgTransporter. addr, err := net.ResolveUDPAddr("udp", udpBindAddr)
// 2. Server will gather information from client and visitor and analyze it. Then send back a NatHoleResp message to them to tell them how to do next.
// 3. Receive NatHoleResp message from server.
func ExchangeInfo(
ctx context.Context, transporter transport.MessageTransporter,
laneKey string, m msg.Message, timeout time.Duration,
) (*msg.NatHoleResp, error) {
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
var natHoleRespMsg *msg.NatHoleResp
m, err := transporter.Do(timeoutCtx, m, laneKey, msg.TypeNameNatHoleResp)
if err != nil { if err != nil {
return nil, fmt.Errorf("get natHoleRespMsg error: %v", err) return nil, err
} }
mm, ok := m.(*msg.NatHoleResp) lconn, err := net.ListenUDP("udp", addr)
if !ok { if err != nil {
return nil, fmt.Errorf("get natHoleRespMsg error: invalid message type") return nil, err
} }
natHoleRespMsg = mm nc = &Controller{
listener: lconn,
if natHoleRespMsg.Error != "" { clientCfgs: make(map[string]*ClientCfg),
return nil, fmt.Errorf("natHoleRespMsg get error info: %s", natHoleRespMsg.Error) sessions: make(map[string]*Session),
encryptionKey: encryptionKey,
} }
if len(natHoleRespMsg.CandidateAddrs) == 0 { return nc, nil
return nil, fmt.Errorf("natHoleRespMsg get empty candidate addresses")
}
return natHoleRespMsg, nil
} }
// MakeHole is used to make a NAT hole between client and visitor. func (nc *Controller) ListenClient(name string, sk string) (sidCh chan *SidRequest) {
func MakeHole(ctx context.Context, listenConn *net.UDPConn, m *msg.NatHoleResp, key []byte) (*net.UDPConn, *net.UDPAddr, error) { clientCfg := &ClientCfg{
xl := xlog.FromContextSafe(ctx) Name: name,
transactionID := NewTransactionID() Sk: sk,
sendToRangePortsFunc := func(conn *net.UDPConn, addr string) error { SidCh: make(chan *SidRequest),
return sendSidMessage(ctx, conn, m.Sid, transactionID, addr, key, m.DetectBehavior.TTL)
}
listenConns := []*net.UDPConn{listenConn}
var detectAddrs []string
if m.DetectBehavior.Role == DetectRoleSender {
// sender
if m.DetectBehavior.SendDelayMs > 0 {
time.Sleep(time.Duration(m.DetectBehavior.SendDelayMs) * time.Millisecond)
}
detectAddrs = m.AssistedAddrs
detectAddrs = append(detectAddrs, m.CandidateAddrs...)
} else {
// receiver
if len(m.DetectBehavior.CandidatePorts) == 0 {
detectAddrs = m.CandidateAddrs
}
if m.DetectBehavior.ListenRandomPorts > 0 {
for i := 0; i < m.DetectBehavior.ListenRandomPorts; i++ {
tmpConn, err := net.ListenUDP("udp4", nil)
if err != nil {
xl.Warn("listen random udp addr error: %v", err)
continue
}
listenConns = append(listenConns, tmpConn)
}
}
}
detectAddrs = lo.Uniq(detectAddrs)
for _, detectAddr := range detectAddrs {
for _, conn := range listenConns {
if err := sendSidMessage(ctx, conn, m.Sid, transactionID, detectAddr, key, m.DetectBehavior.TTL); err != nil {
xl.Trace("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err)
}
}
}
if len(m.DetectBehavior.CandidatePorts) > 0 {
for _, conn := range listenConns {
sendSidMessageToRangePorts(ctx, conn, m.CandidateAddrs, m.DetectBehavior.CandidatePorts, sendToRangePortsFunc)
}
}
if m.DetectBehavior.SendRandomPorts > 0 {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
for i := range listenConns {
go sendSidMessageToRandomPorts(ctx, listenConns[i], m.CandidateAddrs, m.DetectBehavior.SendRandomPorts, sendToRangePortsFunc)
}
}
timeout := 5 * time.Second
if m.DetectBehavior.ReadTimeoutMs > 0 {
timeout = time.Duration(m.DetectBehavior.ReadTimeoutMs) * time.Millisecond
}
if len(listenConns) == 1 {
raddr, err := waitDetectMessage(ctx, listenConns[0], m.Sid, key, timeout, m.DetectBehavior.Role)
if err != nil {
return nil, nil, fmt.Errorf("wait detect message error: %v", err)
}
return listenConns[0], raddr, nil
}
type result struct {
lConn *net.UDPConn
raddr *net.UDPAddr
}
resultCh := make(chan result)
for _, conn := range listenConns {
go func(lConn *net.UDPConn) {
addr, err := waitDetectMessage(ctx, lConn, m.Sid, key, timeout, m.DetectBehavior.Role)
if err != nil {
lConn.Close()
return
}
select {
case resultCh <- result{lConn: lConn, raddr: addr}:
default:
lConn.Close()
}
}(conn)
}
select {
case result := <-resultCh:
return result.lConn, result.raddr, nil
case <-time.After(timeout):
return nil, nil, fmt.Errorf("wait detect message timeout")
case <-ctx.Done():
return nil, nil, fmt.Errorf("wait detect message canceled")
} }
nc.mu.Lock()
nc.clientCfgs[name] = clientCfg
nc.mu.Unlock()
return clientCfg.SidCh
} }
func waitDetectMessage( func (nc *Controller) CloseClient(name string) {
ctx context.Context, conn *net.UDPConn, sid string, key []byte, nc.mu.Lock()
timeout time.Duration, role string, defer nc.mu.Unlock()
) (*net.UDPAddr, error) { delete(nc.clientCfgs, name)
xl := xlog.FromContextSafe(ctx) }
func (nc *Controller) Run() {
for { for {
buf := pool.GetBuf(1024) buf := pool.GetBuf(1024)
_ = conn.SetReadDeadline(time.Now().Add(timeout)) n, raddr, err := nc.listener.ReadFromUDP(buf)
n, raddr, err := conn.ReadFromUDP(buf)
_ = conn.SetReadDeadline(time.Time{})
if err != nil { if err != nil {
return nil, err log.Warn("nat hole listener read from udp error: %v", err)
return
} }
xl.Debug("get udp message local %s, from %s", conn.LocalAddr(), raddr) plain, err := crypto.Decode(buf[:n], nc.encryptionKey)
var m msg.NatHoleSid if err != nil {
if err := DecodeMessageInto(buf[:n], key, &m); err != nil { log.Warn("nathole listener decode from %s error: %v", raddr.String(), err)
xl.Warn("decode sid message error: %v", err) continue
}
rawMsg, err := msg.ReadMsg(bytes.NewReader(plain))
if err != nil {
log.Warn("read nat hole message error: %v", err)
continue
}
switch m := rawMsg.(type) {
case *msg.NatHoleBinding:
go nc.HandleBinding(m, raddr)
case *msg.NatHoleVisitor:
go nc.HandleVisitor(m, raddr)
case *msg.NatHoleClient:
go nc.HandleClient(m, raddr)
default:
log.Trace("unknown nat hole message type")
continue continue
} }
pool.PutBuf(buf) pool.PutBuf(buf)
if m.Sid != sid {
xl.Warn("get sid message with wrong sid: %s, expect: %s", m.Sid, sid)
continue
}
if !m.Response {
// only wait for response messages if we are a sender
if role == DetectRoleSender {
continue
}
m.Response = true
buf2, err := EncodeMessage(&m, key)
if err != nil {
xl.Warn("encode sid message error: %v", err)
continue
}
_, _ = conn.WriteToUDP(buf2, raddr)
}
return raddr, nil
} }
} }
func sendSidMessage( func (nc *Controller) GenSid() string {
ctx context.Context, conn *net.UDPConn, t := time.Now().Unix()
sid string, transactionID string, addr string, key []byte, ttl int, id, _ := util.RandID()
) error { return fmt.Sprintf("%d%s", t, id)
xl := xlog.FromContextSafe(ctx) }
ttlStr := ""
if ttl > 0 { func (nc *Controller) HandleBinding(m *msg.NatHoleBinding, raddr *net.UDPAddr) {
ttlStr = fmt.Sprintf(" with ttl %d", ttl) log.Trace("handle binding message from %s", raddr.String())
resp := &msg.NatHoleBindingResp{
TransactionID: m.TransactionID,
Address: raddr.String(),
} }
xl.Trace("send sid message from %s to %s%s", conn.LocalAddr(), addr, ttlStr) plain, err := msg.Pack(resp)
raddr, err := net.ResolveUDPAddr("udp4", addr)
if err != nil { if err != nil {
return err log.Error("pack nat hole binding response error: %v", err)
return
} }
if transactionID == "" { buf, err := crypto.Encode(plain, nc.encryptionKey)
transactionID = NewTransactionID()
}
m := &msg.NatHoleSid{
TransactionID: transactionID,
Sid: sid,
Response: false,
Nonce: strings.Repeat("0", rand.Intn(20)),
}
buf, err := EncodeMessage(m, key)
if err != nil { if err != nil {
return err log.Error("encode nat hole binding response error: %v", err)
return
} }
if ttl > 0 { _, err = nc.listener.WriteToUDP(buf, raddr)
uConn := ipv4.NewConn(conn) if err != nil {
original, err := uConn.TTL() log.Error("write nat hole binding response to %s error: %v", raddr.String(), err)
if err != nil { return
xl.Trace("get ttl error %v", err)
return err
}
xl.Trace("original ttl %d", original)
err = uConn.SetTTL(ttl)
if err != nil {
xl.Trace("set ttl error %v", err)
} else {
defer func() {
_ = uConn.SetTTL(original)
}()
}
}
if _, err := conn.WriteToUDP(buf, raddr); err != nil {
return err
}
return nil
}
func sendSidMessageToRangePorts(
ctx context.Context, conn *net.UDPConn, addrs []string, ports []msg.PortsRange,
sendFunc func(*net.UDPConn, string) error,
) {
xl := xlog.FromContextSafe(ctx)
for _, ip := range lo.Uniq(parseIPs(addrs)) {
for _, portsRange := range ports {
for i := portsRange.From; i <= portsRange.To; i++ {
detectAddr := net.JoinHostPort(ip, strconv.Itoa(i))
if err := sendFunc(conn, detectAddr); err != nil {
xl.Trace("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err)
}
time.Sleep(2 * time.Millisecond)
}
}
} }
} }
func sendSidMessageToRandomPorts( func (nc *Controller) HandleVisitor(m *msg.NatHoleVisitor, raddr *net.UDPAddr) {
ctx context.Context, conn *net.UDPConn, addrs []string, count int, sid := nc.GenSid()
sendFunc func(*net.UDPConn, string) error, session := &Session{
) { Sid: sid,
xl := xlog.FromContextSafe(ctx) VisitorAddr: raddr,
used := sets.New[int]() NotifyCh: make(chan struct{}),
getUnusedPort := func() int { }
for i := 0; i < 10; i++ { nc.mu.Lock()
port := rand.Intn(65535-1024) + 1024 clientCfg, ok := nc.clientCfgs[m.ProxyName]
if !used.Has(port) { if !ok {
used.Insert(port) nc.mu.Unlock()
return port errInfo := fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName)
} log.Debug(errInfo)
} _, _ = nc.listener.WriteToUDP(nc.GenNatHoleResponse(nil, errInfo), raddr)
return 0 return
}
if m.SignKey != util.GetAuthKey(clientCfg.Sk, m.Timestamp) {
nc.mu.Unlock()
errInfo := fmt.Sprintf("xtcp connection of [%s] auth failed", m.ProxyName)
log.Debug(errInfo)
_, _ = nc.listener.WriteToUDP(nc.GenNatHoleResponse(nil, errInfo), raddr)
return
} }
for i := 0; i < count; i++ { nc.sessions[sid] = session
select { nc.mu.Unlock()
case <-ctx.Done(): log.Trace("handle visitor message, sid [%s]", sid)
return
default:
}
port := getUnusedPort() defer func() {
if port == 0 { nc.mu.Lock()
continue delete(nc.sessions, sid)
} nc.mu.Unlock()
}()
for _, ip := range lo.Uniq(parseIPs(addrs)) { err := errors.PanicToError(func() {
detectAddr := net.JoinHostPort(ip, strconv.Itoa(port)) clientCfg.SidCh <- &SidRequest{
if err := sendFunc(conn, detectAddr); err != nil { Sid: sid,
xl.Trace("send sid message from %s to %s error: %v", conn.LocalAddr(), detectAddr, err) NotifyCh: session.NotifyCh,
}
time.Sleep(time.Millisecond * 15)
} }
})
if err != nil {
return
}
// Wait client connections.
select {
case <-session.NotifyCh:
resp := nc.GenNatHoleResponse(session, "")
log.Trace("send nat hole response to visitor")
_, _ = nc.listener.WriteToUDP(resp, raddr)
case <-time.After(time.Duration(NatHoleTimeout) * time.Second):
return
} }
} }
func parseIPs(addrs []string) []string { func (nc *Controller) HandleClient(m *msg.NatHoleClient, raddr *net.UDPAddr) {
var ips []string nc.mu.RLock()
for _, addr := range addrs { session, ok := nc.sessions[m.Sid]
if ip, _, err := net.SplitHostPort(addr); err == nil { nc.mu.RUnlock()
ips = append(ips, ip) if !ok {
} return
} }
return ips log.Trace("handle client message, sid [%s]", session.Sid)
session.ClientAddr = raddr
resp := nc.GenNatHoleResponse(session, "")
log.Trace("send nat hole response to client")
_, _ = nc.listener.WriteToUDP(resp, raddr)
}
func (nc *Controller) GenNatHoleResponse(session *Session, errInfo string) []byte {
var (
sid string
visitorAddr string
clientAddr string
)
if session != nil {
sid = session.Sid
visitorAddr = session.VisitorAddr.String()
clientAddr = session.ClientAddr.String()
}
m := &msg.NatHoleResp{
Sid: sid,
VisitorAddr: visitorAddr,
ClientAddr: clientAddr,
Error: errInfo,
}
b := bytes.NewBuffer(nil)
err := msg.WriteMsg(b, m)
if err != nil {
return []byte("")
}
return b.Bytes()
}
type Session struct {
Sid string
VisitorAddr *net.UDPAddr
ClientAddr *net.UDPAddr
NotifyCh chan struct{}
}
type ClientCfg struct {
Name string
Sk string
SidCh chan *SidRequest
} }

View File

@ -16,7 +16,6 @@ package nathole
import ( import (
"bytes" "bytes"
"fmt"
"net" "net"
"strconv" "strconv"
@ -64,49 +63,3 @@ func (s *ChangedAddress) GetFrom(m *stun.Message) error {
func (s *ChangedAddress) String() string { func (s *ChangedAddress) String() string {
return net.JoinHostPort(s.IP.String(), strconv.Itoa(s.Port)) return net.JoinHostPort(s.IP.String(), strconv.Itoa(s.Port))
} }
func ListAllLocalIPs() ([]net.IP, error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
return nil, err
}
ips := make([]net.IP, 0, len(addrs))
for _, addr := range addrs {
ip, _, err := net.ParseCIDR(addr.String())
if err != nil {
continue
}
ips = append(ips, ip)
}
return ips, nil
}
func ListLocalIPsForNatHole(max int) ([]string, error) {
if max <= 0 {
return nil, fmt.Errorf("max must be greater than 0")
}
ips, err := ListAllLocalIPs()
if err != nil {
return nil, err
}
filtered := make([]string, 0, max)
for _, ip := range ips {
if len(filtered) >= max {
break
}
// ignore ipv6 address
if ip.To4() == nil {
continue
}
// ignore localhost IP
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
continue
}
filtered = append(filtered, ip.String())
}
return filtered, nil
}

View File

@ -23,7 +23,7 @@ import (
"net/http/httputil" "net/http/httputil"
"strings" "strings"
utilnet "github.com/fatedier/frp/pkg/util/net" frpNet "github.com/fatedier/frp/pkg/util/net"
) )
const PluginHTTP2HTTPS = "http2https" const PluginHTTP2HTTPS = "http2https"
@ -98,7 +98,7 @@ func NewHTTP2HTTPSPlugin(params map[string]string) (Plugin, error) {
} }
func (p *HTTP2HTTPSPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) { func (p *HTTP2HTTPSPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) {
wrapConn := utilnet.WrapReadWriteCloserToConn(conn, realConn) wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
_ = p.l.PutConn(wrapConn) _ = p.l.PutConn(wrapConn)
} }

View File

@ -21,13 +21,11 @@ import (
"net" "net"
"net/http" "net/http"
"strings" "strings"
"time"
libio "github.com/fatedier/golib/io" frpIo "github.com/fatedier/golib/io"
libnet "github.com/fatedier/golib/net" gnet "github.com/fatedier/golib/net"
utilnet "github.com/fatedier/frp/pkg/util/net" frpNet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
) )
const PluginHTTPProxy = "http_proxy" const PluginHTTPProxy = "http_proxy"
@ -69,9 +67,9 @@ func (hp *HTTPProxy) Name() string {
} }
func (hp *HTTPProxy) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) { func (hp *HTTPProxy) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) {
wrapConn := utilnet.WrapReadWriteCloserToConn(conn, realConn) wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
sc, rd := libnet.NewSharedConn(wrapConn) sc, rd := gnet.NewSharedConn(wrapConn)
firstBytes := make([]byte, 7) firstBytes := make([]byte, 7)
_, err := rd.Read(firstBytes) _, err := rd.Read(firstBytes)
if err != nil { if err != nil {
@ -86,7 +84,7 @@ func (hp *HTTPProxy) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBuf
wrapConn.Close() wrapConn.Close()
return return
} }
hp.handleConnectReq(request, libio.WrapReadWriteCloser(bufRd, wrapConn, wrapConn.Close)) hp.handleConnectReq(request, frpIo.WrapReadWriteCloser(bufRd, wrapConn, wrapConn.Close))
return return
} }
@ -158,7 +156,7 @@ func (hp *HTTPProxy) ConnectHandler(rw http.ResponseWriter, req *http.Request) {
} }
_, _ = client.Write([]byte("HTTP/1.1 200 OK\r\n\r\n")) _, _ = client.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
go libio.Join(remote, client) go frpIo.Join(remote, client)
} }
func (hp *HTTPProxy) Auth(req *http.Request) bool { func (hp *HTTPProxy) Auth(req *http.Request) bool {
@ -181,9 +179,7 @@ func (hp *HTTPProxy) Auth(req *http.Request) bool {
return false return false
} }
if !util.ConstantTimeEqString(pair[0], hp.AuthUser) || if pair[0] != hp.AuthUser || pair[1] != hp.AuthPasswd {
!util.ConstantTimeEqString(pair[1], hp.AuthPasswd) {
time.Sleep(200 * time.Millisecond)
return false return false
} }
return true return true
@ -213,7 +209,7 @@ func (hp *HTTPProxy) handleConnectReq(req *http.Request, rwc io.ReadWriteCloser)
} }
_, _ = rwc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n")) _, _ = rwc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
libio.Join(remote, rwc) frpIo.Join(remote, rwc)
} }
func copyHeaders(dst, src http.Header) { func copyHeaders(dst, src http.Header) {

View File

@ -24,7 +24,7 @@ import (
"strings" "strings"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net" frpNet "github.com/fatedier/frp/pkg/util/net"
) )
const PluginHTTPS2HTTP = "https2http" const PluginHTTPS2HTTP = "https2http"
@ -123,7 +123,7 @@ func (p *HTTPS2HTTPPlugin) genTLSConfig() (*tls.Config, error) {
} }
func (p *HTTPS2HTTPPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) { func (p *HTTPS2HTTPPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) {
wrapConn := utilnet.WrapReadWriteCloserToConn(conn, realConn) wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
_ = p.l.PutConn(wrapConn) _ = p.l.PutConn(wrapConn)
} }

View File

@ -24,7 +24,7 @@ import (
"strings" "strings"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net" frpNet "github.com/fatedier/frp/pkg/util/net"
) )
const PluginHTTPS2HTTPS = "https2https" const PluginHTTPS2HTTPS = "https2https"
@ -128,7 +128,7 @@ func (p *HTTPS2HTTPSPlugin) genTLSConfig() (*tls.Config, error) {
} }
func (p *HTTPS2HTTPSPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) { func (p *HTTPS2HTTPSPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) {
wrapConn := utilnet.WrapReadWriteCloserToConn(conn, realConn) wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
_ = p.l.PutConn(wrapConn) _ = p.l.PutConn(wrapConn)
} }

View File

@ -21,7 +21,7 @@ import (
gosocks5 "github.com/armon/go-socks5" gosocks5 "github.com/armon/go-socks5"
utilnet "github.com/fatedier/frp/pkg/util/net" frpNet "github.com/fatedier/frp/pkg/util/net"
) )
const PluginSocks5 = "socks5" const PluginSocks5 = "socks5"
@ -52,7 +52,7 @@ func NewSocks5Plugin(params map[string]string) (p Plugin, err error) {
func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) { func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) {
defer conn.Close() defer conn.Close()
wrapConn := utilnet.WrapReadWriteCloserToConn(conn, realConn) wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
_ = sp.Server.ServeConn(wrapConn) _ = sp.Server.ServeConn(wrapConn)
} }

View File

@ -18,11 +18,10 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
utilnet "github.com/fatedier/frp/pkg/util/net" frpNet "github.com/fatedier/frp/pkg/util/net"
) )
const PluginStaticFile = "static_file" const PluginStaticFile = "static_file"
@ -65,8 +64,8 @@ func NewStaticFilePlugin(params map[string]string) (Plugin, error) {
} }
router := mux.NewRouter() router := mux.NewRouter()
router.Use(utilnet.NewHTTPAuthMiddleware(httpUser, httpPasswd).SetAuthFailDelay(200 * time.Millisecond).Middleware) router.Use(frpNet.NewHTTPAuthMiddleware(httpUser, httpPasswd).Middleware)
router.PathPrefix(prefix).Handler(utilnet.MakeHTTPGzipHandler(http.StripPrefix(prefix, http.FileServer(http.Dir(localPath))))).Methods("GET") router.PathPrefix(prefix).Handler(frpNet.MakeHTTPGzipHandler(http.StripPrefix(prefix, http.FileServer(http.Dir(localPath))))).Methods("GET")
sp.s = &http.Server{ sp.s = &http.Server{
Handler: router, Handler: router,
} }
@ -77,7 +76,7 @@ func NewStaticFilePlugin(params map[string]string) (Plugin, error) {
} }
func (sp *StaticFilePlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) { func (sp *StaticFilePlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) {
wrapConn := utilnet.WrapReadWriteCloserToConn(conn, realConn) wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
_ = sp.l.PutConn(wrapConn) _ = sp.l.PutConn(wrapConn)
} }

View File

@ -19,7 +19,7 @@ import (
"io" "io"
"net" "net"
libio "github.com/fatedier/golib/io" frpIo "github.com/fatedier/golib/io"
) )
const PluginUnixDomainSocket = "unix_domain_socket" const PluginUnixDomainSocket = "unix_domain_socket"
@ -62,7 +62,7 @@ func (uds *UnixDomainSocketPlugin) Handle(conn io.ReadWriteCloser, realConn net.
} }
} }
libio.Join(localConn, conn) frpIo.Join(localConn, conn)
} }
func (uds *UnixDomainSocketPlugin) Name() string { func (uds *UnixDomainSocketPlugin) Name() string {

View File

@ -1,119 +0,0 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package transport
import (
"context"
"reflect"
"sync"
"github.com/fatedier/golib/errors"
"github.com/fatedier/frp/pkg/msg"
)
type MessageTransporter interface {
Send(msg.Message) error
// Recv(ctx context.Context, laneKey string, msgType string) (Message, error)
// Do will first send msg, then recv msg with the same laneKey and specified msgType.
Do(ctx context.Context, req msg.Message, laneKey, recvMsgType string) (msg.Message, error)
Dispatch(m msg.Message, laneKey string) bool
DispatchWithType(m msg.Message, msgType, laneKey string) bool
}
func NewMessageTransporter(sendCh chan msg.Message) MessageTransporter {
return &transporterImpl{
sendCh: sendCh,
registry: make(map[string]map[string]chan msg.Message),
}
}
type transporterImpl struct {
sendCh chan msg.Message
// First key is message type and second key is lane key.
// Dispatch will dispatch message to releated channel by its message type
// and lane key.
registry map[string]map[string]chan msg.Message
mu sync.RWMutex
}
func (impl *transporterImpl) Send(m msg.Message) error {
return errors.PanicToError(func() {
impl.sendCh <- m
})
}
func (impl *transporterImpl) Do(ctx context.Context, req msg.Message, laneKey, recvMsgType string) (msg.Message, error) {
ch := make(chan msg.Message, 1)
defer close(ch)
unregisterFn := impl.registerMsgChan(ch, laneKey, recvMsgType)
defer unregisterFn()
if err := impl.Send(req); err != nil {
return nil, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case resp := <-ch:
return resp, nil
}
}
func (impl *transporterImpl) DispatchWithType(m msg.Message, msgType, laneKey string) bool {
var ch chan msg.Message
impl.mu.RLock()
byLaneKey, ok := impl.registry[msgType]
if ok {
ch = byLaneKey[laneKey]
}
impl.mu.RUnlock()
if ch == nil {
return false
}
if err := errors.PanicToError(func() {
ch <- m
}); err != nil {
return false
}
return true
}
func (impl *transporterImpl) Dispatch(m msg.Message, laneKey string) bool {
msgType := reflect.TypeOf(m).Elem().Name()
return impl.DispatchWithType(m, msgType, laneKey)
}
func (impl *transporterImpl) registerMsgChan(recvCh chan msg.Message, laneKey string, msgType string) (unregister func()) {
impl.mu.Lock()
byLaneKey, ok := impl.registry[msgType]
if !ok {
byLaneKey = make(map[string]chan msg.Message)
impl.registry[msgType] = byLaneKey
}
byLaneKey[laneKey] = recvCh
impl.mu.Unlock()
unregister = func() {
impl.mu.Lock()
delete(byLaneKey, laneKey)
impl.mu.Unlock()
}
return
}

View File

@ -1,17 +1,3 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package transport package transport
import ( import (

View File

@ -19,9 +19,6 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/fatedier/frp/pkg/util/util"
) )
type HTTPAuthWraper struct { type HTTPAuthWraper struct {
@ -49,9 +46,8 @@ func (aw *HTTPAuthWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
type HTTPAuthMiddleware struct { type HTTPAuthMiddleware struct {
user string user string
passwd string passwd string
authFailDelay time.Duration
} }
func NewHTTPAuthMiddleware(user, passwd string) *HTTPAuthMiddleware { func NewHTTPAuthMiddleware(user, passwd string) *HTTPAuthMiddleware {
@ -61,28 +57,32 @@ func NewHTTPAuthMiddleware(user, passwd string) *HTTPAuthMiddleware {
} }
} }
func (authMid *HTTPAuthMiddleware) SetAuthFailDelay(delay time.Duration) *HTTPAuthMiddleware {
authMid.authFailDelay = delay
return authMid
}
func (authMid *HTTPAuthMiddleware) Middleware(next http.Handler) http.Handler { func (authMid *HTTPAuthMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqUser, reqPasswd, hasAuth := r.BasicAuth() reqUser, reqPasswd, hasAuth := r.BasicAuth()
if (authMid.user == "" && authMid.passwd == "") || if (authMid.user == "" && authMid.passwd == "") ||
(hasAuth && util.ConstantTimeEqString(reqUser, authMid.user) && (hasAuth && reqUser == authMid.user && reqPasswd == authMid.passwd) {
util.ConstantTimeEqString(reqPasswd, authMid.passwd)) {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
} else { } else {
if authMid.authFailDelay > 0 {
time.Sleep(authMid.authFailDelay)
}
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
} }
}) })
} }
func HTTPBasicAuth(h http.HandlerFunc, user, passwd string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
reqUser, reqPasswd, hasAuth := r.BasicAuth()
if (user == "" && passwd == "") ||
(hasAuth && reqUser == user && reqPasswd == passwd) {
h.ServeHTTP(w, r)
} else {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
}
}
type HTTPGzipWraper struct { type HTTPGzipWraper struct {
h http.Handler h http.Handler
} }

View File

@ -22,21 +22,20 @@ import (
"github.com/fatedier/golib/errors" "github.com/fatedier/golib/errors"
) )
// InternalListener is a listener that can be used to accept connections from // Custom listener
// other goroutines. type CustomListener struct {
type InternalListener struct {
acceptCh chan net.Conn acceptCh chan net.Conn
closed bool closed bool
mu sync.Mutex mu sync.Mutex
} }
func NewInternalListener() *InternalListener { func NewCustomListener() *CustomListener {
return &InternalListener{ return &CustomListener{
acceptCh: make(chan net.Conn, 128), acceptCh: make(chan net.Conn, 64),
} }
} }
func (l *InternalListener) Accept() (net.Conn, error) { func (l *CustomListener) Accept() (net.Conn, error) {
conn, ok := <-l.acceptCh conn, ok := <-l.acceptCh
if !ok { if !ok {
return nil, fmt.Errorf("listener closed") return nil, fmt.Errorf("listener closed")
@ -44,7 +43,7 @@ func (l *InternalListener) Accept() (net.Conn, error) {
return conn, nil return conn, nil
} }
func (l *InternalListener) PutConn(conn net.Conn) error { func (l *CustomListener) PutConn(conn net.Conn) error {
err := errors.PanicToError(func() { err := errors.PanicToError(func() {
select { select {
case l.acceptCh <- conn: case l.acceptCh <- conn:
@ -55,7 +54,7 @@ func (l *InternalListener) PutConn(conn net.Conn) error {
return err return err
} }
func (l *InternalListener) Close() error { func (l *CustomListener) Close() error {
l.mu.Lock() l.mu.Lock()
defer l.mu.Unlock() defer l.mu.Unlock()
if !l.closed { if !l.closed {
@ -65,16 +64,6 @@ func (l *InternalListener) Close() error {
return nil return nil
} }
func (l *InternalListener) Addr() net.Addr { func (l *CustomListener) Addr() net.Addr {
return &InternalAddr{} return (*net.TCPAddr)(nil)
}
type InternalAddr struct{}
func (ia *InternalAddr) Network() string {
return "internal"
}
func (ia *InternalAddr) String() string {
return "internal"
} }

View File

@ -20,7 +20,7 @@ import (
"net" "net"
"time" "time"
libnet "github.com/fatedier/golib/net" gnet "github.com/fatedier/golib/net"
) )
var FRPTLSHeadByte = 0x17 var FRPTLSHeadByte = 0x17
@ -28,7 +28,7 @@ var FRPTLSHeadByte = 0x17
func CheckAndEnableTLSServerConnWithTimeout( func CheckAndEnableTLSServerConnWithTimeout(
c net.Conn, tlsConfig *tls.Config, tlsOnly bool, timeout time.Duration, c net.Conn, tlsConfig *tls.Config, tlsOnly bool, timeout time.Duration,
) (out net.Conn, isTLS bool, custom bool, err error) { ) (out net.Conn, isTLS bool, custom bool, err error) {
sc, r := libnet.NewSharedConnSize(c, 2) sc, r := gnet.NewSharedConnSize(c, 2)
buf := make([]byte, 1) buf := make([]byte, 1)
var n int var n int
_ = c.SetReadDeadline(time.Now().Add(timeout)) _ = c.SetReadDeadline(time.Now().Add(timeout))

View File

@ -256,11 +256,3 @@ func (l *UDPListener) Close() error {
func (l *UDPListener) Addr() net.Addr { func (l *UDPListener) Addr() net.Addr {
return l.addr return l.addr
} }
// ConnectedUDPConn is a wrapper for net.UDPConn which converts WriteTo syscalls
// to Write syscalls that are 4 times faster on some OS'es. This should only be
// used for connections that were produced by a net.Dial* call.
type ConnectedUDPConn struct{ *net.UDPConn }
// WriteTo redirects all writes to the Write syscall, which is 4 times faster.
func (c *ConnectedUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { return c.Write(b) }

View File

@ -22,7 +22,7 @@ import (
"net/http" "net/http"
"time" "time"
libnet "github.com/fatedier/golib/net" gnet "github.com/fatedier/golib/net"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/vhost" "github.com/fatedier/frp/pkg/util/vhost"
@ -94,7 +94,7 @@ func (muxer *HTTPConnectTCPMuxer) auth(c net.Conn, username, password string, re
func (muxer *HTTPConnectTCPMuxer) getHostFromHTTPConnect(c net.Conn) (net.Conn, map[string]string, error) { func (muxer *HTTPConnectTCPMuxer) getHostFromHTTPConnect(c net.Conn) (net.Conn, map[string]string, error) {
reqInfoMap := make(map[string]string, 0) reqInfoMap := make(map[string]string, 0)
sc, rd := libnet.NewSharedConn(c) sc, rd := gnet.NewSharedConn(c)
host, httpUser, httpPwd, err := muxer.readHTTPConnectRequest(rd) host, httpUser, httpPwd, err := muxer.readHTTPConnectRequest(rd)
if err != nil { if err != nil {

25
pkg/util/util/slice.go Normal file
View File

@ -0,0 +1,25 @@
package util
func InSlice[T comparable](v T, s []T) bool {
for _, vv := range s {
if v == vv {
return true
}
}
return false
}
func InSliceAny[T any](v T, s []T, equalFn func(a, b T) bool) bool {
for _, vv := range s {
if equalFn(v, vv) {
return true
}
}
return false
}
func InSliceAnyFunc[T any](equalFn func(a, b T) bool) func(v T, s []T) bool {
return func(v T, s []T) bool {
return InSliceAny(v, s, equalFn)
}
}

View File

@ -0,0 +1,49 @@
package util
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestInSlice(t *testing.T) {
require := require.New(t)
require.True(InSlice(1, []int{1, 2, 3}))
require.False(InSlice(0, []int{1, 2, 3}))
require.True(InSlice("foo", []string{"foo", "bar"}))
require.False(InSlice("not exist", []string{"foo", "bar"}))
}
type testStructA struct {
Name string
Age int
}
func TestInSliceAny(t *testing.T) {
require := require.New(t)
a := testStructA{Name: "foo", Age: 20}
b := testStructA{Name: "foo", Age: 30}
c := testStructA{Name: "bar", Age: 20}
equalFn := func(o, p testStructA) bool {
return o.Name == p.Name
}
require.True(InSliceAny(a, []testStructA{b, c}, equalFn))
require.False(InSliceAny(c, []testStructA{a, b}, equalFn))
}
func TestInSliceAnyFunc(t *testing.T) {
require := require.New(t)
a := testStructA{Name: "foo", Age: 20}
b := testStructA{Name: "foo", Age: 30}
c := testStructA{Name: "bar", Age: 20}
equalFn := func(o, p testStructA) bool {
return o.Name == p.Name
}
testStructAInSlice := InSliceAnyFunc(equalFn)
require.True(testStructAInSlice(a, []testStructA{b, c}))
require.False(testStructAInSlice(c, []testStructA{a, b}))
}

View File

@ -17,7 +17,6 @@ package util
import ( import (
"crypto/md5" "crypto/md5"
"crypto/rand" "crypto/rand"
"crypto/subtle"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
mathrand "math/rand" mathrand "math/rand"
@ -29,32 +28,19 @@ import (
// RandID return a rand string used in frp. // RandID return a rand string used in frp.
func RandID() (id string, err error) { func RandID() (id string, err error) {
return RandIDWithLen(16) return RandIDWithLen(8)
} }
// RandIDWithLen return a rand string with idLen length. // RandIDWithLen return a rand string with idLen length.
func RandIDWithLen(idLen int) (id string, err error) { func RandIDWithLen(idLen int) (id string, err error) {
if idLen <= 0 { b := make([]byte, idLen)
return "", nil
}
b := make([]byte, idLen/2+1)
_, err = rand.Read(b) _, err = rand.Read(b)
if err != nil { if err != nil {
return return
} }
id = fmt.Sprintf("%x", b) id = fmt.Sprintf("%x", b)
return id[:idLen], nil return
}
// RandIDWithRandLen return a rand string with length between [start, end).
func RandIDWithRandLen(start, end int) (id string, err error) {
if start >= end {
err = fmt.Errorf("start should be less than end")
return
}
idLen := mathrand.Intn(end-start) + start
return RandIDWithLen(idLen)
} }
func GetAuthKey(token string, timestamp int64) (key string) { func GetAuthKey(token string, timestamp int64) (key string) {
@ -140,7 +126,3 @@ func RandomSleep(duration time.Duration, minRatio, maxRatio float64) time.Durati
time.Sleep(d) time.Sleep(d)
return d return d
} }
func ConstantTimeEqString(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}

View File

@ -14,51 +14,10 @@ func TestRandId(t *testing.T) {
assert.Equal(16, len(id)) assert.Equal(16, len(id))
} }
func TestRandIDWithRandLen(t *testing.T) {
tests := []struct {
name string
start int
end int
expectErr bool
}{
{
name: "start and end are equal",
start: 5,
end: 5,
expectErr: true,
},
{
name: "start is less than end",
start: 5,
end: 10,
expectErr: false,
},
{
name: "start is greater than end",
start: 10,
end: 5,
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
id, err := RandIDWithRandLen(tt.start, tt.end)
if tt.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.GreaterOrEqual(len(id), tt.start)
assert.Less(len(id), tt.end)
}
})
}
}
func TestGetAuthKey(t *testing.T) { func TestGetAuthKey(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
key := GetAuthKey("1234", 1488720000) key := GetAuthKey("1234", 1488720000)
t.Log(key)
assert.Equal("6df41a43725f0c770fd56379e12acf8c", key) assert.Equal("6df41a43725f0c770fd56379e12acf8c", key)
} }

View File

@ -19,7 +19,7 @@ import (
"strings" "strings"
) )
var version = "0.49.0" var version = "0.48.0"
func Full() string { func Full() string {
return version return version

View File

@ -28,7 +28,7 @@ import (
"strings" "strings"
"time" "time"
libio "github.com/fatedier/golib/io" frpIo "github.com/fatedier/golib/io"
"github.com/fatedier/golib/pool" "github.com/fatedier/golib/pool"
frpLog "github.com/fatedier/frp/pkg/util/log" frpLog "github.com/fatedier/frp/pkg/util/log"
@ -256,7 +256,7 @@ func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Req
return return
} }
_ = req.Write(remote) _ = req.Write(remote)
go libio.Join(remote, client) go frpIo.Join(remote, client)
} }
func parseBasicAuth(auth string) (username, password string, ok bool) { func parseBasicAuth(auth string) (username, password string, ok bool) {

View File

@ -20,7 +20,7 @@ import (
"net" "net"
"time" "time"
libnet "github.com/fatedier/golib/net" gnet "github.com/fatedier/golib/net"
) )
type HTTPSMuxer struct { type HTTPSMuxer struct {
@ -37,7 +37,7 @@ func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, e
func GetHTTPSHostname(c net.Conn) (_ net.Conn, _ map[string]string, err error) { func GetHTTPSHostname(c net.Conn) (_ net.Conn, _ map[string]string, err error) {
reqInfoMap := make(map[string]string, 0) reqInfoMap := make(map[string]string, 0)
sc, rd := libnet.NewSharedConn(c) sc, rd := gnet.NewSharedConn(c)
clientHello, err := readClientHello(rd) clientHello, err := readClientHello(rd)
if err != nil { if err != nil {

View File

@ -22,7 +22,7 @@ import (
"github.com/fatedier/golib/errors" "github.com/fatedier/golib/errors"
"github.com/fatedier/frp/pkg/util/log" "github.com/fatedier/frp/pkg/util/log"
utilnet "github.com/fatedier/frp/pkg/util/net" frpNet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
@ -282,7 +282,7 @@ func (l *Listener) Accept() (net.Conn, error) {
xl.Debug("rewrite host to [%s] success", l.rewriteHost) xl.Debug("rewrite host to [%s] success", l.rewriteHost)
conn = sConn conn = sConn
} }
return utilnet.NewContextConn(l.ctx, conn), nil return frpNet.NewContextConn(l.ctx, conn), nil
} }
func (l *Listener) Close() error { func (l *Listener) Close() error {

View File

@ -30,10 +30,9 @@ import (
"github.com/fatedier/frp/pkg/auth" "github.com/fatedier/frp/pkg/auth"
"github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/consts" "github.com/fatedier/frp/pkg/consts"
pkgerr "github.com/fatedier/frp/pkg/errors" frpErr "github.com/fatedier/frp/pkg/errors"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
plugin "github.com/fatedier/frp/pkg/plugin/server" plugin "github.com/fatedier/frp/pkg/plugin/server"
"github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/version" "github.com/fatedier/frp/pkg/util/version"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
@ -83,16 +82,6 @@ func (cm *ControlManager) GetByID(runID string) (ctl *Control, ok bool) {
return return
} }
func (cm *ControlManager) Close() error {
cm.mu.Lock()
defer cm.mu.Unlock()
for _, ctl := range cm.ctlsByRunID {
ctl.Close()
}
cm.ctlsByRunID = make(map[string]*Control)
return nil
}
type Control struct { type Control struct {
// all resource managers and controllers // all resource managers and controllers
rc *controller.ResourceController rc *controller.ResourceController
@ -106,9 +95,6 @@ type Control struct {
// verifies authentication based on selected method // verifies authentication based on selected method
authVerifier auth.Verifier authVerifier auth.Verifier
// other components can use this to communicate with client
msgTransporter transport.MessageTransporter
// login message // login message
loginMsg *msg.Login loginMsg *msg.Login
@ -172,7 +158,7 @@ func NewControl(
if poolCount > int(serverCfg.MaxPoolCount) { if poolCount > int(serverCfg.MaxPoolCount) {
poolCount = int(serverCfg.MaxPoolCount) poolCount = int(serverCfg.MaxPoolCount)
} }
ctl := &Control{ return &Control{
rc: rc, rc: rc,
pxyManager: pxyManager, pxyManager: pxyManager,
pluginManager: pluginManager, pluginManager: pluginManager,
@ -196,16 +182,15 @@ func NewControl(
xl: xlog.FromContextSafe(ctx), xl: xlog.FromContextSafe(ctx),
ctx: ctx, ctx: ctx,
} }
ctl.msgTransporter = transport.NewMessageTransporter(ctl.sendCh)
return ctl
} }
// Start send a login success message to client and start working. // Start send a login success message to client and start working.
func (ctl *Control) Start() { func (ctl *Control) Start() {
loginRespMsg := &msg.LoginResp{ loginRespMsg := &msg.LoginResp{
Version: version.Full(), Version: version.Full(),
RunID: ctl.runID, RunID: ctl.runID,
Error: "", ServerUDPPort: ctl.serverCfg.BindUDPPort,
Error: "",
} }
_ = msg.WriteMsg(ctl.conn, loginRespMsg) _ = msg.WriteMsg(ctl.conn, loginRespMsg)
@ -219,18 +204,6 @@ func (ctl *Control) Start() {
go ctl.stoper() go ctl.stoper()
} }
func (ctl *Control) Close() error {
ctl.allShutdown.Start()
return nil
}
func (ctl *Control) Replaced(newCtl *Control) {
xl := ctl.xl
xl.Info("Replaced by client [%s]", newCtl.runID)
ctl.runID = ""
ctl.allShutdown.Start()
}
func (ctl *Control) RegisterWorkConn(conn net.Conn) error { func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
xl := ctl.xl xl := ctl.xl
defer func() { defer func() {
@ -268,7 +241,7 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
select { select {
case workConn, ok = <-ctl.workConnCh: case workConn, ok = <-ctl.workConnCh:
if !ok { if !ok {
err = pkgerr.ErrCtlClosed err = frpErr.ErrCtlClosed
return return
} }
xl.Debug("get work connection from pool") xl.Debug("get work connection from pool")
@ -283,7 +256,7 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
select { select {
case workConn, ok = <-ctl.workConnCh: case workConn, ok = <-ctl.workConnCh:
if !ok { if !ok {
err = pkgerr.ErrCtlClosed err = frpErr.ErrCtlClosed
xl.Warn("no work connections available, %v", err) xl.Warn("no work connections available, %v", err)
return return
} }
@ -302,6 +275,13 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
return return
} }
func (ctl *Control) Replaced(newCtl *Control) {
xl := ctl.xl
xl.Info("Replaced by client [%s]", newCtl.runID)
ctl.runID = ""
ctl.allShutdown.Start()
}
func (ctl *Control) writer() { func (ctl *Control) writer() {
xl := ctl.xl xl := ctl.xl
defer func() { defer func() {
@ -394,7 +374,7 @@ func (ctl *Control) stoper() {
for _, pxy := range ctl.proxies { for _, pxy := range ctl.proxies {
pxy.Close() pxy.Close()
ctl.pxyManager.Del(pxy.GetName()) ctl.pxyManager.Del(pxy.GetName())
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConf().GetBaseConfig().ProxyType) metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConf().GetBaseInfo().ProxyType)
notifyContent := &plugin.CloseProxyContent{ notifyContent := &plugin.CloseProxyContent{
User: plugin.UserInfo{ User: plugin.UserInfo{
@ -485,12 +465,6 @@ func (ctl *Control) manager() {
metrics.Server.NewProxy(m.ProxyName, m.ProxyType) metrics.Server.NewProxy(m.ProxyName, m.ProxyType)
} }
ctl.sendCh <- resp ctl.sendCh <- resp
case *msg.NatHoleVisitor:
go ctl.HandleNatHoleVisitor(m)
case *msg.NatHoleClient:
go ctl.HandleNatHoleClient(m)
case *msg.NatHoleReport:
go ctl.HandleNatHoleReport(m)
case *msg.CloseProxy: case *msg.CloseProxy:
_ = ctl.CloseProxy(m) _ = ctl.CloseProxy(m)
xl.Info("close proxy [%s] success", m.ProxyName) xl.Info("close proxy [%s] success", m.ProxyName)
@ -523,18 +497,6 @@ func (ctl *Control) manager() {
} }
} }
func (ctl *Control) HandleNatHoleVisitor(m *msg.NatHoleVisitor) {
ctl.rc.NatHoleController.HandleVisitor(m, ctl.msgTransporter)
}
func (ctl *Control) HandleNatHoleClient(m *msg.NatHoleClient) {
ctl.rc.NatHoleController.HandleClient(m, ctl.msgTransporter)
}
func (ctl *Control) HandleNatHoleReport(m *msg.NatHoleReport) {
ctl.rc.NatHoleController.HandleReport(m)
}
func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) { func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) {
var pxyConf config.ProxyConf var pxyConf config.ProxyConf
// Load configures from NewProxy message and check. // Load configures from NewProxy message and check.
@ -614,7 +576,7 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) {
delete(ctl.proxies, closeMsg.ProxyName) delete(ctl.proxies, closeMsg.ProxyName)
ctl.mu.Unlock() ctl.mu.Unlock()
metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConf().GetBaseConfig().ProxyType) metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConf().GetBaseInfo().ProxyType)
notifyContent := &plugin.CloseProxyContent{ notifyContent := &plugin.CloseProxyContent{
User: plugin.UserInfo{ User: plugin.UserInfo{

View File

@ -25,7 +25,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/fatedier/frp/assets" "github.com/fatedier/frp/assets"
utilnet "github.com/fatedier/frp/pkg/util/net" frpNet "github.com/fatedier/frp/pkg/util/net"
) )
var ( var (
@ -50,7 +50,7 @@ func (svr *Service) RunDashboardServer(address string) (err error) {
subRouter := router.NewRoute().Subrouter() subRouter := router.NewRoute().Subrouter()
user, passwd := svr.cfg.DashboardUser, svr.cfg.DashboardPwd user, passwd := svr.cfg.DashboardUser, svr.cfg.DashboardPwd
subRouter.Use(utilnet.NewHTTPAuthMiddleware(user, passwd).SetAuthFailDelay(200 * time.Millisecond).Middleware) subRouter.Use(frpNet.NewHTTPAuthMiddleware(user, passwd).Middleware)
// metrics // metrics
if svr.cfg.EnablePrometheus { if svr.cfg.EnablePrometheus {
@ -65,7 +65,7 @@ func (svr *Service) RunDashboardServer(address string) (err error) {
// view // view
subRouter.Handle("/favicon.ico", http.FileServer(assets.FileSystem)).Methods("GET") subRouter.Handle("/favicon.ico", http.FileServer(assets.FileSystem)).Methods("GET")
subRouter.PathPrefix("/static/").Handler(utilnet.MakeHTTPGzipHandler(http.StripPrefix("/static/", http.FileServer(assets.FileSystem)))).Methods("GET") subRouter.PathPrefix("/static/").Handler(frpNet.MakeHTTPGzipHandler(http.StripPrefix("/static/", http.FileServer(assets.FileSystem)))).Methods("GET")
subRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { subRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/static/", http.StatusMovedPermanently) http.Redirect(w, r, "/static/", http.StatusMovedPermanently)

View File

@ -35,6 +35,7 @@ type GeneralResponse struct {
type serverInfoResp struct { type serverInfoResp struct {
Version string `json:"version"` Version string `json:"version"`
BindPort int `json:"bind_port"` BindPort int `json:"bind_port"`
BindUDPPort int `json:"bind_udp_port"`
VhostHTTPPort int `json:"vhost_http_port"` VhostHTTPPort int `json:"vhost_http_port"`
VhostHTTPSPort int `json:"vhost_https_port"` VhostHTTPSPort int `json:"vhost_https_port"`
TCPMuxHTTPConnectPort int `json:"tcpmux_httpconnect_port"` TCPMuxHTTPConnectPort int `json:"tcpmux_httpconnect_port"`
@ -75,6 +76,7 @@ func (svr *Service) APIServerInfo(w http.ResponseWriter, r *http.Request) {
svrResp := serverInfoResp{ svrResp := serverInfoResp{
Version: version.Full(), Version: version.Full(),
BindPort: svr.cfg.BindPort, BindPort: svr.cfg.BindPort,
BindUDPPort: svr.cfg.BindUDPPort,
VhostHTTPPort: svr.cfg.VhostHTTPPort, VhostHTTPPort: svr.cfg.VhostHTTPPort,
VhostHTTPSPort: svr.cfg.VhostHTTPSPort, VhostHTTPSPort: svr.cfg.VhostHTTPSPort,
TCPMuxHTTPConnectPort: svr.cfg.TCPMuxHTTPConnectPort, TCPMuxHTTPConnectPort: svr.cfg.TCPMuxHTTPConnectPort,

View File

@ -19,12 +19,12 @@ import (
"net" "net"
"strings" "strings"
libio "github.com/fatedier/golib/io" frpIo "github.com/fatedier/golib/io"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/util/limit" "github.com/fatedier/frp/pkg/util/limit"
utilnet "github.com/fatedier/frp/pkg/util/net" frpNet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/vhost" "github.com/fatedier/frp/pkg/util/vhost"
"github.com/fatedier/frp/server/metrics" "github.com/fatedier/frp/server/metrics"
@ -157,31 +157,31 @@ func (pxy *HTTPProxy) GetRealConn(remoteAddr string) (workConn net.Conn, err err
var rwc io.ReadWriteCloser = tmpConn var rwc io.ReadWriteCloser = tmpConn
if pxy.cfg.UseEncryption { if pxy.cfg.UseEncryption {
rwc, err = libio.WithEncryption(rwc, []byte(pxy.serverCfg.Token)) rwc, err = frpIo.WithEncryption(rwc, []byte(pxy.serverCfg.Token))
if err != nil { if err != nil {
xl.Error("create encryption stream error: %v", err) xl.Error("create encryption stream error: %v", err)
return return
} }
} }
if pxy.cfg.UseCompression { if pxy.cfg.UseCompression {
rwc = libio.WithCompression(rwc) rwc = frpIo.WithCompression(rwc)
} }
if pxy.GetLimiter() != nil { if pxy.GetLimiter() != nil {
rwc = libio.WrapReadWriteCloser(limit.NewReader(rwc, pxy.GetLimiter()), limit.NewWriter(rwc, pxy.GetLimiter()), func() error { rwc = frpIo.WrapReadWriteCloser(limit.NewReader(rwc, pxy.GetLimiter()), limit.NewWriter(rwc, pxy.GetLimiter()), func() error {
return rwc.Close() return rwc.Close()
}) })
} }
workConn = utilnet.WrapReadWriteCloserToConn(rwc, tmpConn) workConn = frpNet.WrapReadWriteCloserToConn(rwc, tmpConn)
workConn = utilnet.WrapStatsConn(workConn, pxy.updateStatsAfterClosedConn) workConn = frpNet.WrapStatsConn(workConn, pxy.updateStatsAfterClosedConn)
metrics.Server.OpenConnection(pxy.GetName(), pxy.GetConf().GetBaseConfig().ProxyType) metrics.Server.OpenConnection(pxy.GetName(), pxy.GetConf().GetBaseInfo().ProxyType)
return return
} }
func (pxy *HTTPProxy) updateStatsAfterClosedConn(totalRead, totalWrite int64) { func (pxy *HTTPProxy) updateStatsAfterClosedConn(totalRead, totalWrite int64) {
name := pxy.GetName() name := pxy.GetName()
proxyType := pxy.GetConf().GetBaseConfig().ProxyType proxyType := pxy.GetConf().GetBaseInfo().ProxyType
metrics.Server.CloseConnection(name, proxyType) metrics.Server.CloseConnection(name, proxyType)
metrics.Server.AddTrafficIn(name, proxyType, totalWrite) metrics.Server.AddTrafficIn(name, proxyType, totalWrite)
metrics.Server.AddTrafficOut(name, proxyType, totalRead) metrics.Server.AddTrafficOut(name, proxyType, totalRead)

View File

@ -23,14 +23,14 @@ import (
"sync" "sync"
"time" "time"
libio "github.com/fatedier/golib/io" frpIo "github.com/fatedier/golib/io"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
plugin "github.com/fatedier/frp/pkg/plugin/server" plugin "github.com/fatedier/frp/pkg/plugin/server"
"github.com/fatedier/frp/pkg/util/limit" "github.com/fatedier/frp/pkg/util/limit"
utilnet "github.com/fatedier/frp/pkg/util/net" frpNet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/server/controller" "github.com/fatedier/frp/server/controller"
"github.com/fatedier/frp/server/metrics" "github.com/fatedier/frp/server/metrics"
@ -113,7 +113,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
} }
xl.Debug("get a new work connection: [%s]", workConn.RemoteAddr().String()) xl.Debug("get a new work connection: [%s]", workConn.RemoteAddr().String())
xl.Spawn().AppendPrefix(pxy.GetName()) xl.Spawn().AppendPrefix(pxy.GetName())
workConn = utilnet.NewContextConn(pxy.ctx, workConn) workConn = frpNet.NewContextConn(pxy.ctx, workConn)
var ( var (
srcAddr string srcAddr string
@ -156,7 +156,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
} }
// startListenHandler start a goroutine handler for each listener. // startListenHandler start a goroutine handler for each listener.
// p: p will just be passed to handler(Proxy, utilnet.Conn). // p: p will just be passed to handler(Proxy, frpNet.Conn).
// handler: each proxy type can set different handler function to deal with connections accepted from listeners. // handler: each proxy type can set different handler function to deal with connections accepted from listeners.
func (pxy *BaseProxy) startListenHandler(p Proxy, handler func(Proxy, net.Conn, config.ServerCommonConf)) { func (pxy *BaseProxy) startListenHandler(p Proxy, handler func(Proxy, net.Conn, config.ServerCommonConf)) {
xl := xlog.FromContextSafe(pxy.ctx) xl := xlog.FromContextSafe(pxy.ctx)
@ -196,16 +196,16 @@ func (pxy *BaseProxy) startListenHandler(p Proxy, handler func(Proxy, net.Conn,
func NewProxy(ctx context.Context, userInfo plugin.UserInfo, rc *controller.ResourceController, poolCount int, func NewProxy(ctx context.Context, userInfo plugin.UserInfo, rc *controller.ResourceController, poolCount int,
getWorkConnFn GetWorkConnFn, pxyConf config.ProxyConf, serverCfg config.ServerCommonConf, loginMsg *msg.Login, getWorkConnFn GetWorkConnFn, pxyConf config.ProxyConf, serverCfg config.ServerCommonConf, loginMsg *msg.Login,
) (pxy Proxy, err error) { ) (pxy Proxy, err error) {
xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(pxyConf.GetBaseConfig().ProxyName) xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(pxyConf.GetBaseInfo().ProxyName)
var limiter *rate.Limiter var limiter *rate.Limiter
limitBytes := pxyConf.GetBaseConfig().BandwidthLimit.Bytes() limitBytes := pxyConf.GetBaseInfo().BandwidthLimit.Bytes()
if limitBytes > 0 && pxyConf.GetBaseConfig().BandwidthLimitMode == config.BandwidthLimitModeServer { if limitBytes > 0 && pxyConf.GetBaseInfo().BandwidthLimitMode == config.BandwidthLimitModeServer {
limiter = rate.NewLimiter(rate.Limit(float64(limitBytes)), int(limitBytes)) limiter = rate.NewLimiter(rate.Limit(float64(limitBytes)), int(limitBytes))
} }
basePxy := BaseProxy{ basePxy := BaseProxy{
name: pxyConf.GetBaseConfig().ProxyName, name: pxyConf.GetBaseInfo().ProxyName,
rc: rc, rc: rc,
listeners: make([]net.Listener, 0), listeners: make([]net.Listener, 0),
poolCount: poolCount, poolCount: poolCount,
@ -277,7 +277,7 @@ func HandleUserTCPConnection(pxy Proxy, userConn net.Conn, serverCfg config.Serv
content := &plugin.NewUserConnContent{ content := &plugin.NewUserConnContent{
User: pxy.GetUserInfo(), User: pxy.GetUserInfo(),
ProxyName: pxy.GetName(), ProxyName: pxy.GetName(),
ProxyType: pxy.GetConf().GetBaseConfig().ProxyType, ProxyType: pxy.GetConf().GetBaseInfo().ProxyType,
RemoteAddr: userConn.RemoteAddr().String(), RemoteAddr: userConn.RemoteAddr().String(),
} }
_, err := rc.PluginManager.NewUserConn(content) _, err := rc.PluginManager.NewUserConn(content)
@ -294,21 +294,21 @@ func HandleUserTCPConnection(pxy Proxy, userConn net.Conn, serverCfg config.Serv
defer workConn.Close() defer workConn.Close()
var local io.ReadWriteCloser = workConn var local io.ReadWriteCloser = workConn
cfg := pxy.GetConf().GetBaseConfig() cfg := pxy.GetConf().GetBaseInfo()
xl.Trace("handler user tcp connection, use_encryption: %t, use_compression: %t", cfg.UseEncryption, cfg.UseCompression) xl.Trace("handler user tcp connection, use_encryption: %t, use_compression: %t", cfg.UseEncryption, cfg.UseCompression)
if cfg.UseEncryption { if cfg.UseEncryption {
local, err = libio.WithEncryption(local, []byte(serverCfg.Token)) local, err = frpIo.WithEncryption(local, []byte(serverCfg.Token))
if err != nil { if err != nil {
xl.Error("create encryption stream error: %v", err) xl.Error("create encryption stream error: %v", err)
return return
} }
} }
if cfg.UseCompression { if cfg.UseCompression {
local = libio.WithCompression(local) local = frpIo.WithCompression(local)
} }
if pxy.GetLimiter() != nil { if pxy.GetLimiter() != nil {
local = libio.WrapReadWriteCloser(limit.NewReader(local, pxy.GetLimiter()), limit.NewWriter(local, pxy.GetLimiter()), func() error { local = frpIo.WrapReadWriteCloser(limit.NewReader(local, pxy.GetLimiter()), limit.NewWriter(local, pxy.GetLimiter()), func() error {
return local.Close() return local.Close()
}) })
} }
@ -317,9 +317,9 @@ func HandleUserTCPConnection(pxy Proxy, userConn net.Conn, serverCfg config.Serv
workConn.RemoteAddr().String(), userConn.LocalAddr().String(), userConn.RemoteAddr().String()) workConn.RemoteAddr().String(), userConn.LocalAddr().String(), userConn.RemoteAddr().String())
name := pxy.GetName() name := pxy.GetName()
proxyType := pxy.GetConf().GetBaseConfig().ProxyType proxyType := pxy.GetConf().GetBaseInfo().ProxyType
metrics.Server.OpenConnection(name, proxyType) metrics.Server.OpenConnection(name, proxyType)
inCount, outCount, _ := libio.Join(local, userConn) inCount, outCount, _ := frpIo.Join(local, userConn)
metrics.Server.CloseConnection(name, proxyType) metrics.Server.CloseConnection(name, proxyType)
metrics.Server.AddTrafficIn(name, proxyType, inCount) metrics.Server.AddTrafficIn(name, proxyType, inCount)
metrics.Server.AddTrafficOut(name, proxyType, outCount) metrics.Server.AddTrafficOut(name, proxyType, outCount)

View File

@ -23,14 +23,14 @@ import (
"time" "time"
"github.com/fatedier/golib/errors" "github.com/fatedier/golib/errors"
libio "github.com/fatedier/golib/io" frpIo "github.com/fatedier/golib/io"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/udp" "github.com/fatedier/frp/pkg/proto/udp"
"github.com/fatedier/frp/pkg/util/limit" "github.com/fatedier/frp/pkg/util/limit"
utilnet "github.com/fatedier/frp/pkg/util/net" frpNet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/server/metrics" "github.com/fatedier/frp/server/metrics"
) )
@ -124,7 +124,7 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) {
pxy.readCh <- m pxy.readCh <- m
metrics.Server.AddTrafficOut( metrics.Server.AddTrafficOut(
pxy.GetName(), pxy.GetName(),
pxy.GetConf().GetBaseConfig().ProxyType, pxy.GetConf().GetBaseInfo().ProxyType,
int64(len(m.Content)), int64(len(m.Content)),
) )
}); errRet != nil { }); errRet != nil {
@ -154,7 +154,7 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) {
xl.Trace("send message to udp workConn: %s", udpMsg.Content) xl.Trace("send message to udp workConn: %s", udpMsg.Content)
metrics.Server.AddTrafficIn( metrics.Server.AddTrafficIn(
pxy.GetName(), pxy.GetName(),
pxy.GetConf().GetBaseConfig().ProxyType, pxy.GetConf().GetBaseInfo().ProxyType,
int64(len(udpMsg.Content)), int64(len(udpMsg.Content)),
) )
continue continue
@ -189,7 +189,7 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) {
var rwc io.ReadWriteCloser = workConn var rwc io.ReadWriteCloser = workConn
if pxy.cfg.UseEncryption { if pxy.cfg.UseEncryption {
rwc, err = libio.WithEncryption(rwc, []byte(pxy.serverCfg.Token)) rwc, err = frpIo.WithEncryption(rwc, []byte(pxy.serverCfg.Token))
if err != nil { if err != nil {
xl.Error("create encryption stream error: %v", err) xl.Error("create encryption stream error: %v", err)
workConn.Close() workConn.Close()
@ -197,16 +197,16 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) {
} }
} }
if pxy.cfg.UseCompression { if pxy.cfg.UseCompression {
rwc = libio.WithCompression(rwc) rwc = frpIo.WithCompression(rwc)
} }
if pxy.GetLimiter() != nil { if pxy.GetLimiter() != nil {
rwc = libio.WrapReadWriteCloser(limit.NewReader(rwc, pxy.GetLimiter()), limit.NewWriter(rwc, pxy.GetLimiter()), func() error { rwc = frpIo.WrapReadWriteCloser(limit.NewReader(rwc, pxy.GetLimiter()), limit.NewWriter(rwc, pxy.GetLimiter()), func() error {
return rwc.Close() return rwc.Close()
}) })
} }
pxy.workConn = utilnet.WrapReadWriteCloserToConn(rwc, workConn) pxy.workConn = frpNet.WrapReadWriteCloserToConn(rwc, workConn)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
go workConnReaderFn(pxy.workConn) go workConnReaderFn(pxy.workConn)
go workConnSenderFn(pxy.workConn, ctx) go workConnSenderFn(pxy.workConn, ctx)

View File

@ -44,20 +44,41 @@ func (pxy *XTCPProxy) Run() (remoteAddr string, err error) {
for { for {
select { select {
case <-pxy.closeCh: case <-pxy.closeCh:
return break
case sid := <-sidCh: case sidRequest := <-sidCh:
sr := sidRequest
workConn, errRet := pxy.GetWorkConnFromPool(nil, nil) workConn, errRet := pxy.GetWorkConnFromPool(nil, nil)
if errRet != nil { if errRet != nil {
continue continue
} }
m := &msg.NatHoleSid{ m := &msg.NatHoleSid{
Sid: sid, Sid: sr.Sid,
} }
errRet = msg.WriteMsg(workConn, m) errRet = msg.WriteMsg(workConn, m)
if errRet != nil { if errRet != nil {
xl.Warn("write nat hole sid package error, %v", errRet) xl.Warn("write nat hole sid package error, %v", errRet)
workConn.Close()
break
} }
workConn.Close()
go func() {
raw, errRet := msg.ReadMsg(workConn)
if errRet != nil {
xl.Warn("read nat hole client ok package error: %v", errRet)
workConn.Close()
return
}
if _, ok := raw.(*msg.NatHoleClientDetectOK); !ok {
xl.Warn("read nat hole client ok package format error")
workConn.Close()
return
}
select {
case sr.NotifyCh <- struct{}{}:
default:
}
}()
} }
} }
}() }()

View File

@ -39,7 +39,7 @@ import (
plugin "github.com/fatedier/frp/pkg/plugin/server" plugin "github.com/fatedier/frp/pkg/plugin/server"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/log" "github.com/fatedier/frp/pkg/util/log"
utilnet "github.com/fatedier/frp/pkg/util/net" frpNet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/tcpmux" "github.com/fatedier/frp/pkg/util/tcpmux"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/version" "github.com/fatedier/frp/pkg/util/version"
@ -99,11 +99,6 @@ type Service struct {
tlsConfig *tls.Config tlsConfig *tls.Config
cfg config.ServerCommonConf cfg config.ServerCommonConf
// service context
ctx context.Context
// call cancel to stop service
cancel context.CancelFunc
} }
func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
@ -115,7 +110,6 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
return return
} }
ctx, cancel := context.WithCancel(context.Background())
svr = &Service{ svr = &Service{
ctlManager: NewControlManager(), ctlManager: NewControlManager(),
pxyManager: proxy.NewManager(), pxyManager: proxy.NewManager(),
@ -129,8 +123,6 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
authVerifier: auth.NewAuthVerifier(cfg.ServerConfig), authVerifier: auth.NewAuthVerifier(cfg.ServerConfig),
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
cfg: cfg, cfg: cfg,
ctx: ctx,
cancel: cancel,
} }
// Create tcpmux httpconnect multiplexer. // Create tcpmux httpconnect multiplexer.
@ -210,7 +202,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
// Listen for accepting connections from client using kcp protocol. // Listen for accepting connections from client using kcp protocol.
if cfg.KCPBindPort > 0 { if cfg.KCPBindPort > 0 {
address := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.KCPBindPort)) address := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.KCPBindPort))
svr.kcpListener, err = utilnet.ListenKcp(address) svr.kcpListener, err = frpNet.ListenKcp(address)
if err != nil { if err != nil {
err = fmt.Errorf("listen on kcp udp address %s error: %v", address, err) err = fmt.Errorf("listen on kcp udp address %s error: %v", address, err)
return return
@ -235,11 +227,11 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
} }
// Listen for accepting connections from client using websocket protocol. // Listen for accepting connections from client using websocket protocol.
websocketPrefix := []byte("GET " + utilnet.FrpWebsocketPath) websocketPrefix := []byte("GET " + frpNet.FrpWebsocketPath)
websocketLn := svr.muxer.Listen(0, uint32(len(websocketPrefix)), func(data []byte) bool { websocketLn := svr.muxer.Listen(0, uint32(len(websocketPrefix)), func(data []byte) bool {
return bytes.Equal(data, websocketPrefix) return bytes.Equal(data, websocketPrefix)
}) })
svr.websocketListener = utilnet.NewWebsocketListener(websocketLn) svr.websocketListener = frpNet.NewWebsocketListener(websocketLn)
// Create http vhost muxer. // Create http vhost muxer.
if cfg.VhostHTTPPort > 0 { if cfg.VhostHTTPPort > 0 {
@ -294,16 +286,21 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
// frp tls listener // frp tls listener
svr.tlsListener = svr.muxer.Listen(2, 1, func(data []byte) bool { svr.tlsListener = svr.muxer.Listen(2, 1, func(data []byte) bool {
// tls first byte can be 0x16 only when vhost https port is not same with bind port // tls first byte can be 0x16 only when vhost https port is not same with bind port
return int(data[0]) == utilnet.FRPTLSHeadByte || int(data[0]) == 0x16 return int(data[0]) == frpNet.FRPTLSHeadByte || int(data[0]) == 0x16
}) })
// Create nat hole controller. // Create nat hole controller.
nc, err := nathole.NewController(time.Duration(cfg.NatHoleAnalysisDataReserveHours) * time.Hour) if cfg.BindUDPPort > 0 {
if err != nil { var nc *nathole.Controller
err = fmt.Errorf("create nat hole controller error, %v", err) address := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.BindUDPPort))
return nc, err = nathole.NewController(address, []byte(cfg.Token))
if err != nil {
err = fmt.Errorf("create nat hole controller error, %v", err)
return
}
svr.rc.NatHoleController = nc
log.Info("nat hole udp service listen on %s", address)
} }
svr.rc.NatHoleController = nc
var statsEnable bool var statsEnable bool
// Create dashboard web server. // Create dashboard web server.
@ -330,43 +327,22 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
} }
func (svr *Service) Run() { func (svr *Service) Run() {
if svr.rc.NatHoleController != nil {
go svr.rc.NatHoleController.Run()
}
if svr.kcpListener != nil { if svr.kcpListener != nil {
go svr.HandleListener(svr.kcpListener) go svr.HandleListener(svr.kcpListener)
} }
if svr.quicListener != nil { if svr.quicListener != nil {
go svr.HandleQUICListener(svr.quicListener) go svr.HandleQUICListener(svr.quicListener)
} }
go svr.HandleListener(svr.websocketListener) go svr.HandleListener(svr.websocketListener)
go svr.HandleListener(svr.tlsListener) go svr.HandleListener(svr.tlsListener)
if svr.rc.NatHoleController != nil {
go svr.rc.NatHoleController.CleanWorker(svr.ctx)
}
svr.HandleListener(svr.listener) svr.HandleListener(svr.listener)
} }
func (svr *Service) Close() error {
if svr.kcpListener != nil {
svr.kcpListener.Close()
}
if svr.quicListener != nil {
svr.quicListener.Close()
}
if svr.websocketListener != nil {
svr.websocketListener.Close()
}
if svr.tlsListener != nil {
svr.tlsListener.Close()
}
if svr.listener != nil {
svr.listener.Close()
}
svr.cancel()
svr.ctlManager.Close()
return nil
}
func (svr *Service) handleConnection(ctx context.Context, conn net.Conn) { func (svr *Service) handleConnection(ctx context.Context, conn net.Conn) {
xl := xlog.FromContextSafe(ctx) xl := xlog.FromContextSafe(ctx)
@ -442,12 +418,12 @@ func (svr *Service) HandleListener(l net.Listener) {
xl := xlog.New() xl := xlog.New()
ctx := context.Background() ctx := context.Background()
c = utilnet.NewContextConn(xlog.NewContext(ctx, xl), c) c = frpNet.NewContextConn(xlog.NewContext(ctx, xl), c)
log.Trace("start check TLS connection...") log.Trace("start check TLS connection...")
originConn := c originConn := c
var isTLS, custom bool var isTLS, custom bool
c, isTLS, custom, err = utilnet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, svr.cfg.TLSOnly, connReadTimeout) c, isTLS, custom, err = frpNet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, svr.cfg.TLSOnly, connReadTimeout)
if err != nil { if err != nil {
log.Warn("CheckAndEnableTLSServerConnWithTimeout error: %v", err) log.Warn("CheckAndEnableTLSServerConnWithTimeout error: %v", err)
originConn.Close() originConn.Close()
@ -501,7 +477,7 @@ func (svr *Service) HandleQUICListener(l quic.Listener) {
_ = frpConn.CloseWithError(0, "") _ = frpConn.CloseWithError(0, "")
return return
} }
go svr.handleConnection(ctx, utilnet.QuicStreamToNetConn(stream, frpConn)) go svr.handleConnection(ctx, frpNet.QuicStreamToNetConn(stream, frpConn))
} }
}(context.Background(), c) }(context.Background(), c)
} }
@ -517,7 +493,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) (err
} }
} }
ctx := utilnet.NewContextFromConn(ctlConn) ctx := frpNet.NewContextFromConn(ctlConn)
xl := xlog.FromContextSafe(ctx) xl := xlog.FromContextSafe(ctx)
xl.AppendPrefix(loginMsg.RunID) xl.AppendPrefix(loginMsg.RunID)
ctx = xlog.NewContext(ctx, xl) ctx = xlog.NewContext(ctx, xl)
@ -555,7 +531,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) (err
// RegisterWorkConn register a new work connection to control and proxies need it. // RegisterWorkConn register a new work connection to control and proxies need it.
func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) error { func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) error {
xl := utilnet.NewLogFromConn(workConn) xl := frpNet.NewLogFromConn(workConn)
ctl, exist := svr.ctlManager.GetByID(newMsg.RunID) ctl, exist := svr.ctlManager.GetByID(newMsg.RunID)
if !exist { if !exist {
xl.Warn("No client control found for run id [%s]", newMsg.RunID) xl.Warn("No client control found for run id [%s]", newMsg.RunID)

View File

@ -20,15 +20,15 @@ import (
"net" "net"
"sync" "sync"
libio "github.com/fatedier/golib/io" frpIo "github.com/fatedier/golib/io"
utilnet "github.com/fatedier/frp/pkg/util/net" frpNet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/util"
) )
// Manager for visitor listeners. // Manager for visitor listeners.
type Manager struct { type Manager struct {
visitorListeners map[string]*utilnet.InternalListener visitorListeners map[string]*frpNet.CustomListener
skMap map[string]string skMap map[string]string
mu sync.RWMutex mu sync.RWMutex
@ -36,12 +36,12 @@ type Manager struct {
func NewManager() *Manager { func NewManager() *Manager {
return &Manager{ return &Manager{
visitorListeners: make(map[string]*utilnet.InternalListener), visitorListeners: make(map[string]*frpNet.CustomListener),
skMap: make(map[string]string), skMap: make(map[string]string),
} }
} }
func (vm *Manager) Listen(name string, sk string) (l *utilnet.InternalListener, err error) { func (vm *Manager) Listen(name string, sk string) (l *frpNet.CustomListener, err error) {
vm.mu.Lock() vm.mu.Lock()
defer vm.mu.Unlock() defer vm.mu.Unlock()
@ -50,7 +50,7 @@ func (vm *Manager) Listen(name string, sk string) (l *utilnet.InternalListener,
return return
} }
l = utilnet.NewInternalListener() l = frpNet.NewCustomListener()
vm.visitorListeners[name] = l vm.visitorListeners[name] = l
vm.skMap[name] = sk vm.skMap[name] = sk
return return
@ -71,15 +71,15 @@ func (vm *Manager) NewConn(name string, conn net.Conn, timestamp int64, signKey
var rwc io.ReadWriteCloser = conn var rwc io.ReadWriteCloser = conn
if useEncryption { if useEncryption {
if rwc, err = libio.WithEncryption(rwc, []byte(sk)); err != nil { if rwc, err = frpIo.WithEncryption(rwc, []byte(sk)); err != nil {
err = fmt.Errorf("create encryption connection failed: %v", err) err = fmt.Errorf("create encryption connection failed: %v", err)
return return
} }
} }
if useCompression { if useCompression {
rwc = libio.WithCompression(rwc) rwc = frpIo.WithCompression(rwc)
} }
err = l.PutConn(utilnet.WrapReadWriteCloserToConn(rwc, conn)) err = l.PutConn(frpNet.WrapReadWriteCloserToConn(rwc, conn))
} else { } else {
err = fmt.Errorf("custom listener for [%s] doesn't exist", name) err = fmt.Errorf("custom listener for [%s] doesn't exist", name)
return return

View File

@ -4,7 +4,6 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/onsi/ginkgo/v2" "github.com/onsi/ginkgo/v2"
@ -276,8 +275,8 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
}) })
}) })
ginkgo.Describe("STCP && SUDP && XTCP", func() { ginkgo.Describe("STCP && SUDP", func() {
types := []string{"stcp", "sudp", "xtcp"} types := []string{"stcp", "sudp"}
for _, t := range types { for _, t := range types {
proxyType := t proxyType := t
ginkgo.It(fmt.Sprintf("Expose echo server with %s", strings.ToUpper(proxyType)), func() { ginkgo.It(fmt.Sprintf("Expose echo server with %s", strings.ToUpper(proxyType)), func() {
@ -294,9 +293,6 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
case "sudp": case "sudp":
localPortName = framework.UDPEchoServerPort localPortName = framework.UDPEchoServerPort
protocol = "udp" protocol = "udp"
case "xtcp":
localPortName = framework.TCPEchoServerPort
protocol = "tcp"
} }
correctSK := "abc" correctSK := "abc"
@ -375,9 +371,6 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() {
for _, test := range tests { for _, test := range tests {
framework.NewRequestExpect(f). framework.NewRequestExpect(f).
RequestModify(func(r *request.Request) {
r.Timeout(10 * time.Second)
}).
Protocol(protocol). Protocol(protocol).
PortName(test.bindPortName). PortName(test.bindPortName).
Explain(test.proxyName). Explain(test.proxyName).

View File

@ -1,52 +0,0 @@
package basic
import (
"fmt"
"time"
"github.com/onsi/ginkgo/v2"
"github.com/fatedier/frp/test/e2e/framework"
"github.com/fatedier/frp/test/e2e/framework/consts"
"github.com/fatedier/frp/test/e2e/pkg/port"
"github.com/fatedier/frp/test/e2e/pkg/request"
)
var _ = ginkgo.Describe("[Feature: XTCP]", func() {
f := framework.NewDefaultFramework()
ginkgo.It("Fallback To STCP", func() {
serverConf := consts.DefaultServerConfig
clientConf := consts.DefaultClientConfig
bindPortName := port.GenName("XTCP")
clientConf += fmt.Sprintf(`
[foo]
type = stcp
local_port = {{ .%s }}
[foo-visitor]
type = stcp
role = visitor
server_name = foo
bind_port = -1
[bar-visitor]
type = xtcp
role = visitor
server_name = bar
bind_port = {{ .%s }}
keep_tunnel_open = true
fallback_to = foo-visitor
fallback_timeout_ms = 200
`, framework.TCPEchoServerPort, bindPortName)
f.RunProcesses([]string{serverConf}, []string{clientConf})
framework.NewRequestExpect(f).
RequestModify(func(r *request.Request) {
r.Timeout(time.Second)
}).
PortName(bindPortName).
Ensure()
})
})

View File

@ -66,8 +66,8 @@ func NewDefaultFramework() *Framework {
options := Options{ options := Options{
TotalParallelNode: suiteConfig.ParallelTotal, TotalParallelNode: suiteConfig.ParallelTotal,
CurrentNodeIndex: suiteConfig.ParallelProcess, CurrentNodeIndex: suiteConfig.ParallelProcess,
FromPortIndex: 10000, FromPortIndex: 20000,
ToPortIndex: 60000, ToPortIndex: 50000,
} }
return NewFramework(options) return NewFramework(options)
} }
@ -118,14 +118,14 @@ func (f *Framework) AfterEach() {
// stop processor // stop processor
for _, p := range f.serverProcesses { for _, p := range f.serverProcesses {
_ = p.Stop() _ = p.Stop()
if TestContext.Debug || ginkgo.CurrentSpecReport().Failed() { if TestContext.Debug {
fmt.Println(p.ErrorOutput()) fmt.Println(p.ErrorOutput())
fmt.Println(p.StdOutput()) fmt.Println(p.StdOutput())
} }
} }
for _, p := range f.clientProcesses { for _, p := range f.clientProcesses {
_ = p.Stop() _ = p.Stop()
if TestContext.Debug || ginkgo.CurrentSpecReport().Failed() { if TestContext.Debug {
fmt.Println(p.ErrorOutput()) fmt.Println(p.ErrorOutput())
fmt.Println(p.StdOutput()) fmt.Println(p.StdOutput())
} }

View File

@ -38,7 +38,7 @@ func (f *Framework) RunProcesses(serverTemplates []string, clientTemplates []str
err = p.Start() err = p.Start()
ExpectNoError(err) ExpectNoError(err)
} }
time.Sleep(1 * time.Second) time.Sleep(2 * time.Second)
currentClientProcesses := make([]*process.Process, 0, len(clientTemplates)) currentClientProcesses := make([]*process.Process, 0, len(clientTemplates))
for i := range clientTemplates { for i := range clientTemplates {
@ -56,7 +56,7 @@ func (f *Framework) RunProcesses(serverTemplates []string, clientTemplates []str
ExpectNoError(err) ExpectNoError(err)
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
} }
time.Sleep(2 * time.Second) time.Sleep(5 * time.Second)
return currentServerProcesses, currentClientProcesses return currentServerProcesses, currentClientProcesses
} }

View File

@ -58,7 +58,7 @@ func (pa *Allocator) GetByName(portName string) int {
return 0 return 0
} }
l, err := net.Listen("tcp", net.JoinHostPort("0.0.0.0", strconv.Itoa(port))) l, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)))
if err != nil { if err != nil {
// Maybe not controlled by us, mark it used. // Maybe not controlled by us, mark it used.
pa.used.Insert(port) pa.used.Insert(port)
@ -66,7 +66,7 @@ func (pa *Allocator) GetByName(portName string) int {
} }
l.Close() l.Close()
udpAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("0.0.0.0", strconv.Itoa(port))) udpAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)))
if err != nil { if err != nil {
continue continue
} }

View File

@ -14,6 +14,9 @@
<el-form-item label="BindPort"> <el-form-item label="BindPort">
<span>{{ data.bind_port }}</span> <span>{{ data.bind_port }}</span>
</el-form-item> </el-form-item>
<el-form-item label="Bind UDP Port" v-if="data.bind_udp_port != 0">
<span>{{ data.bind_udp_port }}</span>
</el-form-item>
<el-form-item label="KCP Bind Port" v-if="data.kcp_bind_port != 0"> <el-form-item label="KCP Bind Port" v-if="data.kcp_bind_port != 0">
<span>{{ data.kcp_bind_port }}</span> <span>{{ data.kcp_bind_port }}</span>
</el-form-item> </el-form-item>
@ -88,6 +91,7 @@ import LongSpan from './LongSpan.vue'
let data = ref({ let data = ref({
version: '', version: '',
bind_port: 0, bind_port: 0,
bind_udp_port: 0,
kcp_bind_port: 0, kcp_bind_port: 0,
quic_bind_port: 0, quic_bind_port: 0,
vhost_http_port: 0, vhost_http_port: 0,
@ -110,6 +114,7 @@ const fetchData = () => {
.then((json) => { .then((json) => {
data.value.version = json.version data.value.version = json.version
data.value.bind_port = json.bind_port data.value.bind_port = json.bind_port
data.value.bind_udp_port = json.bind_udp_port
data.value.kcp_bind_port = json.kcp_bind_port data.value.kcp_bind_port = json.kcp_bind_port
data.value.quic_bind_port = json.quic_bind_port data.value.quic_bind_port = json.quic_bind_port
data.value.vhost_http_port = json.vhost_http_port data.value.vhost_http_port = json.vhost_http_port