diff --git a/client/admin_api.go b/client/admin_api.go index 2a3633ae..5d163de7 100644 --- a/client/admin_api.go +++ b/client/admin_api.go @@ -91,7 +91,7 @@ func NewProxyStatusResp(status *proxy.WorkingStatus, serverAddr string) ProxySta Status: status.Phase, Err: status.Err, } - baseCfg := status.Cfg.GetBaseInfo() + baseCfg := status.Cfg.GetBaseConfig() if baseCfg.LocalPort != 0 { psr.LocalAddr = net.JoinHostPort(baseCfg.LocalIP, strconv.Itoa(baseCfg.LocalPort)) } diff --git a/client/control.go b/client/control.go index 7626bde8..95f84f8c 100644 --- a/client/control.go +++ b/client/control.go @@ -109,7 +109,7 @@ func NewControl( ctl.msgTransporter = transport.NewMessageTransporter(ctl.sendCh) ctl.pm = proxy.NewManager(ctl.ctx, clientCfg, ctl.msgTransporter) - ctl.vm = visitor.NewManager(ctl.ctx, ctl.clientCfg, ctl.connectServer, ctl.msgTransporter) + ctl.vm = visitor.NewManager(ctl.ctx, ctl.runID, ctl.clientCfg, ctl.connectServer, ctl.msgTransporter) ctl.vm.Reload(visitorCfgs) return ctl } diff --git a/client/proxy/proxy.go b/client/proxy/proxy.go index e961d60f..b336e13a 100644 --- a/client/proxy/proxy.go +++ b/client/proxy/proxy.go @@ -54,8 +54,8 @@ func NewProxy( msgTransporter transport.MessageTransporter, ) (pxy Proxy) { var limiter *rate.Limiter - limitBytes := pxyConf.GetBaseInfo().BandwidthLimit.Bytes() - if limitBytes > 0 && pxyConf.GetBaseInfo().BandwidthLimitMode == config.BandwidthLimitModeClient { + limitBytes := pxyConf.GetBaseConfig().BandwidthLimit.Bytes() + if limitBytes > 0 && pxyConf.GetBaseConfig().BandwidthLimitMode == config.BandwidthLimitModeClient { limiter = rate.NewLimiter(rate.Limit(float64(limitBytes)), int(limitBytes)) } @@ -148,7 +148,7 @@ func (pxy *TCPProxy) Close() { } func (pxy *TCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { - HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter, + HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, conn, []byte(pxy.clientCfg.Token), m) } @@ -177,7 +177,7 @@ func (pxy *TCPMuxProxy) Close() { } func (pxy *TCPMuxProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { - HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter, + HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, conn, []byte(pxy.clientCfg.Token), m) } @@ -206,7 +206,7 @@ func (pxy *HTTPProxy) Close() { } func (pxy *HTTPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { - HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter, + HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, conn, []byte(pxy.clientCfg.Token), m) } @@ -235,7 +235,7 @@ func (pxy *HTTPSProxy) Close() { } func (pxy *HTTPSProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { - HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter, + HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, conn, []byte(pxy.clientCfg.Token), m) } @@ -264,7 +264,7 @@ func (pxy *STCPProxy) Close() { } func (pxy *STCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { - HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter, + HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, conn, []byte(pxy.clientCfg.Token), m) } diff --git a/client/proxy/proxy_manager.go b/client/proxy/proxy_manager.go index f5d7502c..9e551ced 100644 --- a/client/proxy/proxy_manager.go +++ b/client/proxy/proxy_manager.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "net" + "reflect" "sync" "github.com/fatedier/frp/client/event" @@ -121,21 +122,18 @@ func (pm *Manager) Reload(pxyCfgs map[string]config.ProxyConf) { for name, pxy := range pm.proxies { del := false cfg, ok := pxyCfgs[name] - if !ok { - del = true - } else if !pxy.Cfg.Compare(cfg) { + if !ok || !reflect.DeepEqual(pxy.Cfg, cfg) { del = true } if del { delPxyNames = append(delPxyNames, name) delete(pm.proxies, name) - pxy.Stop() } } if len(delPxyNames) > 0 { - xl.Info("proxy removed: %v", delPxyNames) + xl.Info("proxy removed: %s", delPxyNames) } addPxyNames := make([]string, 0) @@ -149,6 +147,6 @@ func (pm *Manager) Reload(pxyCfgs map[string]config.ProxyConf) { } } if len(addPxyNames) > 0 { - xl.Info("proxy added: %v", addPxyNames) + xl.Info("proxy added: %s", addPxyNames) } } diff --git a/client/proxy/proxy_wrapper.go b/client/proxy/proxy_wrapper.go index f2caa618..91309427 100644 --- a/client/proxy/proxy_wrapper.go +++ b/client/proxy/proxy_wrapper.go @@ -91,7 +91,7 @@ func NewWrapper( eventHandler event.Handler, msgTransporter transport.MessageTransporter, ) *Wrapper { - baseInfo := cfg.GetBaseInfo() + baseInfo := cfg.GetBaseConfig() xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(baseInfo.ProxyName) pw := &Wrapper{ WorkingStatus: WorkingStatus{ diff --git a/client/proxy/xtcp.go b/client/proxy/xtcp.go index 64ce5074..a25dc185 100644 --- a/client/proxy/xtcp.go +++ b/client/proxy/xtcp.go @@ -155,7 +155,7 @@ func (pxy *XTCPProxy) listenByKCP(listenConn *net.UDPConn, raddr *net.UDPAddr, s xl.Error("accept connection error: %v", err) return } - go HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter, + go HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, muxConn, []byte(pxy.cfg.Sk), startWorkConnMsg) } } @@ -194,7 +194,7 @@ func (pxy *XTCPProxy) listenByQUIC(listenConn *net.UDPConn, _ *net.UDPAddr, star _ = c.CloseWithError(0, "") return } - go HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseInfo(), pxy.limiter, + go HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, utilnet.QuicStreamToNetConn(stream, c), []byte(pxy.cfg.Sk), startWorkConnMsg) } } diff --git a/client/visitor/stcp.go b/client/visitor/stcp.go index 7086c61f..2ea27f38 100644 --- a/client/visitor/stcp.go +++ b/client/visitor/stcp.go @@ -80,7 +80,7 @@ func (sv *STCPVisitor) handleConn(userConn net.Conn) { defer userConn.Close() xl.Debug("get a new stcp user connection") - visitorConn, err := sv.connectServer() + visitorConn, err := sv.helper.ConnectServer() if err != nil { return } @@ -88,6 +88,7 @@ func (sv *STCPVisitor) handleConn(userConn net.Conn) { now := time.Now().Unix() newVisitorConnMsg := &msg.NewVisitorConn{ + RunID: sv.helper.RunID(), ProxyName: sv.cfg.ServerName, SignKey: util.GetAuthKey(sv.cfg.Sk, now), Timestamp: now, diff --git a/client/visitor/sudp.go b/client/visitor/sudp.go index 1e052c34..5b2d5177 100644 --- a/client/visitor/sudp.go +++ b/client/visitor/sudp.go @@ -199,13 +199,14 @@ func (sv *SUDPVisitor) worker(workConn net.Conn, firstPacket *msg.UDPPacket) { func (sv *SUDPVisitor) getNewVisitorConn() (net.Conn, error) { xl := xlog.FromContextSafe(sv.ctx) - visitorConn, err := sv.connectServer() + visitorConn, err := sv.helper.ConnectServer() if err != nil { return nil, fmt.Errorf("frpc connect frps error: %v", err) } now := time.Now().Unix() newVisitorConnMsg := &msg.NewVisitorConn{ + RunID: sv.helper.RunID(), ProxyName: sv.cfg.ServerName, SignKey: util.GetAuthKey(sv.cfg.Sk, now), Timestamp: now, diff --git a/client/visitor/visitor.go b/client/visitor/visitor.go index 10c0ab13..7020df63 100644 --- a/client/visitor/visitor.go +++ b/client/visitor/visitor.go @@ -25,6 +25,19 @@ import ( "github.com/fatedier/frp/pkg/util/xlog" ) +// Helper wrapps some functions for visitor to use. +type Helper interface { + // ConnectServer directly connects to the frp server. + ConnectServer() (net.Conn, error) + // TransferConn transfers the connection to another visitor. + TransferConn(string, net.Conn) error + // MsgTransporter returns the message transporter that is used to send and receive messages + // to the frp server through the controller. + MsgTransporter() transport.MessageTransporter + // RunID returns the run id of current controller. + RunID() string +} + // Visitor is used for forward traffics from local port tot remote service. type Visitor interface { Run() error @@ -36,18 +49,14 @@ func NewVisitor( ctx context.Context, cfg config.VisitorConf, clientCfg config.ClientCommonConf, - connectServer func() (net.Conn, error), - transferConn func(string, net.Conn) error, - msgTransporter transport.MessageTransporter, + helper Helper, ) (visitor Visitor) { - xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(cfg.GetBaseInfo().ProxyName) + xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(cfg.GetBaseConfig().ProxyName) baseVisitor := BaseVisitor{ - clientCfg: clientCfg, - connectServer: connectServer, - transferConn: transferConn, - msgTransporter: msgTransporter, - ctx: xlog.NewContext(ctx, xl), - internalLn: utilnet.NewInternalListener(), + clientCfg: clientCfg, + helper: helper, + ctx: xlog.NewContext(ctx, xl), + internalLn: utilnet.NewInternalListener(), } switch cfg := cfg.(type) { case *config.STCPVisitorConf: @@ -72,12 +81,10 @@ func NewVisitor( } type BaseVisitor struct { - clientCfg config.ClientCommonConf - connectServer func() (net.Conn, error) - transferConn func(string, net.Conn) error - msgTransporter transport.MessageTransporter - l net.Listener - internalLn *utilnet.InternalListener + clientCfg config.ClientCommonConf + helper Helper + l net.Listener + internalLn *utilnet.InternalListener mu sync.RWMutex ctx context.Context diff --git a/client/visitor/visitor_manager.go b/client/visitor/visitor_manager.go index 799ee3a5..344101ae 100644 --- a/client/visitor/visitor_manager.go +++ b/client/visitor/visitor_manager.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "net" + "reflect" "sync" "time" @@ -27,11 +28,10 @@ import ( ) type Manager struct { - clientCfg config.ClientCommonConf - connectServer func() (net.Conn, error) - msgTransporter transport.MessageTransporter - cfgs map[string]config.VisitorConf - visitors map[string]Visitor + clientCfg config.ClientCommonConf + cfgs map[string]config.VisitorConf + visitors map[string]Visitor + helper Helper checkInterval time.Duration @@ -43,20 +43,26 @@ type Manager struct { func NewManager( ctx context.Context, + runID string, clientCfg config.ClientCommonConf, connectServer func() (net.Conn, error), msgTransporter transport.MessageTransporter, ) *Manager { - return &Manager{ - clientCfg: clientCfg, - connectServer: connectServer, - msgTransporter: msgTransporter, - cfgs: make(map[string]config.VisitorConf), - visitors: make(map[string]Visitor), - checkInterval: 10 * time.Second, - ctx: ctx, - stopCh: make(chan struct{}), + m := &Manager{ + clientCfg: clientCfg, + cfgs: make(map[string]config.VisitorConf), + visitors: make(map[string]Visitor), + checkInterval: 10 * time.Second, + ctx: ctx, + stopCh: make(chan struct{}), } + m.helper = &visitorHelperImpl{ + connectServerFn: connectServer, + msgTransporter: msgTransporter, + transferConnFn: m.TransferConn, + runID: runID, + } + return m } func (vm *Manager) Run() { @@ -73,7 +79,7 @@ func (vm *Manager) Run() { case <-ticker.C: vm.mu.Lock() for _, cfg := range vm.cfgs { - name := cfg.GetBaseInfo().ProxyName + name := cfg.GetBaseConfig().ProxyName if _, exist := vm.visitors[name]; !exist { xl.Info("try to start visitor [%s]", name) _ = vm.startVisitor(cfg) @@ -100,8 +106,8 @@ func (vm *Manager) Close() { // Hold lock before calling this function. func (vm *Manager) startVisitor(cfg config.VisitorConf) (err error) { xl := xlog.FromContextSafe(vm.ctx) - name := cfg.GetBaseInfo().ProxyName - visitor := NewVisitor(vm.ctx, cfg, vm.clientCfg, vm.connectServer, vm.TransferConn, vm.msgTransporter) + name := cfg.GetBaseConfig().ProxyName + visitor := NewVisitor(vm.ctx, cfg, vm.clientCfg, vm.helper) err = visitor.Run() if err != nil { xl.Warn("start error: %v", err) @@ -121,9 +127,7 @@ func (vm *Manager) Reload(cfgs map[string]config.VisitorConf) { for name, oldCfg := range vm.cfgs { del := false cfg, ok := cfgs[name] - if !ok { - del = true - } else if !oldCfg.Compare(cfg) { + if !ok || !reflect.DeepEqual(oldCfg, cfg) { del = true } @@ -163,3 +167,26 @@ func (vm *Manager) TransferConn(name string, conn net.Conn) error { } return v.AcceptConn(conn) } + +type visitorHelperImpl struct { + connectServerFn func() (net.Conn, error) + msgTransporter transport.MessageTransporter + transferConnFn func(name string, conn net.Conn) error + runID string +} + +func (v *visitorHelperImpl) ConnectServer() (net.Conn, error) { + return v.connectServerFn() +} + +func (v *visitorHelperImpl) TransferConn(name string, conn net.Conn) error { + return v.transferConnFn(name, conn) +} + +func (v *visitorHelperImpl) MsgTransporter() transport.MessageTransporter { + return v.msgTransporter +} + +func (v *visitorHelperImpl) RunID() string { + return v.runID +} diff --git a/client/visitor/xtcp.go b/client/visitor/xtcp.go index 5c4e2cf6..73f30789 100644 --- a/client/visitor/xtcp.go +++ b/client/visitor/xtcp.go @@ -183,7 +183,7 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) { } xl.Debug("try to transfer connection to visitor: %s", sv.cfg.FallbackTo) - if err := sv.transferConn(sv.cfg.FallbackTo, userConn); err != nil { + if err := sv.helper.TransferConn(sv.cfg.FallbackTo, userConn); err != nil { xl.Error("transfer connection to visitor %s error: %v", sv.cfg.FallbackTo, err) return } @@ -266,7 +266,7 @@ func (sv *XTCPVisitor) getTunnelConn() (net.Conn, error) { // 4. Create a tunnel session using an underlying UDP connection. func (sv *XTCPVisitor) makeNatHole() { xl := xlog.FromContextSafe(sv.ctx) - if err := nathole.PreCheck(sv.ctx, sv.msgTransporter, sv.cfg.ServerName, 5*time.Second); err != nil { + if err := nathole.PreCheck(sv.ctx, sv.helper.MsgTransporter(), sv.cfg.ServerName, 5*time.Second); err != nil { xl.Warn("nathole precheck error: %v", err) return } @@ -294,7 +294,7 @@ func (sv *XTCPVisitor) makeNatHole() { AssistedAddrs: prepareResult.AssistedAddrs, } - natHoleRespMsg, err := nathole.ExchangeInfo(sv.ctx, sv.msgTransporter, transactionID, natHoleVisitorMsg, 5*time.Second) + natHoleRespMsg, err := nathole.ExchangeInfo(sv.ctx, sv.helper.MsgTransporter(), transactionID, natHoleVisitorMsg, 5*time.Second) if err != nil { listenConn.Close() xl.Warn("nathole exchange info error: %v", err) diff --git a/cmd/frpc/sub/http.go b/cmd/frpc/sub/http.go index 22eeefe6..1d152585 100644 --- a/cmd/frpc/sub/http.go +++ b/cmd/frpc/sub/http.go @@ -79,7 +79,7 @@ var httpCmd = &cobra.Command{ } cfg.BandwidthLimitMode = bandwidthLimitMode - err = cfg.CheckForCli() + err = cfg.ValidateForClient() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/https.go b/cmd/frpc/sub/https.go index 187aa99b..2eb6ed6b 100644 --- a/cmd/frpc/sub/https.go +++ b/cmd/frpc/sub/https.go @@ -71,7 +71,7 @@ var httpsCmd = &cobra.Command{ } cfg.BandwidthLimitMode = bandwidthLimitMode - err = cfg.CheckForCli() + err = cfg.ValidateForClient() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/stcp.go b/cmd/frpc/sub/stcp.go index ad0a57ce..24aa955a 100644 --- a/cmd/frpc/sub/stcp.go +++ b/cmd/frpc/sub/stcp.go @@ -78,7 +78,7 @@ var stcpCmd = &cobra.Command{ os.Exit(1) } cfg.BandwidthLimitMode = bandwidthLimitMode - err = cfg.CheckForCli() + err = cfg.ValidateForClient() if err != nil { fmt.Println(err) os.Exit(1) @@ -95,7 +95,7 @@ var stcpCmd = &cobra.Command{ cfg.ServerName = serverName cfg.BindAddr = bindAddr cfg.BindPort = bindPort - err = cfg.Check() + err = cfg.Validate() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/sudp.go b/cmd/frpc/sub/sudp.go index 0ae8498b..553e4252 100644 --- a/cmd/frpc/sub/sudp.go +++ b/cmd/frpc/sub/sudp.go @@ -78,7 +78,7 @@ var sudpCmd = &cobra.Command{ os.Exit(1) } cfg.BandwidthLimitMode = bandwidthLimitMode - err = cfg.CheckForCli() + err = cfg.ValidateForClient() if err != nil { fmt.Println(err) os.Exit(1) @@ -95,7 +95,7 @@ var sudpCmd = &cobra.Command{ cfg.ServerName = serverName cfg.BindAddr = bindAddr cfg.BindPort = bindPort - err = cfg.Check() + err = cfg.Validate() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/tcp.go b/cmd/frpc/sub/tcp.go index 2c597f19..2da9ad61 100644 --- a/cmd/frpc/sub/tcp.go +++ b/cmd/frpc/sub/tcp.go @@ -68,7 +68,7 @@ var tcpCmd = &cobra.Command{ } cfg.BandwidthLimitMode = bandwidthLimitMode - err = cfg.CheckForCli() + err = cfg.ValidateForClient() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/tcpmux.go b/cmd/frpc/sub/tcpmux.go index ecdd6002..4b993f9c 100644 --- a/cmd/frpc/sub/tcpmux.go +++ b/cmd/frpc/sub/tcpmux.go @@ -73,7 +73,7 @@ var tcpMuxCmd = &cobra.Command{ } cfg.BandwidthLimitMode = bandwidthLimitMode - err = cfg.CheckForCli() + err = cfg.ValidateForClient() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/udp.go b/cmd/frpc/sub/udp.go index f9dfa3f6..9a4803dc 100644 --- a/cmd/frpc/sub/udp.go +++ b/cmd/frpc/sub/udp.go @@ -68,7 +68,7 @@ var udpCmd = &cobra.Command{ } cfg.BandwidthLimitMode = bandwidthLimitMode - err = cfg.CheckForCli() + err = cfg.ValidateForClient() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/xtcp.go b/cmd/frpc/sub/xtcp.go index ea201d53..60483afa 100644 --- a/cmd/frpc/sub/xtcp.go +++ b/cmd/frpc/sub/xtcp.go @@ -78,7 +78,7 @@ var xtcpCmd = &cobra.Command{ os.Exit(1) } cfg.BandwidthLimitMode = bandwidthLimitMode - err = cfg.CheckForCli() + err = cfg.ValidateForClient() if err != nil { fmt.Println(err) os.Exit(1) @@ -95,7 +95,7 @@ var xtcpCmd = &cobra.Command{ cfg.ServerName = serverName cfg.BindAddr = bindAddr cfg.BindPort = bindPort - err = cfg.Check() + err = cfg.Validate() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/conf/frpc_full.ini b/conf/frpc_full.ini index 1c609f6c..4186b538 100644 --- a/conf/frpc_full.ini +++ b/conf/frpc_full.ini @@ -326,6 +326,9 @@ local_ip = 127.0.0.1 local_port = 22 use_encryption = false use_compression = false +# If not empty, only visitors from specified users can connect. +# Otherwise, visitors from same user can connect. '*' means allow all users. +allow_users = * # user of frpc should be same in both stcp server and stcp visitor [secret_tcp_visitor] @@ -350,10 +353,15 @@ local_ip = 127.0.0.1 local_port = 22 use_encryption = false use_compression = false +# If not empty, only visitors from specified users can connect. +# Otherwise, visitors from same user can connect. '*' means allow all users. +allow_users = user1, user2 [p2p_tcp_visitor] role = visitor type = xtcp +# if the server user is not set, it defaults to the current user +server_user = user1 server_name = p2p_tcp sk = abcdefg bind_addr = 127.0.0.1 diff --git a/pkg/config/client_test.go b/pkg/config/client_test.go index e6ce3a10..dc43fd97 100644 --- a/pkg/config/client_test.go +++ b/pkg/config/client_test.go @@ -500,8 +500,10 @@ func Test_LoadClientBasicConf(t *testing.T) { }, BandwidthLimitMode: BandwidthLimitModeClient, }, - Role: "server", - Sk: "abcdefg", + RoleServerCommonConf: RoleServerCommonConf{ + Role: "server", + Sk: "abcdefg", + }, }, testUser + ".p2p_tcp": &XTCPProxyConf{ BaseProxyConf: BaseProxyConf{ @@ -513,8 +515,10 @@ func Test_LoadClientBasicConf(t *testing.T) { }, BandwidthLimitMode: BandwidthLimitModeClient, }, - Role: "server", - Sk: "abcdefg", + RoleServerCommonConf: RoleServerCommonConf{ + Role: "server", + Sk: "abcdefg", + }, }, testUser + ".tcpmuxhttpconnect": &TCPMuxProxyConf{ BaseProxyConf: BaseProxyConf{ diff --git a/pkg/config/proxy.go b/pkg/config/proxy.go index bca015eb..9cc89492 100644 --- a/pkg/config/proxy.go +++ b/pkg/config/proxy.go @@ -51,13 +51,23 @@ func NewConfByType(proxyType string) ProxyConf { } type ProxyConf interface { - GetBaseInfo() *BaseProxyConf + // GetBaseConfig returns the BaseProxyConf for this config. + GetBaseConfig() *BaseProxyConf + // SetDefaultValues sets the default values for this config. + SetDefaultValues() + // UnmarshalFromMsg unmarshals a msg.NewProxy message into this config. + // This function will be called on the frps side. UnmarshalFromMsg(*msg.NewProxy) + // UnmarshalFromIni unmarshals a ini.Section into this config. This function + // will be called on the frpc side. UnmarshalFromIni(string, string, *ini.Section) error + // MarshalToMsg marshals this config into a msg.NewProxy message. This + // function will be called on the frpc side. MarshalToMsg(*msg.NewProxy) - CheckForCli() error - CheckForSvr(ServerCommonConf) error - Compare(ProxyConf) bool + // ValidateForClient checks that the config is valid for the frpc side. + ValidateForClient() error + // ValidateForServer checks that the config is valid for the frps side. + ValidateForServer(ServerCommonConf) error } // LocalSvrConf configures what location the client will to, or what @@ -158,6 +168,16 @@ type DomainConf struct { SubDomain string `ini:"subdomain" json:"subdomain"` } +type RoleServerCommonConf struct { + Role string `ini:"role" json:"role"` + Sk string `ini:"sk" json:"sk"` + AllowUsers []string `ini:"allow_users" json:"allow_users"` +} + +func (cfg *RoleServerCommonConf) setDefaultValues() { + cfg.Role = "server" +} + // HTTP type HTTPProxyConf struct { BaseProxyConf `ini:",extends"` @@ -203,73 +223,30 @@ type TCPMuxProxyConf struct { // STCP type STCPProxyConf struct { - BaseProxyConf `ini:",extends"` - - Role string `ini:"role" json:"role"` - Sk string `ini:"sk" json:"sk"` + BaseProxyConf `ini:",extends"` + RoleServerCommonConf `ini:",extends"` } // XTCP type XTCPProxyConf struct { - BaseProxyConf `ini:",extends"` - - Role string `ini:"role" json:"role"` - Sk string `ini:"sk" json:"sk"` + BaseProxyConf `ini:",extends"` + RoleServerCommonConf `ini:",extends"` } // SUDP type SUDPProxyConf struct { - BaseProxyConf `ini:",extends"` - - Role string `ini:"role" json:"role"` - Sk string `ini:"sk" json:"sk"` + BaseProxyConf `ini:",extends"` + RoleServerCommonConf `ini:",extends"` } // Proxy Conf Loader // DefaultProxyConf creates a empty ProxyConf object by proxyType. // If proxyType doesn't exist, return nil. func DefaultProxyConf(proxyType string) ProxyConf { - var conf ProxyConf - switch proxyType { - case consts.TCPProxy: - conf = &TCPProxyConf{ - BaseProxyConf: defaultBaseProxyConf(proxyType), - } - case consts.TCPMuxProxy: - conf = &TCPMuxProxyConf{ - BaseProxyConf: defaultBaseProxyConf(proxyType), - } - case consts.UDPProxy: - conf = &UDPProxyConf{ - BaseProxyConf: defaultBaseProxyConf(proxyType), - } - case consts.HTTPProxy: - conf = &HTTPProxyConf{ - BaseProxyConf: defaultBaseProxyConf(proxyType), - } - case consts.HTTPSProxy: - conf = &HTTPSProxyConf{ - BaseProxyConf: defaultBaseProxyConf(proxyType), - } - case consts.STCPProxy: - conf = &STCPProxyConf{ - BaseProxyConf: defaultBaseProxyConf(proxyType), - Role: "server", - } - case consts.XTCPProxy: - conf = &XTCPProxyConf{ - BaseProxyConf: defaultBaseProxyConf(proxyType), - Role: "server", - } - case consts.SUDPProxy: - conf = &SUDPProxyConf{ - BaseProxyConf: defaultBaseProxyConf(proxyType), - Role: "server", - } - default: - return nil + conf := NewConfByType(proxyType) + if conf != nil { + conf.SetDefaultValues() } - return conf } @@ -290,10 +267,9 @@ func NewProxyConfFromIni(prefix, name string, section *ini.Section) (ProxyConf, return nil, err } - if err := conf.CheckForCli(); err != nil { + if err := conf.ValidateForClient(); err != nil { return nil, err } - return conf, nil } @@ -310,7 +286,7 @@ func NewProxyConfFromMsg(pMsg *msg.NewProxy, serverCfg ServerCommonConf) (ProxyC conf.UnmarshalFromMsg(pMsg) - err := conf.CheckForSvr(serverCfg) + err := conf.ValidateForServer(serverCfg) if err != nil { return nil, err } @@ -319,42 +295,15 @@ func NewProxyConfFromMsg(pMsg *msg.NewProxy, serverCfg ServerCommonConf) (ProxyC } // Base -func defaultBaseProxyConf(proxyType string) BaseProxyConf { - return BaseProxyConf{ - ProxyType: proxyType, - LocalSvrConf: LocalSvrConf{ - LocalIP: "127.0.0.1", - }, - BandwidthLimitMode: BandwidthLimitModeClient, - } -} - -func (cfg *BaseProxyConf) GetBaseInfo() *BaseProxyConf { +func (cfg *BaseProxyConf) GetBaseConfig() *BaseProxyConf { return cfg } -func (cfg *BaseProxyConf) compare(cmp *BaseProxyConf) bool { - if cfg.ProxyName != cmp.ProxyName || - cfg.ProxyType != cmp.ProxyType || - cfg.UseEncryption != cmp.UseEncryption || - cfg.UseCompression != cmp.UseCompression || - cfg.Group != cmp.Group || - cfg.GroupKey != cmp.GroupKey || - cfg.ProxyProtocolVersion != cmp.ProxyProtocolVersion || - !cfg.BandwidthLimit.Equal(&cmp.BandwidthLimit) || - cfg.BandwidthLimitMode != cmp.BandwidthLimitMode || - !reflect.DeepEqual(cfg.Metas, cmp.Metas) { - return false +func (cfg *BaseProxyConf) SetDefaultValues() { + cfg.LocalSvrConf = LocalSvrConf{ + LocalIP: "127.0.0.1", } - - if !reflect.DeepEqual(cfg.LocalSvrConf, cmp.LocalSvrConf) { - return false - } - if !reflect.DeepEqual(cfg.HealthCheckConf, cmp.HealthCheckConf) { - return false - } - - return true + cfg.BandwidthLimitMode = BandwidthLimitModeClient } // BaseProxyConf apply custom logic changes. @@ -423,7 +372,7 @@ func (cfg *BaseProxyConf) unmarshalFromMsg(pMsg *msg.NewProxy) { cfg.Metas = pMsg.Metas } -func (cfg *BaseProxyConf) checkForCli() (err error) { +func (cfg *BaseProxyConf) validateForClient() (err error) { if cfg.ProxyProtocolVersion != "" { if cfg.ProxyProtocolVersion != "v1" && cfg.ProxyProtocolVersion != "v2" { return fmt.Errorf("no support proxy protocol version: %s", cfg.ProxyProtocolVersion) @@ -434,16 +383,16 @@ func (cfg *BaseProxyConf) checkForCli() (err error) { return fmt.Errorf("bandwidth_limit_mode should be client or server") } - if err = cfg.LocalSvrConf.checkForCli(); err != nil { + if err = cfg.LocalSvrConf.validateForClient(); err != nil { return } - if err = cfg.HealthCheckConf.checkForCli(); err != nil { + if err = cfg.HealthCheckConf.validateForClient(); err != nil { return } return nil } -func (cfg *BaseProxyConf) checkForSvr() (err error) { +func (cfg *BaseProxyConf) validateForServer() (err error) { if cfg.BandwidthLimitMode != "client" && cfg.BandwidthLimitMode != "server" { return fmt.Errorf("bandwidth_limit_mode should be client or server") } @@ -459,14 +408,14 @@ func (cfg *DomainConf) check() (err error) { return } -func (cfg *DomainConf) checkForCli() (err error) { +func (cfg *DomainConf) validateForClient() (err error) { if err = cfg.check(); err != nil { return } return } -func (cfg *DomainConf) checkForSvr(serverCfg ServerCommonConf) (err error) { +func (cfg *DomainConf) validateForServer(serverCfg ServerCommonConf) (err error) { if err = cfg.check(); err != nil { return } @@ -491,7 +440,7 @@ func (cfg *DomainConf) checkForSvr(serverCfg ServerCommonConf) (err error) { } // LocalSvrConf -func (cfg *LocalSvrConf) checkForCli() (err error) { +func (cfg *LocalSvrConf) validateForClient() (err error) { if cfg.Plugin == "" { if cfg.LocalIP == "" { err = fmt.Errorf("local ip or plugin is required") @@ -506,7 +455,7 @@ func (cfg *LocalSvrConf) checkForCli() (err error) { } // HealthCheckConf -func (cfg *HealthCheckConf) checkForCli() error { +func (cfg *HealthCheckConf) validateForClient() error { if cfg.HealthCheckType != "" && cfg.HealthCheckType != "tcp" && cfg.HealthCheckType != "http" { return fmt.Errorf("unsupport health check type") } @@ -524,7 +473,7 @@ func preUnmarshalFromIni(cfg ProxyConf, prefix string, name string, section *ini return err } - err = cfg.GetBaseInfo().decorate(prefix, name, section) + err = cfg.GetBaseConfig().decorate(prefix, name, section) if err != nil { return err } @@ -533,24 +482,6 @@ func preUnmarshalFromIni(cfg ProxyConf, prefix string, name string, section *ini } // TCP -func (cfg *TCPProxyConf) Compare(cmp ProxyConf) bool { - cmpConf, ok := cmp.(*TCPProxyConf) - if !ok { - return false - } - - if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) { - return false - } - - // Add custom logic equal if exists. - if cfg.RemotePort != cmpConf.RemotePort { - return false - } - - return true -} - func (cfg *TCPProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { cfg.BaseProxyConf.unmarshalFromMsg(pMsg) @@ -576,8 +507,8 @@ func (cfg *TCPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { pMsg.RemotePort = cfg.RemotePort } -func (cfg *TCPProxyConf) CheckForCli() (err error) { - if err = cfg.BaseProxyConf.checkForCli(); err != nil { +func (cfg *TCPProxyConf) ValidateForClient() (err error) { + if err = cfg.BaseProxyConf.validateForClient(); err != nil { return } @@ -586,39 +517,14 @@ func (cfg *TCPProxyConf) CheckForCli() (err error) { return } -func (cfg *TCPProxyConf) CheckForSvr(serverCfg ServerCommonConf) error { - if err := cfg.BaseProxyConf.checkForSvr(); err != nil { +func (cfg *TCPProxyConf) ValidateForServer(serverCfg ServerCommonConf) error { + if err := cfg.BaseProxyConf.validateForServer(); err != nil { return err } return nil } // TCPMux -func (cfg *TCPMuxProxyConf) Compare(cmp ProxyConf) bool { - cmpConf, ok := cmp.(*TCPMuxProxyConf) - if !ok { - return false - } - - if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) { - return false - } - - // Add custom logic equal if exists. - if !reflect.DeepEqual(cfg.DomainConf, cmpConf.DomainConf) { - return false - } - - if cfg.Multiplexer != cmpConf.Multiplexer || - cfg.HTTPUser != cmpConf.HTTPUser || - cfg.HTTPPwd != cmpConf.HTTPPwd || - cfg.RouteByHTTPUser != cmpConf.RouteByHTTPUser { - return false - } - - return true -} - func (cfg *TCPMuxProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error { err := preUnmarshalFromIni(cfg, prefix, name, section) if err != nil { @@ -654,13 +560,13 @@ func (cfg *TCPMuxProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { pMsg.RouteByHTTPUser = cfg.RouteByHTTPUser } -func (cfg *TCPMuxProxyConf) CheckForCli() (err error) { - if err = cfg.BaseProxyConf.checkForCli(); err != nil { +func (cfg *TCPMuxProxyConf) ValidateForClient() (err error) { + if err = cfg.BaseProxyConf.validateForClient(); err != nil { return } // Add custom logic check if exists - if err = cfg.DomainConf.checkForCli(); err != nil { + if err = cfg.DomainConf.validateForClient(); err != nil { return } @@ -671,8 +577,8 @@ func (cfg *TCPMuxProxyConf) CheckForCli() (err error) { return } -func (cfg *TCPMuxProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) { - if err := cfg.BaseProxyConf.checkForSvr(); err != nil { +func (cfg *TCPMuxProxyConf) ValidateForServer(serverCfg ServerCommonConf) (err error) { + if err := cfg.BaseProxyConf.validateForServer(); err != nil { return err } @@ -684,7 +590,7 @@ func (cfg *TCPMuxProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) return fmt.Errorf("proxy [%s] type [tcpmux] with multiplexer [httpconnect] requires tcpmux_httpconnect_port configuration", cfg.ProxyName) } - if err = cfg.DomainConf.checkForSvr(serverCfg); err != nil { + if err = cfg.DomainConf.validateForServer(serverCfg); err != nil { err = fmt.Errorf("proxy [%s] domain conf check error: %v", cfg.ProxyName, err) return } @@ -693,24 +599,6 @@ func (cfg *TCPMuxProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) } // UDP -func (cfg *UDPProxyConf) Compare(cmp ProxyConf) bool { - cmpConf, ok := cmp.(*UDPProxyConf) - if !ok { - return false - } - - if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) { - return false - } - - // Add custom logic equal if exists. - if cfg.RemotePort != cmpConf.RemotePort { - return false - } - - return true -} - func (cfg *UDPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error { err := preUnmarshalFromIni(cfg, prefix, name, section) if err != nil { @@ -736,8 +624,8 @@ func (cfg *UDPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { pMsg.RemotePort = cfg.RemotePort } -func (cfg *UDPProxyConf) CheckForCli() (err error) { - if err = cfg.BaseProxyConf.checkForCli(); err != nil { +func (cfg *UDPProxyConf) ValidateForClient() (err error) { + if err = cfg.BaseProxyConf.validateForClient(); err != nil { return } @@ -746,41 +634,14 @@ func (cfg *UDPProxyConf) CheckForCli() (err error) { return } -func (cfg *UDPProxyConf) CheckForSvr(serverCfg ServerCommonConf) error { - if err := cfg.BaseProxyConf.checkForSvr(); err != nil { +func (cfg *UDPProxyConf) ValidateForServer(serverCfg ServerCommonConf) error { + if err := cfg.BaseProxyConf.validateForServer(); err != nil { return err } return nil } // HTTP -func (cfg *HTTPProxyConf) Compare(cmp ProxyConf) bool { - cmpConf, ok := cmp.(*HTTPProxyConf) - if !ok { - return false - } - - if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) { - return false - } - - // Add custom logic equal if exists. - if !reflect.DeepEqual(cfg.DomainConf, cmpConf.DomainConf) { - return false - } - - if !reflect.DeepEqual(cfg.Locations, cmpConf.Locations) || - cfg.HTTPUser != cmpConf.HTTPUser || - cfg.HTTPPwd != cmpConf.HTTPPwd || - cfg.HostHeaderRewrite != cmpConf.HostHeaderRewrite || - cfg.RouteByHTTPUser != cmpConf.RouteByHTTPUser || - !reflect.DeepEqual(cfg.Headers, cmpConf.Headers) { - return false - } - - return true -} - func (cfg *HTTPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error { err := preUnmarshalFromIni(cfg, prefix, name, section) if err != nil { @@ -789,7 +650,6 @@ func (cfg *HTTPProxyConf) UnmarshalFromIni(prefix string, name string, section * // Add custom logic unmarshal if exists cfg.Headers = GetMapWithoutPrefix(section.KeysHash(), "header_") - return nil } @@ -821,21 +681,21 @@ func (cfg *HTTPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { pMsg.RouteByHTTPUser = cfg.RouteByHTTPUser } -func (cfg *HTTPProxyConf) CheckForCli() (err error) { - if err = cfg.BaseProxyConf.checkForCli(); err != nil { +func (cfg *HTTPProxyConf) ValidateForClient() (err error) { + if err = cfg.BaseProxyConf.validateForClient(); err != nil { return } // Add custom logic check if exists - if err = cfg.DomainConf.checkForCli(); err != nil { + if err = cfg.DomainConf.validateForClient(); err != nil { return } return } -func (cfg *HTTPProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) { - if err := cfg.BaseProxyConf.checkForSvr(); err != nil { +func (cfg *HTTPProxyConf) ValidateForServer(serverCfg ServerCommonConf) (err error) { + if err := cfg.BaseProxyConf.validateForServer(); err != nil { return err } @@ -843,7 +703,7 @@ func (cfg *HTTPProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) { return fmt.Errorf("type [http] not support when vhost_http_port is not set") } - if err = cfg.DomainConf.checkForSvr(serverCfg); err != nil { + if err = cfg.DomainConf.validateForServer(serverCfg); err != nil { err = fmt.Errorf("proxy [%s] domain conf check error: %v", cfg.ProxyName, err) return } @@ -852,24 +712,6 @@ func (cfg *HTTPProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) { } // HTTPS -func (cfg *HTTPSProxyConf) Compare(cmp ProxyConf) bool { - cmpConf, ok := cmp.(*HTTPSProxyConf) - if !ok { - return false - } - - if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) { - return false - } - - // Add custom logic equal if exists. - if !reflect.DeepEqual(cfg.DomainConf, cmpConf.DomainConf) { - return false - } - - return true -} - func (cfg *HTTPSProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error { err := preUnmarshalFromIni(cfg, prefix, name, section) if err != nil { @@ -877,7 +719,6 @@ func (cfg *HTTPSProxyConf) UnmarshalFromIni(prefix string, name string, section } // Add custom logic unmarshal if exists - return nil } @@ -897,21 +738,20 @@ func (cfg *HTTPSProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { pMsg.SubDomain = cfg.SubDomain } -func (cfg *HTTPSProxyConf) CheckForCli() (err error) { - if err = cfg.BaseProxyConf.checkForCli(); err != nil { +func (cfg *HTTPSProxyConf) ValidateForClient() (err error) { + if err = cfg.BaseProxyConf.validateForClient(); err != nil { return } // Add custom logic check if exists - if err = cfg.DomainConf.checkForCli(); err != nil { + if err = cfg.DomainConf.validateForClient(); err != nil { return } - return } -func (cfg *HTTPSProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) { - if err := cfg.BaseProxyConf.checkForSvr(); err != nil { +func (cfg *HTTPSProxyConf) ValidateForServer(serverCfg ServerCommonConf) (err error) { + if err := cfg.BaseProxyConf.validateForServer(); err != nil { return err } @@ -919,7 +759,7 @@ func (cfg *HTTPSProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) { return fmt.Errorf("type [https] not support when vhost_https_port is not set") } - if err = cfg.DomainConf.checkForSvr(serverCfg); err != nil { + if err = cfg.DomainConf.validateForServer(serverCfg); err != nil { err = fmt.Errorf("proxy [%s] domain conf check error: %v", cfg.ProxyName, err) return } @@ -928,23 +768,9 @@ func (cfg *HTTPSProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) { } // SUDP -func (cfg *SUDPProxyConf) Compare(cmp ProxyConf) bool { - cmpConf, ok := cmp.(*SUDPProxyConf) - if !ok { - return false - } - - if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) { - return false - } - - // Add custom logic equal if exists. - if cfg.Role != cmpConf.Role || - cfg.Sk != cmpConf.Sk { - return false - } - - return true +func (cfg *SUDPProxyConf) SetDefaultValues() { + cfg.BaseProxyConf.SetDefaultValues() + cfg.RoleServerCommonConf.setDefaultValues() } func (cfg *SUDPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error { @@ -954,7 +780,6 @@ func (cfg *SUDPProxyConf) UnmarshalFromIni(prefix string, name string, section * } // Add custom logic unmarshal if exists - return nil } @@ -973,8 +798,8 @@ func (cfg *SUDPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { pMsg.Sk = cfg.Sk } -func (cfg *SUDPProxyConf) CheckForCli() (err error) { - if err := cfg.BaseProxyConf.checkForCli(); err != nil { +func (cfg *SUDPProxyConf) ValidateForClient() (err error) { + if err := cfg.BaseProxyConf.validateForClient(); err != nil { return err } @@ -986,31 +811,17 @@ func (cfg *SUDPProxyConf) CheckForCli() (err error) { return nil } -func (cfg *SUDPProxyConf) CheckForSvr(serverCfg ServerCommonConf) error { - if err := cfg.BaseProxyConf.checkForSvr(); err != nil { +func (cfg *SUDPProxyConf) ValidateForServer(serverCfg ServerCommonConf) error { + if err := cfg.BaseProxyConf.validateForServer(); err != nil { return err } return nil } // STCP -func (cfg *STCPProxyConf) Compare(cmp ProxyConf) bool { - cmpConf, ok := cmp.(*STCPProxyConf) - if !ok { - return false - } - - if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) { - return false - } - - // Add custom logic equal if exists. - if cfg.Role != cmpConf.Role || - cfg.Sk != cmpConf.Sk { - return false - } - - return true +func (cfg *STCPProxyConf) SetDefaultValues() { + cfg.BaseProxyConf.SetDefaultValues() + cfg.RoleServerCommonConf.setDefaultValues() } func (cfg *STCPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error { @@ -1023,7 +834,6 @@ func (cfg *STCPProxyConf) UnmarshalFromIni(prefix string, name string, section * if cfg.Role == "" { cfg.Role = "server" } - return nil } @@ -1042,8 +852,8 @@ func (cfg *STCPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { pMsg.Sk = cfg.Sk } -func (cfg *STCPProxyConf) CheckForCli() (err error) { - if err = cfg.BaseProxyConf.checkForCli(); err != nil { +func (cfg *STCPProxyConf) ValidateForClient() (err error) { + if err = cfg.BaseProxyConf.validateForClient(); err != nil { return } @@ -1055,30 +865,17 @@ func (cfg *STCPProxyConf) CheckForCli() (err error) { return } -func (cfg *STCPProxyConf) CheckForSvr(serverCfg ServerCommonConf) error { - if err := cfg.BaseProxyConf.checkForSvr(); err != nil { +func (cfg *STCPProxyConf) ValidateForServer(serverCfg ServerCommonConf) error { + if err := cfg.BaseProxyConf.validateForServer(); err != nil { return err } return nil } // XTCP -func (cfg *XTCPProxyConf) Compare(cmp ProxyConf) bool { - cmpConf, ok := cmp.(*XTCPProxyConf) - if !ok { - return false - } - - if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) { - return false - } - - // Add custom logic equal if exists. - if cfg.Role != cmpConf.Role || - cfg.Sk != cmpConf.Sk { - return false - } - return true +func (cfg *XTCPProxyConf) SetDefaultValues() { + cfg.BaseProxyConf.SetDefaultValues() + cfg.RoleServerCommonConf.setDefaultValues() } func (cfg *XTCPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error { @@ -1109,8 +906,8 @@ func (cfg *XTCPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { pMsg.Sk = cfg.Sk } -func (cfg *XTCPProxyConf) CheckForCli() (err error) { - if err = cfg.BaseProxyConf.checkForCli(); err != nil { +func (cfg *XTCPProxyConf) ValidateForClient() (err error) { + if err = cfg.BaseProxyConf.validateForClient(); err != nil { return } @@ -1121,8 +918,8 @@ func (cfg *XTCPProxyConf) CheckForCli() (err error) { return } -func (cfg *XTCPProxyConf) CheckForSvr(serverCfg ServerCommonConf) error { - if err := cfg.BaseProxyConf.checkForSvr(); err != nil { +func (cfg *XTCPProxyConf) ValidateForServer(serverCfg ServerCommonConf) error { + if err := cfg.BaseProxyConf.validateForServer(); err != nil { return err } return nil diff --git a/pkg/config/proxy_test.go b/pkg/config/proxy_test.go index 894d6fd3..9ef6f87c 100644 --- a/pkg/config/proxy_test.go +++ b/pkg/config/proxy_test.go @@ -254,8 +254,10 @@ func Test_Proxy_UnmarshalFromIni(t *testing.T) { }, BandwidthLimitMode: BandwidthLimitModeClient, }, - Role: "server", - Sk: "abcdefg", + RoleServerCommonConf: RoleServerCommonConf{ + Role: "server", + Sk: "abcdefg", + }, }, }, { @@ -279,8 +281,10 @@ func Test_Proxy_UnmarshalFromIni(t *testing.T) { }, BandwidthLimitMode: BandwidthLimitModeClient, }, - Role: "server", - Sk: "abcdefg", + RoleServerCommonConf: RoleServerCommonConf{ + Role: "server", + Sk: "abcdefg", + }, }, }, { diff --git a/pkg/config/visitor.go b/pkg/config/visitor.go index 808240bb..1f388bad 100644 --- a/pkg/config/visitor.go +++ b/pkg/config/visitor.go @@ -34,10 +34,12 @@ var ( ) type VisitorConf interface { - GetBaseInfo() *BaseVisitorConf - Compare(cmp VisitorConf) bool + // GetBaseConfig returns the base config of visitor. + GetBaseConfig() *BaseVisitorConf + // UnmarshalFromIni unmarshals config from ini. UnmarshalFromIni(prefix string, name string, section *ini.Section) error - Check() error + // Validate validates config. + Validate() error } type BaseVisitorConf struct { @@ -47,8 +49,10 @@ type BaseVisitorConf struct { UseCompression bool `ini:"use_compression" json:"use_compression"` Role string `ini:"role" json:"role"` Sk string `ini:"sk" json:"sk"` - ServerName string `ini:"server_name" json:"server_name"` - BindAddr string `ini:"bind_addr" json:"bind_addr"` + // if the server user is not set, it defaults to the current user + ServerUser string `ini:"server_user" json:"server_user"` + ServerName string `ini:"server_name" json:"server_name"` + BindAddr string `ini:"bind_addr" json:"bind_addr"` // BindPort is the port that visitor listens on. // It can be less than 0, it means don't bind to the port and only receive connections redirected from // other visitors. (This is not supported for SUDP now) @@ -81,7 +85,6 @@ func DefaultVisitorConf(visitorType string) VisitorConf { if !ok { return nil } - return reflect.New(v).Interface().(VisitorConf) } @@ -103,7 +106,7 @@ func NewVisitorConfFromIni(prefix string, name string, section *ini.Section) (Vi return nil, fmt.Errorf("visitor [%s] type [%s] error", name, visitorType) } - if err := conf.Check(); err != nil { + if err := conf.Validate(); err != nil { return nil, err } @@ -111,26 +114,11 @@ func NewVisitorConfFromIni(prefix string, name string, section *ini.Section) (Vi } // Base -func (cfg *BaseVisitorConf) GetBaseInfo() *BaseVisitorConf { +func (cfg *BaseVisitorConf) GetBaseConfig() *BaseVisitorConf { return cfg } -func (cfg *BaseVisitorConf) compare(cmp *BaseVisitorConf) bool { - if cfg.ProxyName != cmp.ProxyName || - cfg.ProxyType != cmp.ProxyType || - cfg.UseEncryption != cmp.UseEncryption || - cfg.UseCompression != cmp.UseCompression || - cfg.Role != cmp.Role || - cfg.Sk != cmp.Sk || - cfg.ServerName != cmp.ServerName || - cfg.BindAddr != cmp.BindAddr || - cfg.BindPort != cmp.BindPort { - return false - } - return true -} - -func (cfg *BaseVisitorConf) check() (err error) { +func (cfg *BaseVisitorConf) validate() (err error) { if cfg.Role != "visitor" { err = fmt.Errorf("invalid role") return @@ -156,7 +144,11 @@ func (cfg *BaseVisitorConf) unmarshalFromIni(prefix string, name string, section cfg.ProxyName = prefix + name // server_name - cfg.ServerName = prefix + cfg.ServerName + if cfg.ServerUser == "" { + cfg.ServerName = prefix + cfg.ServerName + } else { + cfg.ServerName = cfg.ServerUser + "." + cfg.ServerName + } // bind_addr if cfg.BindAddr == "" { @@ -171,7 +163,7 @@ func preVisitorUnmarshalFromIni(cfg VisitorConf, prefix string, name string, sec return err } - err = cfg.GetBaseInfo().unmarshalFromIni(prefix, name, section) + err = cfg.GetBaseConfig().unmarshalFromIni(prefix, name, section) if err != nil { return err } @@ -181,21 +173,6 @@ func preVisitorUnmarshalFromIni(cfg VisitorConf, prefix string, name string, sec // SUDP var _ VisitorConf = &SUDPVisitorConf{} -func (cfg *SUDPVisitorConf) Compare(cmp VisitorConf) bool { - cmpConf, ok := cmp.(*SUDPVisitorConf) - if !ok { - return false - } - - if !cfg.BaseVisitorConf.compare(&cmpConf.BaseVisitorConf) { - return false - } - - // Add custom login equal, if exists - - return true -} - func (cfg *SUDPVisitorConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) (err error) { err = preVisitorUnmarshalFromIni(cfg, prefix, name, section) if err != nil { @@ -207,8 +184,8 @@ func (cfg *SUDPVisitorConf) UnmarshalFromIni(prefix string, name string, section return } -func (cfg *SUDPVisitorConf) Check() (err error) { - if err = cfg.BaseVisitorConf.check(); err != nil { +func (cfg *SUDPVisitorConf) Validate() (err error) { + if err = cfg.BaseVisitorConf.validate(); err != nil { return } @@ -220,21 +197,6 @@ func (cfg *SUDPVisitorConf) Check() (err error) { // STCP var _ VisitorConf = &STCPVisitorConf{} -func (cfg *STCPVisitorConf) Compare(cmp VisitorConf) bool { - cmpConf, ok := cmp.(*STCPVisitorConf) - if !ok { - return false - } - - if !cfg.BaseVisitorConf.compare(&cmpConf.BaseVisitorConf) { - return false - } - - // Add custom login equal, if exists - - return true -} - func (cfg *STCPVisitorConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) (err error) { err = preVisitorUnmarshalFromIni(cfg, prefix, name, section) if err != nil { @@ -246,8 +208,8 @@ func (cfg *STCPVisitorConf) UnmarshalFromIni(prefix string, name string, section return } -func (cfg *STCPVisitorConf) Check() (err error) { - if err = cfg.BaseVisitorConf.check(); err != nil { +func (cfg *STCPVisitorConf) Validate() (err error) { + if err = cfg.BaseVisitorConf.validate(); err != nil { return } @@ -259,28 +221,6 @@ func (cfg *STCPVisitorConf) Check() (err error) { // XTCP var _ VisitorConf = &XTCPVisitorConf{} -func (cfg *XTCPVisitorConf) Compare(cmp VisitorConf) bool { - cmpConf, ok := cmp.(*XTCPVisitorConf) - if !ok { - return false - } - - if !cfg.BaseVisitorConf.compare(&cmpConf.BaseVisitorConf) { - return false - } - - // Add custom login equal, if exists - if cfg.Protocol != cmpConf.Protocol || - cfg.KeepTunnelOpen != cmpConf.KeepTunnelOpen || - cfg.MaxRetriesAnHour != cmpConf.MaxRetriesAnHour || - cfg.MinRetryInterval != cmpConf.MinRetryInterval || - cfg.FallbackTo != cmpConf.FallbackTo || - cfg.FallbackTimeoutMs != cmpConf.FallbackTimeoutMs { - return false - } - return true -} - func (cfg *XTCPVisitorConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) (err error) { err = preVisitorUnmarshalFromIni(cfg, prefix, name, section) if err != nil { @@ -303,8 +243,8 @@ func (cfg *XTCPVisitorConf) UnmarshalFromIni(prefix string, name string, section return } -func (cfg *XTCPVisitorConf) Check() (err error) { - if err = cfg.BaseVisitorConf.check(); err != nil { +func (cfg *XTCPVisitorConf) Validate() (err error) { + if err = cfg.BaseVisitorConf.validate(); err != nil { return } diff --git a/pkg/msg/msg.go b/pkg/msg/msg.go index c6fec4df..2cb291ac 100644 --- a/pkg/msg/msg.go +++ b/pkg/msg/msg.go @@ -145,6 +145,7 @@ type StartWorkConn struct { } type NewVisitorConn struct { + RunID string `json:"run_id,omitempty"` ProxyName string `json:"proxy_name,omitempty"` SignKey string `json:"sign_key,omitempty"` Timestamp int64 `json:"timestamp,omitempty"` diff --git a/server/control.go b/server/control.go index a4bffe39..01075484 100644 --- a/server/control.go +++ b/server/control.go @@ -394,7 +394,7 @@ func (ctl *Control) stoper() { for _, pxy := range ctl.proxies { pxy.Close() ctl.pxyManager.Del(pxy.GetName()) - metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConf().GetBaseInfo().ProxyType) + metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConf().GetBaseConfig().ProxyType) notifyContent := &plugin.CloseProxyContent{ User: plugin.UserInfo{ @@ -614,7 +614,7 @@ func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) { delete(ctl.proxies, closeMsg.ProxyName) ctl.mu.Unlock() - metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConf().GetBaseInfo().ProxyType) + metrics.Server.CloseProxy(pxy.GetName(), pxy.GetConf().GetBaseConfig().ProxyType) notifyContent := &plugin.CloseProxyContent{ User: plugin.UserInfo{ diff --git a/server/proxy/http.go b/server/proxy/http.go index 143665b8..31742b7f 100644 --- a/server/proxy/http.go +++ b/server/proxy/http.go @@ -175,13 +175,13 @@ func (pxy *HTTPProxy) GetRealConn(remoteAddr string) (workConn net.Conn, err err workConn = utilnet.WrapReadWriteCloserToConn(rwc, tmpConn) workConn = utilnet.WrapStatsConn(workConn, pxy.updateStatsAfterClosedConn) - metrics.Server.OpenConnection(pxy.GetName(), pxy.GetConf().GetBaseInfo().ProxyType) + metrics.Server.OpenConnection(pxy.GetName(), pxy.GetConf().GetBaseConfig().ProxyType) return } func (pxy *HTTPProxy) updateStatsAfterClosedConn(totalRead, totalWrite int64) { name := pxy.GetName() - proxyType := pxy.GetConf().GetBaseInfo().ProxyType + proxyType := pxy.GetConf().GetBaseConfig().ProxyType metrics.Server.CloseConnection(name, proxyType) metrics.Server.AddTrafficIn(name, proxyType, totalWrite) metrics.Server.AddTrafficOut(name, proxyType, totalRead) diff --git a/server/proxy/proxy.go b/server/proxy/proxy.go index 56dcea83..b2d326a9 100644 --- a/server/proxy/proxy.go +++ b/server/proxy/proxy.go @@ -196,16 +196,16 @@ func (pxy *BaseProxy) startListenHandler(p Proxy, handler func(Proxy, net.Conn, func NewProxy(ctx context.Context, userInfo plugin.UserInfo, rc *controller.ResourceController, poolCount int, getWorkConnFn GetWorkConnFn, pxyConf config.ProxyConf, serverCfg config.ServerCommonConf, loginMsg *msg.Login, ) (pxy Proxy, err error) { - xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(pxyConf.GetBaseInfo().ProxyName) + xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(pxyConf.GetBaseConfig().ProxyName) var limiter *rate.Limiter - limitBytes := pxyConf.GetBaseInfo().BandwidthLimit.Bytes() - if limitBytes > 0 && pxyConf.GetBaseInfo().BandwidthLimitMode == config.BandwidthLimitModeServer { + limitBytes := pxyConf.GetBaseConfig().BandwidthLimit.Bytes() + if limitBytes > 0 && pxyConf.GetBaseConfig().BandwidthLimitMode == config.BandwidthLimitModeServer { limiter = rate.NewLimiter(rate.Limit(float64(limitBytes)), int(limitBytes)) } basePxy := BaseProxy{ - name: pxyConf.GetBaseInfo().ProxyName, + name: pxyConf.GetBaseConfig().ProxyName, rc: rc, listeners: make([]net.Listener, 0), poolCount: poolCount, @@ -277,7 +277,7 @@ func HandleUserTCPConnection(pxy Proxy, userConn net.Conn, serverCfg config.Serv content := &plugin.NewUserConnContent{ User: pxy.GetUserInfo(), ProxyName: pxy.GetName(), - ProxyType: pxy.GetConf().GetBaseInfo().ProxyType, + ProxyType: pxy.GetConf().GetBaseConfig().ProxyType, RemoteAddr: userConn.RemoteAddr().String(), } _, err := rc.PluginManager.NewUserConn(content) @@ -294,7 +294,7 @@ func HandleUserTCPConnection(pxy Proxy, userConn net.Conn, serverCfg config.Serv defer workConn.Close() var local io.ReadWriteCloser = workConn - cfg := pxy.GetConf().GetBaseInfo() + cfg := pxy.GetConf().GetBaseConfig() xl.Trace("handler user tcp connection, use_encryption: %t, use_compression: %t", cfg.UseEncryption, cfg.UseCompression) if cfg.UseEncryption { local, err = libio.WithEncryption(local, []byte(serverCfg.Token)) @@ -317,7 +317,7 @@ func HandleUserTCPConnection(pxy Proxy, userConn net.Conn, serverCfg config.Serv workConn.RemoteAddr().String(), userConn.LocalAddr().String(), userConn.RemoteAddr().String()) name := pxy.GetName() - proxyType := pxy.GetConf().GetBaseInfo().ProxyType + proxyType := pxy.GetConf().GetBaseConfig().ProxyType metrics.Server.OpenConnection(name, proxyType) inCount, outCount, _ := libio.Join(local, userConn) metrics.Server.CloseConnection(name, proxyType) diff --git a/server/proxy/udp.go b/server/proxy/udp.go index f1b7d06d..e047c9de 100644 --- a/server/proxy/udp.go +++ b/server/proxy/udp.go @@ -124,7 +124,7 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) { pxy.readCh <- m metrics.Server.AddTrafficOut( pxy.GetName(), - pxy.GetConf().GetBaseInfo().ProxyType, + pxy.GetConf().GetBaseConfig().ProxyType, int64(len(m.Content)), ) }); errRet != nil { @@ -154,7 +154,7 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) { xl.Trace("send message to udp workConn: %s", udpMsg.Content) metrics.Server.AddTrafficIn( pxy.GetName(), - pxy.GetConf().GetBaseInfo().ProxyType, + pxy.GetConf().GetBaseConfig().ProxyType, int64(len(udpMsg.Content)), ) continue