From ff4bdec3f75ac07f59edbbbbf1f5e958efcbf63a Mon Sep 17 00:00:00 2001 From: fatedier Date: Sat, 16 Dec 2017 23:59:46 +0800 Subject: [PATCH 1/7] add test case --- tests/echo_server.go | 28 ++++++----- tests/func_test.go | 116 ++++++++++++++++++------------------------- tests/util.go | 57 +++++++++++++++++++++ 3 files changed, 119 insertions(+), 82 deletions(-) create mode 100644 tests/util.go diff --git a/tests/echo_server.go b/tests/echo_server.go index 391c87e7..47b87cb1 100644 --- a/tests/echo_server.go +++ b/tests/echo_server.go @@ -1,7 +1,6 @@ package tests import ( - "bufio" "fmt" "io" "net" @@ -11,8 +10,8 @@ import ( frpNet "github.com/fatedier/frp/utils/net" ) -func StartEchoServer() { - l, err := frpNet.ListenTcp("127.0.0.1", 10701) +func StartTcpEchoServer() { + l, err := frpNet.ListenTcp("127.0.0.1", TEST_TCP_ECHO_PORT) if err != nil { fmt.Printf("echo server listen error: %v\n", err) return @@ -30,7 +29,7 @@ func StartEchoServer() { } func StartUdpEchoServer() { - l, err := frpNet.ListenUDP("127.0.0.1", 10703) + l, err := frpNet.ListenUDP("127.0.0.1", TEST_UDP_ECHO_PORT) if err != nil { fmt.Printf("udp echo server listen error: %v\n", err) return @@ -48,7 +47,7 @@ func StartUdpEchoServer() { } func StartUnixDomainServer() { - unixPath := "/tmp/frp_echo_server.sock" + unixPath := TEST_UNIX_DOMAIN_ADDR os.Remove(unixPath) syscall.Umask(0) l, err := net.Listen("unix", unixPath) @@ -69,17 +68,20 @@ func StartUnixDomainServer() { } func echoWorker(c net.Conn) { - br := bufio.NewReader(c) + buf := make([]byte, 2048) + for { - buf, err := br.ReadString('\n') - if err == io.EOF { - break - } + n, err := c.Read(buf) if err != nil { - fmt.Printf("echo server read error: %v\n", err) - return + if err == io.EOF { + c.Close() + break + } else { + fmt.Printf("echo server read error: %v\n", err) + return + } } - c.Write([]byte(buf + "\n")) + c.Write(buf[:n]) } } diff --git a/tests/func_test.go b/tests/func_test.go index 444e673f..1ac25051 100644 --- a/tests/func_test.go +++ b/tests/func_test.go @@ -12,43 +12,67 @@ import ( "time" frpNet "github.com/fatedier/frp/utils/net" + "github.com/stretchr/testify/assert" ) var ( - ECHO_PORT int64 = 10711 - UDP_ECHO_PORT int64 = 10712 - HTTP_PORT int64 = 10710 - ECHO_TEST_STR string = "Hello World\n" - HTTP_RES_STR string = "Hello World" + TEST_STR = "frp is a fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet." + TEST_TCP_PORT int64 = 10701 + TEST_TCP_FRP_PORT int64 = 10801 + TEST_TCP_EC_FRP_PORT int64 = 10901 + TEST_TCP_ECHO_STR string = "tcp type:" + TEST_STR + + TEST_UDP_PORT int64 = 10702 + TEST_UDP_FRP_PORT int64 = 10802 + TEST_UDP_ECHO_STR string = "udp type:" + TEST_STR + + TEST_UNIX_DOMAIN_ADDR string = "/tmp/frp_echo_server.sock" + TEST_UNIX_DOMAIN_FRP_PORT int64 = 10803 + TEST_UNIX_DOMAIN_STR string = "unix domain type:" + TEST_STR + + TEST_HTTP_PORT int64 = 10704 + TEST_HTTP_FRP_PORT int64 = 10804 + TEST_HTTP_WEB01_STR string = "http web01:" + TEST_STR ) func init() { - go StartEchoServer() + go StartTcpEchoServer() go StartUdpEchoServer() - go StartHttpServer() go StartUnixDomainServer() + go StartHttpServer() time.Sleep(500 * time.Millisecond) } -func TestEchoServer(t *testing.T) { - c, err := frpNet.ConnectTcpServer(fmt.Sprintf("127.0.0.1:%d", ECHO_PORT)) - if err != nil { - t.Fatalf("connect to echo server error: %v", err) - } - timer := time.Now().Add(time.Duration(5) * time.Second) - c.SetDeadline(timer) +func TestTcpServer(t *testing.T) { + assert := assert.New(t) + // Normal + addr := fmt.Sprintf("127.0.0.1:%d", TEST_TCP_FRP_PORT) + res, err := sendTcpMsg(addr, TEST_TCP_ECHO_STR) + assert.NoError(err) + assert.Equal(TEST_TCP_ECHO_STR, res) - c.Write([]byte(ECHO_TEST_STR + "\n")) + // Encrytion and compression + addr = fmt.Sprintf("127.0.0.1:%d", TEST_TCP_EC_FRP_PORT) + res, err = sendTcpMsg(addr, TEST_TCP_ECHO_STR) + assert.NoError(err) + assert.Equal(TEST_TCP_ECHO_STR, res) +} - br := bufio.NewReader(c) - buf, err := br.ReadString('\n') - if err != nil { - t.Fatalf("read from echo server error: %v", err) - } +func TestUdpEchoServer(t *testing.T) { + assert := assert.New(t) + // Normal + addr := fmt.Sprintf("127.0.0.1:%d", TEST_UDP_FRP_PORT) + res, err := sendUdpMsg(addr, TEST_UDP_ECHO_STR) + assert.NoError(err) + assert.Equal(TEST_UDP_ECHO_STR, res) - if ECHO_TEST_STR != buf { - t.Fatalf("content error, send [%s], get [%s]", strings.Trim(ECHO_TEST_STR, "\n"), strings.Trim(buf, "\n")) - } +func TestUnixDomainServer(t *testing.T) { + assert := assert.New(t) + // Normal + addr := fmt.Sprintf("127.0.0.1:%d", TEST_UNIX_DOMAIN_FRP_PORT) + res, err := sendTcpMsg(addr, TEST_UNIX_DOMAIN_STR) + assert.NoError(err) + assert.Equal(TEST_UNIX_DOMAIN_STR, res) } func TestHttpServer(t *testing.T) { @@ -71,49 +95,3 @@ func TestHttpServer(t *testing.T) { t.Fatalf("http code from http server error [%d]", res.StatusCode) } } - -func TestUdpEchoServer(t *testing.T) { - addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:10712") - if err != nil { - t.Fatalf("do udp request error: %v", err) - } - conn, err := net.DialUDP("udp", nil, addr) - if err != nil { - t.Fatalf("dial udp server error: %v", err) - } - defer conn.Close() - _, err = conn.Write([]byte("hello frp\n")) - if err != nil { - t.Fatalf("write to udp server error: %v", err) - } - data := make([]byte, 20) - n, err := conn.Read(data) - if err != nil { - t.Fatalf("read from udp server error: %v", err) - } - - if string(bytes.TrimSpace(data[:n])) != "hello frp" { - t.Fatalf("message got from udp server error, get %s", string(data[:n-1])) - } -} - -func TestUnixDomainServer(t *testing.T) { - c, err := frpNet.ConnectTcpServer(fmt.Sprintf("127.0.0.1:%d", 10704)) - if err != nil { - t.Fatalf("connect to echo server error: %v", err) - } - timer := time.Now().Add(time.Duration(5) * time.Second) - c.SetDeadline(timer) - - c.Write([]byte(ECHO_TEST_STR + "\n")) - - br := bufio.NewReader(c) - buf, err := br.ReadString('\n') - if err != nil { - t.Fatalf("read from echo server error: %v", err) - } - - if ECHO_TEST_STR != buf { - t.Fatalf("content error, send [%s], get [%s]", strings.Trim(ECHO_TEST_STR, "\n"), strings.Trim(buf, "\n")) - } -} diff --git a/tests/util.go b/tests/util.go new file mode 100644 index 00000000..0352463c --- /dev/null +++ b/tests/util.go @@ -0,0 +1,57 @@ +package test + +import ( + "fmt" + "net" + "time" + + frpNet "github.com/fatedier/frp/utils/net" +) + +func sendTcpMsg(addr string, msg string) (res string, err error) { + c, err := frpNet.ConnectTcpServer(addr) + defer c.Close() + if err != nil { + err = fmt.Errorf("connect to tcp server error: %v", err) + return + } + + timer := time.Now().Add(5 * time.Second) + c.SetDeadline(timer) + c.Write([]byte(msg)) + + buf := make([]byte, 2048) + n, errRet := c.Read(buf) + if errRet != nil { + err = fmt.Errorf("read from tcp server error: %v", errRet) + return + } + return string(buf[:n]), nil +} + +func sendUdpMsg(addr string, msg string) (res string, err error) { + udpAddr, errRet := net.ResolveUDPAddr("udp", addr) + if errRet != nil { + err = fmt.Errorf("resolve udp addr error: %v", err) + return + } + conn, errRet := net.DialUDP("udp", nil, udpAddr) + if errRet != nil { + err = fmt.Errorf("dial udp server error: %v", err) + return + } + defer conn.Close() + _, err = conn.Write([]byte(msg)) + if err != nil { + err = fmt.Errorf("write to udp server error: %v", err) + return + } + + buf := make([]byte, 2048) + n, errRet := conn.Read(buf) + if errRet != nil { + err = fmt.Errorf("read from udp server error: %v", err) + return + } + return string(buf[:n]), nil +} From 3bb404dfb51a8948dcee9e513744363d2c5c0f63 Mon Sep 17 00:00:00 2001 From: fatedier Date: Mon, 18 Dec 2017 19:35:09 +0800 Subject: [PATCH 2/7] more test case --- conf/frpc_full.ini | 2 +- tests/clean_test.sh | 6 ++ tests/conf/auto_test_frpc.ini | 90 +++++++++++++++--- tests/conf/auto_test_frpc_visitor.ini | 25 +++++ tests/conf/auto_test_frps.ini | 2 +- tests/echo_server.go | 4 +- tests/func_test.go | 130 +++++++++++++++++++------- tests/http_server.go | 21 ++++- tests/run_test.sh | 1 + tests/util.go | 40 +++++++- 10 files changed, 266 insertions(+), 55 deletions(-) create mode 100644 tests/conf/auto_test_frpc_visitor.ini diff --git a/conf/frpc_full.ini b/conf/frpc_full.ini index a0988e3e..bcb61d22 100644 --- a/conf/frpc_full.ini +++ b/conf/frpc_full.ini @@ -88,7 +88,7 @@ http_pwd = admin # if domain for frps is frps.com, then you can access [web01] proxy by URL http://test.frps.com subdomain = web01 custom_domains = web02.yourdomain.com -# locations is only useful for http type +# locations is only available for http type locations = /,/pic host_header_rewrite = example.com diff --git a/tests/clean_test.sh b/tests/clean_test.sh index 6a6ef562..b0b37636 100755 --- a/tests/clean_test.sh +++ b/tests/clean_test.sh @@ -10,5 +10,11 @@ if [ -n "${pid}" ]; then kill ${pid} fi +pid=`ps aux|grep './../bin/frpc -c ./conf/auto_test_frpc_visitor.ini'|grep -v grep|awk {'print $2'}` +if [ -n "${pid}" ]; then + kill ${pid} +fi + rm -f ./frps.log rm -f ./frpc.log +rm -f ./frpc_visitor.log diff --git a/tests/conf/auto_test_frpc.ini b/tests/conf/auto_test_frpc.ini index d21f01f1..2d243e80 100644 --- a/tests/conf/auto_test_frpc.ini +++ b/tests/conf/auto_test_frpc.ini @@ -6,30 +6,96 @@ log_file = ./frpc.log log_level = debug privilege_token = 123456 -[echo] +[tcp_normal] type = tcp local_ip = 127.0.0.1 local_port = 10701 -remote_port = 10711 -use_encryption = true -use_compression = true +remote_port = 10801 -[web] -type = http +[tcp_ec] +type = tcp local_ip = 127.0.0.1 -local_port = 10702 +local_port = 10701 +remote_port = 10901 use_encryption = true use_compression = true -custom_domains = 127.0.0.1 -[udp] +[udp_normal] type = udp local_ip = 127.0.0.1 -local_port = 10703 -remote_port = 10712 +local_port = 10702 +remote_port = 10802 + +[udp_ec] +type = udp +local_ip = 127.0.0.1 +local_port = 10702 +remote_port = 10902 +use_encryption = true +use_compression = true [unix_domain] type = tcp -remote_port = 10704 +remote_port = 10803 plugin = unix_domain_socket plugin_unix_path = /tmp/frp_echo_server.sock + +[stcp] +type = stcp +sk = abcdefg +local_ip = 127.0.0.1 +local_port = 10701 + +[stcp_ec] +type = stcp +sk = abc +local_ip = 127.0.0.1 +local_port = 10701 +use_encryption = true +use_compression = true + +[web01] +type = http +local_ip = 127.0.0.1 +local_port = 10704 +custom_domains = 127.0.0.1 + +[web02] +type = http +local_ip = 127.0.0.1 +local_port = 10704 +custom_domains = test2.frp.com +host_header_rewrite = test2.frp.com +use_encryption = true +use_compression = true + +[web03] +type = http +local_ip = 127.0.0.1 +local_port = 10704 +custom_domains = test3.frp.com +use_encryption = true +use_compression = true +host_header_rewrite = test3.frp.com +locations = /,/foo + +[web04] +type = http +local_ip = 127.0.0.1 +local_port = 10704 +custom_domains = test3.frp.com +use_encryption = true +use_compression = true +host_header_rewrite = test3.frp.com +locations = /bar + +[web05] +type = http +local_ip = 127.0.0.1 +local_port = 10704 +custom_domains = test5.frp.com +host_header_rewrite = test5.frp.com +use_encryption = true +use_compression = true +http_user = test +http_user = test diff --git a/tests/conf/auto_test_frpc_visitor.ini b/tests/conf/auto_test_frpc_visitor.ini new file mode 100644 index 00000000..cc524deb --- /dev/null +++ b/tests/conf/auto_test_frpc_visitor.ini @@ -0,0 +1,25 @@ +[common] +server_addr = 0.0.0.0 +server_port = 10700 +log_file = ./frpc_visitor.log +# debug, info, warn, error +log_level = debug +privilege_token = 123456 + +[stcp_visitor] +type = stcp +role = visitor +server_name = stcp +sk = abcdefg +bind_addr = 127.0.0.1 +bind_port = 10805 + +[stcp_ec_visitor] +type = stcp +role = visitor +server_name = stcp_ec +sk = abc +bind_addr = 127.0.0.1 +bind_port = 10905 +use_encryption = true +use_compression = true diff --git a/tests/conf/auto_test_frps.ini b/tests/conf/auto_test_frps.ini index 3d933595..c03dc146 100644 --- a/tests/conf/auto_test_frps.ini +++ b/tests/conf/auto_test_frps.ini @@ -1,7 +1,7 @@ [common] bind_addr = 0.0.0.0 bind_port = 10700 -vhost_http_port = 10710 +vhost_http_port = 10804 log_file = ./frps.log log_level = debug privilege_token = 123456 diff --git a/tests/echo_server.go b/tests/echo_server.go index 47b87cb1..5c73fef1 100644 --- a/tests/echo_server.go +++ b/tests/echo_server.go @@ -11,7 +11,7 @@ import ( ) func StartTcpEchoServer() { - l, err := frpNet.ListenTcp("127.0.0.1", TEST_TCP_ECHO_PORT) + l, err := frpNet.ListenTcp("127.0.0.1", TEST_TCP_PORT) if err != nil { fmt.Printf("echo server listen error: %v\n", err) return @@ -29,7 +29,7 @@ func StartTcpEchoServer() { } func StartUdpEchoServer() { - l, err := frpNet.ListenUDP("127.0.0.1", TEST_UDP_ECHO_PORT) + l, err := frpNet.ListenUDP("127.0.0.1", TEST_UDP_PORT) if err != nil { fmt.Printf("udp echo server listen error: %v\n", err) return diff --git a/tests/func_test.go b/tests/func_test.go index 1ac25051..238046fb 100644 --- a/tests/func_test.go +++ b/tests/func_test.go @@ -1,17 +1,10 @@ package tests import ( - "bufio" - "bytes" "fmt" - "io/ioutil" - "net" - "net/http" - "strings" "testing" "time" - frpNet "github.com/fatedier/frp/utils/net" "github.com/stretchr/testify/assert" ) @@ -22,17 +15,24 @@ var ( TEST_TCP_EC_FRP_PORT int64 = 10901 TEST_TCP_ECHO_STR string = "tcp type:" + TEST_STR - TEST_UDP_PORT int64 = 10702 - TEST_UDP_FRP_PORT int64 = 10802 - TEST_UDP_ECHO_STR string = "udp type:" + TEST_STR + TEST_UDP_PORT int64 = 10702 + TEST_UDP_FRP_PORT int64 = 10802 + TEST_UDP_EC_FRP_PORT int64 = 10902 + TEST_UDP_ECHO_STR string = "udp type:" + TEST_STR TEST_UNIX_DOMAIN_ADDR string = "/tmp/frp_echo_server.sock" TEST_UNIX_DOMAIN_FRP_PORT int64 = 10803 TEST_UNIX_DOMAIN_STR string = "unix domain type:" + TEST_STR - TEST_HTTP_PORT int64 = 10704 - TEST_HTTP_FRP_PORT int64 = 10804 - TEST_HTTP_WEB01_STR string = "http web01:" + TEST_STR + TEST_HTTP_PORT int64 = 10704 + TEST_HTTP_FRP_PORT int64 = 10804 + TEST_HTTP_NORMAL_STR string = "http normal string: " + TEST_STR + TEST_HTTP_FOO_STR string = "http foo string: " + TEST_STR + TEST_HTTP_BAR_STR string = "http bar string: " + TEST_STR + + TEST_STCP_FRP_PORT int64 = 10805 + TEST_STCP_EC_FRP_PORT int64 = 10905 + TEST_STCP_ECHO_STR string = "stcp type:" + TEST_STR ) func init() { @@ -43,7 +43,7 @@ func init() { time.Sleep(500 * time.Millisecond) } -func TestTcpServer(t *testing.T) { +func TestTcp(t *testing.T) { assert := assert.New(t) // Normal addr := fmt.Sprintf("127.0.0.1:%d", TEST_TCP_FRP_PORT) @@ -58,7 +58,7 @@ func TestTcpServer(t *testing.T) { assert.Equal(TEST_TCP_ECHO_STR, res) } -func TestUdpEchoServer(t *testing.T) { +func TestUdp(t *testing.T) { assert := assert.New(t) // Normal addr := fmt.Sprintf("127.0.0.1:%d", TEST_UDP_FRP_PORT) @@ -66,32 +66,92 @@ func TestUdpEchoServer(t *testing.T) { assert.NoError(err) assert.Equal(TEST_UDP_ECHO_STR, res) -func TestUnixDomainServer(t *testing.T) { + // Encrytion and compression + addr = fmt.Sprintf("127.0.0.1:%d", TEST_UDP_EC_FRP_PORT) + res, err = sendUdpMsg(addr, TEST_UDP_ECHO_STR) + assert.NoError(err) + assert.Equal(TEST_UDP_ECHO_STR, res) +} + +func TestUnixDomain(t *testing.T) { assert := assert.New(t) // Normal addr := fmt.Sprintf("127.0.0.1:%d", TEST_UNIX_DOMAIN_FRP_PORT) res, err := sendTcpMsg(addr, TEST_UNIX_DOMAIN_STR) - assert.NoError(err) - assert.Equal(TEST_UNIX_DOMAIN_STR, res) + if assert.NoError(err) { + assert.Equal(TEST_UNIX_DOMAIN_STR, res) + } } -func TestHttpServer(t *testing.T) { - client := &http.Client{} - req, _ := http.NewRequest("GET", fmt.Sprintf("http://127.0.0.1:%d", HTTP_PORT), nil) - res, err := client.Do(req) - if err != nil { - t.Fatalf("do http request error: %v", err) +func TestStcp(t *testing.T) { + assert := assert.New(t) + // Normal + addr := fmt.Sprintf("127.0.0.1:%d", TEST_STCP_FRP_PORT) + res, err := sendTcpMsg(addr, TEST_STCP_ECHO_STR) + if assert.NoError(err) { + assert.Equal(TEST_STCP_ECHO_STR, res) } - if res.StatusCode == 200 { - body, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Fatalf("read from http server error: %v", err) - } - bodystr := string(body) - if bodystr != HTTP_RES_STR { - t.Fatalf("content from http server error [%s], correct string is [%s]", bodystr, HTTP_RES_STR) - } - } else { - t.Fatalf("http code from http server error [%d]", res.StatusCode) + + // Encrytion and compression + addr = fmt.Sprintf("127.0.0.1:%d", TEST_STCP_EC_FRP_PORT) + res, err = sendTcpMsg(addr, TEST_STCP_ECHO_STR) + if assert.NoError(err) { + assert.Equal(TEST_STCP_ECHO_STR, res) + } +} + +func TestHttp(t *testing.T) { + assert := assert.New(t) + // web01 + code, body, err := sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "", nil) + if assert.NoError(err) { + assert.Equal(200, code) + assert.Equal(TEST_HTTP_NORMAL_STR, body) + } + + // web02 + code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test2.frp.com", nil) + if assert.NoError(err) { + assert.Equal(200, code) + assert.Equal(TEST_HTTP_NORMAL_STR, body) + } + + // error host header + code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "errorhost.frp.com", nil) + if assert.NoError(err) { + assert.Equal(404, code) + } + + // web03 + code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test3.frp.com", nil) + if assert.NoError(err) { + assert.Equal(200, code) + assert.Equal(TEST_HTTP_NORMAL_STR, body) + } + + code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d/foo", TEST_HTTP_FRP_PORT), "test3.frp.com", nil) + if assert.NoError(err) { + assert.Equal(200, code) + assert.Equal(TEST_HTTP_FOO_STR, body) + } + + // web04 + code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d/bar", TEST_HTTP_FRP_PORT), "test3.frp.com", nil) + if assert.NoError(err) { + assert.Equal(200, code) + assert.Equal(TEST_HTTP_BAR_STR, body) + } + + // web05 + code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test5.frp.com", nil) + if assert.NoError(err) { + assert.Equal(401, code) + } + + header := make(map[string]string) + header["Authorization"] = basicAuth("test", "test") + code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test5.frp.com", header) + if assert.NoError(err) { + assert.Equal(401, code) } } diff --git a/tests/http_server.go b/tests/http_server.go index 29eaa4de..6564d1dd 100644 --- a/tests/http_server.go +++ b/tests/http_server.go @@ -3,13 +3,30 @@ package tests import ( "fmt" "net/http" + "strings" ) func StartHttpServer() { http.HandleFunc("/", request) - http.ListenAndServe(fmt.Sprintf("0.0.0.0:%d", 10702), nil) + http.ListenAndServe(fmt.Sprintf("0.0.0.0:%d", TEST_HTTP_PORT), nil) } func request(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(HTTP_RES_STR)) + if strings.Contains(r.Host, "127.0.0.1") || strings.Contains(r.Host, "test2.frp.com") || + strings.Contains(r.Host, "test5.frp.com") { + w.WriteHeader(200) + w.Write([]byte(TEST_HTTP_NORMAL_STR)) + } else if strings.Contains(r.Host, "test3.frp.com") { + w.WriteHeader(200) + if strings.Contains(r.URL.Path, "foo") { + w.Write([]byte(TEST_HTTP_FOO_STR)) + } else if strings.Contains(r.URL.Path, "bar") { + w.Write([]byte(TEST_HTTP_BAR_STR)) + } else { + w.Write([]byte(TEST_HTTP_NORMAL_STR)) + } + } else { + w.WriteHeader(404) + } + return } diff --git a/tests/run_test.sh b/tests/run_test.sh index 5ef490a6..a852a3d0 100755 --- a/tests/run_test.sh +++ b/tests/run_test.sh @@ -3,6 +3,7 @@ ./../bin/frps -c ./conf/auto_test_frps.ini & sleep 1 ./../bin/frpc -c ./conf/auto_test_frpc.ini & +./../bin/frpc -c ./conf/auto_test_frpc_visitor.ini & # wait until proxies are connected sleep 2 diff --git a/tests/util.go b/tests/util.go index 0352463c..9eed5802 100644 --- a/tests/util.go +++ b/tests/util.go @@ -1,8 +1,11 @@ -package test +package tests import ( + "encoding/base64" "fmt" + "io/ioutil" "net" + "net/http" "time" frpNet "github.com/fatedier/frp/utils/net" @@ -10,11 +13,11 @@ import ( func sendTcpMsg(addr string, msg string) (res string, err error) { c, err := frpNet.ConnectTcpServer(addr) - defer c.Close() if err != nil { err = fmt.Errorf("connect to tcp server error: %v", err) return } + defer c.Close() timer := time.Now().Add(5 * time.Second) c.SetDeadline(timer) @@ -55,3 +58,36 @@ func sendUdpMsg(addr string, msg string) (res string, err error) { } return string(buf[:n]), nil } + +func sendHttpMsg(method, url string, host string, header map[string]string) (code int, body string, err error) { + req, errRet := http.NewRequest(method, url, nil) + if errRet != nil { + err = errRet + return + } + + if host != "" { + req.Host = host + } + for k, v := range header { + req.Header.Set(k, v) + } + resp, errRet := http.DefaultClient.Do(req) + if errRet != nil { + err = errRet + return + } + code = resp.StatusCode + buf, errRet := ioutil.ReadAll(resp.Body) + if errRet != nil { + err = errRet + return + } + body = string(buf) + return +} + +func basicAuth(username, passwd string) string { + auth := username + ":" + passwd + return "Basic " + base64.StdEncoding.EncodeToString([]byte(auth)) +} From 584e098e8e90c665dfa32ec1317734ab27fef5a4 Mon Sep 17 00:00:00 2001 From: fatedier Date: Wed, 17 Jan 2018 01:09:33 +0800 Subject: [PATCH 3/7] frpc: add status command --- client/admin.go | 1 + client/admin_api.go | 125 ++++++++++++- client/control.go | 348 +++++++++---------------------------- client/proxy.go | 5 +- client/proxy_manager.go | 340 ++++++++++++++++++++++++++++++++++++ client/service.go | 4 +- cmd/frpc/main.go | 176 +++++++++++++++---- models/config/proxy.go | 5 + server/control.go | 6 +- server/service.go | 2 +- utils/shutdown/shutdown.go | 28 +-- 11 files changed, 722 insertions(+), 318 deletions(-) create mode 100644 client/proxy_manager.go diff --git a/client/admin.go b/client/admin.go index f728483e..37cdf4c1 100644 --- a/client/admin.go +++ b/client/admin.go @@ -39,6 +39,7 @@ func (svr *Service) RunAdminServer(addr string, port int64) (err error) { // api, see dashboard_api.go router.GET("/api/reload", frpNet.HttprouterBasicAuth(svr.apiReload, user, passwd)) + router.GET("/api/status", frpNet.HttprouterBasicAuth(svr.apiStatus, user, passwd)) address := fmt.Sprintf("%s:%d", addr, port) server := &http.Server{ diff --git a/client/admin_api.go b/client/admin_api.go index 70842e65..fae1737e 100644 --- a/client/admin_api.go +++ b/client/admin_api.go @@ -16,7 +16,10 @@ package client import ( "encoding/json" + "fmt" "net/http" + "sort" + "strings" "github.com/julienschmidt/httprouter" ini "github.com/vaughan0/go-ini" @@ -72,7 +75,127 @@ func (svr *Service) apiReload(w http.ResponseWriter, r *http.Request, _ httprout return } - svr.ctl.reloadConf(pxyCfgs, visitorCfgs) + err = svr.ctl.reloadConf(pxyCfgs, visitorCfgs) + if err != nil { + res.Code = 4 + res.Msg = err.Error() + log.Error("reload frpc proxy config error: %v", err) + return + } log.Info("success reload conf") return } + +type StatusResp struct { + Tcp []ProxyStatusResp `json:"tcp"` + Udp []ProxyStatusResp `json:"udp"` + Http []ProxyStatusResp `json:"http"` + Https []ProxyStatusResp `json:"https"` + Stcp []ProxyStatusResp `json:"stcp"` + Xtcp []ProxyStatusResp `json:"xtcp"` +} + +type ProxyStatusResp struct { + Name string `json:"name"` + Type string `json:"type"` + Status string `json:"status"` + Err string `json:"err"` + LocalAddr string `json:"local_addr"` + Plugin string `json:"plugin"` + RemoteAddr string `json:"remote_addr"` +} + +type ByProxyStatusResp []ProxyStatusResp + +func (a ByProxyStatusResp) Len() int { return len(a) } +func (a ByProxyStatusResp) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a ByProxyStatusResp) Less(i, j int) bool { return strings.Compare(a[i].Name, a[j].Name) < 0 } + +func NewProxyStatusResp(status *ProxyStatus) ProxyStatusResp { + psr := ProxyStatusResp{ + Name: status.Name, + Type: status.Type, + Status: status.Status, + Err: status.Err, + } + switch cfg := status.Cfg.(type) { + case *config.TcpProxyConf: + if cfg.LocalPort != 0 { + psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) + } + psr.Plugin = cfg.Plugin + psr.RemoteAddr = fmt.Sprintf(":%d", cfg.RemotePort) + case *config.UdpProxyConf: + if cfg.LocalPort != 0 { + psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) + } + psr.RemoteAddr = fmt.Sprintf(":%d", cfg.RemotePort) + case *config.HttpProxyConf: + if cfg.LocalPort != 0 { + psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) + } + psr.Plugin = cfg.Plugin + case *config.HttpsProxyConf: + if cfg.LocalPort != 0 { + psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) + } + psr.Plugin = cfg.Plugin + case *config.StcpProxyConf: + if cfg.LocalPort != 0 { + psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) + } + psr.Plugin = cfg.Plugin + case *config.XtcpProxyConf: + if cfg.LocalPort != 0 { + psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) + } + psr.Plugin = cfg.Plugin + } + return psr +} + +// api/status +func (svr *Service) apiStatus(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + var ( + buf []byte + res StatusResp + ) + res.Tcp = make([]ProxyStatusResp, 0) + res.Udp = make([]ProxyStatusResp, 0) + res.Http = make([]ProxyStatusResp, 0) + res.Https = make([]ProxyStatusResp, 0) + res.Stcp = make([]ProxyStatusResp, 0) + res.Xtcp = make([]ProxyStatusResp, 0) + defer func() { + log.Info("Http response [/api/status]") + buf, _ = json.Marshal(&res) + w.Write(buf) + }() + + log.Info("Http request: [/api/status]") + + ps := svr.ctl.pm.GetAllProxyStatus() + for _, status := range ps { + switch status.Type { + case "tcp": + res.Tcp = append(res.Tcp, NewProxyStatusResp(status)) + case "udp": + res.Udp = append(res.Udp, NewProxyStatusResp(status)) + case "http": + res.Http = append(res.Http, NewProxyStatusResp(status)) + case "https": + res.Https = append(res.Https, NewProxyStatusResp(status)) + case "stcp": + res.Stcp = append(res.Stcp, NewProxyStatusResp(status)) + case "xtcp": + res.Xtcp = append(res.Xtcp, NewProxyStatusResp(status)) + } + } + sort.Sort(ByProxyStatusResp(res.Tcp)) + sort.Sort(ByProxyStatusResp(res.Udp)) + sort.Sort(ByProxyStatusResp(res.Http)) + sort.Sort(ByProxyStatusResp(res.Https)) + sort.Sort(ByProxyStatusResp(res.Stcp)) + sort.Sort(ByProxyStatusResp(res.Xtcp)) + return +} diff --git a/client/control.go b/client/control.go index 788621cd..65bec393 100644 --- a/client/control.go +++ b/client/control.go @@ -24,9 +24,9 @@ import ( "github.com/fatedier/frp/models/config" "github.com/fatedier/frp/models/msg" "github.com/fatedier/frp/utils/crypto" - "github.com/fatedier/frp/utils/errors" "github.com/fatedier/frp/utils/log" frpNet "github.com/fatedier/frp/utils/net" + "github.com/fatedier/frp/utils/shutdown" "github.com/fatedier/frp/utils/util" "github.com/fatedier/frp/utils/version" "github.com/xtaci/smux" @@ -40,20 +40,10 @@ type Control struct { // frpc service svr *Service - // login message to server + // login message to server, only used loginMsg *msg.Login - // proxy configures - pxyCfgs map[string]config.ProxyConf - - // proxies - proxies map[string]Proxy - - // visitor configures - visitorCfgs map[string]config.ProxyConf - - // visitors - visitors map[string]Visitor + pm *ProxyManager // control connection conn frpNet.Conn @@ -79,6 +69,10 @@ type Control struct { // last time got the Pong message lastPong time.Time + readerShutdown *shutdown.Shutdown + writerShutdown *shutdown.Shutdown + msgHandlerShutdown *shutdown.Shutdown + mu sync.RWMutex log.Logger @@ -92,28 +86,22 @@ func NewControl(svr *Service, pxyCfgs map[string]config.ProxyConf, visitorCfgs m User: config.ClientCommonCfg.User, Version: version.Full(), } - return &Control{ - svr: svr, - loginMsg: loginMsg, - pxyCfgs: pxyCfgs, - visitorCfgs: visitorCfgs, - proxies: make(map[string]Proxy), - visitors: make(map[string]Visitor), - sendCh: make(chan msg.Message, 10), - readCh: make(chan msg.Message, 10), - closedCh: make(chan int), - Logger: log.NewPrefixLogger(""), + ctl := &Control{ + svr: svr, + loginMsg: loginMsg, + sendCh: make(chan msg.Message, 10), + readCh: make(chan msg.Message, 10), + closedCh: make(chan int), + readerShutdown: shutdown.New(), + writerShutdown: shutdown.New(), + msgHandlerShutdown: shutdown.New(), + Logger: log.NewPrefixLogger(""), } + ctl.pm = NewProxyManager(ctl, ctl.sendCh, "") + ctl.pm.Reload(pxyCfgs, visitorCfgs) + return ctl } -// 1. login -// 2. start reader() writer() manager() -// 3. connection closed -// 4. In reader(): close closedCh and exit, controler() get it -// 5. In controler(): close readCh and sendCh, manager() and writer() will exit -// 6. In controler(): ini readCh, sendCh, closedCh -// 7. In controler(): start new reader(), writer(), manager() -// controler() will keep running func (ctl *Control) Run() (err error) { for { err = ctl.login() @@ -125,47 +113,29 @@ func (ctl *Control) Run() (err error) { if config.ClientCommonCfg.LoginFailExit { return } else { - time.Sleep(30 * time.Second) + time.Sleep(10 * time.Second) } } else { break } } - go ctl.controler() - go ctl.manager() - go ctl.writer() - go ctl.reader() + go ctl.worker() - // start all local visitors - for _, cfg := range ctl.visitorCfgs { - visitor := NewVisitor(ctl, cfg) - err = visitor.Run() - if err != nil { - visitor.Warn("start error: %v", err) - continue - } - ctl.visitors[cfg.GetName()] = visitor - visitor.Info("start visitor success") - } - - // send NewProxy message for all configured proxies - for _, cfg := range ctl.pxyCfgs { - var newProxyMsg msg.NewProxy - cfg.UnMarshalToMsg(&newProxyMsg) - ctl.sendCh <- &newProxyMsg - } + // start all local visitors and send NewProxy message for all configured proxies + ctl.pm.Reset(ctl.sendCh, ctl.runId) + ctl.pm.CheckAndStartProxy() return nil } -func (ctl *Control) NewWorkConn() { +func (ctl *Control) HandleReqWorkConn(inMsg *msg.ReqWorkConn) { workConn, err := ctl.connectServer() if err != nil { return } m := &msg.NewWorkConn{ - RunId: ctl.getRunId(), + RunId: ctl.runId, } if err = msg.WriteMsg(workConn, m); err != nil { ctl.Warn("work connection write to server error: %v", err) @@ -182,33 +152,26 @@ func (ctl *Control) NewWorkConn() { workConn.AddLogPrefix(startMsg.ProxyName) // dispatch this work connection to related proxy - pxy, ok := ctl.getProxy(startMsg.ProxyName) - if ok { - workConn.Debug("start a new work connection, localAddr: %s remoteAddr: %s", workConn.LocalAddr().String(), workConn.RemoteAddr().String()) - go pxy.InWorkConn(workConn) + ctl.pm.HandleWorkConn(startMsg.ProxyName, workConn) +} + +func (ctl *Control) HandleNewProxyResp(inMsg *msg.NewProxyResp) { + // Server will return NewProxyResp message to each NewProxy message. + // Start a new proxy handler if no error got + err := ctl.pm.StartProxy(inMsg.ProxyName, inMsg.Error) + if err != nil { + ctl.Warn("[%s] start error: %v", inMsg.ProxyName, err) } else { - workConn.Close() + ctl.Info("[%s] start proxy success", inMsg.ProxyName) } } func (ctl *Control) Close() error { ctl.mu.Lock() + defer ctl.mu.Unlock() ctl.exit = true - err := errors.PanicToError(func() { - for name, _ := range ctl.proxies { - ctl.sendCh <- &msg.CloseProxy{ - ProxyName: name, - } - } - }) - ctl.mu.Unlock() - return err -} - -func (ctl *Control) init() { - ctl.sendCh = make(chan msg.Message, 10) - ctl.readCh = make(chan msg.Message, 10) - ctl.closedCh = make(chan int) + ctl.pm.CloseProxies() + return nil } // login send a login message to server and wait for a loginResp message. @@ -249,7 +212,7 @@ func (ctl *Control) login() (err error) { now := time.Now().Unix() ctl.loginMsg.PrivilegeKey = util.GetAuthKey(config.ClientCommonCfg.PrivilegeToken, now) ctl.loginMsg.Timestamp = now - ctl.loginMsg.RunId = ctl.getRunId() + ctl.loginMsg.RunId = ctl.runId if err = msg.WriteMsg(conn, ctl.loginMsg); err != nil { return err @@ -270,16 +233,11 @@ func (ctl *Control) login() (err error) { ctl.conn = conn // update runId got from server - ctl.setRunId(loginRespMsg.RunId) + ctl.runId = loginRespMsg.RunId config.ClientCommonCfg.ServerUdpPort = loginRespMsg.ServerUdpPort ctl.ClearLogPrefix() ctl.AddLogPrefix(loginRespMsg.RunId) ctl.Info("login to server success, get run id [%s], server udp port [%d]", loginRespMsg.RunId, loginRespMsg.ServerUdpPort) - - // login success, so we let closedCh available again - ctl.closedCh = make(chan int) - ctl.lastPong = time.Now() - return nil } @@ -292,7 +250,6 @@ func (ctl *Control) connectServer() (conn frpNet.Conn, err error) { return } conn = frpNet.WrapConn(stream) - } else { conn, err = frpNet.ConnectServerByHttpProxy(config.ClientCommonCfg.HttpProxy, config.ClientCommonCfg.Protocol, fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, config.ClientCommonCfg.ServerPort)) @@ -304,12 +261,14 @@ func (ctl *Control) connectServer() (conn frpNet.Conn, err error) { return } +// reader read all messages from frps and send to readCh func (ctl *Control) reader() { defer func() { if err := recover(); err != nil { ctl.Error("panic error: %v", err) } }() + defer ctl.readerShutdown.Done() defer close(ctl.closedCh) encReader := crypto.NewReader(ctl.conn, []byte(config.ClientCommonCfg.PrivilegeToken)) @@ -328,7 +287,9 @@ func (ctl *Control) reader() { } } +// writer writes messages got from sendCh to frps func (ctl *Control) writer() { + defer ctl.writerShutdown.Done() encWriter, err := crypto.NewWriter(ctl.conn, []byte(config.ClientCommonCfg.PrivilegeToken)) if err != nil { ctl.conn.Error("crypto new writer error: %v", err) @@ -348,19 +309,22 @@ func (ctl *Control) writer() { } } -// manager handles all channel events and do corresponding process -func (ctl *Control) manager() { +// msgHandler handles all channel events and do corresponding operations. +func (ctl *Control) msgHandler() { defer func() { if err := recover(); err != nil { ctl.Error("panic error: %v", err) } }() + defer ctl.msgHandlerShutdown.Done() hbSend := time.NewTicker(time.Duration(config.ClientCommonCfg.HeartBeatInterval) * time.Second) defer hbSend.Stop() hbCheck := time.NewTicker(time.Second) defer hbCheck.Stop() + ctl.lastPong = time.Now() + for { select { case <-hbSend.C: @@ -381,35 +345,9 @@ func (ctl *Control) manager() { switch m := rawMsg.(type) { case *msg.ReqWorkConn: - go ctl.NewWorkConn() + go ctl.HandleReqWorkConn(m) case *msg.NewProxyResp: - // Server will return NewProxyResp message to each NewProxy message. - // Start a new proxy handler if no error got - if m.Error != "" { - ctl.Warn("[%s] start error: %s", m.ProxyName, m.Error) - continue - } - cfg, ok := ctl.getProxyConf(m.ProxyName) - if !ok { - // it will never go to this branch now - ctl.Warn("[%s] no proxy conf found", m.ProxyName) - continue - } - - oldPxy, ok := ctl.getProxy(m.ProxyName) - if ok { - oldPxy.Close() - } - pxy := NewProxy(ctl, cfg) - if err := pxy.Run(); err != nil { - ctl.Warn("[%s] proxy start running error: %v", m.ProxyName, err) - ctl.sendCh <- &msg.CloseProxy{ - ProxyName: m.ProxyName, - } - continue - } - ctl.addProxy(m.ProxyName, pxy) - ctl.Info("[%s] start proxy success", m.ProxyName) + ctl.HandleNewProxyResp(m) case *msg.Pong: ctl.lastPong = time.Now() ctl.Debug("receive heartbeat from server") @@ -419,10 +357,14 @@ func (ctl *Control) manager() { } // controler keep watching closedCh, start a new connection if previous control connection is closed. -// If controler is notified by closedCh, reader and writer and manager will exit, then recall these functions. -func (ctl *Control) controler() { +// If controler is notified by closedCh, reader and writer and handler will exit, then recall these functions. +func (ctl *Control) worker() { + go ctl.msgHandler() + go ctl.writer() + go ctl.reader() + var err error - maxDelayTime := 30 * time.Second + maxDelayTime := 20 * time.Second delayTime := time.Second checkInterval := 10 * time.Second @@ -430,41 +372,20 @@ func (ctl *Control) controler() { for { select { case <-checkProxyTicker.C: - // Every 10 seconds, check which proxy registered failed and reregister it to server. - ctl.mu.RLock() - for _, cfg := range ctl.pxyCfgs { - if _, exist := ctl.proxies[cfg.GetName()]; !exist { - ctl.Info("try to register proxy [%s]", cfg.GetName()) - var newProxyMsg msg.NewProxy - cfg.UnMarshalToMsg(&newProxyMsg) - ctl.sendCh <- &newProxyMsg - } - } - - for _, cfg := range ctl.visitorCfgs { - if _, exist := ctl.visitors[cfg.GetName()]; !exist { - ctl.Info("try to start visitor [%s]", cfg.GetName()) - visitor := NewVisitor(ctl, cfg) - err = visitor.Run() - if err != nil { - visitor.Warn("start error: %v", err) - continue - } - ctl.visitors[cfg.GetName()] = visitor - visitor.Info("start visitor success") - } - } - ctl.mu.RUnlock() + // every 10 seconds, check which proxy registered failed and reregister it to server + ctl.pm.CheckAndStartProxy() case _, ok := <-ctl.closedCh: // we won't get any variable from this channel if !ok { - // close related channels + // close related channels and wait until other goroutines done close(ctl.readCh) - close(ctl.sendCh) + ctl.readerShutdown.WaitDone() + ctl.msgHandlerShutdown.WaitDone() - for _, pxy := range ctl.proxies { - pxy.Close() - } + close(ctl.sendCh) + ctl.writerShutdown.WaitDone() + + ctl.pm.CloseProxies() // if ctl.exit is true, just exit ctl.mu.RLock() exit := ctl.exit @@ -473,9 +394,7 @@ func (ctl *Control) controler() { return } - time.Sleep(time.Second) - - // loop util reconnect to server success + // loop util reconnecting to server success for { ctl.Info("try to reconnect to server...") err = ctl.login() @@ -488,27 +407,27 @@ func (ctl *Control) controler() { } continue } - // reconnect success, init the delayTime + // reconnect success, init delayTime delayTime = time.Second break } // init related channels and variables - ctl.init() + ctl.sendCh = make(chan msg.Message, 10) + ctl.readCh = make(chan msg.Message, 10) + ctl.closedCh = make(chan int) + ctl.readerShutdown = shutdown.New() + ctl.writerShutdown = shutdown.New() + ctl.msgHandlerShutdown = shutdown.New() + ctl.pm.Reset(ctl.sendCh, ctl.runId) // previous work goroutines should be closed and start them here - go ctl.manager() + go ctl.msgHandler() go ctl.writer() go ctl.reader() - // send NewProxy message for all configured proxies - ctl.mu.RLock() - for _, cfg := range ctl.pxyCfgs { - var newProxyMsg msg.NewProxy - cfg.UnMarshalToMsg(&newProxyMsg) - ctl.sendCh <- &newProxyMsg - } - ctl.mu.RUnlock() + // start all configured proxies + ctl.pm.CheckAndStartProxy() checkProxyTicker.Stop() checkProxyTicker = time.NewTicker(checkInterval) @@ -517,106 +436,7 @@ func (ctl *Control) controler() { } } -func (ctl *Control) setRunId(runId string) { - ctl.mu.Lock() - defer ctl.mu.Unlock() - ctl.runId = runId -} - -func (ctl *Control) getRunId() string { - ctl.mu.RLock() - defer ctl.mu.RUnlock() - return ctl.runId -} - -func (ctl *Control) getProxy(name string) (pxy Proxy, ok bool) { - ctl.mu.RLock() - defer ctl.mu.RUnlock() - pxy, ok = ctl.proxies[name] - return -} - -func (ctl *Control) addProxy(name string, pxy Proxy) { - ctl.mu.Lock() - defer ctl.mu.Unlock() - ctl.proxies[name] = pxy -} - -func (ctl *Control) getProxyConf(name string) (conf config.ProxyConf, ok bool) { - ctl.mu.RLock() - defer ctl.mu.RUnlock() - conf, ok = ctl.pxyCfgs[name] - return -} - -func (ctl *Control) reloadConf(pxyCfgs map[string]config.ProxyConf, visitorCfgs map[string]config.ProxyConf) { - ctl.mu.Lock() - defer ctl.mu.Unlock() - - removedPxyNames := make([]string, 0) - for name, oldCfg := range ctl.pxyCfgs { - del := false - cfg, ok := pxyCfgs[name] - if !ok { - del = true - } else { - if !oldCfg.Compare(cfg) { - del = true - } - } - - if del { - removedPxyNames = append(removedPxyNames, name) - delete(ctl.pxyCfgs, name) - if pxy, ok := ctl.proxies[name]; ok { - pxy.Close() - } - delete(ctl.proxies, name) - ctl.sendCh <- &msg.CloseProxy{ - ProxyName: name, - } - } - } - ctl.Info("proxy removed: %v", removedPxyNames) - - addedPxyNames := make([]string, 0) - for name, cfg := range pxyCfgs { - if _, ok := ctl.pxyCfgs[name]; !ok { - ctl.pxyCfgs[name] = cfg - addedPxyNames = append(addedPxyNames, name) - } - } - ctl.Info("proxy added: %v", addedPxyNames) - - removedVisitorName := make([]string, 0) - for name, oldVisitorCfg := range ctl.visitorCfgs { - del := false - cfg, ok := visitorCfgs[name] - if !ok { - del = true - } else { - if !oldVisitorCfg.Compare(cfg) { - del = true - } - } - - if del { - removedVisitorName = append(removedVisitorName, name) - delete(ctl.visitorCfgs, name) - if visitor, ok := ctl.visitors[name]; ok { - visitor.Close() - } - delete(ctl.visitors, name) - } - } - ctl.Info("visitor removed: %v", removedVisitorName) - - addedVisitorName := make([]string, 0) - for name, visitorCfg := range visitorCfgs { - if _, ok := ctl.visitorCfgs[name]; !ok { - ctl.visitorCfgs[name] = visitorCfg - addedVisitorName = append(addedVisitorName, name) - } - } - ctl.Info("visitor added: %v", addedVisitorName) +func (ctl *Control) reloadConf(pxyCfgs map[string]config.ProxyConf, visitorCfgs map[string]config.ProxyConf) error { + err := ctl.pm.Reload(pxyCfgs, visitorCfgs) + return err } diff --git a/client/proxy.go b/client/proxy.go index 0b26bf41..4d07cc63 100644 --- a/client/proxy.go +++ b/client/proxy.go @@ -39,13 +39,13 @@ type Proxy interface { // InWorkConn accept work connections registered to server. InWorkConn(conn frpNet.Conn) + Close() log.Logger } -func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy) { +func NewProxy(pxyConf config.ProxyConf) (pxy Proxy) { baseProxy := BaseProxy{ - ctl: ctl, Logger: log.NewPrefixLogger(pxyConf.GetName()), } switch cfg := pxyConf.(type) { @@ -84,7 +84,6 @@ func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy) { } type BaseProxy struct { - ctl *Control closed bool mu sync.RWMutex log.Logger diff --git a/client/proxy_manager.go b/client/proxy_manager.go new file mode 100644 index 00000000..d756b5ba --- /dev/null +++ b/client/proxy_manager.go @@ -0,0 +1,340 @@ +package client + +import ( + "fmt" + "sync" + + "github.com/fatedier/frp/models/config" + "github.com/fatedier/frp/models/msg" + "github.com/fatedier/frp/utils/errors" + "github.com/fatedier/frp/utils/log" + frpNet "github.com/fatedier/frp/utils/net" +) + +const ( + ProxyStatusNew = "new" + ProxyStatusStartErr = "start error" + ProxyStatusRunning = "running" + ProxyStatusClosed = "closed" +) + +type ProxyManager struct { + ctl *Control + + proxies map[string]*ProxyWrapper + + visitorCfgs map[string]config.ProxyConf + visitors map[string]Visitor + + sendCh chan (msg.Message) + + closed bool + mu sync.RWMutex + + log.Logger +} + +type ProxyWrapper struct { + Name string + Type string + Status string + Err string + Cfg config.ProxyConf + + pxy Proxy + + mu sync.RWMutex +} + +type ProxyStatus struct { + Name string `json:"name"` + Type string `json:"type"` + Status string `json:"status"` + Err string `json:"err"` + Cfg config.ProxyConf `json:"cfg"` +} + +func NewProxyWrapper(cfg config.ProxyConf) *ProxyWrapper { + return &ProxyWrapper{ + Name: cfg.GetName(), + Type: cfg.GetType(), + Status: ProxyStatusNew, + Cfg: cfg, + pxy: nil, + } +} + +func (pw *ProxyWrapper) IsRunning() bool { + pw.mu.RLock() + defer pw.mu.RUnlock() + if pw.Status == ProxyStatusRunning { + return true + } else { + return false + } +} + +func (pw *ProxyWrapper) GetStatus() *ProxyStatus { + pw.mu.RLock() + defer pw.mu.RUnlock() + ps := &ProxyStatus{ + Name: pw.Name, + Type: pw.Type, + Status: pw.Status, + Err: pw.Err, + Cfg: pw.Cfg, + } + return ps +} + +func (pw *ProxyWrapper) Start(serverRespErr string) error { + if pw.pxy != nil { + pw.pxy.Close() + pw.pxy = nil + } + + if serverRespErr != "" { + pw.mu.Lock() + pw.Status = ProxyStatusStartErr + pw.Err = serverRespErr + pw.mu.Unlock() + return fmt.Errorf(serverRespErr) + } + + pxy := NewProxy(pw.Cfg) + pw.mu.Lock() + defer pw.mu.Unlock() + if err := pxy.Run(); err != nil { + pw.Status = ProxyStatusStartErr + pw.Err = err.Error() + return err + } + pw.Status = ProxyStatusRunning + pw.Err = "" + pw.pxy = pxy + return nil +} + +func (pw *ProxyWrapper) InWorkConn(workConn frpNet.Conn) { + pw.mu.RLock() + pxy := pw.pxy + pw.mu.RUnlock() + if pxy != nil { + workConn.Debug("start a new work connection, localAddr: %s remoteAddr: %s", workConn.LocalAddr().String(), workConn.RemoteAddr().String()) + go pxy.InWorkConn(workConn) + } else { + workConn.Close() + } +} + +func (pw *ProxyWrapper) Close() { + pw.mu.Lock() + defer pw.mu.Unlock() + if pw.pxy != nil { + pw.pxy.Close() + pw.pxy = nil + } + pw.Status = ProxyStatusClosed +} + +func NewProxyManager(ctl *Control, msgSendCh chan (msg.Message), logPrefix string) *ProxyManager { + return &ProxyManager{ + proxies: make(map[string]*ProxyWrapper), + visitorCfgs: make(map[string]config.ProxyConf), + visitors: make(map[string]Visitor), + sendCh: msgSendCh, + closed: false, + Logger: log.NewPrefixLogger(logPrefix), + } +} + +func (pm *ProxyManager) Reset(msgSendCh chan (msg.Message), logPrefix string) { + pm.mu.Lock() + defer pm.mu.Unlock() + pm.closed = false + pm.sendCh = msgSendCh + pm.ClearLogPrefix() + pm.AddLogPrefix(logPrefix) +} + +// Must hold the lock before calling this function. +func (pm *ProxyManager) sendMsg(m msg.Message) error { + err := errors.PanicToError(func() { + pm.sendCh <- m + }) + if err != nil { + pm.closed = true + } + return err +} + +func (pm *ProxyManager) StartProxy(name string, serverRespErr string) error { + pm.mu.Lock() + defer pm.mu.Unlock() + if pm.closed { + return fmt.Errorf("ProxyManager is closed now") + } + + pxy, ok := pm.proxies[name] + if !ok { + return fmt.Errorf("no proxy found") + } + + if err := pxy.Start(serverRespErr); err != nil { + errRet := err + err = pm.sendMsg(&msg.CloseProxy{ + ProxyName: name, + }) + if err != nil { + errRet = fmt.Errorf("send CloseProxy message error") + } + return errRet + } + return nil +} + +func (pm *ProxyManager) CloseProxies() { + pm.mu.RLock() + defer pm.mu.RUnlock() + for _, pxy := range pm.proxies { + pxy.Close() + } +} + +func (pm *ProxyManager) CheckAndStartProxy() { + pm.mu.RLock() + defer pm.mu.RUnlock() + if pm.closed { + pm.Warn("CheckAndStartProxy error: ProxyManager is closed now") + return + } + + for _, pxy := range pm.proxies { + if !pxy.IsRunning() { + var newProxyMsg msg.NewProxy + pxy.Cfg.UnMarshalToMsg(&newProxyMsg) + err := pm.sendMsg(&newProxyMsg) + if err != nil { + pm.Warn("[%s] proxy send NewProxy message error") + return + } + } + } + + for _, cfg := range pm.visitorCfgs { + if _, exist := pm.visitors[cfg.GetName()]; !exist { + pm.Info("try to start visitor [%s]", cfg.GetName()) + visitor := NewVisitor(pm.ctl, cfg) + err := visitor.Run() + if err != nil { + visitor.Warn("start error: %v", err) + continue + } + pm.visitors[cfg.GetName()] = visitor + visitor.Info("start visitor success") + } + } +} + +func (pm *ProxyManager) Reload(pxyCfgs map[string]config.ProxyConf, visitorCfgs map[string]config.ProxyConf) error { + pm.mu.Lock() + defer pm.mu.Unlock() + if pm.closed { + err := fmt.Errorf("Reload error: ProxyManager is closed now") + pm.Warn(err.Error()) + return err + } + + delPxyNames := make([]string, 0) + for name, pxy := range pm.proxies { + del := false + cfg, ok := pxyCfgs[name] + if !ok { + del = true + } else { + if !pxy.Cfg.Compare(cfg) { + del = true + } + } + + if del { + delPxyNames = append(delPxyNames, name) + delete(pm.proxies, name) + + pxy.Close() + err := pm.sendMsg(&msg.CloseProxy{ + ProxyName: name, + }) + if err != nil { + err = fmt.Errorf("Reload error: ProxyManager is closed now") + pm.Warn(err.Error()) + return err + } + } + } + pm.Info("proxy removed: %v", delPxyNames) + + addPxyNames := make([]string, 0) + for name, cfg := range pxyCfgs { + if _, ok := pm.proxies[name]; !ok { + pxy := NewProxyWrapper(cfg) + pm.proxies[name] = pxy + addPxyNames = append(addPxyNames, name) + } + } + pm.Info("proxy added: %v", addPxyNames) + + delVisitorName := make([]string, 0) + for name, oldVisitorCfg := range pm.visitorCfgs { + del := false + cfg, ok := visitorCfgs[name] + if !ok { + del = true + } else { + if !oldVisitorCfg.Compare(cfg) { + del = true + } + } + + if del { + delVisitorName = append(delVisitorName, name) + delete(pm.visitorCfgs, name) + if visitor, ok := pm.visitors[name]; ok { + visitor.Close() + } + delete(pm.visitors, name) + } + } + pm.Info("visitor removed: %v", delVisitorName) + + addVisitorName := make([]string, 0) + for name, visitorCfg := range visitorCfgs { + if _, ok := pm.visitorCfgs[name]; !ok { + pm.visitorCfgs[name] = visitorCfg + addVisitorName = append(addVisitorName, name) + } + } + pm.Info("visitor added: %v", addVisitorName) + return nil +} + +func (pm *ProxyManager) HandleWorkConn(name string, workConn frpNet.Conn) { + pm.mu.RLock() + pw, ok := pm.proxies[name] + pm.mu.RUnlock() + if ok { + pw.InWorkConn(workConn) + } else { + workConn.Close() + } +} + +func (pm *ProxyManager) GetAllProxyStatus() []*ProxyStatus { + ps := make([]*ProxyStatus, 0) + pm.mu.RLock() + defer pm.mu.RUnlock() + for _, pxy := range pm.proxies { + ps = append(ps, pxy.GetStatus()) + } + return ps +} diff --git a/client/service.go b/client/service.go index 49c78486..c5a2f1e4 100644 --- a/client/service.go +++ b/client/service.go @@ -53,6 +53,6 @@ func (svr *Service) Run() error { return nil } -func (svr *Service) Close() error { - return svr.ctl.Close() +func (svr *Service) Close() { + svr.ctl.Close() } diff --git a/cmd/frpc/main.go b/cmd/frpc/main.go index f0d438f8..f1836db2 100644 --- a/cmd/frpc/main.go +++ b/cmd/frpc/main.go @@ -28,6 +28,7 @@ import ( "time" docopt "github.com/docopt/docopt-go" + "github.com/rodaine/table" ini "github.com/vaughan0/go-ini" "github.com/fatedier/frp/client" @@ -44,7 +45,8 @@ var usage string = `frpc is the client of frp Usage: frpc [-c config_file] [-L log_file] [--log-level=] [--server-addr=] - frpc [-c config_file] --reload + frpc reload [-c config_file] + frpc status [-c config_file] frpc -h | --help frpc -v | --version @@ -53,7 +55,6 @@ Options: -L log_file set output log file, including console --log-level= set log level: debug, info, warn, error --server-addr= addr which frps is listening for, example: 0.0.0.0:7000 - --reload reload configure file without program exit -h --help show this screen -v --version show version ` @@ -82,40 +83,25 @@ func main() { config.ClientCommonCfg.ConfigFile = confFile // check if reload command - if args["--reload"] != nil { - if args["--reload"].(bool) { - req, err := http.NewRequest("GET", "http://"+ - config.ClientCommonCfg.AdminAddr+":"+fmt.Sprintf("%d", config.ClientCommonCfg.AdminPort)+"/api/reload", nil) - if err != nil { + if args["reload"] != nil { + if args["reload"].(bool) { + if err = CmdReload(); err != nil { fmt.Printf("frps reload error: %v\n", err) os.Exit(1) + } else { + fmt.Printf("reload success\n") + os.Exit(0) } + } + } - authStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(config.ClientCommonCfg.AdminUser+":"+ - config.ClientCommonCfg.AdminPwd)) - - req.Header.Add("Authorization", authStr) - resp, err := http.DefaultClient.Do(req) - if err != nil { - fmt.Printf("frpc reload error: %v\n", err) + // check if status command + if args["status"] != nil { + if args["status"].(bool) { + if err = CmdStatus(); err != nil { + fmt.Println("frps get status error: %v\n", err) os.Exit(1) } else { - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - fmt.Printf("frpc reload error: %v\n", err) - os.Exit(1) - } - res := &client.GeneralResponse{} - err = json.Unmarshal(body, &res) - if err != nil { - fmt.Printf("http response error: %s\n", strings.TrimSpace(string(body))) - os.Exit(1) - } else if res.Code != 0 { - fmt.Printf("reload error: %s\n", res.Msg) - os.Exit(1) - } - fmt.Printf("reload success\n") os.Exit(0) } } @@ -187,3 +173,133 @@ func HandleSignal(svr *client.Service) { time.Sleep(250 * time.Millisecond) os.Exit(0) } + +func CmdReload() error { + if config.ClientCommonCfg.AdminPort == 0 { + return fmt.Errorf("admin_port shoud be set if you want to use reload feature") + } + + req, err := http.NewRequest("GET", "http://"+ + config.ClientCommonCfg.AdminAddr+":"+fmt.Sprintf("%d", config.ClientCommonCfg.AdminPort)+"/api/reload", nil) + if err != nil { + return err + } + + authStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(config.ClientCommonCfg.AdminUser+":"+ + config.ClientCommonCfg.AdminPwd)) + + req.Header.Add("Authorization", authStr) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } else { + if resp.StatusCode != 200 { + return fmt.Errorf("admin api status code [%d]", resp.StatusCode) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + res := &client.GeneralResponse{} + err = json.Unmarshal(body, &res) + if err != nil { + return fmt.Errorf("unmarshal http response error: %s", strings.TrimSpace(string(body))) + } else if res.Code != 0 { + return fmt.Errorf(res.Msg) + } + } + return nil +} + +func CmdStatus() error { + if config.ClientCommonCfg.AdminPort == 0 { + return fmt.Errorf("admin_port shoud be set if you want to get proxy status") + } + + req, err := http.NewRequest("GET", "http://"+ + config.ClientCommonCfg.AdminAddr+":"+fmt.Sprintf("%d", config.ClientCommonCfg.AdminPort)+"/api/status", nil) + if err != nil { + return err + } + + authStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(config.ClientCommonCfg.AdminUser+":"+ + config.ClientCommonCfg.AdminPwd)) + + req.Header.Add("Authorization", authStr) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } else { + if resp.StatusCode != 200 { + return fmt.Errorf("admin api status code [%d]", resp.StatusCode) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + res := &client.StatusResp{} + err = json.Unmarshal(body, &res) + if err != nil { + return fmt.Errorf("unmarshal http response error: %s", strings.TrimSpace(string(body))) + } + + fmt.Println("Proxy Status...") + if len(res.Tcp) > 0 { + fmt.Printf("TCP") + tbl := table.New("Name", "Status", "LocalAddr", "Plugin", "RemoteAddr", "Error") + for _, ps := range res.Tcp { + tbl.AddRow(ps.Name, ps.Status, ps.LocalAddr, ps.Plugin, ps.RemoteAddr, ps.Err) + } + tbl.Print() + fmt.Println("") + } + if len(res.Udp) > 0 { + fmt.Printf("UDP") + tbl := table.New("Name", "Status", "LocalAddr", "Plugin", "RemoteAddr", "Error") + for _, ps := range res.Udp { + tbl.AddRow(ps.Name, ps.Status, ps.LocalAddr, ps.Plugin, ps.RemoteAddr, ps.Err) + } + tbl.Print() + fmt.Println("") + } + if len(res.Http) > 0 { + fmt.Printf("HTTP") + tbl := table.New("Name", "Status", "LocalAddr", "Plugin", "RemoteAddr", "Error") + for _, ps := range res.Http { + tbl.AddRow(ps.Name, ps.Status, ps.LocalAddr, ps.Plugin, ps.RemoteAddr, ps.Err) + } + tbl.Print() + fmt.Println("") + } + if len(res.Https) > 0 { + fmt.Printf("HTTPS") + tbl := table.New("Name", "Status", "LocalAddr", "Plugin", "RemoteAddr", "Error") + for _, ps := range res.Https { + tbl.AddRow(ps.Name, ps.Status, ps.LocalAddr, ps.Plugin, ps.RemoteAddr, ps.Err) + } + tbl.Print() + fmt.Println("") + } + if len(res.Stcp) > 0 { + fmt.Printf("STCP") + tbl := table.New("Name", "Status", "LocalAddr", "Plugin", "RemoteAddr", "Error") + for _, ps := range res.Stcp { + tbl.AddRow(ps.Name, ps.Status, ps.LocalAddr, ps.Plugin, ps.RemoteAddr, ps.Err) + } + tbl.Print() + fmt.Println("") + } + if len(res.Xtcp) > 0 { + fmt.Printf("XTCP") + tbl := table.New("Name", "Status", "LocalAddr", "Plugin", "RemoteAddr", "Error") + for _, ps := range res.Xtcp { + tbl.AddRow(ps.Name, ps.Status, ps.LocalAddr, ps.Plugin, ps.RemoteAddr, ps.Err) + } + tbl.Print() + fmt.Println("") + } + } + return nil +} diff --git a/models/config/proxy.go b/models/config/proxy.go index ce4c1c2d..e87b7eca 100644 --- a/models/config/proxy.go +++ b/models/config/proxy.go @@ -52,6 +52,7 @@ func NewConfByType(proxyType string) ProxyConf { type ProxyConf interface { GetName() string + GetType() string GetBaseInfo() *BaseProxyConf LoadFromMsg(pMsg *msg.NewProxy) LoadFromFile(name string, conf ini.Section) error @@ -103,6 +104,10 @@ func (cfg *BaseProxyConf) GetName() string { return cfg.ProxyName } +func (cfg *BaseProxyConf) GetType() string { + return cfg.ProxyType +} + func (cfg *BaseProxyConf) GetBaseInfo() *BaseProxyConf { return cfg } diff --git a/server/control.go b/server/control.go index 2833277b..5e8fe95e 100644 --- a/server/control.go +++ b/server/control.go @@ -253,13 +253,13 @@ func (ctl *Control) stoper() { ctl.allShutdown.WaitStart() close(ctl.readCh) - ctl.managerShutdown.WaitDown() + ctl.managerShutdown.WaitDone() close(ctl.sendCh) - ctl.writerShutdown.WaitDown() + ctl.writerShutdown.WaitDone() ctl.conn.Close() - ctl.readerShutdown.WaitDown() + ctl.readerShutdown.WaitDone() close(ctl.workConnCh) for workConn := range ctl.workConnCh { diff --git a/server/service.go b/server/service.go index 5997f3dc..a510b179 100644 --- a/server/service.go +++ b/server/service.go @@ -283,7 +283,7 @@ func (svr *Service) RegisterControl(ctlConn frpNet.Conn, loginMsg *msg.Login) (e ctl := NewControl(svr, ctlConn, loginMsg) if oldCtl := svr.ctlManager.Add(loginMsg.RunId, ctl); oldCtl != nil { - oldCtl.allShutdown.WaitDown() + oldCtl.allShutdown.WaitDone() } ctlConn.AddLogPrefix(loginMsg.RunId) diff --git a/utils/shutdown/shutdown.go b/utils/shutdown/shutdown.go index cdd87268..7fd7bfcc 100644 --- a/utils/shutdown/shutdown.go +++ b/utils/shutdown/shutdown.go @@ -19,19 +19,19 @@ import ( ) type Shutdown struct { - doing bool - ending bool - start chan struct{} - down chan struct{} - mu sync.Mutex + doing bool + ending bool + startCh chan struct{} + doneCh chan struct{} + mu sync.Mutex } func New() *Shutdown { return &Shutdown{ - doing: false, - ending: false, - start: make(chan struct{}), - down: make(chan struct{}), + doing: false, + ending: false, + startCh: make(chan struct{}), + doneCh: make(chan struct{}), } } @@ -40,12 +40,12 @@ func (s *Shutdown) Start() { defer s.mu.Unlock() if !s.doing { s.doing = true - close(s.start) + close(s.startCh) } } func (s *Shutdown) WaitStart() { - <-s.start + <-s.startCh } func (s *Shutdown) Done() { @@ -53,10 +53,10 @@ func (s *Shutdown) Done() { defer s.mu.Unlock() if !s.ending { s.ending = true - close(s.down) + close(s.doneCh) } } -func (s *Shutdown) WaitDown() { - <-s.down +func (s *Shutdown) WaitDone() { + <-s.doneCh } From afde0c515cb00454bdfe55ed2095a0f7dea814ab Mon Sep 17 00:00:00 2001 From: fatedier Date: Wed, 17 Jan 2018 01:15:34 +0800 Subject: [PATCH 4/7] packages: add package github.com/rodaine/table --- glide.lock | 6 +- glide.yaml | 2 + vendor/github.com/rodaine/table/.travis.yml | 10 + vendor/github.com/rodaine/table/license | 9 + vendor/github.com/rodaine/table/makefile | 9 + vendor/github.com/rodaine/table/readme.md | 61 ++++ vendor/github.com/rodaine/table/table.go | 267 ++++++++++++++++++ vendor/github.com/rodaine/table/table_test.go | 181 ++++++++++++ 8 files changed, 543 insertions(+), 2 deletions(-) create mode 100644 vendor/github.com/rodaine/table/.travis.yml create mode 100644 vendor/github.com/rodaine/table/license create mode 100644 vendor/github.com/rodaine/table/makefile create mode 100644 vendor/github.com/rodaine/table/readme.md create mode 100644 vendor/github.com/rodaine/table/table.go create mode 100644 vendor/github.com/rodaine/table/table_test.go diff --git a/glide.lock b/glide.lock index f09731f7..30c2faeb 100644 --- a/glide.lock +++ b/glide.lock @@ -1,5 +1,5 @@ -hash: 03ff8b71f63e9038c0182a4ef2a55aa9349782f4813c331e2d1f02f3dd15b4f8 -updated: 2017-11-01T16:16:18.577622991+08:00 +hash: 188e1149e415ff9cefab8db2cded030efae57558a0b9551795c5c7d0b0572a7b +updated: 2018-01-17T01:14:34.435613+08:00 imports: - name: github.com/armon/go-socks5 version: e75332964ef517daa070d7c38a9466a0d687e0a5 @@ -33,6 +33,8 @@ imports: version: 274df120e9065bdd08eb1120e0375e3dc1ae8465 subpackages: - fs +- name: github.com/rodaine/table + version: 212a2ad1c462ed4d5b5511ea2b480a573281dbbd - name: github.com/stretchr/testify version: 2402e8e7a02fc811447d11f881aa9746cdc57983 subpackages: diff --git a/glide.yaml b/glide.yaml index d69c3f95..f5e0f561 100644 --- a/glide.yaml +++ b/glide.yaml @@ -71,3 +71,5 @@ import: - internal/iana - internal/socket - ipv4 +- package: github.com/rodaine/table + version: v1.0.0 diff --git a/vendor/github.com/rodaine/table/.travis.yml b/vendor/github.com/rodaine/table/.travis.yml new file mode 100644 index 00000000..65da59c0 --- /dev/null +++ b/vendor/github.com/rodaine/table/.travis.yml @@ -0,0 +1,10 @@ +sudo: false +language: go +go: 1.8 + +branches: + only: + - master + +install: go get -t ./... github.com/golang/lint/golint +script: make lint test diff --git a/vendor/github.com/rodaine/table/license b/vendor/github.com/rodaine/table/license new file mode 100644 index 00000000..4a1a5779 --- /dev/null +++ b/vendor/github.com/rodaine/table/license @@ -0,0 +1,9 @@ +The MIT License (MIT) + +Copyright (c) 2015 Chris Roche (rodaine+github@gmail.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/rodaine/table/makefile b/vendor/github.com/rodaine/table/makefile new file mode 100644 index 00000000..56a42885 --- /dev/null +++ b/vendor/github.com/rodaine/table/makefile @@ -0,0 +1,9 @@ +.PHONY: lint +lint: + gofmt -d -s . + golint -set_exit_status ./... + go tool vet -all -shadow -shadowstrict . + +.PHONY: test +test: + go test -v -cover -race ./... diff --git a/vendor/github.com/rodaine/table/readme.md b/vendor/github.com/rodaine/table/readme.md new file mode 100644 index 00000000..7905afff --- /dev/null +++ b/vendor/github.com/rodaine/table/readme.md @@ -0,0 +1,61 @@ +# table
[![GoDoc](https://godoc.org/github.com/rodaine/table?status.svg)](https://godoc.org/github.com/rodaine/table) [![Build Status](https://travis-ci.org/rodaine/table.svg)](https://travis-ci.org/rodaine/table) + +![Example Table Output With ANSI Colors](http://res.cloudinary.com/rodaine/image/upload/v1442524799/go-table-example0.png) + +Package table provides a convenient way to generate tabular output of any data, primarily useful for CLI tools. + +## Features + +- Accepts all data types (`string`, `int`, `interface{}`, everything!) and will use the `String() string` method of a type if available. +- Can specify custom formatting for the header and first column cells for better readability. +- Columns are left-aligned and sized to fit the data, with customizable padding. +- The printed output can be sent to any `io.Writer`, defaulting to `os.Stdout`. +- Built to an interface, so you can roll your own `Table` implementation. +- Works well with ANSI colors ([fatih/color](https://github.com/fatih/color) in the example)! +- Can provide a custom `WidthFunc` to accomodate multi- and zero-width characters (such as [runewidth](https://github.com/mattn/go-runewidth)) + +## Usage + +**Download the package:** + +```sh +go get -u github.com/rodaine/table +``` + +**Example:** + +```go +package main + +import ( + "fmt" + "strings" + + "github.com/fatih/color" + "github.com/rodaine/table" +) + +func main() { + headerFmt := color.New(color.FgGreen, color.Underline).SprintfFunc() + columnFmt := color.New(color.FgYellow).SprintfFunc() + + tbl := table.New("ID", "Name", "Score", "Added") + tbl.WithHeaderFormatter(headerFmt).WithFirstColumnFormatter(columnFmt) + + for _, widget := range getWidgets() { + tbl.AddRow(widget.ID, widget.Name, widget.Cost, widget.Added) + } + + tbl.Print() +} +``` + +_Consult the [documentation](https://godoc.org/github.com/rodaine/table) for further examples and usage information_ + +## Contributing + +Please feel free to submit an [issue](https://github.com/rodaine/table/issues) or [PR](https://github.com/rodaine/table/pulls) to this repository for features or bugs. All submitted code must pass the scripts specified within [.travis.yml](https://github.com/rodaine/table/blob/master/.travis.yml) and should include tests to back up the changes. + +## License + +table is released under the MIT License (Expat). See the [full license](https://github.com/rodaine/table/blob/master/license). diff --git a/vendor/github.com/rodaine/table/table.go b/vendor/github.com/rodaine/table/table.go new file mode 100644 index 00000000..76f2e63e --- /dev/null +++ b/vendor/github.com/rodaine/table/table.go @@ -0,0 +1,267 @@ +// Package table provides a convenient way to generate tabular output of any +// data, primarily useful for CLI tools. +// +// Columns are left-aligned and padded to accomodate the largest cell in that +// column. +// +// Source: https://github.com/rodaine/table +// +// table.DefaultHeaderFormatter = func(format string, vals ...interface{}) string { +// return strings.ToUpper(fmt.Sprintf(format, vals...)) +// } +// +// tbl := table.New("ID", "Name", "Cost ($)") +// +// for _, widget := range Widgets { +// tbl.AddRow(widget.ID, widget.Name, widget.Cost) +// } +// +// tbl.Print() +// +// // Output: +// // ID NAME COST ($) +// // 1 Foobar 1.23 +// // 2 Fizzbuzz 4.56 +// // 3 Gizmo 78.90 +package table + +import ( + "fmt" + "io" + "os" + "strings" + "unicode/utf8" +) + +// These are the default properties for all Tables created from this package +// and can be modified. +var ( + // DefaultPadding specifies the number of spaces between columns in a table. + DefaultPadding = 2 + + // DefaultWriter specifies the output io.Writer for the Table.Print method. + DefaultWriter io.Writer = os.Stdout + + // DefaultHeaderFormatter specifies the default Formatter for the table header. + DefaultHeaderFormatter Formatter + + // DefaultFirstColumnFormatter specifies the default Formatter for the first column cells. + DefaultFirstColumnFormatter Formatter + + // DefaultWidthFunc specifies the default WidthFunc for calculating column widths + DefaultWidthFunc WidthFunc = utf8.RuneCountInString +) + +// Formatter functions expose a fmt.Sprintf signature that can be used to modify +// the display of the text in either the header or first column of a Table. +// The formatter should not change the width of original text as printed since +// column widths are calculated pre-formatting (though this issue can be mitigated +// with increased padding). +// +// tbl.WithHeaderFormatter(func(format string, vals ...interface{}) string { +// return strings.ToUpper(fmt.Sprintf(format, vals...)) +// }) +// +// A good use case for formatters is to use ANSI escape codes to color the cells +// for a nicer interface. The package color (https://github.com/fatih/color) makes +// it easy to generate these automatically: http://godoc.org/github.com/fatih/color#Color.SprintfFunc +type Formatter func(string, ...interface{}) string + +// A WidthFunc calculates the width of a string. By default, the number of runes +// is used but this may not be appropriate for certain character sets. The +// package runewidth (https://github.com/mattn/go-runewidth) could be used to +// accomodate multi-cell characters (such as emoji or CJK characters). +type WidthFunc func(string) int + +// Table describes the interface for building up a tabular representation of data. +// It exposes fluent/chainable methods for convenient table building. +// +// WithHeaderFormatter and WithFirstColumnFormatter sets the Formatter for the +// header and first column, respectively. If nil is passed in (the default), no +// formatting will be applied. +// +// New("foo", "bar").WithFirstColumnFormatter(func(f string, v ...interface{}) string { +// return strings.ToUpper(fmt.Sprintf(f, v...)) +// }) +// +// WithPadding specifies the minimum padding between cells in a row and defaults +// to DefaultPadding. Padding values less than or equal to zero apply no extra +// padding between the columns. +// +// New("foo", "bar").WithPadding(3) +// +// WithWriter modifies the writer which Print outputs to, defaulting to DefaultWriter +// when instantiated. If nil is passed, os.Stdout will be used. +// +// New("foo", "bar").WithWriter(os.Stderr) +// +// WithWidthFunc sets the function used to calculate the width of the string in +// a column. By default, the number of utf8 runes in the string is used. +// +// AddRow adds another row of data to the table. Any values can be passed in and +// will be output as its string representation as described in the fmt standard +// package. Rows can have less cells than the total number of columns in the table; +// subsequent cells will be rendered empty. Rows with more cells than the total +// number of columns will be truncated. References to the data are not held, so +// the passed in values can be modified without affecting the table's output. +// +// New("foo", "bar").AddRow("fizz", "buzz").AddRow(time.Now()).AddRow(1, 2, 3).Print() +// // Output: +// // foo bar +// // fizz buzz +// // 2006-01-02 15:04:05.0 -0700 MST +// // 1 2 +// +// Print writes the string representation of the table to the provided writer. +// Print can be called multiple times, even after subsequent mutations of the +// provided data. The output is always preceded and followed by a new line. +type Table interface { + WithHeaderFormatter(f Formatter) Table + WithFirstColumnFormatter(f Formatter) Table + WithPadding(p int) Table + WithWriter(w io.Writer) Table + WithWidthFunc(f WidthFunc) Table + + AddRow(vals ...interface{}) Table + Print() +} + +// New creates a Table instance with the specified header(s) provided. The number +// of columns is fixed at this point to len(columnHeaders) and the defined defaults +// are set on the instance. +func New(columnHeaders ...interface{}) Table { + t := table{header: make([]string, len(columnHeaders))} + + t.WithPadding(DefaultPadding) + t.WithWriter(DefaultWriter) + t.WithHeaderFormatter(DefaultHeaderFormatter) + t.WithFirstColumnFormatter(DefaultFirstColumnFormatter) + t.WithWidthFunc(DefaultWidthFunc) + + for i, col := range columnHeaders { + t.header[i] = fmt.Sprint(col) + } + + return &t +} + +type table struct { + FirstColumnFormatter Formatter + HeaderFormatter Formatter + Padding int + Writer io.Writer + Width WidthFunc + + header []string + rows [][]string + widths []int +} + +func (t *table) WithHeaderFormatter(f Formatter) Table { + t.HeaderFormatter = f + return t +} + +func (t *table) WithFirstColumnFormatter(f Formatter) Table { + t.FirstColumnFormatter = f + return t +} + +func (t *table) WithPadding(p int) Table { + if p < 0 { + p = 0 + } + + t.Padding = p + return t +} + +func (t *table) WithWriter(w io.Writer) Table { + if w == nil { + w = os.Stdout + } + + t.Writer = w + return t +} + +func (t *table) WithWidthFunc(f WidthFunc) Table { + t.Width = f + return t +} + +func (t *table) AddRow(vals ...interface{}) Table { + row := make([]string, len(t.header)) + for i, val := range vals { + if i >= len(t.header) { + break + } + row[i] = fmt.Sprint(val) + } + t.rows = append(t.rows, row) + + return t +} + +func (t *table) Print() { + format := strings.Repeat("%s", len(t.header)) + "\n" + t.calculateWidths() + fmt.Fprintln(t.Writer) + t.printHeader(format) + for _, row := range t.rows { + t.printRow(format, row) + } +} + +func (t *table) printHeader(format string) { + vals := t.applyWidths(t.header, t.widths) + if t.HeaderFormatter != nil { + txt := t.HeaderFormatter(format, vals...) + fmt.Fprint(t.Writer, txt) + } else { + fmt.Fprintf(t.Writer, format, vals...) + } +} + +func (t *table) printRow(format string, row []string) { + vals := t.applyWidths(row, t.widths) + + if t.FirstColumnFormatter != nil { + vals[0] = t.FirstColumnFormatter("%s", vals[0]) + } + + fmt.Fprintf(t.Writer, format, vals...) +} + +func (t *table) calculateWidths() { + t.widths = make([]int, len(t.header)) + for _, row := range t.rows { + for i, v := range row { + if w := t.Width(v) + t.Padding; w > t.widths[i] { + t.widths[i] = w + } + } + } + + for i, v := range t.header { + if w := t.Width(v) + t.Padding; w > t.widths[i] { + t.widths[i] = w + } + } +} + +func (t *table) applyWidths(row []string, widths []int) []interface{} { + out := make([]interface{}, len(row)) + for i, s := range row { + out[i] = s + t.lenOffset(s, widths[i]) + } + return out +} + +func (t *table) lenOffset(s string, w int) string { + l := w - t.Width(s) + if l <= 0 { + return "" + } + return strings.Repeat(" ", l) +} diff --git a/vendor/github.com/rodaine/table/table_test.go b/vendor/github.com/rodaine/table/table_test.go new file mode 100644 index 00000000..aebe8bcc --- /dev/null +++ b/vendor/github.com/rodaine/table/table_test.go @@ -0,0 +1,181 @@ +package table + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "strings" + "testing" + + "github.com/mattn/go-runewidth" + "github.com/stretchr/testify/assert" +) + +func TestFormatter(t *testing.T) { + t.Parallel() + + var formatter Formatter + + fn := func(a string, b ...interface{}) string { return "" } + f := Formatter(fn) + + assert.IsType(t, formatter, f) +} + +func TestTable_New(t *testing.T) { + t.Parallel() + + buf := bytes.Buffer{} + New("foo", "bar").WithWriter(&buf).Print() + out := buf.String() + + assert.Contains(t, out, "foo") + assert.Contains(t, out, "bar") + + buf.Reset() + New().WithWriter(&buf).Print() + out = buf.String() + + assert.Empty(t, strings.TrimSpace(out)) +} + +func TestTable_WithHeaderFormatter(t *testing.T) { + t.Parallel() + + uppercase := func(f string, v ...interface{}) string { + return strings.ToUpper(fmt.Sprintf(f, v...)) + } + buf := bytes.Buffer{} + + tbl := New("foo", "bar").WithWriter(&buf).WithHeaderFormatter(uppercase) + tbl.Print() + out := buf.String() + + assert.Contains(t, out, "FOO") + assert.Contains(t, out, "BAR") + + buf.Reset() + tbl.WithHeaderFormatter(nil).Print() + out = buf.String() + + assert.Contains(t, out, "foo") + assert.Contains(t, out, "bar") +} + +func TestTable_WithFirstColumnFormatter(t *testing.T) { + t.Parallel() + + uppercase := func(f string, v ...interface{}) string { + return strings.ToUpper(fmt.Sprintf(f, v...)) + } + + buf := bytes.Buffer{} + + tbl := New("foo", "bar").WithWriter(&buf).WithFirstColumnFormatter(uppercase).AddRow("fizz", "buzz") + tbl.Print() + out := buf.String() + + assert.Contains(t, out, "foo") + assert.Contains(t, out, "bar") + assert.Contains(t, out, "FIZZ") + assert.Contains(t, out, "buzz") + + buf.Reset() + tbl.WithFirstColumnFormatter(nil).Print() + out = buf.String() + + assert.Contains(t, out, "fizz") + assert.Contains(t, out, "buzz") +} + +func TestTable_WithPadding(t *testing.T) { + t.Parallel() + + // zero value + buf := bytes.Buffer{} + tbl := New("foo", "bar").WithWriter(&buf).WithPadding(0) + tbl.Print() + out := buf.String() + assert.Contains(t, out, "foobar") + + // positive value + buf.Reset() + tbl.WithPadding(4).Print() + out = buf.String() + assert.Contains(t, out, "foo bar ") + + // negative value + buf.Reset() + tbl.WithPadding(-1).Print() + out = buf.String() + assert.Contains(t, out, "foobar") +} + +func TestTable_WithWriter(t *testing.T) { + t.Parallel() + + // not that we haven't been using it in all these tests but: + buf := bytes.Buffer{} + New("foo", "bar").WithWriter(&buf).Print() + assert.NotEmpty(t, buf.String()) + + stdout := os.Stdout + temp, _ := ioutil.TempFile("", "") + os.Stdout = temp + defer func() { + os.Stdout = stdout + temp.Close() + }() + + New("foo", "bar").WithWriter(nil).Print() + temp.Seek(0, 0) + + out, _ := ioutil.ReadAll(temp) + assert.NotEmpty(t, out) +} + +func TestTable_AddRow(t *testing.T) { + t.Parallel() + + buf := bytes.Buffer{} + tbl := New("foo", "bar").WithWriter(&buf).AddRow("fizz", "buzz") + tbl.Print() + out := buf.String() + assert.Contains(t, out, "fizz") + assert.Contains(t, out, "buzz") + lines := strings.Count(out, "\n") + + // empty should add empty line + buf.Reset() + tbl.AddRow().Print() + assert.Equal(t, lines+1, strings.Count(buf.String(), "\n")) + + // less than one will fill left-to-right + buf.Reset() + tbl.AddRow("cat").Print() + assert.Contains(t, buf.String(), "\ncat") + + // more than initial length are truncated + buf.Reset() + tbl.AddRow("bippity", "boppity", "boo").Print() + assert.NotContains(t, buf.String(), "boo") +} + +func TestTable_WithWidthFunc(t *testing.T) { + t.Parallel() + + buf := bytes.Buffer{} + + New("", ""). + WithWriter(&buf). + WithPadding(1). + WithWidthFunc(runewidth.StringWidth). + AddRow("请求", "alpha"). + AddRow("abc", "beta"). + Print() + + actual := buf.String() + assert.Contains(t, actual, "请求 alpha") + assert.Contains(t, actual, "abc beta") +} From 9a5f0c23c4246131aedebfe180487a2f9cbab5c5 Mon Sep 17 00:00:00 2001 From: fatedier Date: Wed, 17 Jan 2018 01:18:40 +0800 Subject: [PATCH 5/7] fix ci --- utils/shutdown/shutdown_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/shutdown/shutdown_test.go b/utils/shutdown/shutdown_test.go index d1cc389f..bd6507a5 100644 --- a/utils/shutdown/shutdown_test.go +++ b/utils/shutdown/shutdown_test.go @@ -17,5 +17,5 @@ func TestShutdown(t *testing.T) { time.Sleep(time.Millisecond) s.Done() }() - s.WaitDown() + s.WaitDone() } From 3f6799c06a0fe79e77acb7c7bd7cce25cf5698f0 Mon Sep 17 00:00:00 2001 From: fatedier Date: Wed, 17 Jan 2018 14:40:08 +0800 Subject: [PATCH 6/7] add remoteAddr in NewProxyResp message --- client/admin_api.go | 6 ++- client/control.go | 2 +- client/proxy_manager.go | 25 +++++++---- models/msg/msg.go | 5 ++- server/control.go | 17 +++---- server/proxy.go | 98 +++++++++++++++++++++++++---------------- utils/util/util.go | 9 ++++ 7 files changed, 102 insertions(+), 60 deletions(-) diff --git a/client/admin_api.go b/client/admin_api.go index fae1737e..3c64b917 100644 --- a/client/admin_api.go +++ b/client/admin_api.go @@ -124,22 +124,24 @@ func NewProxyStatusResp(status *ProxyStatus) ProxyStatusResp { psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) } psr.Plugin = cfg.Plugin - psr.RemoteAddr = fmt.Sprintf(":%d", cfg.RemotePort) + psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr case *config.UdpProxyConf: if cfg.LocalPort != 0 { psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) } - psr.RemoteAddr = fmt.Sprintf(":%d", cfg.RemotePort) + psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr case *config.HttpProxyConf: if cfg.LocalPort != 0 { psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) } psr.Plugin = cfg.Plugin + psr.RemoteAddr = status.RemoteAddr case *config.HttpsProxyConf: if cfg.LocalPort != 0 { psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) } psr.Plugin = cfg.Plugin + psr.RemoteAddr = status.RemoteAddr case *config.StcpProxyConf: if cfg.LocalPort != 0 { psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) diff --git a/client/control.go b/client/control.go index 65bec393..a4bb9e09 100644 --- a/client/control.go +++ b/client/control.go @@ -158,7 +158,7 @@ func (ctl *Control) HandleReqWorkConn(inMsg *msg.ReqWorkConn) { func (ctl *Control) HandleNewProxyResp(inMsg *msg.NewProxyResp) { // Server will return NewProxyResp message to each NewProxy message. // Start a new proxy handler if no error got - err := ctl.pm.StartProxy(inMsg.ProxyName, inMsg.Error) + err := ctl.pm.StartProxy(inMsg.ProxyName, inMsg.RemoteAddr, inMsg.Error) if err != nil { ctl.Warn("[%s] start error: %v", inMsg.ProxyName, err) } else { diff --git a/client/proxy_manager.go b/client/proxy_manager.go index d756b5ba..77823986 100644 --- a/client/proxy_manager.go +++ b/client/proxy_manager.go @@ -41,6 +41,8 @@ type ProxyWrapper struct { Err string Cfg config.ProxyConf + RemoteAddr string + pxy Proxy mu sync.RWMutex @@ -52,6 +54,9 @@ type ProxyStatus struct { Status string `json:"status"` Err string `json:"err"` Cfg config.ProxyConf `json:"cfg"` + + // Got from server. + RemoteAddr string `json:"remote_addr"` } func NewProxyWrapper(cfg config.ProxyConf) *ProxyWrapper { @@ -78,16 +83,17 @@ func (pw *ProxyWrapper) GetStatus() *ProxyStatus { pw.mu.RLock() defer pw.mu.RUnlock() ps := &ProxyStatus{ - Name: pw.Name, - Type: pw.Type, - Status: pw.Status, - Err: pw.Err, - Cfg: pw.Cfg, + Name: pw.Name, + Type: pw.Type, + Status: pw.Status, + Err: pw.Err, + Cfg: pw.Cfg, + RemoteAddr: pw.RemoteAddr, } return ps } -func (pw *ProxyWrapper) Start(serverRespErr string) error { +func (pw *ProxyWrapper) Start(remoteAddr string, serverRespErr string) error { if pw.pxy != nil { pw.pxy.Close() pw.pxy = nil @@ -96,6 +102,7 @@ func (pw *ProxyWrapper) Start(serverRespErr string) error { if serverRespErr != "" { pw.mu.Lock() pw.Status = ProxyStatusStartErr + pw.RemoteAddr = remoteAddr pw.Err = serverRespErr pw.mu.Unlock() return fmt.Errorf(serverRespErr) @@ -104,6 +111,7 @@ func (pw *ProxyWrapper) Start(serverRespErr string) error { pxy := NewProxy(pw.Cfg) pw.mu.Lock() defer pw.mu.Unlock() + pw.RemoteAddr = remoteAddr if err := pxy.Run(); err != nil { pw.Status = ProxyStatusStartErr pw.Err = err.Error() @@ -139,6 +147,7 @@ func (pw *ProxyWrapper) Close() { func NewProxyManager(ctl *Control, msgSendCh chan (msg.Message), logPrefix string) *ProxyManager { return &ProxyManager{ + ctl: ctl, proxies: make(map[string]*ProxyWrapper), visitorCfgs: make(map[string]config.ProxyConf), visitors: make(map[string]Visitor), @@ -168,7 +177,7 @@ func (pm *ProxyManager) sendMsg(m msg.Message) error { return err } -func (pm *ProxyManager) StartProxy(name string, serverRespErr string) error { +func (pm *ProxyManager) StartProxy(name string, remoteAddr string, serverRespErr string) error { pm.mu.Lock() defer pm.mu.Unlock() if pm.closed { @@ -180,7 +189,7 @@ func (pm *ProxyManager) StartProxy(name string, serverRespErr string) error { return fmt.Errorf("no proxy found") } - if err := pxy.Start(serverRespErr); err != nil { + if err := pxy.Start(remoteAddr, serverRespErr); err != nil { errRet := err err = pm.sendMsg(&msg.CloseProxy{ ProxyName: name, diff --git a/models/msg/msg.go b/models/msg/msg.go index aac0ce70..dd0dde71 100644 --- a/models/msg/msg.go +++ b/models/msg/msg.go @@ -119,8 +119,9 @@ type NewProxy struct { } type NewProxyResp struct { - ProxyName string `json:"proxy_name"` - Error string `json:"error"` + ProxyName string `json:"proxy_name"` + RemoteAddr string `json:"remote_addr"` + Error string `json:"error"` } type CloseProxy struct { diff --git a/server/control.go b/server/control.go index 5e8fe95e..7492ce4b 100644 --- a/server/control.go +++ b/server/control.go @@ -308,7 +308,7 @@ func (ctl *Control) manager() { switch m := rawMsg.(type) { case *msg.NewProxy: // register proxy in this control - err := ctl.RegisterProxy(m) + remoteAddr, err := ctl.RegisterProxy(m) resp := &msg.NewProxyResp{ ProxyName: m.ProxyName, } @@ -316,6 +316,7 @@ func (ctl *Control) manager() { resp.Error = err.Error() ctl.conn.Warn("new proxy [%s] error: %v", m.ProxyName, err) } else { + resp.RemoteAddr = remoteAddr ctl.conn.Info("new proxy [%s] success", m.ProxyName) StatsNewProxy(m.ProxyName, m.ProxyType) } @@ -332,24 +333,24 @@ func (ctl *Control) manager() { } } -func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (err error) { +func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) { var pxyConf config.ProxyConf // Load configures from NewProxy message and check. pxyConf, err = config.NewProxyConf(pxyMsg) if err != nil { - return err + return } // NewProxy will return a interface Proxy. // In fact it create different proxies by different proxy type, we just call run() here. pxy, err := NewProxy(ctl, pxyConf) if err != nil { - return err + return remoteAddr, err } - err = pxy.Run() + remoteAddr, err = pxy.Run() if err != nil { - return err + return } defer func() { if err != nil { @@ -359,13 +360,13 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (err error) { err = ctl.svr.RegisterProxy(pxyMsg.ProxyName, pxy) if err != nil { - return err + return } ctl.mu.Lock() ctl.proxies[pxy.GetName()] = pxy ctl.mu.Unlock() - return nil + return } func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) { diff --git a/server/proxy.go b/server/proxy.go index 8ce1e2b4..f744b8ba 100644 --- a/server/proxy.go +++ b/server/proxy.go @@ -19,6 +19,7 @@ import ( "fmt" "io" "net" + "strings" "sync" "time" @@ -29,11 +30,12 @@ import ( frpIo "github.com/fatedier/frp/utils/io" "github.com/fatedier/frp/utils/log" frpNet "github.com/fatedier/frp/utils/net" + "github.com/fatedier/frp/utils/util" "github.com/fatedier/frp/utils/vhost" ) type Proxy interface { - Run() error + Run() (remoteAddr string, err error) GetControl() *Control GetName() string GetConf() config.ProxyConf @@ -165,17 +167,19 @@ type TcpProxy struct { cfg *config.TcpProxyConf } -func (pxy *TcpProxy) Run() error { - listener, err := frpNet.ListenTcp(config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort) - if err != nil { - return err +func (pxy *TcpProxy) Run() (remoteAddr string, err error) { + remoteAddr = fmt.Sprintf(":%d", pxy.cfg.RemotePort) + listener, errRet := frpNet.ListenTcp(config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort) + if errRet != nil { + err = errRet + return } listener.AddLogPrefix(pxy.name) pxy.listeners = append(pxy.listeners, listener) pxy.Info("tcp proxy listen port [%d]", pxy.cfg.RemotePort) pxy.startListenHandler(pxy, HandleUserTcpConnection) - return nil + return } func (pxy *TcpProxy) GetConf() config.ProxyConf { @@ -193,7 +197,7 @@ type HttpProxy struct { closeFuncs []func() } -func (pxy *HttpProxy) Run() (err error) { +func (pxy *HttpProxy) Run() (remoteAddr string, err error) { routeConfig := vhost.VhostRouteConfig{ RewriteHost: pxy.cfg.HostHeaderRewrite, Username: pxy.cfg.HttpUser, @@ -205,16 +209,19 @@ func (pxy *HttpProxy) Run() (err error) { if len(locations) == 0 { locations = []string{""} } + + addrs := make([]string, 0) for _, domain := range pxy.cfg.CustomDomains { routeConfig.Domain = domain for _, location := range locations { routeConfig.Location = location - err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig) + err = pxy.ctl.svr.httpReverseProxy.Register(routeConfig) if err != nil { - return err + return } tmpDomain := routeConfig.Domain tmpLocation := routeConfig.Location + addrs = append(addrs, util.CanonicalAddr(tmpDomain, int(config.ServerCommonCfg.VhostHttpPort))) pxy.closeFuncs = append(pxy.closeFuncs, func() { pxy.ctl.svr.httpReverseProxy.UnRegister(tmpDomain, tmpLocation) }) @@ -226,18 +233,20 @@ func (pxy *HttpProxy) Run() (err error) { routeConfig.Domain = pxy.cfg.SubDomain + "." + config.ServerCommonCfg.SubDomainHost for _, location := range locations { routeConfig.Location = location - err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig) + err = pxy.ctl.svr.httpReverseProxy.Register(routeConfig) if err != nil { - return err + return } tmpDomain := routeConfig.Domain tmpLocation := routeConfig.Location + addrs = append(addrs, util.CanonicalAddr(tmpDomain, int(config.ServerCommonCfg.VhostHttpPort))) pxy.closeFuncs = append(pxy.closeFuncs, func() { pxy.ctl.svr.httpReverseProxy.UnRegister(tmpDomain, tmpLocation) }) pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location) } } + remoteAddr = strings.Join(addrs, ",") return } @@ -279,32 +288,38 @@ type HttpsProxy struct { cfg *config.HttpsProxyConf } -func (pxy *HttpsProxy) Run() (err error) { +func (pxy *HttpsProxy) Run() (remoteAddr string, err error) { routeConfig := &vhost.VhostRouteConfig{} + addrs := make([]string, 0) for _, domain := range pxy.cfg.CustomDomains { routeConfig.Domain = domain - l, err := pxy.ctl.svr.VhostHttpsMuxer.Listen(routeConfig) - if err != nil { - return err + l, errRet := pxy.ctl.svr.VhostHttpsMuxer.Listen(routeConfig) + if errRet != nil { + err = errRet + return } l.AddLogPrefix(pxy.name) pxy.Info("https proxy listen for host [%s]", routeConfig.Domain) pxy.listeners = append(pxy.listeners, l) + addrs = append(addrs, util.CanonicalAddr(routeConfig.Domain, int(config.ServerCommonCfg.VhostHttpsPort))) } if pxy.cfg.SubDomain != "" { routeConfig.Domain = pxy.cfg.SubDomain + "." + config.ServerCommonCfg.SubDomainHost - l, err := pxy.ctl.svr.VhostHttpsMuxer.Listen(routeConfig) - if err != nil { - return err + l, errRet := pxy.ctl.svr.VhostHttpsMuxer.Listen(routeConfig) + if errRet != nil { + err = errRet + return } l.AddLogPrefix(pxy.name) pxy.Info("https proxy listen for host [%s]", routeConfig.Domain) pxy.listeners = append(pxy.listeners, l) + addrs = append(addrs, util.CanonicalAddr(routeConfig.Domain, int(config.ServerCommonCfg.VhostHttpsPort))) } pxy.startListenHandler(pxy, HandleUserTcpConnection) + remoteAddr = strings.Join(addrs, ",") return } @@ -321,17 +336,18 @@ type StcpProxy struct { cfg *config.StcpProxyConf } -func (pxy *StcpProxy) Run() error { - listener, err := pxy.ctl.svr.visitorManager.Listen(pxy.GetName(), pxy.cfg.Sk) - if err != nil { - return err +func (pxy *StcpProxy) Run() (remoteAddr string, err error) { + listener, errRet := pxy.ctl.svr.visitorManager.Listen(pxy.GetName(), pxy.cfg.Sk) + if errRet != nil { + err = errRet + return } listener.AddLogPrefix(pxy.name) pxy.listeners = append(pxy.listeners, listener) pxy.Info("stcp proxy custom listen success") pxy.startListenHandler(pxy, HandleUserTcpConnection) - return nil + return } func (pxy *StcpProxy) GetConf() config.ProxyConf { @@ -350,10 +366,11 @@ type XtcpProxy struct { closeCh chan struct{} } -func (pxy *XtcpProxy) Run() error { +func (pxy *XtcpProxy) Run() (remoteAddr string, err error) { if pxy.ctl.svr.natHoleController == nil { pxy.Error("udp port for xtcp is not specified.") - return fmt.Errorf("xtcp is not supported in frps") + err = fmt.Errorf("xtcp is not supported in frps") + return } sidCh := pxy.ctl.svr.natHoleController.ListenClient(pxy.GetName(), pxy.cfg.Sk) go func() { @@ -362,21 +379,21 @@ func (pxy *XtcpProxy) Run() error { case <-pxy.closeCh: break case sid := <-sidCh: - workConn, err := pxy.GetWorkConnFromPool() - if err != nil { + workConn, errRet := pxy.GetWorkConnFromPool() + if errRet != nil { continue } m := &msg.NatHoleSid{ Sid: sid, } - err = msg.WriteMsg(workConn, m) - if err != nil { - pxy.Warn("write nat hole sid package error, %v", err) + errRet = msg.WriteMsg(workConn, m) + if errRet != nil { + pxy.Warn("write nat hole sid package error, %v", errRet) } } } }() - return nil + return } func (pxy *XtcpProxy) GetConf() config.ProxyConf { @@ -414,15 +431,18 @@ type UdpProxy struct { isClosed bool } -func (pxy *UdpProxy) Run() (err error) { - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort)) - if err != nil { - return err +func (pxy *UdpProxy) Run() (remoteAddr string, err error) { + remoteAddr = fmt.Sprintf(":%d", pxy.cfg.RemotePort) + addr, errRet := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort)) + if errRet != nil { + err = errRet + return } - udpConn, err := net.ListenUDP("udp", addr) - if err != nil { + udpConn, errRet := net.ListenUDP("udp", addr) + if errRet != nil { + err = errRet pxy.Warn("listen udp port error: %v", err) - return err + return } pxy.Info("udp proxy listen port [%d]", pxy.cfg.RemotePort) @@ -537,7 +557,7 @@ func (pxy *UdpProxy) Run() (err error) { udp.ForwardUserConn(udpConn, pxy.readCh, pxy.sendCh) pxy.Close() }() - return nil + return remoteAddr, nil } func (pxy *UdpProxy) GetConf() config.ProxyConf { diff --git a/utils/util/util.go b/utils/util/util.go index 87a6907a..88180e35 100644 --- a/utils/util/util.go +++ b/utils/util/util.go @@ -110,3 +110,12 @@ func PortRangesCut(portRanges [][2]int64, port int64) [][2]int64 { } return tmpRanges } + +func CanonicalAddr(host string, port int) (addr string) { + if port == 80 || port == 443 { + addr = host + } else { + addr = fmt.Sprintf("%s:%d", host, port) + } + return +} From b2c846664d0a627aa16f164bf7b5d1e19212b225 Mon Sep 17 00:00:00 2001 From: fatedier Date: Wed, 17 Jan 2018 21:49:37 +0800 Subject: [PATCH 7/7] new feature: assign a random port if remote_port is 0 in type tcp and udp --- .travis.yml | 2 +- client/admin.go | 2 +- client/admin_api.go | 12 ++- client/visitor.go | 8 +- cmd/frpc/main.go | 4 +- cmd/frps/main.go | 2 +- models/config/client_common.go | 26 +++-- models/config/proxy.go | 13 +-- models/config/server_common.go | 155 +++++++++++++++++----------- models/msg/msg.go | 4 +- server/dashboard.go | 2 +- server/dashboard_api.go | 4 +- server/ports.go | 180 +++++++++++++++++++++++++++++++++ server/proxy.go | 36 ++++++- server/service.go | 10 +- tests/func_test.go | 22 ++-- utils/net/kcp.go | 2 +- utils/net/tcp.go | 2 +- utils/net/udp.go | 2 +- utils/util/util.go | 65 ------------ utils/util/util_test.go | 64 ------------ 21 files changed, 379 insertions(+), 238 deletions(-) create mode 100644 server/ports.go diff --git a/.travis.yml b/.travis.yml index 303e1a21..51c2421c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,7 +3,7 @@ language: go go: - 1.8.x - - 1.x + - 1.9.x install: - make diff --git a/client/admin.go b/client/admin.go index 37cdf4c1..e34f44d2 100644 --- a/client/admin.go +++ b/client/admin.go @@ -31,7 +31,7 @@ var ( httpServerWriteTimeout = 10 * time.Second ) -func (svr *Service) RunAdminServer(addr string, port int64) (err error) { +func (svr *Service) RunAdminServer(addr string, port int) (err error) { // url router router := httprouter.New() diff --git a/client/admin_api.go b/client/admin_api.go index 3c64b917..1a947521 100644 --- a/client/admin_api.go +++ b/client/admin_api.go @@ -124,12 +124,20 @@ func NewProxyStatusResp(status *ProxyStatus) ProxyStatusResp { psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) } psr.Plugin = cfg.Plugin - psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr + if status.Err != "" { + psr.RemoteAddr = fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, cfg.RemotePort) + } else { + psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr + } case *config.UdpProxyConf: if cfg.LocalPort != 0 { psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) } - psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr + if status.Err != "" { + psr.RemoteAddr = fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, cfg.RemotePort) + } else { + psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr + } case *config.HttpProxyConf: if cfg.LocalPort != 0 { psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) diff --git a/client/visitor.go b/client/visitor.go index e7a22d80..fd182255 100644 --- a/client/visitor.go +++ b/client/visitor.go @@ -77,7 +77,7 @@ type StcpVisitor struct { } func (sv *StcpVisitor) Run() (err error) { - sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, int64(sv.cfg.BindPort)) + sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, sv.cfg.BindPort) if err != nil { return } @@ -164,7 +164,7 @@ type XtcpVisitor struct { } func (sv *XtcpVisitor) Run() (err error) { - sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, int64(sv.cfg.BindPort)) + sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, sv.cfg.BindPort) if err != nil { return } @@ -255,7 +255,7 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) { sv.Error("get natHoleResp client address error: %s", natHoleRespMsg.ClientAddr) return } - sv.sendDetectMsg(array[0], int64(port), laddr, []byte(natHoleRespMsg.Sid)) + sv.sendDetectMsg(array[0], int(port), laddr, []byte(natHoleRespMsg.Sid)) sv.Trace("send all detect msg done") // Listen for visitorConn's address and wait for client connection. @@ -302,7 +302,7 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) { sv.Debug("join connections closed") } -func (sv *XtcpVisitor) sendDetectMsg(addr string, port int64, laddr *net.UDPAddr, content []byte) (err error) { +func (sv *XtcpVisitor) sendDetectMsg(addr string, port int, laddr *net.UDPAddr, content []byte) (err error) { daddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addr, port)) if err != nil { return err diff --git a/cmd/frpc/main.go b/cmd/frpc/main.go index f1836db2..234d9e79 100644 --- a/cmd/frpc/main.go +++ b/cmd/frpc/main.go @@ -99,7 +99,7 @@ func main() { if args["status"] != nil { if args["status"].(bool) { if err = CmdStatus(); err != nil { - fmt.Println("frps get status error: %v\n", err) + fmt.Printf("frps get status error: %v\n", err) os.Exit(1) } else { os.Exit(0) @@ -132,7 +132,7 @@ func main() { os.Exit(1) } config.ClientCommonCfg.ServerAddr = addr[0] - config.ClientCommonCfg.ServerPort = serverPort + config.ClientCommonCfg.ServerPort = int(serverPort) } if args["-v"] != nil { diff --git a/cmd/frps/main.go b/cmd/frps/main.go index fc5d6436..c3b495ad 100644 --- a/cmd/frps/main.go +++ b/cmd/frps/main.go @@ -91,7 +91,7 @@ func main() { os.Exit(1) } config.ServerCommonCfg.BindAddr = addr[0] - config.ServerCommonCfg.BindPort = bindPort + config.ServerCommonCfg.BindPort = int(bindPort) } if args["-v"] != nil { diff --git a/models/config/client_common.go b/models/config/client_common.go index f98169e7..4b9ede72 100644 --- a/models/config/client_common.go +++ b/models/config/client_common.go @@ -29,8 +29,8 @@ var ClientCommonCfg *ClientCommonConf type ClientCommonConf struct { ConfigFile string ServerAddr string - ServerPort int64 - ServerUdpPort int64 // this is specified by login response message from frps + ServerPort int + ServerUdpPort int // this is specified by login response message from frps HttpProxy string LogFile string LogWay string @@ -38,7 +38,7 @@ type ClientCommonConf struct { LogMaxDays int64 PrivilegeToken string AdminAddr string - AdminPort int64 + AdminPort int AdminUser string AdminPwd string PoolCount int @@ -93,7 +93,12 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { tmpStr, ok = conf.Get("common", "server_port") if ok { - cfg.ServerPort, _ = strconv.ParseInt(tmpStr, 10, 64) + v, err = strconv.ParseInt(tmpStr, 10, 64) + if err != nil { + err = fmt.Errorf("Parse conf error: invalid server_port") + return + } + cfg.ServerPort = int(v) } tmpStr, ok = conf.Get("common", "http_proxy") @@ -139,7 +144,10 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { tmpStr, ok = conf.Get("common", "admin_port") if ok { if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil { - cfg.AdminPort = v + cfg.AdminPort = int(v) + } else { + err = fmt.Errorf("Parse conf error: invalid admin_port") + return } } @@ -203,7 +211,7 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { if ok { v, err = strconv.ParseInt(tmpStr, 10, 64) if err != nil { - err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect") + err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout") return } else { cfg.HeartBeatTimeout = v @@ -214,7 +222,7 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { if ok { v, err = strconv.ParseInt(tmpStr, 10, 64) if err != nil { - err = fmt.Errorf("Parse conf error: heartbeat_interval is incorrect") + err = fmt.Errorf("Parse conf error: invalid heartbeat_interval") return } else { cfg.HeartBeatInterval = v @@ -222,12 +230,12 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { } if cfg.HeartBeatInterval <= 0 { - err = fmt.Errorf("Parse conf error: heartbeat_interval is incorrect") + err = fmt.Errorf("Parse conf error: invalid heartbeat_interval") return } if cfg.HeartBeatTimeout < cfg.HeartBeatInterval { - err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect, heartbeat_timeout is less than heartbeat_interval") + err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout, heartbeat_timeout is less than heartbeat_interval") return } return diff --git a/models/config/proxy.go b/models/config/proxy.go index e87b7eca..022e64f4 100644 --- a/models/config/proxy.go +++ b/models/config/proxy.go @@ -23,7 +23,6 @@ import ( "github.com/fatedier/frp/models/consts" "github.com/fatedier/frp/models/msg" - "github.com/fatedier/frp/utils/util" ini "github.com/vaughan0/go-ini" ) @@ -163,7 +162,7 @@ func (cfg *BaseProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) { // Bind info type BindInfoConf struct { BindAddr string `json:"bind_addr"` - RemotePort int64 `json:"remote_port"` + RemotePort int `json:"remote_port"` } func (cfg *BindInfoConf) compare(cmp *BindInfoConf) bool { @@ -183,10 +182,13 @@ func (cfg *BindInfoConf) LoadFromFile(name string, section ini.Section) (err err var ( tmpStr string ok bool + v int64 ) if tmpStr, ok = section["remote_port"]; ok { - if cfg.RemotePort, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { return fmt.Errorf("Parse conf error: proxy [%s] remote_port error", name) + } else { + cfg.RemotePort = int(v) } } else { return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", name) @@ -199,11 +201,6 @@ func (cfg *BindInfoConf) UnMarshalToMsg(pMsg *msg.NewProxy) { } func (cfg *BindInfoConf) check() (err error) { - if len(ServerCommonCfg.PrivilegeAllowPorts) != 0 { - if ok := util.ContainsPort(ServerCommonCfg.PrivilegeAllowPorts, cfg.RemotePort); !ok { - return fmt.Errorf("remote port [%d] isn't allowed", cfg.RemotePort) - } - } return nil } diff --git a/models/config/server_common.go b/models/config/server_common.go index 4d177665..37892b4e 100644 --- a/models/config/server_common.go +++ b/models/config/server_common.go @@ -19,7 +19,6 @@ import ( "strconv" "strings" - "github.com/fatedier/frp/utils/util" ini "github.com/vaughan0/go-ini" ) @@ -29,20 +28,20 @@ var ServerCommonCfg *ServerCommonConf type ServerCommonConf struct { ConfigFile string BindAddr string - BindPort int64 - BindUdpPort int64 - KcpBindPort int64 + BindPort int + BindUdpPort int + KcpBindPort int ProxyBindAddr string // If VhostHttpPort equals 0, don't listen a public port for http protocol. - VhostHttpPort int64 + VhostHttpPort int // if VhostHttpsPort equals 0, don't listen a public port for https protocol - VhostHttpsPort int64 + VhostHttpsPort int DashboardAddr string // if DashboardPort equals 0, dashboard is not available - DashboardPort int64 + DashboardPort int DashboardUser string DashboardPwd string AssetsDir string @@ -56,8 +55,7 @@ type ServerCommonConf struct { SubDomainHost string TcpMux bool - // if PrivilegeAllowPorts is not nil, tcp proxies which remote port exist in this map can be connected - PrivilegeAllowPorts [][2]int64 + PrivilegeAllowPorts map[int]struct{} MaxPoolCount int64 HeartBeatTimeout int64 UserConnTimeout int64 @@ -65,31 +63,32 @@ type ServerCommonConf struct { func GetDefaultServerCommonConf() *ServerCommonConf { return &ServerCommonConf{ - ConfigFile: "./frps.ini", - BindAddr: "0.0.0.0", - BindPort: 7000, - BindUdpPort: 0, - KcpBindPort: 0, - ProxyBindAddr: "0.0.0.0", - VhostHttpPort: 0, - VhostHttpsPort: 0, - DashboardAddr: "0.0.0.0", - DashboardPort: 0, - DashboardUser: "admin", - DashboardPwd: "admin", - AssetsDir: "", - LogFile: "console", - LogWay: "console", - LogLevel: "info", - LogMaxDays: 3, - PrivilegeMode: true, - PrivilegeToken: "", - AuthTimeout: 900, - SubDomainHost: "", - TcpMux: true, - MaxPoolCount: 5, - HeartBeatTimeout: 90, - UserConnTimeout: 10, + ConfigFile: "./frps.ini", + BindAddr: "0.0.0.0", + BindPort: 7000, + BindUdpPort: 0, + KcpBindPort: 0, + ProxyBindAddr: "0.0.0.0", + VhostHttpPort: 0, + VhostHttpsPort: 0, + DashboardAddr: "0.0.0.0", + DashboardPort: 0, + DashboardUser: "admin", + DashboardPwd: "admin", + AssetsDir: "", + LogFile: "console", + LogWay: "console", + LogLevel: "info", + LogMaxDays: 3, + PrivilegeMode: true, + PrivilegeToken: "", + AuthTimeout: 900, + SubDomainHost: "", + TcpMux: true, + PrivilegeAllowPorts: make(map[int]struct{}), + MaxPoolCount: 5, + HeartBeatTimeout: 90, + UserConnTimeout: 10, } } @@ -109,25 +108,31 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) { tmpStr, ok = conf.Get("common", "bind_port") if ok { - v, err = strconv.ParseInt(tmpStr, 10, 64) - if err == nil { - cfg.BindPort = v + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid bind_port") + return + } else { + cfg.BindPort = int(v) } } tmpStr, ok = conf.Get("common", "bind_udp_port") if ok { - v, err = strconv.ParseInt(tmpStr, 10, 64) - if err == nil { - cfg.BindUdpPort = v + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid bind_udp_port") + return + } else { + cfg.BindUdpPort = int(v) } } tmpStr, ok = conf.Get("common", "kcp_bind_port") if ok { - v, err = strconv.ParseInt(tmpStr, 10, 64) - if err == nil && v > 0 { - cfg.KcpBindPort = v + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid kcp_bind_port") + return + } else { + cfg.KcpBindPort = int(v) } } @@ -140,10 +145,11 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) { tmpStr, ok = conf.Get("common", "vhost_http_port") if ok { - cfg.VhostHttpPort, err = strconv.ParseInt(tmpStr, 10, 64) - if err != nil { - err = fmt.Errorf("Parse conf error: vhost_http_port is incorrect") + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid vhost_http_port") return + } else { + cfg.VhostHttpPort = int(v) } } else { cfg.VhostHttpPort = 0 @@ -151,10 +157,11 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) { tmpStr, ok = conf.Get("common", "vhost_https_port") if ok { - cfg.VhostHttpsPort, err = strconv.ParseInt(tmpStr, 10, 64) - if err != nil { - err = fmt.Errorf("Parse conf error: vhost_https_port is incorrect") + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid vhost_https_port") return + } else { + cfg.VhostHttpsPort = int(v) } } else { cfg.VhostHttpsPort = 0 @@ -169,10 +176,11 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) { tmpStr, ok = conf.Get("common", "dashboard_port") if ok { - cfg.DashboardPort, err = strconv.ParseInt(tmpStr, 10, 64) - if err != nil { - err = fmt.Errorf("Parse conf error: dashboard_port is incorrect") + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid dashboard_port") return + } else { + cfg.DashboardPort = int(v) } } else { cfg.DashboardPort = 0 @@ -228,12 +236,45 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) { cfg.PrivilegeToken, _ = conf.Get("common", "privilege_token") allowPortsStr, ok := conf.Get("common", "privilege_allow_ports") - // TODO: check if conflicts exist in port ranges if ok { - cfg.PrivilegeAllowPorts, err = util.GetPortRanges(allowPortsStr) - if err != nil { - err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) - return + // e.g. 1000-2000,2001,2002,3000-4000 + portRanges := strings.Split(allowPortsStr, ",") + for _, portRangeStr := range portRanges { + // 1000-2000 or 2001 + portArray := strings.Split(portRangeStr, "-") + // length: only 1 or 2 is correct + rangeType := len(portArray) + if rangeType == 1 { + // single port + singlePort, errRet := strconv.ParseInt(portArray[0], 10, 64) + if errRet != nil { + err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet) + return + } + cfg.PrivilegeAllowPorts[int(singlePort)] = struct{}{} + } else if rangeType == 2 { + // range ports + min, errRet := strconv.ParseInt(portArray[0], 10, 64) + if errRet != nil { + err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet) + return + } + max, errRet := strconv.ParseInt(portArray[1], 10, 64) + if errRet != nil { + err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet) + return + } + if max < min { + err = fmt.Errorf("Parse conf error: privilege_allow_ports range incorrect") + return + } + for i := min; i <= max; i++ { + cfg.PrivilegeAllowPorts[int(i)] = struct{}{} + } + } else { + err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect") + return + } } } } diff --git a/models/msg/msg.go b/models/msg/msg.go index dd0dde71..0cdb3f47 100644 --- a/models/msg/msg.go +++ b/models/msg/msg.go @@ -92,7 +92,7 @@ type Login struct { type LoginResp struct { Version string `json:"version"` RunId string `json:"run_id"` - ServerUdpPort int64 `json:"server_udp_port"` + ServerUdpPort int `json:"server_udp_port"` Error string `json:"error"` } @@ -104,7 +104,7 @@ type NewProxy struct { UseCompression bool `json:"use_compression"` // tcp and udp only - RemotePort int64 `json:"remote_port"` + RemotePort int `json:"remote_port"` // http and https only CustomDomains []string `json:"custom_domains"` diff --git a/server/dashboard.go b/server/dashboard.go index 01f71591..3c77875c 100644 --- a/server/dashboard.go +++ b/server/dashboard.go @@ -32,7 +32,7 @@ var ( httpServerWriteTimeout = 10 * time.Second ) -func RunDashboardServer(addr string, port int64) (err error) { +func RunDashboardServer(addr string, port int) (err error) { // url router router := httprouter.New() diff --git a/server/dashboard_api.go b/server/dashboard_api.go index 89d285c5..3f9acd0f 100644 --- a/server/dashboard_api.go +++ b/server/dashboard_api.go @@ -36,8 +36,8 @@ type ServerInfoResp struct { GeneralResponse Version string `json:"version"` - VhostHttpPort int64 `json:"vhost_http_port"` - VhostHttpsPort int64 `json:"vhost_https_port"` + VhostHttpPort int `json:"vhost_http_port"` + VhostHttpsPort int `json:"vhost_https_port"` AuthTimeout int64 `json:"auth_timeout"` SubdomainHost string `json:"subdomain_host"` MaxPoolCount int64 `json:"max_pool_count"` diff --git a/server/ports.go b/server/ports.go new file mode 100644 index 00000000..b9cc4c16 --- /dev/null +++ b/server/ports.go @@ -0,0 +1,180 @@ +package server + +import ( + "errors" + "fmt" + "net" + "sync" + "time" +) + +const ( + MinPort = 1025 + MaxPort = 65535 + MaxPortReservedDuration = time.Duration(24) * time.Hour + CleanReservedPortsInterval = time.Hour +) + +var ( + ErrPortAlreadyUsed = errors.New("port already used") + ErrPortNotAllowed = errors.New("port not allowed") + ErrPortUnAvailable = errors.New("port unavailable") + ErrNoAvailablePort = errors.New("no available port") +) + +type PortCtx struct { + ProxyName string + Port int + Closed bool + UpdateTime time.Time +} + +type PortManager struct { + reservedPorts map[string]*PortCtx + usedPorts map[int]*PortCtx + freePorts map[int]struct{} + + bindAddr string + netType string + mu sync.Mutex +} + +func NewPortManager(netType string, bindAddr string, allowPorts map[int]struct{}) *PortManager { + pm := &PortManager{ + reservedPorts: make(map[string]*PortCtx), + usedPorts: make(map[int]*PortCtx), + freePorts: make(map[int]struct{}), + bindAddr: bindAddr, + netType: netType, + } + if len(allowPorts) > 0 { + for port, _ := range allowPorts { + pm.freePorts[port] = struct{}{} + } + } else { + for i := MinPort; i <= MaxPort; i++ { + pm.freePorts[i] = struct{}{} + } + } + go pm.cleanReservedPortsWorker() + return pm +} + +func (pm *PortManager) Acquire(name string, port int) (realPort int, err error) { + portCtx := &PortCtx{ + ProxyName: name, + Closed: false, + UpdateTime: time.Now(), + } + + var ok bool + + pm.mu.Lock() + defer func() { + if err == nil { + portCtx.Port = realPort + } + pm.mu.Unlock() + }() + + // check reserved ports first + if port == 0 { + if ctx, ok := pm.reservedPorts[name]; ok { + if pm.isPortAvailable(ctx.Port) { + realPort = ctx.Port + pm.usedPorts[realPort] = portCtx + pm.reservedPorts[name] = portCtx + delete(pm.freePorts, realPort) + return + } + } + } + + if port == 0 { + // get random port + count := 0 + maxTryTimes := 5 + for k, _ := range pm.freePorts { + count++ + if count > maxTryTimes { + break + } + if pm.isPortAvailable(k) { + realPort = k + pm.usedPorts[realPort] = portCtx + pm.reservedPorts[name] = portCtx + delete(pm.freePorts, realPort) + break + } + } + if realPort == 0 { + err = ErrNoAvailablePort + } + } else { + // specified port + if _, ok = pm.freePorts[port]; ok { + if pm.isPortAvailable(port) { + realPort = port + pm.usedPorts[realPort] = portCtx + pm.reservedPorts[name] = portCtx + delete(pm.freePorts, realPort) + } else { + err = ErrPortUnAvailable + } + } else { + if _, ok = pm.usedPorts[port]; ok { + err = ErrPortAlreadyUsed + } else { + err = ErrPortNotAllowed + } + } + } + return +} + +func (pm *PortManager) isPortAvailable(port int) bool { + if pm.netType == "udp" { + addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pm.bindAddr, port)) + if err != nil { + return false + } + l, err := net.ListenUDP("udp", addr) + if err != nil { + return false + } + l.Close() + return true + } else { + l, err := net.Listen(pm.netType, fmt.Sprintf("%s:%d", pm.bindAddr, port)) + if err != nil { + return false + } + l.Close() + return true + } +} + +func (pm *PortManager) Release(port int) { + pm.mu.Lock() + defer pm.mu.Unlock() + if ctx, ok := pm.usedPorts[port]; ok { + pm.freePorts[port] = struct{}{} + delete(pm.usedPorts, port) + ctx.Closed = true + ctx.UpdateTime = time.Now() + } +} + +// Release reserved port if it isn't used in last 24 hours. +func (pm *PortManager) cleanReservedPortsWorker() { + for { + time.Sleep(CleanReservedPortsInterval) + pm.mu.Lock() + for name, ctx := range pm.reservedPorts { + if ctx.Closed && time.Since(ctx.UpdateTime) > MaxPortReservedDuration { + delete(pm.reservedPorts, name) + } + } + pm.mu.Unlock() + } +} diff --git a/server/proxy.go b/server/proxy.go index f744b8ba..bfb9793a 100644 --- a/server/proxy.go +++ b/server/proxy.go @@ -165,11 +165,24 @@ func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy, err error) { type TcpProxy struct { BaseProxy cfg *config.TcpProxyConf + + realPort int } func (pxy *TcpProxy) Run() (remoteAddr string, err error) { - remoteAddr = fmt.Sprintf(":%d", pxy.cfg.RemotePort) - listener, errRet := frpNet.ListenTcp(config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort) + pxy.realPort, err = pxy.ctl.svr.tcpPortManager.Acquire(pxy.name, pxy.cfg.RemotePort) + if err != nil { + return + } + defer func() { + if err != nil { + pxy.ctl.svr.tcpPortManager.Release(pxy.realPort) + } + }() + + remoteAddr = fmt.Sprintf(":%d", pxy.realPort) + pxy.cfg.RemotePort = pxy.realPort + listener, errRet := frpNet.ListenTcp(config.ServerCommonCfg.ProxyBindAddr, pxy.realPort) if errRet != nil { err = errRet return @@ -188,6 +201,7 @@ func (pxy *TcpProxy) GetConf() config.ProxyConf { func (pxy *TcpProxy) Close() { pxy.BaseProxy.Close() + pxy.ctl.svr.tcpPortManager.Release(pxy.realPort) } type HttpProxy struct { @@ -412,6 +426,8 @@ type UdpProxy struct { BaseProxy cfg *config.UdpProxyConf + realPort int + // udpConn is the listener of udp packages udpConn *net.UDPConn @@ -432,8 +448,19 @@ type UdpProxy struct { } func (pxy *UdpProxy) Run() (remoteAddr string, err error) { - remoteAddr = fmt.Sprintf(":%d", pxy.cfg.RemotePort) - addr, errRet := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort)) + pxy.realPort, err = pxy.ctl.svr.udpPortManager.Acquire(pxy.name, pxy.cfg.RemotePort) + if err != nil { + return + } + defer func() { + if err != nil { + pxy.ctl.svr.udpPortManager.Release(pxy.realPort) + } + }() + + remoteAddr = fmt.Sprintf(":%d", pxy.realPort) + pxy.cfg.RemotePort = pxy.realPort + addr, errRet := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", config.ServerCommonCfg.ProxyBindAddr, pxy.realPort)) if errRet != nil { err = errRet return @@ -581,6 +608,7 @@ func (pxy *UdpProxy) Close() { close(pxy.readCh) close(pxy.sendCh) } + pxy.ctl.svr.udpPortManager.Release(pxy.realPort) } // HandleUserTcpConnection is used for incoming tcp user connections. diff --git a/server/service.go b/server/service.go index a510b179..e976658a 100644 --- a/server/service.go +++ b/server/service.go @@ -60,17 +60,25 @@ type Service struct { // Manage all visitor listeners. visitorManager *VisitorManager + // Manage all tcp ports. + tcpPortManager *PortManager + + // Manage all udp ports. + udpPortManager *PortManager + // Controller for nat hole connections. natHoleController *NatHoleController } func NewService() (svr *Service, err error) { + cfg := config.ServerCommonCfg svr = &Service{ ctlManager: NewControlManager(), pxyManager: NewProxyManager(), visitorManager: NewVisitorManager(), + tcpPortManager: NewPortManager("tcp", cfg.ProxyBindAddr, cfg.PrivilegeAllowPorts), + udpPortManager: NewPortManager("udp", cfg.ProxyBindAddr, cfg.PrivilegeAllowPorts), } - cfg := config.ServerCommonCfg // Init assets. err = assets.Load(cfg.AssetsDir) diff --git a/tests/func_test.go b/tests/func_test.go index 238046fb..1f154089 100644 --- a/tests/func_test.go +++ b/tests/func_test.go @@ -10,28 +10,28 @@ import ( var ( TEST_STR = "frp is a fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet." - TEST_TCP_PORT int64 = 10701 - TEST_TCP_FRP_PORT int64 = 10801 - TEST_TCP_EC_FRP_PORT int64 = 10901 + TEST_TCP_PORT int = 10701 + TEST_TCP_FRP_PORT int = 10801 + TEST_TCP_EC_FRP_PORT int = 10901 TEST_TCP_ECHO_STR string = "tcp type:" + TEST_STR - TEST_UDP_PORT int64 = 10702 - TEST_UDP_FRP_PORT int64 = 10802 - TEST_UDP_EC_FRP_PORT int64 = 10902 + TEST_UDP_PORT int = 10702 + TEST_UDP_FRP_PORT int = 10802 + TEST_UDP_EC_FRP_PORT int = 10902 TEST_UDP_ECHO_STR string = "udp type:" + TEST_STR TEST_UNIX_DOMAIN_ADDR string = "/tmp/frp_echo_server.sock" - TEST_UNIX_DOMAIN_FRP_PORT int64 = 10803 + TEST_UNIX_DOMAIN_FRP_PORT int = 10803 TEST_UNIX_DOMAIN_STR string = "unix domain type:" + TEST_STR - TEST_HTTP_PORT int64 = 10704 - TEST_HTTP_FRP_PORT int64 = 10804 + TEST_HTTP_PORT int = 10704 + TEST_HTTP_FRP_PORT int = 10804 TEST_HTTP_NORMAL_STR string = "http normal string: " + TEST_STR TEST_HTTP_FOO_STR string = "http foo string: " + TEST_STR TEST_HTTP_BAR_STR string = "http bar string: " + TEST_STR - TEST_STCP_FRP_PORT int64 = 10805 - TEST_STCP_EC_FRP_PORT int64 = 10905 + TEST_STCP_FRP_PORT int = 10805 + TEST_STCP_EC_FRP_PORT int = 10905 TEST_STCP_ECHO_STR string = "stcp type:" + TEST_STR ) diff --git a/utils/net/kcp.go b/utils/net/kcp.go index 18979c12..3d080fdd 100644 --- a/utils/net/kcp.go +++ b/utils/net/kcp.go @@ -31,7 +31,7 @@ type KcpListener struct { log.Logger } -func ListenKcp(bindAddr string, bindPort int64) (l *KcpListener, err error) { +func ListenKcp(bindAddr string, bindPort int) (l *KcpListener, err error) { listener, err := kcp.ListenWithOptions(fmt.Sprintf("%s:%d", bindAddr, bindPort), nil, 10, 3) if err != nil { return l, err diff --git a/utils/net/tcp.go b/utils/net/tcp.go index ca71de0a..b2c5a2b6 100644 --- a/utils/net/tcp.go +++ b/utils/net/tcp.go @@ -33,7 +33,7 @@ type TcpListener struct { log.Logger } -func ListenTcp(bindAddr string, bindPort int64) (l *TcpListener, err error) { +func ListenTcp(bindAddr string, bindPort int) (l *TcpListener, err error) { tcpAddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) if err != nil { return l, err diff --git a/utils/net/udp.go b/utils/net/udp.go index ec2fb261..f2e9a797 100644 --- a/utils/net/udp.go +++ b/utils/net/udp.go @@ -167,7 +167,7 @@ type UdpListener struct { log.Logger } -func ListenUDP(bindAddr string, bindPort int64) (l *UdpListener, err error) { +func ListenUDP(bindAddr string, bindPort int) (l *UdpListener, err error) { udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) if err != nil { return l, err diff --git a/utils/util/util.go b/utils/util/util.go index 88180e35..4439f1aa 100644 --- a/utils/util/util.go +++ b/utils/util/util.go @@ -19,8 +19,6 @@ import ( "crypto/rand" "encoding/hex" "fmt" - "strconv" - "strings" ) // RandId return a rand string used in frp. @@ -48,69 +46,6 @@ func GetAuthKey(token string, timestamp int64) (key string) { return hex.EncodeToString(data) } -// for example: rangeStr is "1000-2000,2001,2002,3000-4000", return an array as port ranges. -func GetPortRanges(rangeStr string) (portRanges [][2]int64, err error) { - // for example: 1000-2000,2001,2002,3000-4000 - rangeArray := strings.Split(rangeStr, ",") - for _, portRangeStr := range rangeArray { - // 1000-2000 or 2001 - portArray := strings.Split(portRangeStr, "-") - // length: only 1 or 2 is correct - rangeType := len(portArray) - if rangeType == 1 { - singlePort, err := strconv.ParseInt(portArray[0], 10, 64) - if err != nil { - return [][2]int64{}, err - } - portRanges = append(portRanges, [2]int64{singlePort, singlePort}) - } else if rangeType == 2 { - min, err := strconv.ParseInt(portArray[0], 10, 64) - if err != nil { - return [][2]int64{}, err - } - max, err := strconv.ParseInt(portArray[1], 10, 64) - if err != nil { - return [][2]int64{}, err - } - if max < min { - return [][2]int64{}, fmt.Errorf("range incorrect") - } - portRanges = append(portRanges, [2]int64{min, max}) - } else { - return [][2]int64{}, fmt.Errorf("format error") - } - } - return portRanges, nil -} - -func ContainsPort(portRanges [][2]int64, port int64) bool { - for _, pr := range portRanges { - if port >= pr[0] && port <= pr[1] { - return true - } - } - return false -} - -func PortRangesCut(portRanges [][2]int64, port int64) [][2]int64 { - var tmpRanges [][2]int64 - for _, pr := range portRanges { - if port >= pr[0] && port <= pr[1] { - leftRange := [2]int64{pr[0], port - 1} - rightRange := [2]int64{port + 1, pr[1]} - if leftRange[0] <= leftRange[1] { - tmpRanges = append(tmpRanges, leftRange) - } - if rightRange[0] <= rightRange[1] { - tmpRanges = append(tmpRanges, rightRange) - } - } else { - tmpRanges = append(tmpRanges, pr) - } - } - return tmpRanges -} - func CanonicalAddr(host string, port int) (addr string) { if port == 80 || port == 443 { addr = host diff --git a/utils/util/util_test.go b/utils/util/util_test.go index 17d77547..8210c613 100644 --- a/utils/util/util_test.go +++ b/utils/util/util_test.go @@ -20,67 +20,3 @@ func TestGetAuthKey(t *testing.T) { t.Log(key) assert.Equal("6df41a43725f0c770fd56379e12acf8c", key) } - -func TestGetPortRanges(t *testing.T) { - assert := assert.New(t) - - rangesStr := "2000-3000,3001,4000-50000" - expect := [][2]int64{ - [2]int64{2000, 3000}, - [2]int64{3001, 3001}, - [2]int64{4000, 50000}, - } - actual, err := GetPortRanges(rangesStr) - assert.Nil(err) - t.Log(actual) - assert.Equal(expect, actual) -} - -func TestContainsPort(t *testing.T) { - assert := assert.New(t) - - rangesStr := "2000-3000,3001,4000-50000" - portRanges, err := GetPortRanges(rangesStr) - assert.Nil(err) - - type Case struct { - Port int64 - Answer bool - } - cases := []Case{ - Case{ - Port: 3001, - Answer: true, - }, - Case{ - Port: 3002, - Answer: false, - }, - Case{ - Port: 44444, - Answer: true, - }, - } - for _, elem := range cases { - ok := ContainsPort(portRanges, elem.Port) - assert.Equal(elem.Answer, ok) - } -} - -func TestPortRangesCut(t *testing.T) { - assert := assert.New(t) - - rangesStr := "2000-3000,3001,4000-50000" - portRanges, err := GetPortRanges(rangesStr) - assert.Nil(err) - - expect := [][2]int64{ - [2]int64{2000, 3000}, - [2]int64{3001, 3001}, - [2]int64{4000, 44443}, - [2]int64{44445, 50000}, - } - actual := PortRangesCut(portRanges, 44444) - t.Log(actual) - assert.Equal(expect, actual) -}