Merge pull request #603 from fatedier/test

add test cases and new feature assgin a random port if remote_port is 0
This commit is contained in:
fatedier 2018-01-17 22:45:02 +08:00 committed by GitHub
commit 5b08201e5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 2077 additions and 701 deletions

View File

@ -3,7 +3,7 @@ language: go
go: go:
- 1.8.x - 1.8.x
- 1.x - 1.9.x
install: install:
- make - make

View File

@ -31,7 +31,7 @@ var (
httpServerWriteTimeout = 10 * time.Second httpServerWriteTimeout = 10 * time.Second
) )
func (svr *Service) RunAdminServer(addr string, port int64) (err error) { func (svr *Service) RunAdminServer(addr string, port int) (err error) {
// url router // url router
router := httprouter.New() router := httprouter.New()
@ -39,6 +39,7 @@ func (svr *Service) RunAdminServer(addr string, port int64) (err error) {
// api, see dashboard_api.go // api, see dashboard_api.go
router.GET("/api/reload", frpNet.HttprouterBasicAuth(svr.apiReload, user, passwd)) router.GET("/api/reload", frpNet.HttprouterBasicAuth(svr.apiReload, user, passwd))
router.GET("/api/status", frpNet.HttprouterBasicAuth(svr.apiStatus, user, passwd))
address := fmt.Sprintf("%s:%d", addr, port) address := fmt.Sprintf("%s:%d", addr, port)
server := &http.Server{ server := &http.Server{

View File

@ -16,7 +16,10 @@ package client
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"sort"
"strings"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
ini "github.com/vaughan0/go-ini" ini "github.com/vaughan0/go-ini"
@ -72,7 +75,137 @@ func (svr *Service) apiReload(w http.ResponseWriter, r *http.Request, _ httprout
return return
} }
svr.ctl.reloadConf(pxyCfgs, visitorCfgs) err = svr.ctl.reloadConf(pxyCfgs, visitorCfgs)
if err != nil {
res.Code = 4
res.Msg = err.Error()
log.Error("reload frpc proxy config error: %v", err)
return
}
log.Info("success reload conf") log.Info("success reload conf")
return return
} }
type StatusResp struct {
Tcp []ProxyStatusResp `json:"tcp"`
Udp []ProxyStatusResp `json:"udp"`
Http []ProxyStatusResp `json:"http"`
Https []ProxyStatusResp `json:"https"`
Stcp []ProxyStatusResp `json:"stcp"`
Xtcp []ProxyStatusResp `json:"xtcp"`
}
type ProxyStatusResp struct {
Name string `json:"name"`
Type string `json:"type"`
Status string `json:"status"`
Err string `json:"err"`
LocalAddr string `json:"local_addr"`
Plugin string `json:"plugin"`
RemoteAddr string `json:"remote_addr"`
}
type ByProxyStatusResp []ProxyStatusResp
func (a ByProxyStatusResp) Len() int { return len(a) }
func (a ByProxyStatusResp) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByProxyStatusResp) Less(i, j int) bool { return strings.Compare(a[i].Name, a[j].Name) < 0 }
func NewProxyStatusResp(status *ProxyStatus) ProxyStatusResp {
psr := ProxyStatusResp{
Name: status.Name,
Type: status.Type,
Status: status.Status,
Err: status.Err,
}
switch cfg := status.Cfg.(type) {
case *config.TcpProxyConf:
if cfg.LocalPort != 0 {
psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort)
}
psr.Plugin = cfg.Plugin
if status.Err != "" {
psr.RemoteAddr = fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, cfg.RemotePort)
} else {
psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr
}
case *config.UdpProxyConf:
if cfg.LocalPort != 0 {
psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort)
}
if status.Err != "" {
psr.RemoteAddr = fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, cfg.RemotePort)
} else {
psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr
}
case *config.HttpProxyConf:
if cfg.LocalPort != 0 {
psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort)
}
psr.Plugin = cfg.Plugin
psr.RemoteAddr = status.RemoteAddr
case *config.HttpsProxyConf:
if cfg.LocalPort != 0 {
psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort)
}
psr.Plugin = cfg.Plugin
psr.RemoteAddr = status.RemoteAddr
case *config.StcpProxyConf:
if cfg.LocalPort != 0 {
psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort)
}
psr.Plugin = cfg.Plugin
case *config.XtcpProxyConf:
if cfg.LocalPort != 0 {
psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort)
}
psr.Plugin = cfg.Plugin
}
return psr
}
// api/status
func (svr *Service) apiStatus(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
var (
buf []byte
res StatusResp
)
res.Tcp = make([]ProxyStatusResp, 0)
res.Udp = make([]ProxyStatusResp, 0)
res.Http = make([]ProxyStatusResp, 0)
res.Https = make([]ProxyStatusResp, 0)
res.Stcp = make([]ProxyStatusResp, 0)
res.Xtcp = make([]ProxyStatusResp, 0)
defer func() {
log.Info("Http response [/api/status]")
buf, _ = json.Marshal(&res)
w.Write(buf)
}()
log.Info("Http request: [/api/status]")
ps := svr.ctl.pm.GetAllProxyStatus()
for _, status := range ps {
switch status.Type {
case "tcp":
res.Tcp = append(res.Tcp, NewProxyStatusResp(status))
case "udp":
res.Udp = append(res.Udp, NewProxyStatusResp(status))
case "http":
res.Http = append(res.Http, NewProxyStatusResp(status))
case "https":
res.Https = append(res.Https, NewProxyStatusResp(status))
case "stcp":
res.Stcp = append(res.Stcp, NewProxyStatusResp(status))
case "xtcp":
res.Xtcp = append(res.Xtcp, NewProxyStatusResp(status))
}
}
sort.Sort(ByProxyStatusResp(res.Tcp))
sort.Sort(ByProxyStatusResp(res.Udp))
sort.Sort(ByProxyStatusResp(res.Http))
sort.Sort(ByProxyStatusResp(res.Https))
sort.Sort(ByProxyStatusResp(res.Stcp))
sort.Sort(ByProxyStatusResp(res.Xtcp))
return
}

View File

@ -24,9 +24,9 @@ import (
"github.com/fatedier/frp/models/config" "github.com/fatedier/frp/models/config"
"github.com/fatedier/frp/models/msg" "github.com/fatedier/frp/models/msg"
"github.com/fatedier/frp/utils/crypto" "github.com/fatedier/frp/utils/crypto"
"github.com/fatedier/frp/utils/errors"
"github.com/fatedier/frp/utils/log" "github.com/fatedier/frp/utils/log"
frpNet "github.com/fatedier/frp/utils/net" frpNet "github.com/fatedier/frp/utils/net"
"github.com/fatedier/frp/utils/shutdown"
"github.com/fatedier/frp/utils/util" "github.com/fatedier/frp/utils/util"
"github.com/fatedier/frp/utils/version" "github.com/fatedier/frp/utils/version"
"github.com/xtaci/smux" "github.com/xtaci/smux"
@ -40,20 +40,10 @@ type Control struct {
// frpc service // frpc service
svr *Service svr *Service
// login message to server // login message to server, only used
loginMsg *msg.Login loginMsg *msg.Login
// proxy configures pm *ProxyManager
pxyCfgs map[string]config.ProxyConf
// proxies
proxies map[string]Proxy
// visitor configures
visitorCfgs map[string]config.ProxyConf
// visitors
visitors map[string]Visitor
// control connection // control connection
conn frpNet.Conn conn frpNet.Conn
@ -79,6 +69,10 @@ type Control struct {
// last time got the Pong message // last time got the Pong message
lastPong time.Time lastPong time.Time
readerShutdown *shutdown.Shutdown
writerShutdown *shutdown.Shutdown
msgHandlerShutdown *shutdown.Shutdown
mu sync.RWMutex mu sync.RWMutex
log.Logger log.Logger
@ -92,28 +86,22 @@ func NewControl(svr *Service, pxyCfgs map[string]config.ProxyConf, visitorCfgs m
User: config.ClientCommonCfg.User, User: config.ClientCommonCfg.User,
Version: version.Full(), Version: version.Full(),
} }
return &Control{ ctl := &Control{
svr: svr, svr: svr,
loginMsg: loginMsg, loginMsg: loginMsg,
pxyCfgs: pxyCfgs,
visitorCfgs: visitorCfgs,
proxies: make(map[string]Proxy),
visitors: make(map[string]Visitor),
sendCh: make(chan msg.Message, 10), sendCh: make(chan msg.Message, 10),
readCh: make(chan msg.Message, 10), readCh: make(chan msg.Message, 10),
closedCh: make(chan int), closedCh: make(chan int),
readerShutdown: shutdown.New(),
writerShutdown: shutdown.New(),
msgHandlerShutdown: shutdown.New(),
Logger: log.NewPrefixLogger(""), Logger: log.NewPrefixLogger(""),
} }
ctl.pm = NewProxyManager(ctl, ctl.sendCh, "")
ctl.pm.Reload(pxyCfgs, visitorCfgs)
return ctl
} }
// 1. login
// 2. start reader() writer() manager()
// 3. connection closed
// 4. In reader(): close closedCh and exit, controler() get it
// 5. In controler(): close readCh and sendCh, manager() and writer() will exit
// 6. In controler(): ini readCh, sendCh, closedCh
// 7. In controler(): start new reader(), writer(), manager()
// controler() will keep running
func (ctl *Control) Run() (err error) { func (ctl *Control) Run() (err error) {
for { for {
err = ctl.login() err = ctl.login()
@ -125,47 +113,29 @@ func (ctl *Control) Run() (err error) {
if config.ClientCommonCfg.LoginFailExit { if config.ClientCommonCfg.LoginFailExit {
return return
} else { } else {
time.Sleep(30 * time.Second) time.Sleep(10 * time.Second)
} }
} else { } else {
break break
} }
} }
go ctl.controler() go ctl.worker()
go ctl.manager()
go ctl.writer()
go ctl.reader()
// start all local visitors // start all local visitors and send NewProxy message for all configured proxies
for _, cfg := range ctl.visitorCfgs { ctl.pm.Reset(ctl.sendCh, ctl.runId)
visitor := NewVisitor(ctl, cfg) ctl.pm.CheckAndStartProxy()
err = visitor.Run()
if err != nil {
visitor.Warn("start error: %v", err)
continue
}
ctl.visitors[cfg.GetName()] = visitor
visitor.Info("start visitor success")
}
// send NewProxy message for all configured proxies
for _, cfg := range ctl.pxyCfgs {
var newProxyMsg msg.NewProxy
cfg.UnMarshalToMsg(&newProxyMsg)
ctl.sendCh <- &newProxyMsg
}
return nil return nil
} }
func (ctl *Control) NewWorkConn() { func (ctl *Control) HandleReqWorkConn(inMsg *msg.ReqWorkConn) {
workConn, err := ctl.connectServer() workConn, err := ctl.connectServer()
if err != nil { if err != nil {
return return
} }
m := &msg.NewWorkConn{ m := &msg.NewWorkConn{
RunId: ctl.getRunId(), RunId: ctl.runId,
} }
if err = msg.WriteMsg(workConn, m); err != nil { if err = msg.WriteMsg(workConn, m); err != nil {
ctl.Warn("work connection write to server error: %v", err) ctl.Warn("work connection write to server error: %v", err)
@ -182,33 +152,26 @@ func (ctl *Control) NewWorkConn() {
workConn.AddLogPrefix(startMsg.ProxyName) workConn.AddLogPrefix(startMsg.ProxyName)
// dispatch this work connection to related proxy // dispatch this work connection to related proxy
pxy, ok := ctl.getProxy(startMsg.ProxyName) ctl.pm.HandleWorkConn(startMsg.ProxyName, workConn)
if ok { }
workConn.Debug("start a new work connection, localAddr: %s remoteAddr: %s", workConn.LocalAddr().String(), workConn.RemoteAddr().String())
go pxy.InWorkConn(workConn) func (ctl *Control) HandleNewProxyResp(inMsg *msg.NewProxyResp) {
// Server will return NewProxyResp message to each NewProxy message.
// Start a new proxy handler if no error got
err := ctl.pm.StartProxy(inMsg.ProxyName, inMsg.RemoteAddr, inMsg.Error)
if err != nil {
ctl.Warn("[%s] start error: %v", inMsg.ProxyName, err)
} else { } else {
workConn.Close() ctl.Info("[%s] start proxy success", inMsg.ProxyName)
} }
} }
func (ctl *Control) Close() error { func (ctl *Control) Close() error {
ctl.mu.Lock() ctl.mu.Lock()
defer ctl.mu.Unlock()
ctl.exit = true ctl.exit = true
err := errors.PanicToError(func() { ctl.pm.CloseProxies()
for name, _ := range ctl.proxies { return nil
ctl.sendCh <- &msg.CloseProxy{
ProxyName: name,
}
}
})
ctl.mu.Unlock()
return err
}
func (ctl *Control) init() {
ctl.sendCh = make(chan msg.Message, 10)
ctl.readCh = make(chan msg.Message, 10)
ctl.closedCh = make(chan int)
} }
// login send a login message to server and wait for a loginResp message. // login send a login message to server and wait for a loginResp message.
@ -249,7 +212,7 @@ func (ctl *Control) login() (err error) {
now := time.Now().Unix() now := time.Now().Unix()
ctl.loginMsg.PrivilegeKey = util.GetAuthKey(config.ClientCommonCfg.PrivilegeToken, now) ctl.loginMsg.PrivilegeKey = util.GetAuthKey(config.ClientCommonCfg.PrivilegeToken, now)
ctl.loginMsg.Timestamp = now ctl.loginMsg.Timestamp = now
ctl.loginMsg.RunId = ctl.getRunId() ctl.loginMsg.RunId = ctl.runId
if err = msg.WriteMsg(conn, ctl.loginMsg); err != nil { if err = msg.WriteMsg(conn, ctl.loginMsg); err != nil {
return err return err
@ -270,16 +233,11 @@ func (ctl *Control) login() (err error) {
ctl.conn = conn ctl.conn = conn
// update runId got from server // update runId got from server
ctl.setRunId(loginRespMsg.RunId) ctl.runId = loginRespMsg.RunId
config.ClientCommonCfg.ServerUdpPort = loginRespMsg.ServerUdpPort config.ClientCommonCfg.ServerUdpPort = loginRespMsg.ServerUdpPort
ctl.ClearLogPrefix() ctl.ClearLogPrefix()
ctl.AddLogPrefix(loginRespMsg.RunId) ctl.AddLogPrefix(loginRespMsg.RunId)
ctl.Info("login to server success, get run id [%s], server udp port [%d]", loginRespMsg.RunId, loginRespMsg.ServerUdpPort) ctl.Info("login to server success, get run id [%s], server udp port [%d]", loginRespMsg.RunId, loginRespMsg.ServerUdpPort)
// login success, so we let closedCh available again
ctl.closedCh = make(chan int)
ctl.lastPong = time.Now()
return nil return nil
} }
@ -292,7 +250,6 @@ func (ctl *Control) connectServer() (conn frpNet.Conn, err error) {
return return
} }
conn = frpNet.WrapConn(stream) conn = frpNet.WrapConn(stream)
} else { } else {
conn, err = frpNet.ConnectServerByHttpProxy(config.ClientCommonCfg.HttpProxy, config.ClientCommonCfg.Protocol, conn, err = frpNet.ConnectServerByHttpProxy(config.ClientCommonCfg.HttpProxy, config.ClientCommonCfg.Protocol,
fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, config.ClientCommonCfg.ServerPort)) fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, config.ClientCommonCfg.ServerPort))
@ -304,12 +261,14 @@ func (ctl *Control) connectServer() (conn frpNet.Conn, err error) {
return return
} }
// reader read all messages from frps and send to readCh
func (ctl *Control) reader() { func (ctl *Control) reader() {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
ctl.Error("panic error: %v", err) ctl.Error("panic error: %v", err)
} }
}() }()
defer ctl.readerShutdown.Done()
defer close(ctl.closedCh) defer close(ctl.closedCh)
encReader := crypto.NewReader(ctl.conn, []byte(config.ClientCommonCfg.PrivilegeToken)) encReader := crypto.NewReader(ctl.conn, []byte(config.ClientCommonCfg.PrivilegeToken))
@ -328,7 +287,9 @@ func (ctl *Control) reader() {
} }
} }
// writer writes messages got from sendCh to frps
func (ctl *Control) writer() { func (ctl *Control) writer() {
defer ctl.writerShutdown.Done()
encWriter, err := crypto.NewWriter(ctl.conn, []byte(config.ClientCommonCfg.PrivilegeToken)) encWriter, err := crypto.NewWriter(ctl.conn, []byte(config.ClientCommonCfg.PrivilegeToken))
if err != nil { if err != nil {
ctl.conn.Error("crypto new writer error: %v", err) ctl.conn.Error("crypto new writer error: %v", err)
@ -348,19 +309,22 @@ func (ctl *Control) writer() {
} }
} }
// manager handles all channel events and do corresponding process // msgHandler handles all channel events and do corresponding operations.
func (ctl *Control) manager() { func (ctl *Control) msgHandler() {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
ctl.Error("panic error: %v", err) ctl.Error("panic error: %v", err)
} }
}() }()
defer ctl.msgHandlerShutdown.Done()
hbSend := time.NewTicker(time.Duration(config.ClientCommonCfg.HeartBeatInterval) * time.Second) hbSend := time.NewTicker(time.Duration(config.ClientCommonCfg.HeartBeatInterval) * time.Second)
defer hbSend.Stop() defer hbSend.Stop()
hbCheck := time.NewTicker(time.Second) hbCheck := time.NewTicker(time.Second)
defer hbCheck.Stop() defer hbCheck.Stop()
ctl.lastPong = time.Now()
for { for {
select { select {
case <-hbSend.C: case <-hbSend.C:
@ -381,35 +345,9 @@ func (ctl *Control) manager() {
switch m := rawMsg.(type) { switch m := rawMsg.(type) {
case *msg.ReqWorkConn: case *msg.ReqWorkConn:
go ctl.NewWorkConn() go ctl.HandleReqWorkConn(m)
case *msg.NewProxyResp: case *msg.NewProxyResp:
// Server will return NewProxyResp message to each NewProxy message. ctl.HandleNewProxyResp(m)
// Start a new proxy handler if no error got
if m.Error != "" {
ctl.Warn("[%s] start error: %s", m.ProxyName, m.Error)
continue
}
cfg, ok := ctl.getProxyConf(m.ProxyName)
if !ok {
// it will never go to this branch now
ctl.Warn("[%s] no proxy conf found", m.ProxyName)
continue
}
oldPxy, ok := ctl.getProxy(m.ProxyName)
if ok {
oldPxy.Close()
}
pxy := NewProxy(ctl, cfg)
if err := pxy.Run(); err != nil {
ctl.Warn("[%s] proxy start running error: %v", m.ProxyName, err)
ctl.sendCh <- &msg.CloseProxy{
ProxyName: m.ProxyName,
}
continue
}
ctl.addProxy(m.ProxyName, pxy)
ctl.Info("[%s] start proxy success", m.ProxyName)
case *msg.Pong: case *msg.Pong:
ctl.lastPong = time.Now() ctl.lastPong = time.Now()
ctl.Debug("receive heartbeat from server") ctl.Debug("receive heartbeat from server")
@ -419,10 +357,14 @@ func (ctl *Control) manager() {
} }
// controler keep watching closedCh, start a new connection if previous control connection is closed. // controler keep watching closedCh, start a new connection if previous control connection is closed.
// If controler is notified by closedCh, reader and writer and manager will exit, then recall these functions. // If controler is notified by closedCh, reader and writer and handler will exit, then recall these functions.
func (ctl *Control) controler() { func (ctl *Control) worker() {
go ctl.msgHandler()
go ctl.writer()
go ctl.reader()
var err error var err error
maxDelayTime := 30 * time.Second maxDelayTime := 20 * time.Second
delayTime := time.Second delayTime := time.Second
checkInterval := 10 * time.Second checkInterval := 10 * time.Second
@ -430,41 +372,20 @@ func (ctl *Control) controler() {
for { for {
select { select {
case <-checkProxyTicker.C: case <-checkProxyTicker.C:
// Every 10 seconds, check which proxy registered failed and reregister it to server. // every 10 seconds, check which proxy registered failed and reregister it to server
ctl.mu.RLock() ctl.pm.CheckAndStartProxy()
for _, cfg := range ctl.pxyCfgs {
if _, exist := ctl.proxies[cfg.GetName()]; !exist {
ctl.Info("try to register proxy [%s]", cfg.GetName())
var newProxyMsg msg.NewProxy
cfg.UnMarshalToMsg(&newProxyMsg)
ctl.sendCh <- &newProxyMsg
}
}
for _, cfg := range ctl.visitorCfgs {
if _, exist := ctl.visitors[cfg.GetName()]; !exist {
ctl.Info("try to start visitor [%s]", cfg.GetName())
visitor := NewVisitor(ctl, cfg)
err = visitor.Run()
if err != nil {
visitor.Warn("start error: %v", err)
continue
}
ctl.visitors[cfg.GetName()] = visitor
visitor.Info("start visitor success")
}
}
ctl.mu.RUnlock()
case _, ok := <-ctl.closedCh: case _, ok := <-ctl.closedCh:
// we won't get any variable from this channel // we won't get any variable from this channel
if !ok { if !ok {
// close related channels // close related channels and wait until other goroutines done
close(ctl.readCh) close(ctl.readCh)
close(ctl.sendCh) ctl.readerShutdown.WaitDone()
ctl.msgHandlerShutdown.WaitDone()
for _, pxy := range ctl.proxies { close(ctl.sendCh)
pxy.Close() ctl.writerShutdown.WaitDone()
}
ctl.pm.CloseProxies()
// if ctl.exit is true, just exit // if ctl.exit is true, just exit
ctl.mu.RLock() ctl.mu.RLock()
exit := ctl.exit exit := ctl.exit
@ -473,9 +394,7 @@ func (ctl *Control) controler() {
return return
} }
time.Sleep(time.Second) // loop util reconnecting to server success
// loop util reconnect to server success
for { for {
ctl.Info("try to reconnect to server...") ctl.Info("try to reconnect to server...")
err = ctl.login() err = ctl.login()
@ -488,27 +407,27 @@ func (ctl *Control) controler() {
} }
continue continue
} }
// reconnect success, init the delayTime // reconnect success, init delayTime
delayTime = time.Second delayTime = time.Second
break break
} }
// init related channels and variables // init related channels and variables
ctl.init() ctl.sendCh = make(chan msg.Message, 10)
ctl.readCh = make(chan msg.Message, 10)
ctl.closedCh = make(chan int)
ctl.readerShutdown = shutdown.New()
ctl.writerShutdown = shutdown.New()
ctl.msgHandlerShutdown = shutdown.New()
ctl.pm.Reset(ctl.sendCh, ctl.runId)
// previous work goroutines should be closed and start them here // previous work goroutines should be closed and start them here
go ctl.manager() go ctl.msgHandler()
go ctl.writer() go ctl.writer()
go ctl.reader() go ctl.reader()
// send NewProxy message for all configured proxies // start all configured proxies
ctl.mu.RLock() ctl.pm.CheckAndStartProxy()
for _, cfg := range ctl.pxyCfgs {
var newProxyMsg msg.NewProxy
cfg.UnMarshalToMsg(&newProxyMsg)
ctl.sendCh <- &newProxyMsg
}
ctl.mu.RUnlock()
checkProxyTicker.Stop() checkProxyTicker.Stop()
checkProxyTicker = time.NewTicker(checkInterval) checkProxyTicker = time.NewTicker(checkInterval)
@ -517,106 +436,7 @@ func (ctl *Control) controler() {
} }
} }
func (ctl *Control) setRunId(runId string) { func (ctl *Control) reloadConf(pxyCfgs map[string]config.ProxyConf, visitorCfgs map[string]config.ProxyConf) error {
ctl.mu.Lock() err := ctl.pm.Reload(pxyCfgs, visitorCfgs)
defer ctl.mu.Unlock() return err
ctl.runId = runId
}
func (ctl *Control) getRunId() string {
ctl.mu.RLock()
defer ctl.mu.RUnlock()
return ctl.runId
}
func (ctl *Control) getProxy(name string) (pxy Proxy, ok bool) {
ctl.mu.RLock()
defer ctl.mu.RUnlock()
pxy, ok = ctl.proxies[name]
return
}
func (ctl *Control) addProxy(name string, pxy Proxy) {
ctl.mu.Lock()
defer ctl.mu.Unlock()
ctl.proxies[name] = pxy
}
func (ctl *Control) getProxyConf(name string) (conf config.ProxyConf, ok bool) {
ctl.mu.RLock()
defer ctl.mu.RUnlock()
conf, ok = ctl.pxyCfgs[name]
return
}
func (ctl *Control) reloadConf(pxyCfgs map[string]config.ProxyConf, visitorCfgs map[string]config.ProxyConf) {
ctl.mu.Lock()
defer ctl.mu.Unlock()
removedPxyNames := make([]string, 0)
for name, oldCfg := range ctl.pxyCfgs {
del := false
cfg, ok := pxyCfgs[name]
if !ok {
del = true
} else {
if !oldCfg.Compare(cfg) {
del = true
}
}
if del {
removedPxyNames = append(removedPxyNames, name)
delete(ctl.pxyCfgs, name)
if pxy, ok := ctl.proxies[name]; ok {
pxy.Close()
}
delete(ctl.proxies, name)
ctl.sendCh <- &msg.CloseProxy{
ProxyName: name,
}
}
}
ctl.Info("proxy removed: %v", removedPxyNames)
addedPxyNames := make([]string, 0)
for name, cfg := range pxyCfgs {
if _, ok := ctl.pxyCfgs[name]; !ok {
ctl.pxyCfgs[name] = cfg
addedPxyNames = append(addedPxyNames, name)
}
}
ctl.Info("proxy added: %v", addedPxyNames)
removedVisitorName := make([]string, 0)
for name, oldVisitorCfg := range ctl.visitorCfgs {
del := false
cfg, ok := visitorCfgs[name]
if !ok {
del = true
} else {
if !oldVisitorCfg.Compare(cfg) {
del = true
}
}
if del {
removedVisitorName = append(removedVisitorName, name)
delete(ctl.visitorCfgs, name)
if visitor, ok := ctl.visitors[name]; ok {
visitor.Close()
}
delete(ctl.visitors, name)
}
}
ctl.Info("visitor removed: %v", removedVisitorName)
addedVisitorName := make([]string, 0)
for name, visitorCfg := range visitorCfgs {
if _, ok := ctl.visitorCfgs[name]; !ok {
ctl.visitorCfgs[name] = visitorCfg
addedVisitorName = append(addedVisitorName, name)
}
}
ctl.Info("visitor added: %v", addedVisitorName)
} }

View File

@ -39,13 +39,13 @@ type Proxy interface {
// InWorkConn accept work connections registered to server. // InWorkConn accept work connections registered to server.
InWorkConn(conn frpNet.Conn) InWorkConn(conn frpNet.Conn)
Close() Close()
log.Logger log.Logger
} }
func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy) { func NewProxy(pxyConf config.ProxyConf) (pxy Proxy) {
baseProxy := BaseProxy{ baseProxy := BaseProxy{
ctl: ctl,
Logger: log.NewPrefixLogger(pxyConf.GetName()), Logger: log.NewPrefixLogger(pxyConf.GetName()),
} }
switch cfg := pxyConf.(type) { switch cfg := pxyConf.(type) {
@ -84,7 +84,6 @@ func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy) {
} }
type BaseProxy struct { type BaseProxy struct {
ctl *Control
closed bool closed bool
mu sync.RWMutex mu sync.RWMutex
log.Logger log.Logger

349
client/proxy_manager.go Normal file
View File

@ -0,0 +1,349 @@
package client
import (
"fmt"
"sync"
"github.com/fatedier/frp/models/config"
"github.com/fatedier/frp/models/msg"
"github.com/fatedier/frp/utils/errors"
"github.com/fatedier/frp/utils/log"
frpNet "github.com/fatedier/frp/utils/net"
)
const (
ProxyStatusNew = "new"
ProxyStatusStartErr = "start error"
ProxyStatusRunning = "running"
ProxyStatusClosed = "closed"
)
type ProxyManager struct {
ctl *Control
proxies map[string]*ProxyWrapper
visitorCfgs map[string]config.ProxyConf
visitors map[string]Visitor
sendCh chan (msg.Message)
closed bool
mu sync.RWMutex
log.Logger
}
type ProxyWrapper struct {
Name string
Type string
Status string
Err string
Cfg config.ProxyConf
RemoteAddr string
pxy Proxy
mu sync.RWMutex
}
type ProxyStatus struct {
Name string `json:"name"`
Type string `json:"type"`
Status string `json:"status"`
Err string `json:"err"`
Cfg config.ProxyConf `json:"cfg"`
// Got from server.
RemoteAddr string `json:"remote_addr"`
}
func NewProxyWrapper(cfg config.ProxyConf) *ProxyWrapper {
return &ProxyWrapper{
Name: cfg.GetName(),
Type: cfg.GetType(),
Status: ProxyStatusNew,
Cfg: cfg,
pxy: nil,
}
}
func (pw *ProxyWrapper) IsRunning() bool {
pw.mu.RLock()
defer pw.mu.RUnlock()
if pw.Status == ProxyStatusRunning {
return true
} else {
return false
}
}
func (pw *ProxyWrapper) GetStatus() *ProxyStatus {
pw.mu.RLock()
defer pw.mu.RUnlock()
ps := &ProxyStatus{
Name: pw.Name,
Type: pw.Type,
Status: pw.Status,
Err: pw.Err,
Cfg: pw.Cfg,
RemoteAddr: pw.RemoteAddr,
}
return ps
}
func (pw *ProxyWrapper) Start(remoteAddr string, serverRespErr string) error {
if pw.pxy != nil {
pw.pxy.Close()
pw.pxy = nil
}
if serverRespErr != "" {
pw.mu.Lock()
pw.Status = ProxyStatusStartErr
pw.RemoteAddr = remoteAddr
pw.Err = serverRespErr
pw.mu.Unlock()
return fmt.Errorf(serverRespErr)
}
pxy := NewProxy(pw.Cfg)
pw.mu.Lock()
defer pw.mu.Unlock()
pw.RemoteAddr = remoteAddr
if err := pxy.Run(); err != nil {
pw.Status = ProxyStatusStartErr
pw.Err = err.Error()
return err
}
pw.Status = ProxyStatusRunning
pw.Err = ""
pw.pxy = pxy
return nil
}
func (pw *ProxyWrapper) InWorkConn(workConn frpNet.Conn) {
pw.mu.RLock()
pxy := pw.pxy
pw.mu.RUnlock()
if pxy != nil {
workConn.Debug("start a new work connection, localAddr: %s remoteAddr: %s", workConn.LocalAddr().String(), workConn.RemoteAddr().String())
go pxy.InWorkConn(workConn)
} else {
workConn.Close()
}
}
func (pw *ProxyWrapper) Close() {
pw.mu.Lock()
defer pw.mu.Unlock()
if pw.pxy != nil {
pw.pxy.Close()
pw.pxy = nil
}
pw.Status = ProxyStatusClosed
}
func NewProxyManager(ctl *Control, msgSendCh chan (msg.Message), logPrefix string) *ProxyManager {
return &ProxyManager{
ctl: ctl,
proxies: make(map[string]*ProxyWrapper),
visitorCfgs: make(map[string]config.ProxyConf),
visitors: make(map[string]Visitor),
sendCh: msgSendCh,
closed: false,
Logger: log.NewPrefixLogger(logPrefix),
}
}
func (pm *ProxyManager) Reset(msgSendCh chan (msg.Message), logPrefix string) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.closed = false
pm.sendCh = msgSendCh
pm.ClearLogPrefix()
pm.AddLogPrefix(logPrefix)
}
// Must hold the lock before calling this function.
func (pm *ProxyManager) sendMsg(m msg.Message) error {
err := errors.PanicToError(func() {
pm.sendCh <- m
})
if err != nil {
pm.closed = true
}
return err
}
func (pm *ProxyManager) StartProxy(name string, remoteAddr string, serverRespErr string) error {
pm.mu.Lock()
defer pm.mu.Unlock()
if pm.closed {
return fmt.Errorf("ProxyManager is closed now")
}
pxy, ok := pm.proxies[name]
if !ok {
return fmt.Errorf("no proxy found")
}
if err := pxy.Start(remoteAddr, serverRespErr); err != nil {
errRet := err
err = pm.sendMsg(&msg.CloseProxy{
ProxyName: name,
})
if err != nil {
errRet = fmt.Errorf("send CloseProxy message error")
}
return errRet
}
return nil
}
func (pm *ProxyManager) CloseProxies() {
pm.mu.RLock()
defer pm.mu.RUnlock()
for _, pxy := range pm.proxies {
pxy.Close()
}
}
func (pm *ProxyManager) CheckAndStartProxy() {
pm.mu.RLock()
defer pm.mu.RUnlock()
if pm.closed {
pm.Warn("CheckAndStartProxy error: ProxyManager is closed now")
return
}
for _, pxy := range pm.proxies {
if !pxy.IsRunning() {
var newProxyMsg msg.NewProxy
pxy.Cfg.UnMarshalToMsg(&newProxyMsg)
err := pm.sendMsg(&newProxyMsg)
if err != nil {
pm.Warn("[%s] proxy send NewProxy message error")
return
}
}
}
for _, cfg := range pm.visitorCfgs {
if _, exist := pm.visitors[cfg.GetName()]; !exist {
pm.Info("try to start visitor [%s]", cfg.GetName())
visitor := NewVisitor(pm.ctl, cfg)
err := visitor.Run()
if err != nil {
visitor.Warn("start error: %v", err)
continue
}
pm.visitors[cfg.GetName()] = visitor
visitor.Info("start visitor success")
}
}
}
func (pm *ProxyManager) Reload(pxyCfgs map[string]config.ProxyConf, visitorCfgs map[string]config.ProxyConf) error {
pm.mu.Lock()
defer pm.mu.Unlock()
if pm.closed {
err := fmt.Errorf("Reload error: ProxyManager is closed now")
pm.Warn(err.Error())
return err
}
delPxyNames := make([]string, 0)
for name, pxy := range pm.proxies {
del := false
cfg, ok := pxyCfgs[name]
if !ok {
del = true
} else {
if !pxy.Cfg.Compare(cfg) {
del = true
}
}
if del {
delPxyNames = append(delPxyNames, name)
delete(pm.proxies, name)
pxy.Close()
err := pm.sendMsg(&msg.CloseProxy{
ProxyName: name,
})
if err != nil {
err = fmt.Errorf("Reload error: ProxyManager is closed now")
pm.Warn(err.Error())
return err
}
}
}
pm.Info("proxy removed: %v", delPxyNames)
addPxyNames := make([]string, 0)
for name, cfg := range pxyCfgs {
if _, ok := pm.proxies[name]; !ok {
pxy := NewProxyWrapper(cfg)
pm.proxies[name] = pxy
addPxyNames = append(addPxyNames, name)
}
}
pm.Info("proxy added: %v", addPxyNames)
delVisitorName := make([]string, 0)
for name, oldVisitorCfg := range pm.visitorCfgs {
del := false
cfg, ok := visitorCfgs[name]
if !ok {
del = true
} else {
if !oldVisitorCfg.Compare(cfg) {
del = true
}
}
if del {
delVisitorName = append(delVisitorName, name)
delete(pm.visitorCfgs, name)
if visitor, ok := pm.visitors[name]; ok {
visitor.Close()
}
delete(pm.visitors, name)
}
}
pm.Info("visitor removed: %v", delVisitorName)
addVisitorName := make([]string, 0)
for name, visitorCfg := range visitorCfgs {
if _, ok := pm.visitorCfgs[name]; !ok {
pm.visitorCfgs[name] = visitorCfg
addVisitorName = append(addVisitorName, name)
}
}
pm.Info("visitor added: %v", addVisitorName)
return nil
}
func (pm *ProxyManager) HandleWorkConn(name string, workConn frpNet.Conn) {
pm.mu.RLock()
pw, ok := pm.proxies[name]
pm.mu.RUnlock()
if ok {
pw.InWorkConn(workConn)
} else {
workConn.Close()
}
}
func (pm *ProxyManager) GetAllProxyStatus() []*ProxyStatus {
ps := make([]*ProxyStatus, 0)
pm.mu.RLock()
defer pm.mu.RUnlock()
for _, pxy := range pm.proxies {
ps = append(ps, pxy.GetStatus())
}
return ps
}

View File

@ -53,6 +53,6 @@ func (svr *Service) Run() error {
return nil return nil
} }
func (svr *Service) Close() error { func (svr *Service) Close() {
return svr.ctl.Close() svr.ctl.Close()
} }

View File

@ -77,7 +77,7 @@ type StcpVisitor struct {
} }
func (sv *StcpVisitor) Run() (err error) { func (sv *StcpVisitor) Run() (err error) {
sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, int64(sv.cfg.BindPort)) sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, sv.cfg.BindPort)
if err != nil { if err != nil {
return return
} }
@ -164,7 +164,7 @@ type XtcpVisitor struct {
} }
func (sv *XtcpVisitor) Run() (err error) { func (sv *XtcpVisitor) Run() (err error) {
sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, int64(sv.cfg.BindPort)) sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, sv.cfg.BindPort)
if err != nil { if err != nil {
return return
} }
@ -255,7 +255,7 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) {
sv.Error("get natHoleResp client address error: %s", natHoleRespMsg.ClientAddr) sv.Error("get natHoleResp client address error: %s", natHoleRespMsg.ClientAddr)
return return
} }
sv.sendDetectMsg(array[0], int64(port), laddr, []byte(natHoleRespMsg.Sid)) sv.sendDetectMsg(array[0], int(port), laddr, []byte(natHoleRespMsg.Sid))
sv.Trace("send all detect msg done") sv.Trace("send all detect msg done")
// Listen for visitorConn's address and wait for client connection. // Listen for visitorConn's address and wait for client connection.
@ -302,7 +302,7 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) {
sv.Debug("join connections closed") sv.Debug("join connections closed")
} }
func (sv *XtcpVisitor) sendDetectMsg(addr string, port int64, laddr *net.UDPAddr, content []byte) (err error) { func (sv *XtcpVisitor) sendDetectMsg(addr string, port int, laddr *net.UDPAddr, content []byte) (err error) {
daddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addr, port)) daddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addr, port))
if err != nil { if err != nil {
return err return err

View File

@ -28,6 +28,7 @@ import (
"time" "time"
docopt "github.com/docopt/docopt-go" docopt "github.com/docopt/docopt-go"
"github.com/rodaine/table"
ini "github.com/vaughan0/go-ini" ini "github.com/vaughan0/go-ini"
"github.com/fatedier/frp/client" "github.com/fatedier/frp/client"
@ -44,7 +45,8 @@ var usage string = `frpc is the client of frp
Usage: Usage:
frpc [-c config_file] [-L log_file] [--log-level=<log_level>] [--server-addr=<server_addr>] frpc [-c config_file] [-L log_file] [--log-level=<log_level>] [--server-addr=<server_addr>]
frpc [-c config_file] --reload frpc reload [-c config_file]
frpc status [-c config_file]
frpc -h | --help frpc -h | --help
frpc -v | --version frpc -v | --version
@ -53,7 +55,6 @@ Options:
-L log_file set output log file, including console -L log_file set output log file, including console
--log-level=<log_level> set log level: debug, info, warn, error --log-level=<log_level> set log level: debug, info, warn, error
--server-addr=<server_addr> addr which frps is listening for, example: 0.0.0.0:7000 --server-addr=<server_addr> addr which frps is listening for, example: 0.0.0.0:7000
--reload reload configure file without program exit
-h --help show this screen -h --help show this screen
-v --version show version -v --version show version
` `
@ -82,40 +83,25 @@ func main() {
config.ClientCommonCfg.ConfigFile = confFile config.ClientCommonCfg.ConfigFile = confFile
// check if reload command // check if reload command
if args["--reload"] != nil { if args["reload"] != nil {
if args["--reload"].(bool) { if args["reload"].(bool) {
req, err := http.NewRequest("GET", "http://"+ if err = CmdReload(); err != nil {
config.ClientCommonCfg.AdminAddr+":"+fmt.Sprintf("%d", config.ClientCommonCfg.AdminPort)+"/api/reload", nil)
if err != nil {
fmt.Printf("frps reload error: %v\n", err) fmt.Printf("frps reload error: %v\n", err)
os.Exit(1) os.Exit(1)
} else {
fmt.Printf("reload success\n")
os.Exit(0)
}
}
} }
authStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(config.ClientCommonCfg.AdminUser+":"+ // check if status command
config.ClientCommonCfg.AdminPwd)) if args["status"] != nil {
if args["status"].(bool) {
req.Header.Add("Authorization", authStr) if err = CmdStatus(); err != nil {
resp, err := http.DefaultClient.Do(req) fmt.Printf("frps get status error: %v\n", err)
if err != nil {
fmt.Printf("frpc reload error: %v\n", err)
os.Exit(1) os.Exit(1)
} else { } else {
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
fmt.Printf("frpc reload error: %v\n", err)
os.Exit(1)
}
res := &client.GeneralResponse{}
err = json.Unmarshal(body, &res)
if err != nil {
fmt.Printf("http response error: %s\n", strings.TrimSpace(string(body)))
os.Exit(1)
} else if res.Code != 0 {
fmt.Printf("reload error: %s\n", res.Msg)
os.Exit(1)
}
fmt.Printf("reload success\n")
os.Exit(0) os.Exit(0)
} }
} }
@ -146,7 +132,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
config.ClientCommonCfg.ServerAddr = addr[0] config.ClientCommonCfg.ServerAddr = addr[0]
config.ClientCommonCfg.ServerPort = serverPort config.ClientCommonCfg.ServerPort = int(serverPort)
} }
if args["-v"] != nil { if args["-v"] != nil {
@ -187,3 +173,133 @@ func HandleSignal(svr *client.Service) {
time.Sleep(250 * time.Millisecond) time.Sleep(250 * time.Millisecond)
os.Exit(0) os.Exit(0)
} }
func CmdReload() error {
if config.ClientCommonCfg.AdminPort == 0 {
return fmt.Errorf("admin_port shoud be set if you want to use reload feature")
}
req, err := http.NewRequest("GET", "http://"+
config.ClientCommonCfg.AdminAddr+":"+fmt.Sprintf("%d", config.ClientCommonCfg.AdminPort)+"/api/reload", nil)
if err != nil {
return err
}
authStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(config.ClientCommonCfg.AdminUser+":"+
config.ClientCommonCfg.AdminPwd))
req.Header.Add("Authorization", authStr)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
} else {
if resp.StatusCode != 200 {
return fmt.Errorf("admin api status code [%d]", resp.StatusCode)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
res := &client.GeneralResponse{}
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("unmarshal http response error: %s", strings.TrimSpace(string(body)))
} else if res.Code != 0 {
return fmt.Errorf(res.Msg)
}
}
return nil
}
func CmdStatus() error {
if config.ClientCommonCfg.AdminPort == 0 {
return fmt.Errorf("admin_port shoud be set if you want to get proxy status")
}
req, err := http.NewRequest("GET", "http://"+
config.ClientCommonCfg.AdminAddr+":"+fmt.Sprintf("%d", config.ClientCommonCfg.AdminPort)+"/api/status", nil)
if err != nil {
return err
}
authStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(config.ClientCommonCfg.AdminUser+":"+
config.ClientCommonCfg.AdminPwd))
req.Header.Add("Authorization", authStr)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
} else {
if resp.StatusCode != 200 {
return fmt.Errorf("admin api status code [%d]", resp.StatusCode)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
res := &client.StatusResp{}
err = json.Unmarshal(body, &res)
if err != nil {
return fmt.Errorf("unmarshal http response error: %s", strings.TrimSpace(string(body)))
}
fmt.Println("Proxy Status...")
if len(res.Tcp) > 0 {
fmt.Printf("TCP")
tbl := table.New("Name", "Status", "LocalAddr", "Plugin", "RemoteAddr", "Error")
for _, ps := range res.Tcp {
tbl.AddRow(ps.Name, ps.Status, ps.LocalAddr, ps.Plugin, ps.RemoteAddr, ps.Err)
}
tbl.Print()
fmt.Println("")
}
if len(res.Udp) > 0 {
fmt.Printf("UDP")
tbl := table.New("Name", "Status", "LocalAddr", "Plugin", "RemoteAddr", "Error")
for _, ps := range res.Udp {
tbl.AddRow(ps.Name, ps.Status, ps.LocalAddr, ps.Plugin, ps.RemoteAddr, ps.Err)
}
tbl.Print()
fmt.Println("")
}
if len(res.Http) > 0 {
fmt.Printf("HTTP")
tbl := table.New("Name", "Status", "LocalAddr", "Plugin", "RemoteAddr", "Error")
for _, ps := range res.Http {
tbl.AddRow(ps.Name, ps.Status, ps.LocalAddr, ps.Plugin, ps.RemoteAddr, ps.Err)
}
tbl.Print()
fmt.Println("")
}
if len(res.Https) > 0 {
fmt.Printf("HTTPS")
tbl := table.New("Name", "Status", "LocalAddr", "Plugin", "RemoteAddr", "Error")
for _, ps := range res.Https {
tbl.AddRow(ps.Name, ps.Status, ps.LocalAddr, ps.Plugin, ps.RemoteAddr, ps.Err)
}
tbl.Print()
fmt.Println("")
}
if len(res.Stcp) > 0 {
fmt.Printf("STCP")
tbl := table.New("Name", "Status", "LocalAddr", "Plugin", "RemoteAddr", "Error")
for _, ps := range res.Stcp {
tbl.AddRow(ps.Name, ps.Status, ps.LocalAddr, ps.Plugin, ps.RemoteAddr, ps.Err)
}
tbl.Print()
fmt.Println("")
}
if len(res.Xtcp) > 0 {
fmt.Printf("XTCP")
tbl := table.New("Name", "Status", "LocalAddr", "Plugin", "RemoteAddr", "Error")
for _, ps := range res.Xtcp {
tbl.AddRow(ps.Name, ps.Status, ps.LocalAddr, ps.Plugin, ps.RemoteAddr, ps.Err)
}
tbl.Print()
fmt.Println("")
}
}
return nil
}

View File

@ -91,7 +91,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
config.ServerCommonCfg.BindAddr = addr[0] config.ServerCommonCfg.BindAddr = addr[0]
config.ServerCommonCfg.BindPort = bindPort config.ServerCommonCfg.BindPort = int(bindPort)
} }
if args["-v"] != nil { if args["-v"] != nil {

View File

@ -88,7 +88,7 @@ http_pwd = admin
# if domain for frps is frps.com, then you can access [web01] proxy by URL http://test.frps.com # if domain for frps is frps.com, then you can access [web01] proxy by URL http://test.frps.com
subdomain = web01 subdomain = web01
custom_domains = web02.yourdomain.com custom_domains = web02.yourdomain.com
# locations is only useful for http type # locations is only available for http type
locations = /,/pic locations = /,/pic
host_header_rewrite = example.com host_header_rewrite = example.com

6
glide.lock generated
View File

@ -1,5 +1,5 @@
hash: 03ff8b71f63e9038c0182a4ef2a55aa9349782f4813c331e2d1f02f3dd15b4f8 hash: 188e1149e415ff9cefab8db2cded030efae57558a0b9551795c5c7d0b0572a7b
updated: 2017-11-01T16:16:18.577622991+08:00 updated: 2018-01-17T01:14:34.435613+08:00
imports: imports:
- name: github.com/armon/go-socks5 - name: github.com/armon/go-socks5
version: e75332964ef517daa070d7c38a9466a0d687e0a5 version: e75332964ef517daa070d7c38a9466a0d687e0a5
@ -33,6 +33,8 @@ imports:
version: 274df120e9065bdd08eb1120e0375e3dc1ae8465 version: 274df120e9065bdd08eb1120e0375e3dc1ae8465
subpackages: subpackages:
- fs - fs
- name: github.com/rodaine/table
version: 212a2ad1c462ed4d5b5511ea2b480a573281dbbd
- name: github.com/stretchr/testify - name: github.com/stretchr/testify
version: 2402e8e7a02fc811447d11f881aa9746cdc57983 version: 2402e8e7a02fc811447d11f881aa9746cdc57983
subpackages: subpackages:

View File

@ -71,3 +71,5 @@ import:
- internal/iana - internal/iana
- internal/socket - internal/socket
- ipv4 - ipv4
- package: github.com/rodaine/table
version: v1.0.0

View File

@ -29,8 +29,8 @@ var ClientCommonCfg *ClientCommonConf
type ClientCommonConf struct { type ClientCommonConf struct {
ConfigFile string ConfigFile string
ServerAddr string ServerAddr string
ServerPort int64 ServerPort int
ServerUdpPort int64 // this is specified by login response message from frps ServerUdpPort int // this is specified by login response message from frps
HttpProxy string HttpProxy string
LogFile string LogFile string
LogWay string LogWay string
@ -38,7 +38,7 @@ type ClientCommonConf struct {
LogMaxDays int64 LogMaxDays int64
PrivilegeToken string PrivilegeToken string
AdminAddr string AdminAddr string
AdminPort int64 AdminPort int
AdminUser string AdminUser string
AdminPwd string AdminPwd string
PoolCount int PoolCount int
@ -93,7 +93,12 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
tmpStr, ok = conf.Get("common", "server_port") tmpStr, ok = conf.Get("common", "server_port")
if ok { if ok {
cfg.ServerPort, _ = strconv.ParseInt(tmpStr, 10, 64) v, err = strconv.ParseInt(tmpStr, 10, 64)
if err != nil {
err = fmt.Errorf("Parse conf error: invalid server_port")
return
}
cfg.ServerPort = int(v)
} }
tmpStr, ok = conf.Get("common", "http_proxy") tmpStr, ok = conf.Get("common", "http_proxy")
@ -139,7 +144,10 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
tmpStr, ok = conf.Get("common", "admin_port") tmpStr, ok = conf.Get("common", "admin_port")
if ok { if ok {
if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil { if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil {
cfg.AdminPort = v cfg.AdminPort = int(v)
} else {
err = fmt.Errorf("Parse conf error: invalid admin_port")
return
} }
} }
@ -203,7 +211,7 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
if ok { if ok {
v, err = strconv.ParseInt(tmpStr, 10, 64) v, err = strconv.ParseInt(tmpStr, 10, 64)
if err != nil { if err != nil {
err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect") err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout")
return return
} else { } else {
cfg.HeartBeatTimeout = v cfg.HeartBeatTimeout = v
@ -214,7 +222,7 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
if ok { if ok {
v, err = strconv.ParseInt(tmpStr, 10, 64) v, err = strconv.ParseInt(tmpStr, 10, 64)
if err != nil { if err != nil {
err = fmt.Errorf("Parse conf error: heartbeat_interval is incorrect") err = fmt.Errorf("Parse conf error: invalid heartbeat_interval")
return return
} else { } else {
cfg.HeartBeatInterval = v cfg.HeartBeatInterval = v
@ -222,12 +230,12 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
} }
if cfg.HeartBeatInterval <= 0 { if cfg.HeartBeatInterval <= 0 {
err = fmt.Errorf("Parse conf error: heartbeat_interval is incorrect") err = fmt.Errorf("Parse conf error: invalid heartbeat_interval")
return return
} }
if cfg.HeartBeatTimeout < cfg.HeartBeatInterval { if cfg.HeartBeatTimeout < cfg.HeartBeatInterval {
err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect, heartbeat_timeout is less than heartbeat_interval") err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout, heartbeat_timeout is less than heartbeat_interval")
return return
} }
return return

View File

@ -23,7 +23,6 @@ import (
"github.com/fatedier/frp/models/consts" "github.com/fatedier/frp/models/consts"
"github.com/fatedier/frp/models/msg" "github.com/fatedier/frp/models/msg"
"github.com/fatedier/frp/utils/util"
ini "github.com/vaughan0/go-ini" ini "github.com/vaughan0/go-ini"
) )
@ -52,6 +51,7 @@ func NewConfByType(proxyType string) ProxyConf {
type ProxyConf interface { type ProxyConf interface {
GetName() string GetName() string
GetType() string
GetBaseInfo() *BaseProxyConf GetBaseInfo() *BaseProxyConf
LoadFromMsg(pMsg *msg.NewProxy) LoadFromMsg(pMsg *msg.NewProxy)
LoadFromFile(name string, conf ini.Section) error LoadFromFile(name string, conf ini.Section) error
@ -103,6 +103,10 @@ func (cfg *BaseProxyConf) GetName() string {
return cfg.ProxyName return cfg.ProxyName
} }
func (cfg *BaseProxyConf) GetType() string {
return cfg.ProxyType
}
func (cfg *BaseProxyConf) GetBaseInfo() *BaseProxyConf { func (cfg *BaseProxyConf) GetBaseInfo() *BaseProxyConf {
return cfg return cfg
} }
@ -158,7 +162,7 @@ func (cfg *BaseProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) {
// Bind info // Bind info
type BindInfoConf struct { type BindInfoConf struct {
BindAddr string `json:"bind_addr"` BindAddr string `json:"bind_addr"`
RemotePort int64 `json:"remote_port"` RemotePort int `json:"remote_port"`
} }
func (cfg *BindInfoConf) compare(cmp *BindInfoConf) bool { func (cfg *BindInfoConf) compare(cmp *BindInfoConf) bool {
@ -178,10 +182,13 @@ func (cfg *BindInfoConf) LoadFromFile(name string, section ini.Section) (err err
var ( var (
tmpStr string tmpStr string
ok bool ok bool
v int64
) )
if tmpStr, ok = section["remote_port"]; ok { if tmpStr, ok = section["remote_port"]; ok {
if cfg.RemotePort, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
return fmt.Errorf("Parse conf error: proxy [%s] remote_port error", name) return fmt.Errorf("Parse conf error: proxy [%s] remote_port error", name)
} else {
cfg.RemotePort = int(v)
} }
} else { } else {
return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", name) return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", name)
@ -194,11 +201,6 @@ func (cfg *BindInfoConf) UnMarshalToMsg(pMsg *msg.NewProxy) {
} }
func (cfg *BindInfoConf) check() (err error) { func (cfg *BindInfoConf) check() (err error) {
if len(ServerCommonCfg.PrivilegeAllowPorts) != 0 {
if ok := util.ContainsPort(ServerCommonCfg.PrivilegeAllowPorts, cfg.RemotePort); !ok {
return fmt.Errorf("remote port [%d] isn't allowed", cfg.RemotePort)
}
}
return nil return nil
} }

View File

@ -19,7 +19,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/fatedier/frp/utils/util"
ini "github.com/vaughan0/go-ini" ini "github.com/vaughan0/go-ini"
) )
@ -29,20 +28,20 @@ var ServerCommonCfg *ServerCommonConf
type ServerCommonConf struct { type ServerCommonConf struct {
ConfigFile string ConfigFile string
BindAddr string BindAddr string
BindPort int64 BindPort int
BindUdpPort int64 BindUdpPort int
KcpBindPort int64 KcpBindPort int
ProxyBindAddr string ProxyBindAddr string
// If VhostHttpPort equals 0, don't listen a public port for http protocol. // If VhostHttpPort equals 0, don't listen a public port for http protocol.
VhostHttpPort int64 VhostHttpPort int
// if VhostHttpsPort equals 0, don't listen a public port for https protocol // if VhostHttpsPort equals 0, don't listen a public port for https protocol
VhostHttpsPort int64 VhostHttpsPort int
DashboardAddr string DashboardAddr string
// if DashboardPort equals 0, dashboard is not available // if DashboardPort equals 0, dashboard is not available
DashboardPort int64 DashboardPort int
DashboardUser string DashboardUser string
DashboardPwd string DashboardPwd string
AssetsDir string AssetsDir string
@ -56,8 +55,7 @@ type ServerCommonConf struct {
SubDomainHost string SubDomainHost string
TcpMux bool TcpMux bool
// if PrivilegeAllowPorts is not nil, tcp proxies which remote port exist in this map can be connected PrivilegeAllowPorts map[int]struct{}
PrivilegeAllowPorts [][2]int64
MaxPoolCount int64 MaxPoolCount int64
HeartBeatTimeout int64 HeartBeatTimeout int64
UserConnTimeout int64 UserConnTimeout int64
@ -87,6 +85,7 @@ func GetDefaultServerCommonConf() *ServerCommonConf {
AuthTimeout: 900, AuthTimeout: 900,
SubDomainHost: "", SubDomainHost: "",
TcpMux: true, TcpMux: true,
PrivilegeAllowPorts: make(map[int]struct{}),
MaxPoolCount: 5, MaxPoolCount: 5,
HeartBeatTimeout: 90, HeartBeatTimeout: 90,
UserConnTimeout: 10, UserConnTimeout: 10,
@ -109,25 +108,31 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
tmpStr, ok = conf.Get("common", "bind_port") tmpStr, ok = conf.Get("common", "bind_port")
if ok { if ok {
v, err = strconv.ParseInt(tmpStr, 10, 64) if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
if err == nil { err = fmt.Errorf("Parse conf error: invalid bind_port")
cfg.BindPort = v return
} else {
cfg.BindPort = int(v)
} }
} }
tmpStr, ok = conf.Get("common", "bind_udp_port") tmpStr, ok = conf.Get("common", "bind_udp_port")
if ok { if ok {
v, err = strconv.ParseInt(tmpStr, 10, 64) if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
if err == nil { err = fmt.Errorf("Parse conf error: invalid bind_udp_port")
cfg.BindUdpPort = v return
} else {
cfg.BindUdpPort = int(v)
} }
} }
tmpStr, ok = conf.Get("common", "kcp_bind_port") tmpStr, ok = conf.Get("common", "kcp_bind_port")
if ok { if ok {
v, err = strconv.ParseInt(tmpStr, 10, 64) if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
if err == nil && v > 0 { err = fmt.Errorf("Parse conf error: invalid kcp_bind_port")
cfg.KcpBindPort = v return
} else {
cfg.KcpBindPort = int(v)
} }
} }
@ -140,10 +145,11 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
tmpStr, ok = conf.Get("common", "vhost_http_port") tmpStr, ok = conf.Get("common", "vhost_http_port")
if ok { if ok {
cfg.VhostHttpPort, err = strconv.ParseInt(tmpStr, 10, 64) if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
if err != nil { err = fmt.Errorf("Parse conf error: invalid vhost_http_port")
err = fmt.Errorf("Parse conf error: vhost_http_port is incorrect")
return return
} else {
cfg.VhostHttpPort = int(v)
} }
} else { } else {
cfg.VhostHttpPort = 0 cfg.VhostHttpPort = 0
@ -151,10 +157,11 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
tmpStr, ok = conf.Get("common", "vhost_https_port") tmpStr, ok = conf.Get("common", "vhost_https_port")
if ok { if ok {
cfg.VhostHttpsPort, err = strconv.ParseInt(tmpStr, 10, 64) if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
if err != nil { err = fmt.Errorf("Parse conf error: invalid vhost_https_port")
err = fmt.Errorf("Parse conf error: vhost_https_port is incorrect")
return return
} else {
cfg.VhostHttpsPort = int(v)
} }
} else { } else {
cfg.VhostHttpsPort = 0 cfg.VhostHttpsPort = 0
@ -169,10 +176,11 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
tmpStr, ok = conf.Get("common", "dashboard_port") tmpStr, ok = conf.Get("common", "dashboard_port")
if ok { if ok {
cfg.DashboardPort, err = strconv.ParseInt(tmpStr, 10, 64) if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
if err != nil { err = fmt.Errorf("Parse conf error: invalid dashboard_port")
err = fmt.Errorf("Parse conf error: dashboard_port is incorrect")
return return
} else {
cfg.DashboardPort = int(v)
} }
} else { } else {
cfg.DashboardPort = 0 cfg.DashboardPort = 0
@ -228,13 +236,46 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
cfg.PrivilegeToken, _ = conf.Get("common", "privilege_token") cfg.PrivilegeToken, _ = conf.Get("common", "privilege_token")
allowPortsStr, ok := conf.Get("common", "privilege_allow_ports") allowPortsStr, ok := conf.Get("common", "privilege_allow_ports")
// TODO: check if conflicts exist in port ranges
if ok { if ok {
cfg.PrivilegeAllowPorts, err = util.GetPortRanges(allowPortsStr) // e.g. 1000-2000,2001,2002,3000-4000
if err != nil { portRanges := strings.Split(allowPortsStr, ",")
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) for _, portRangeStr := range portRanges {
// 1000-2000 or 2001
portArray := strings.Split(portRangeStr, "-")
// length: only 1 or 2 is correct
rangeType := len(portArray)
if rangeType == 1 {
// single port
singlePort, errRet := strconv.ParseInt(portArray[0], 10, 64)
if errRet != nil {
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet)
return return
} }
cfg.PrivilegeAllowPorts[int(singlePort)] = struct{}{}
} else if rangeType == 2 {
// range ports
min, errRet := strconv.ParseInt(portArray[0], 10, 64)
if errRet != nil {
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet)
return
}
max, errRet := strconv.ParseInt(portArray[1], 10, 64)
if errRet != nil {
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet)
return
}
if max < min {
err = fmt.Errorf("Parse conf error: privilege_allow_ports range incorrect")
return
}
for i := min; i <= max; i++ {
cfg.PrivilegeAllowPorts[int(i)] = struct{}{}
}
} else {
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect")
return
}
}
} }
} }

View File

@ -92,7 +92,7 @@ type Login struct {
type LoginResp struct { type LoginResp struct {
Version string `json:"version"` Version string `json:"version"`
RunId string `json:"run_id"` RunId string `json:"run_id"`
ServerUdpPort int64 `json:"server_udp_port"` ServerUdpPort int `json:"server_udp_port"`
Error string `json:"error"` Error string `json:"error"`
} }
@ -104,7 +104,7 @@ type NewProxy struct {
UseCompression bool `json:"use_compression"` UseCompression bool `json:"use_compression"`
// tcp and udp only // tcp and udp only
RemotePort int64 `json:"remote_port"` RemotePort int `json:"remote_port"`
// http and https only // http and https only
CustomDomains []string `json:"custom_domains"` CustomDomains []string `json:"custom_domains"`
@ -120,6 +120,7 @@ type NewProxy struct {
type NewProxyResp struct { type NewProxyResp struct {
ProxyName string `json:"proxy_name"` ProxyName string `json:"proxy_name"`
RemoteAddr string `json:"remote_addr"`
Error string `json:"error"` Error string `json:"error"`
} }

View File

@ -253,13 +253,13 @@ func (ctl *Control) stoper() {
ctl.allShutdown.WaitStart() ctl.allShutdown.WaitStart()
close(ctl.readCh) close(ctl.readCh)
ctl.managerShutdown.WaitDown() ctl.managerShutdown.WaitDone()
close(ctl.sendCh) close(ctl.sendCh)
ctl.writerShutdown.WaitDown() ctl.writerShutdown.WaitDone()
ctl.conn.Close() ctl.conn.Close()
ctl.readerShutdown.WaitDown() ctl.readerShutdown.WaitDone()
close(ctl.workConnCh) close(ctl.workConnCh)
for workConn := range ctl.workConnCh { for workConn := range ctl.workConnCh {
@ -308,7 +308,7 @@ func (ctl *Control) manager() {
switch m := rawMsg.(type) { switch m := rawMsg.(type) {
case *msg.NewProxy: case *msg.NewProxy:
// register proxy in this control // register proxy in this control
err := ctl.RegisterProxy(m) remoteAddr, err := ctl.RegisterProxy(m)
resp := &msg.NewProxyResp{ resp := &msg.NewProxyResp{
ProxyName: m.ProxyName, ProxyName: m.ProxyName,
} }
@ -316,6 +316,7 @@ func (ctl *Control) manager() {
resp.Error = err.Error() resp.Error = err.Error()
ctl.conn.Warn("new proxy [%s] error: %v", m.ProxyName, err) ctl.conn.Warn("new proxy [%s] error: %v", m.ProxyName, err)
} else { } else {
resp.RemoteAddr = remoteAddr
ctl.conn.Info("new proxy [%s] success", m.ProxyName) ctl.conn.Info("new proxy [%s] success", m.ProxyName)
StatsNewProxy(m.ProxyName, m.ProxyType) StatsNewProxy(m.ProxyName, m.ProxyType)
} }
@ -332,24 +333,24 @@ func (ctl *Control) manager() {
} }
} }
func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (err error) { func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) {
var pxyConf config.ProxyConf var pxyConf config.ProxyConf
// Load configures from NewProxy message and check. // Load configures from NewProxy message and check.
pxyConf, err = config.NewProxyConf(pxyMsg) pxyConf, err = config.NewProxyConf(pxyMsg)
if err != nil { if err != nil {
return err return
} }
// NewProxy will return a interface Proxy. // NewProxy will return a interface Proxy.
// In fact it create different proxies by different proxy type, we just call run() here. // In fact it create different proxies by different proxy type, we just call run() here.
pxy, err := NewProxy(ctl, pxyConf) pxy, err := NewProxy(ctl, pxyConf)
if err != nil { if err != nil {
return err return remoteAddr, err
} }
err = pxy.Run() remoteAddr, err = pxy.Run()
if err != nil { if err != nil {
return err return
} }
defer func() { defer func() {
if err != nil { if err != nil {
@ -359,13 +360,13 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (err error) {
err = ctl.svr.RegisterProxy(pxyMsg.ProxyName, pxy) err = ctl.svr.RegisterProxy(pxyMsg.ProxyName, pxy)
if err != nil { if err != nil {
return err return
} }
ctl.mu.Lock() ctl.mu.Lock()
ctl.proxies[pxy.GetName()] = pxy ctl.proxies[pxy.GetName()] = pxy
ctl.mu.Unlock() ctl.mu.Unlock()
return nil return
} }
func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) { func (ctl *Control) CloseProxy(closeMsg *msg.CloseProxy) (err error) {

View File

@ -32,7 +32,7 @@ var (
httpServerWriteTimeout = 10 * time.Second httpServerWriteTimeout = 10 * time.Second
) )
func RunDashboardServer(addr string, port int64) (err error) { func RunDashboardServer(addr string, port int) (err error) {
// url router // url router
router := httprouter.New() router := httprouter.New()

View File

@ -36,8 +36,8 @@ type ServerInfoResp struct {
GeneralResponse GeneralResponse
Version string `json:"version"` Version string `json:"version"`
VhostHttpPort int64 `json:"vhost_http_port"` VhostHttpPort int `json:"vhost_http_port"`
VhostHttpsPort int64 `json:"vhost_https_port"` VhostHttpsPort int `json:"vhost_https_port"`
AuthTimeout int64 `json:"auth_timeout"` AuthTimeout int64 `json:"auth_timeout"`
SubdomainHost string `json:"subdomain_host"` SubdomainHost string `json:"subdomain_host"`
MaxPoolCount int64 `json:"max_pool_count"` MaxPoolCount int64 `json:"max_pool_count"`

180
server/ports.go Normal file
View File

@ -0,0 +1,180 @@
package server
import (
"errors"
"fmt"
"net"
"sync"
"time"
)
const (
MinPort = 1025
MaxPort = 65535
MaxPortReservedDuration = time.Duration(24) * time.Hour
CleanReservedPortsInterval = time.Hour
)
var (
ErrPortAlreadyUsed = errors.New("port already used")
ErrPortNotAllowed = errors.New("port not allowed")
ErrPortUnAvailable = errors.New("port unavailable")
ErrNoAvailablePort = errors.New("no available port")
)
type PortCtx struct {
ProxyName string
Port int
Closed bool
UpdateTime time.Time
}
type PortManager struct {
reservedPorts map[string]*PortCtx
usedPorts map[int]*PortCtx
freePorts map[int]struct{}
bindAddr string
netType string
mu sync.Mutex
}
func NewPortManager(netType string, bindAddr string, allowPorts map[int]struct{}) *PortManager {
pm := &PortManager{
reservedPorts: make(map[string]*PortCtx),
usedPorts: make(map[int]*PortCtx),
freePorts: make(map[int]struct{}),
bindAddr: bindAddr,
netType: netType,
}
if len(allowPorts) > 0 {
for port, _ := range allowPorts {
pm.freePorts[port] = struct{}{}
}
} else {
for i := MinPort; i <= MaxPort; i++ {
pm.freePorts[i] = struct{}{}
}
}
go pm.cleanReservedPortsWorker()
return pm
}
func (pm *PortManager) Acquire(name string, port int) (realPort int, err error) {
portCtx := &PortCtx{
ProxyName: name,
Closed: false,
UpdateTime: time.Now(),
}
var ok bool
pm.mu.Lock()
defer func() {
if err == nil {
portCtx.Port = realPort
}
pm.mu.Unlock()
}()
// check reserved ports first
if port == 0 {
if ctx, ok := pm.reservedPorts[name]; ok {
if pm.isPortAvailable(ctx.Port) {
realPort = ctx.Port
pm.usedPorts[realPort] = portCtx
pm.reservedPorts[name] = portCtx
delete(pm.freePorts, realPort)
return
}
}
}
if port == 0 {
// get random port
count := 0
maxTryTimes := 5
for k, _ := range pm.freePorts {
count++
if count > maxTryTimes {
break
}
if pm.isPortAvailable(k) {
realPort = k
pm.usedPorts[realPort] = portCtx
pm.reservedPorts[name] = portCtx
delete(pm.freePorts, realPort)
break
}
}
if realPort == 0 {
err = ErrNoAvailablePort
}
} else {
// specified port
if _, ok = pm.freePorts[port]; ok {
if pm.isPortAvailable(port) {
realPort = port
pm.usedPorts[realPort] = portCtx
pm.reservedPorts[name] = portCtx
delete(pm.freePorts, realPort)
} else {
err = ErrPortUnAvailable
}
} else {
if _, ok = pm.usedPorts[port]; ok {
err = ErrPortAlreadyUsed
} else {
err = ErrPortNotAllowed
}
}
}
return
}
func (pm *PortManager) isPortAvailable(port int) bool {
if pm.netType == "udp" {
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pm.bindAddr, port))
if err != nil {
return false
}
l, err := net.ListenUDP("udp", addr)
if err != nil {
return false
}
l.Close()
return true
} else {
l, err := net.Listen(pm.netType, fmt.Sprintf("%s:%d", pm.bindAddr, port))
if err != nil {
return false
}
l.Close()
return true
}
}
func (pm *PortManager) Release(port int) {
pm.mu.Lock()
defer pm.mu.Unlock()
if ctx, ok := pm.usedPorts[port]; ok {
pm.freePorts[port] = struct{}{}
delete(pm.usedPorts, port)
ctx.Closed = true
ctx.UpdateTime = time.Now()
}
}
// Release reserved port if it isn't used in last 24 hours.
func (pm *PortManager) cleanReservedPortsWorker() {
for {
time.Sleep(CleanReservedPortsInterval)
pm.mu.Lock()
for name, ctx := range pm.reservedPorts {
if ctx.Closed && time.Since(ctx.UpdateTime) > MaxPortReservedDuration {
delete(pm.reservedPorts, name)
}
}
pm.mu.Unlock()
}
}

View File

@ -19,6 +19,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
@ -29,11 +30,12 @@ import (
frpIo "github.com/fatedier/frp/utils/io" frpIo "github.com/fatedier/frp/utils/io"
"github.com/fatedier/frp/utils/log" "github.com/fatedier/frp/utils/log"
frpNet "github.com/fatedier/frp/utils/net" frpNet "github.com/fatedier/frp/utils/net"
"github.com/fatedier/frp/utils/util"
"github.com/fatedier/frp/utils/vhost" "github.com/fatedier/frp/utils/vhost"
) )
type Proxy interface { type Proxy interface {
Run() error Run() (remoteAddr string, err error)
GetControl() *Control GetControl() *Control
GetName() string GetName() string
GetConf() config.ProxyConf GetConf() config.ProxyConf
@ -163,19 +165,34 @@ func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy, err error) {
type TcpProxy struct { type TcpProxy struct {
BaseProxy BaseProxy
cfg *config.TcpProxyConf cfg *config.TcpProxyConf
realPort int
} }
func (pxy *TcpProxy) Run() error { func (pxy *TcpProxy) Run() (remoteAddr string, err error) {
listener, err := frpNet.ListenTcp(config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort) pxy.realPort, err = pxy.ctl.svr.tcpPortManager.Acquire(pxy.name, pxy.cfg.RemotePort)
if err != nil { if err != nil {
return err return
}
defer func() {
if err != nil {
pxy.ctl.svr.tcpPortManager.Release(pxy.realPort)
}
}()
remoteAddr = fmt.Sprintf(":%d", pxy.realPort)
pxy.cfg.RemotePort = pxy.realPort
listener, errRet := frpNet.ListenTcp(config.ServerCommonCfg.ProxyBindAddr, pxy.realPort)
if errRet != nil {
err = errRet
return
} }
listener.AddLogPrefix(pxy.name) listener.AddLogPrefix(pxy.name)
pxy.listeners = append(pxy.listeners, listener) pxy.listeners = append(pxy.listeners, listener)
pxy.Info("tcp proxy listen port [%d]", pxy.cfg.RemotePort) pxy.Info("tcp proxy listen port [%d]", pxy.cfg.RemotePort)
pxy.startListenHandler(pxy, HandleUserTcpConnection) pxy.startListenHandler(pxy, HandleUserTcpConnection)
return nil return
} }
func (pxy *TcpProxy) GetConf() config.ProxyConf { func (pxy *TcpProxy) GetConf() config.ProxyConf {
@ -184,6 +201,7 @@ func (pxy *TcpProxy) GetConf() config.ProxyConf {
func (pxy *TcpProxy) Close() { func (pxy *TcpProxy) Close() {
pxy.BaseProxy.Close() pxy.BaseProxy.Close()
pxy.ctl.svr.tcpPortManager.Release(pxy.realPort)
} }
type HttpProxy struct { type HttpProxy struct {
@ -193,7 +211,7 @@ type HttpProxy struct {
closeFuncs []func() closeFuncs []func()
} }
func (pxy *HttpProxy) Run() (err error) { func (pxy *HttpProxy) Run() (remoteAddr string, err error) {
routeConfig := vhost.VhostRouteConfig{ routeConfig := vhost.VhostRouteConfig{
RewriteHost: pxy.cfg.HostHeaderRewrite, RewriteHost: pxy.cfg.HostHeaderRewrite,
Username: pxy.cfg.HttpUser, Username: pxy.cfg.HttpUser,
@ -205,16 +223,19 @@ func (pxy *HttpProxy) Run() (err error) {
if len(locations) == 0 { if len(locations) == 0 {
locations = []string{""} locations = []string{""}
} }
addrs := make([]string, 0)
for _, domain := range pxy.cfg.CustomDomains { for _, domain := range pxy.cfg.CustomDomains {
routeConfig.Domain = domain routeConfig.Domain = domain
for _, location := range locations { for _, location := range locations {
routeConfig.Location = location routeConfig.Location = location
err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig) err = pxy.ctl.svr.httpReverseProxy.Register(routeConfig)
if err != nil { if err != nil {
return err return
} }
tmpDomain := routeConfig.Domain tmpDomain := routeConfig.Domain
tmpLocation := routeConfig.Location tmpLocation := routeConfig.Location
addrs = append(addrs, util.CanonicalAddr(tmpDomain, int(config.ServerCommonCfg.VhostHttpPort)))
pxy.closeFuncs = append(pxy.closeFuncs, func() { pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.ctl.svr.httpReverseProxy.UnRegister(tmpDomain, tmpLocation) pxy.ctl.svr.httpReverseProxy.UnRegister(tmpDomain, tmpLocation)
}) })
@ -226,18 +247,20 @@ func (pxy *HttpProxy) Run() (err error) {
routeConfig.Domain = pxy.cfg.SubDomain + "." + config.ServerCommonCfg.SubDomainHost routeConfig.Domain = pxy.cfg.SubDomain + "." + config.ServerCommonCfg.SubDomainHost
for _, location := range locations { for _, location := range locations {
routeConfig.Location = location routeConfig.Location = location
err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig) err = pxy.ctl.svr.httpReverseProxy.Register(routeConfig)
if err != nil { if err != nil {
return err return
} }
tmpDomain := routeConfig.Domain tmpDomain := routeConfig.Domain
tmpLocation := routeConfig.Location tmpLocation := routeConfig.Location
addrs = append(addrs, util.CanonicalAddr(tmpDomain, int(config.ServerCommonCfg.VhostHttpPort)))
pxy.closeFuncs = append(pxy.closeFuncs, func() { pxy.closeFuncs = append(pxy.closeFuncs, func() {
pxy.ctl.svr.httpReverseProxy.UnRegister(tmpDomain, tmpLocation) pxy.ctl.svr.httpReverseProxy.UnRegister(tmpDomain, tmpLocation)
}) })
pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location) pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location)
} }
} }
remoteAddr = strings.Join(addrs, ",")
return return
} }
@ -279,32 +302,38 @@ type HttpsProxy struct {
cfg *config.HttpsProxyConf cfg *config.HttpsProxyConf
} }
func (pxy *HttpsProxy) Run() (err error) { func (pxy *HttpsProxy) Run() (remoteAddr string, err error) {
routeConfig := &vhost.VhostRouteConfig{} routeConfig := &vhost.VhostRouteConfig{}
addrs := make([]string, 0)
for _, domain := range pxy.cfg.CustomDomains { for _, domain := range pxy.cfg.CustomDomains {
routeConfig.Domain = domain routeConfig.Domain = domain
l, err := pxy.ctl.svr.VhostHttpsMuxer.Listen(routeConfig) l, errRet := pxy.ctl.svr.VhostHttpsMuxer.Listen(routeConfig)
if err != nil { if errRet != nil {
return err err = errRet
return
} }
l.AddLogPrefix(pxy.name) l.AddLogPrefix(pxy.name)
pxy.Info("https proxy listen for host [%s]", routeConfig.Domain) pxy.Info("https proxy listen for host [%s]", routeConfig.Domain)
pxy.listeners = append(pxy.listeners, l) pxy.listeners = append(pxy.listeners, l)
addrs = append(addrs, util.CanonicalAddr(routeConfig.Domain, int(config.ServerCommonCfg.VhostHttpsPort)))
} }
if pxy.cfg.SubDomain != "" { if pxy.cfg.SubDomain != "" {
routeConfig.Domain = pxy.cfg.SubDomain + "." + config.ServerCommonCfg.SubDomainHost routeConfig.Domain = pxy.cfg.SubDomain + "." + config.ServerCommonCfg.SubDomainHost
l, err := pxy.ctl.svr.VhostHttpsMuxer.Listen(routeConfig) l, errRet := pxy.ctl.svr.VhostHttpsMuxer.Listen(routeConfig)
if err != nil { if errRet != nil {
return err err = errRet
return
} }
l.AddLogPrefix(pxy.name) l.AddLogPrefix(pxy.name)
pxy.Info("https proxy listen for host [%s]", routeConfig.Domain) pxy.Info("https proxy listen for host [%s]", routeConfig.Domain)
pxy.listeners = append(pxy.listeners, l) pxy.listeners = append(pxy.listeners, l)
addrs = append(addrs, util.CanonicalAddr(routeConfig.Domain, int(config.ServerCommonCfg.VhostHttpsPort)))
} }
pxy.startListenHandler(pxy, HandleUserTcpConnection) pxy.startListenHandler(pxy, HandleUserTcpConnection)
remoteAddr = strings.Join(addrs, ",")
return return
} }
@ -321,17 +350,18 @@ type StcpProxy struct {
cfg *config.StcpProxyConf cfg *config.StcpProxyConf
} }
func (pxy *StcpProxy) Run() error { func (pxy *StcpProxy) Run() (remoteAddr string, err error) {
listener, err := pxy.ctl.svr.visitorManager.Listen(pxy.GetName(), pxy.cfg.Sk) listener, errRet := pxy.ctl.svr.visitorManager.Listen(pxy.GetName(), pxy.cfg.Sk)
if err != nil { if errRet != nil {
return err err = errRet
return
} }
listener.AddLogPrefix(pxy.name) listener.AddLogPrefix(pxy.name)
pxy.listeners = append(pxy.listeners, listener) pxy.listeners = append(pxy.listeners, listener)
pxy.Info("stcp proxy custom listen success") pxy.Info("stcp proxy custom listen success")
pxy.startListenHandler(pxy, HandleUserTcpConnection) pxy.startListenHandler(pxy, HandleUserTcpConnection)
return nil return
} }
func (pxy *StcpProxy) GetConf() config.ProxyConf { func (pxy *StcpProxy) GetConf() config.ProxyConf {
@ -350,10 +380,11 @@ type XtcpProxy struct {
closeCh chan struct{} closeCh chan struct{}
} }
func (pxy *XtcpProxy) Run() error { func (pxy *XtcpProxy) Run() (remoteAddr string, err error) {
if pxy.ctl.svr.natHoleController == nil { if pxy.ctl.svr.natHoleController == nil {
pxy.Error("udp port for xtcp is not specified.") pxy.Error("udp port for xtcp is not specified.")
return fmt.Errorf("xtcp is not supported in frps") err = fmt.Errorf("xtcp is not supported in frps")
return
} }
sidCh := pxy.ctl.svr.natHoleController.ListenClient(pxy.GetName(), pxy.cfg.Sk) sidCh := pxy.ctl.svr.natHoleController.ListenClient(pxy.GetName(), pxy.cfg.Sk)
go func() { go func() {
@ -362,21 +393,21 @@ func (pxy *XtcpProxy) Run() error {
case <-pxy.closeCh: case <-pxy.closeCh:
break break
case sid := <-sidCh: case sid := <-sidCh:
workConn, err := pxy.GetWorkConnFromPool() workConn, errRet := pxy.GetWorkConnFromPool()
if err != nil { if errRet != nil {
continue continue
} }
m := &msg.NatHoleSid{ m := &msg.NatHoleSid{
Sid: sid, Sid: sid,
} }
err = msg.WriteMsg(workConn, m) errRet = msg.WriteMsg(workConn, m)
if err != nil { if errRet != nil {
pxy.Warn("write nat hole sid package error, %v", err) pxy.Warn("write nat hole sid package error, %v", errRet)
} }
} }
} }
}() }()
return nil return
} }
func (pxy *XtcpProxy) GetConf() config.ProxyConf { func (pxy *XtcpProxy) GetConf() config.ProxyConf {
@ -395,6 +426,8 @@ type UdpProxy struct {
BaseProxy BaseProxy
cfg *config.UdpProxyConf cfg *config.UdpProxyConf
realPort int
// udpConn is the listener of udp packages // udpConn is the listener of udp packages
udpConn *net.UDPConn udpConn *net.UDPConn
@ -414,15 +447,29 @@ type UdpProxy struct {
isClosed bool isClosed bool
} }
func (pxy *UdpProxy) Run() (err error) { func (pxy *UdpProxy) Run() (remoteAddr string, err error) {
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort)) pxy.realPort, err = pxy.ctl.svr.udpPortManager.Acquire(pxy.name, pxy.cfg.RemotePort)
if err != nil { if err != nil {
return err return
} }
udpConn, err := net.ListenUDP("udp", addr) defer func() {
if err != nil { if err != nil {
pxy.ctl.svr.udpPortManager.Release(pxy.realPort)
}
}()
remoteAddr = fmt.Sprintf(":%d", pxy.realPort)
pxy.cfg.RemotePort = pxy.realPort
addr, errRet := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", config.ServerCommonCfg.ProxyBindAddr, pxy.realPort))
if errRet != nil {
err = errRet
return
}
udpConn, errRet := net.ListenUDP("udp", addr)
if errRet != nil {
err = errRet
pxy.Warn("listen udp port error: %v", err) pxy.Warn("listen udp port error: %v", err)
return err return
} }
pxy.Info("udp proxy listen port [%d]", pxy.cfg.RemotePort) pxy.Info("udp proxy listen port [%d]", pxy.cfg.RemotePort)
@ -537,7 +584,7 @@ func (pxy *UdpProxy) Run() (err error) {
udp.ForwardUserConn(udpConn, pxy.readCh, pxy.sendCh) udp.ForwardUserConn(udpConn, pxy.readCh, pxy.sendCh)
pxy.Close() pxy.Close()
}() }()
return nil return remoteAddr, nil
} }
func (pxy *UdpProxy) GetConf() config.ProxyConf { func (pxy *UdpProxy) GetConf() config.ProxyConf {
@ -561,6 +608,7 @@ func (pxy *UdpProxy) Close() {
close(pxy.readCh) close(pxy.readCh)
close(pxy.sendCh) close(pxy.sendCh)
} }
pxy.ctl.svr.udpPortManager.Release(pxy.realPort)
} }
// HandleUserTcpConnection is used for incoming tcp user connections. // HandleUserTcpConnection is used for incoming tcp user connections.

View File

@ -60,17 +60,25 @@ type Service struct {
// Manage all visitor listeners. // Manage all visitor listeners.
visitorManager *VisitorManager visitorManager *VisitorManager
// Manage all tcp ports.
tcpPortManager *PortManager
// Manage all udp ports.
udpPortManager *PortManager
// Controller for nat hole connections. // Controller for nat hole connections.
natHoleController *NatHoleController natHoleController *NatHoleController
} }
func NewService() (svr *Service, err error) { func NewService() (svr *Service, err error) {
cfg := config.ServerCommonCfg
svr = &Service{ svr = &Service{
ctlManager: NewControlManager(), ctlManager: NewControlManager(),
pxyManager: NewProxyManager(), pxyManager: NewProxyManager(),
visitorManager: NewVisitorManager(), visitorManager: NewVisitorManager(),
tcpPortManager: NewPortManager("tcp", cfg.ProxyBindAddr, cfg.PrivilegeAllowPorts),
udpPortManager: NewPortManager("udp", cfg.ProxyBindAddr, cfg.PrivilegeAllowPorts),
} }
cfg := config.ServerCommonCfg
// Init assets. // Init assets.
err = assets.Load(cfg.AssetsDir) err = assets.Load(cfg.AssetsDir)
@ -283,7 +291,7 @@ func (svr *Service) RegisterControl(ctlConn frpNet.Conn, loginMsg *msg.Login) (e
ctl := NewControl(svr, ctlConn, loginMsg) ctl := NewControl(svr, ctlConn, loginMsg)
if oldCtl := svr.ctlManager.Add(loginMsg.RunId, ctl); oldCtl != nil { if oldCtl := svr.ctlManager.Add(loginMsg.RunId, ctl); oldCtl != nil {
oldCtl.allShutdown.WaitDown() oldCtl.allShutdown.WaitDone()
} }
ctlConn.AddLogPrefix(loginMsg.RunId) ctlConn.AddLogPrefix(loginMsg.RunId)

View File

@ -10,5 +10,11 @@ if [ -n "${pid}" ]; then
kill ${pid} kill ${pid}
fi fi
pid=`ps aux|grep './../bin/frpc -c ./conf/auto_test_frpc_visitor.ini'|grep -v grep|awk {'print $2'}`
if [ -n "${pid}" ]; then
kill ${pid}
fi
rm -f ./frps.log rm -f ./frps.log
rm -f ./frpc.log rm -f ./frpc.log
rm -f ./frpc_visitor.log

View File

@ -6,30 +6,96 @@ log_file = ./frpc.log
log_level = debug log_level = debug
privilege_token = 123456 privilege_token = 123456
[echo] [tcp_normal]
type = tcp type = tcp
local_ip = 127.0.0.1 local_ip = 127.0.0.1
local_port = 10701 local_port = 10701
remote_port = 10711 remote_port = 10801
use_encryption = true
use_compression = true
[web] [tcp_ec]
type = http type = tcp
local_ip = 127.0.0.1 local_ip = 127.0.0.1
local_port = 10702 local_port = 10701
remote_port = 10901
use_encryption = true use_encryption = true
use_compression = true use_compression = true
custom_domains = 127.0.0.1
[udp] [udp_normal]
type = udp type = udp
local_ip = 127.0.0.1 local_ip = 127.0.0.1
local_port = 10703 local_port = 10702
remote_port = 10712 remote_port = 10802
[udp_ec]
type = udp
local_ip = 127.0.0.1
local_port = 10702
remote_port = 10902
use_encryption = true
use_compression = true
[unix_domain] [unix_domain]
type = tcp type = tcp
remote_port = 10704 remote_port = 10803
plugin = unix_domain_socket plugin = unix_domain_socket
plugin_unix_path = /tmp/frp_echo_server.sock plugin_unix_path = /tmp/frp_echo_server.sock
[stcp]
type = stcp
sk = abcdefg
local_ip = 127.0.0.1
local_port = 10701
[stcp_ec]
type = stcp
sk = abc
local_ip = 127.0.0.1
local_port = 10701
use_encryption = true
use_compression = true
[web01]
type = http
local_ip = 127.0.0.1
local_port = 10704
custom_domains = 127.0.0.1
[web02]
type = http
local_ip = 127.0.0.1
local_port = 10704
custom_domains = test2.frp.com
host_header_rewrite = test2.frp.com
use_encryption = true
use_compression = true
[web03]
type = http
local_ip = 127.0.0.1
local_port = 10704
custom_domains = test3.frp.com
use_encryption = true
use_compression = true
host_header_rewrite = test3.frp.com
locations = /,/foo
[web04]
type = http
local_ip = 127.0.0.1
local_port = 10704
custom_domains = test3.frp.com
use_encryption = true
use_compression = true
host_header_rewrite = test3.frp.com
locations = /bar
[web05]
type = http
local_ip = 127.0.0.1
local_port = 10704
custom_domains = test5.frp.com
host_header_rewrite = test5.frp.com
use_encryption = true
use_compression = true
http_user = test
http_user = test

View File

@ -0,0 +1,25 @@
[common]
server_addr = 0.0.0.0
server_port = 10700
log_file = ./frpc_visitor.log
# debug, info, warn, error
log_level = debug
privilege_token = 123456
[stcp_visitor]
type = stcp
role = visitor
server_name = stcp
sk = abcdefg
bind_addr = 127.0.0.1
bind_port = 10805
[stcp_ec_visitor]
type = stcp
role = visitor
server_name = stcp_ec
sk = abc
bind_addr = 127.0.0.1
bind_port = 10905
use_encryption = true
use_compression = true

View File

@ -1,7 +1,7 @@
[common] [common]
bind_addr = 0.0.0.0 bind_addr = 0.0.0.0
bind_port = 10700 bind_port = 10700
vhost_http_port = 10710 vhost_http_port = 10804
log_file = ./frps.log log_file = ./frps.log
log_level = debug log_level = debug
privilege_token = 123456 privilege_token = 123456

View File

@ -1,7 +1,6 @@
package tests package tests
import ( import (
"bufio"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -11,8 +10,8 @@ import (
frpNet "github.com/fatedier/frp/utils/net" frpNet "github.com/fatedier/frp/utils/net"
) )
func StartEchoServer() { func StartTcpEchoServer() {
l, err := frpNet.ListenTcp("127.0.0.1", 10701) l, err := frpNet.ListenTcp("127.0.0.1", TEST_TCP_PORT)
if err != nil { if err != nil {
fmt.Printf("echo server listen error: %v\n", err) fmt.Printf("echo server listen error: %v\n", err)
return return
@ -30,7 +29,7 @@ func StartEchoServer() {
} }
func StartUdpEchoServer() { func StartUdpEchoServer() {
l, err := frpNet.ListenUDP("127.0.0.1", 10703) l, err := frpNet.ListenUDP("127.0.0.1", TEST_UDP_PORT)
if err != nil { if err != nil {
fmt.Printf("udp echo server listen error: %v\n", err) fmt.Printf("udp echo server listen error: %v\n", err)
return return
@ -48,7 +47,7 @@ func StartUdpEchoServer() {
} }
func StartUnixDomainServer() { func StartUnixDomainServer() {
unixPath := "/tmp/frp_echo_server.sock" unixPath := TEST_UNIX_DOMAIN_ADDR
os.Remove(unixPath) os.Remove(unixPath)
syscall.Umask(0) syscall.Umask(0)
l, err := net.Listen("unix", unixPath) l, err := net.Listen("unix", unixPath)
@ -69,17 +68,20 @@ func StartUnixDomainServer() {
} }
func echoWorker(c net.Conn) { func echoWorker(c net.Conn) {
br := bufio.NewReader(c) buf := make([]byte, 2048)
for { for {
buf, err := br.ReadString('\n') n, err := c.Read(buf)
if err == io.EOF {
break
}
if err != nil { if err != nil {
if err == io.EOF {
c.Close()
break
} else {
fmt.Printf("echo server read error: %v\n", err) fmt.Printf("echo server read error: %v\n", err)
return return
} }
}
c.Write([]byte(buf + "\n")) c.Write(buf[:n])
} }
} }

View File

@ -1,119 +1,157 @@
package tests package tests
import ( import (
"bufio"
"bytes"
"fmt" "fmt"
"io/ioutil"
"net"
"net/http"
"strings"
"testing" "testing"
"time" "time"
frpNet "github.com/fatedier/frp/utils/net" "github.com/stretchr/testify/assert"
) )
var ( var (
ECHO_PORT int64 = 10711 TEST_STR = "frp is a fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet."
UDP_ECHO_PORT int64 = 10712 TEST_TCP_PORT int = 10701
HTTP_PORT int64 = 10710 TEST_TCP_FRP_PORT int = 10801
ECHO_TEST_STR string = "Hello World\n" TEST_TCP_EC_FRP_PORT int = 10901
HTTP_RES_STR string = "Hello World" TEST_TCP_ECHO_STR string = "tcp type:" + TEST_STR
TEST_UDP_PORT int = 10702
TEST_UDP_FRP_PORT int = 10802
TEST_UDP_EC_FRP_PORT int = 10902
TEST_UDP_ECHO_STR string = "udp type:" + TEST_STR
TEST_UNIX_DOMAIN_ADDR string = "/tmp/frp_echo_server.sock"
TEST_UNIX_DOMAIN_FRP_PORT int = 10803
TEST_UNIX_DOMAIN_STR string = "unix domain type:" + TEST_STR
TEST_HTTP_PORT int = 10704
TEST_HTTP_FRP_PORT int = 10804
TEST_HTTP_NORMAL_STR string = "http normal string: " + TEST_STR
TEST_HTTP_FOO_STR string = "http foo string: " + TEST_STR
TEST_HTTP_BAR_STR string = "http bar string: " + TEST_STR
TEST_STCP_FRP_PORT int = 10805
TEST_STCP_EC_FRP_PORT int = 10905
TEST_STCP_ECHO_STR string = "stcp type:" + TEST_STR
) )
func init() { func init() {
go StartEchoServer() go StartTcpEchoServer()
go StartUdpEchoServer() go StartUdpEchoServer()
go StartHttpServer()
go StartUnixDomainServer() go StartUnixDomainServer()
go StartHttpServer()
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
} }
func TestEchoServer(t *testing.T) { func TestTcp(t *testing.T) {
c, err := frpNet.ConnectTcpServer(fmt.Sprintf("127.0.0.1:%d", ECHO_PORT)) assert := assert.New(t)
if err != nil { // Normal
t.Fatalf("connect to echo server error: %v", err) addr := fmt.Sprintf("127.0.0.1:%d", TEST_TCP_FRP_PORT)
} res, err := sendTcpMsg(addr, TEST_TCP_ECHO_STR)
timer := time.Now().Add(time.Duration(5) * time.Second) assert.NoError(err)
c.SetDeadline(timer) assert.Equal(TEST_TCP_ECHO_STR, res)
c.Write([]byte(ECHO_TEST_STR + "\n")) // Encrytion and compression
addr = fmt.Sprintf("127.0.0.1:%d", TEST_TCP_EC_FRP_PORT)
br := bufio.NewReader(c) res, err = sendTcpMsg(addr, TEST_TCP_ECHO_STR)
buf, err := br.ReadString('\n') assert.NoError(err)
if err != nil { assert.Equal(TEST_TCP_ECHO_STR, res)
t.Fatalf("read from echo server error: %v", err)
} }
if ECHO_TEST_STR != buf { func TestUdp(t *testing.T) {
t.Fatalf("content error, send [%s], get [%s]", strings.Trim(ECHO_TEST_STR, "\n"), strings.Trim(buf, "\n")) assert := assert.New(t)
// Normal
addr := fmt.Sprintf("127.0.0.1:%d", TEST_UDP_FRP_PORT)
res, err := sendUdpMsg(addr, TEST_UDP_ECHO_STR)
assert.NoError(err)
assert.Equal(TEST_UDP_ECHO_STR, res)
// Encrytion and compression
addr = fmt.Sprintf("127.0.0.1:%d", TEST_UDP_EC_FRP_PORT)
res, err = sendUdpMsg(addr, TEST_UDP_ECHO_STR)
assert.NoError(err)
assert.Equal(TEST_UDP_ECHO_STR, res)
}
func TestUnixDomain(t *testing.T) {
assert := assert.New(t)
// Normal
addr := fmt.Sprintf("127.0.0.1:%d", TEST_UNIX_DOMAIN_FRP_PORT)
res, err := sendTcpMsg(addr, TEST_UNIX_DOMAIN_STR)
if assert.NoError(err) {
assert.Equal(TEST_UNIX_DOMAIN_STR, res)
} }
} }
func TestHttpServer(t *testing.T) { func TestStcp(t *testing.T) {
client := &http.Client{} assert := assert.New(t)
req, _ := http.NewRequest("GET", fmt.Sprintf("http://127.0.0.1:%d", HTTP_PORT), nil) // Normal
res, err := client.Do(req) addr := fmt.Sprintf("127.0.0.1:%d", TEST_STCP_FRP_PORT)
if err != nil { res, err := sendTcpMsg(addr, TEST_STCP_ECHO_STR)
t.Fatalf("do http request error: %v", err) if assert.NoError(err) {
assert.Equal(TEST_STCP_ECHO_STR, res)
} }
if res.StatusCode == 200 {
body, err := ioutil.ReadAll(res.Body) // Encrytion and compression
if err != nil { addr = fmt.Sprintf("127.0.0.1:%d", TEST_STCP_EC_FRP_PORT)
t.Fatalf("read from http server error: %v", err) res, err = sendTcpMsg(addr, TEST_STCP_ECHO_STR)
} if assert.NoError(err) {
bodystr := string(body) assert.Equal(TEST_STCP_ECHO_STR, res)
if bodystr != HTTP_RES_STR {
t.Fatalf("content from http server error [%s], correct string is [%s]", bodystr, HTTP_RES_STR)
}
} else {
t.Fatalf("http code from http server error [%d]", res.StatusCode)
} }
} }
func TestUdpEchoServer(t *testing.T) { func TestHttp(t *testing.T) {
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:10712") assert := assert.New(t)
if err != nil { // web01
t.Fatalf("do udp request error: %v", err) code, body, err := sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "", nil)
} if assert.NoError(err) {
conn, err := net.DialUDP("udp", nil, addr) assert.Equal(200, code)
if err != nil { assert.Equal(TEST_HTTP_NORMAL_STR, body)
t.Fatalf("dial udp server error: %v", err)
}
defer conn.Close()
_, err = conn.Write([]byte("hello frp\n"))
if err != nil {
t.Fatalf("write to udp server error: %v", err)
}
data := make([]byte, 20)
n, err := conn.Read(data)
if err != nil {
t.Fatalf("read from udp server error: %v", err)
} }
if string(bytes.TrimSpace(data[:n])) != "hello frp" { // web02
t.Fatalf("message got from udp server error, get %s", string(data[:n-1])) code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test2.frp.com", nil)
} if assert.NoError(err) {
assert.Equal(200, code)
assert.Equal(TEST_HTTP_NORMAL_STR, body)
} }
func TestUnixDomainServer(t *testing.T) { // error host header
c, err := frpNet.ConnectTcpServer(fmt.Sprintf("127.0.0.1:%d", 10704)) code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "errorhost.frp.com", nil)
if err != nil { if assert.NoError(err) {
t.Fatalf("connect to echo server error: %v", err) assert.Equal(404, code)
}
timer := time.Now().Add(time.Duration(5) * time.Second)
c.SetDeadline(timer)
c.Write([]byte(ECHO_TEST_STR + "\n"))
br := bufio.NewReader(c)
buf, err := br.ReadString('\n')
if err != nil {
t.Fatalf("read from echo server error: %v", err)
} }
if ECHO_TEST_STR != buf { // web03
t.Fatalf("content error, send [%s], get [%s]", strings.Trim(ECHO_TEST_STR, "\n"), strings.Trim(buf, "\n")) code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test3.frp.com", nil)
if assert.NoError(err) {
assert.Equal(200, code)
assert.Equal(TEST_HTTP_NORMAL_STR, body)
}
code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d/foo", TEST_HTTP_FRP_PORT), "test3.frp.com", nil)
if assert.NoError(err) {
assert.Equal(200, code)
assert.Equal(TEST_HTTP_FOO_STR, body)
}
// web04
code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d/bar", TEST_HTTP_FRP_PORT), "test3.frp.com", nil)
if assert.NoError(err) {
assert.Equal(200, code)
assert.Equal(TEST_HTTP_BAR_STR, body)
}
// web05
code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test5.frp.com", nil)
if assert.NoError(err) {
assert.Equal(401, code)
}
header := make(map[string]string)
header["Authorization"] = basicAuth("test", "test")
code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test5.frp.com", header)
if assert.NoError(err) {
assert.Equal(401, code)
} }
} }

View File

@ -3,13 +3,30 @@ package tests
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"strings"
) )
func StartHttpServer() { func StartHttpServer() {
http.HandleFunc("/", request) http.HandleFunc("/", request)
http.ListenAndServe(fmt.Sprintf("0.0.0.0:%d", 10702), nil) http.ListenAndServe(fmt.Sprintf("0.0.0.0:%d", TEST_HTTP_PORT), nil)
} }
func request(w http.ResponseWriter, r *http.Request) { func request(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(HTTP_RES_STR)) if strings.Contains(r.Host, "127.0.0.1") || strings.Contains(r.Host, "test2.frp.com") ||
strings.Contains(r.Host, "test5.frp.com") {
w.WriteHeader(200)
w.Write([]byte(TEST_HTTP_NORMAL_STR))
} else if strings.Contains(r.Host, "test3.frp.com") {
w.WriteHeader(200)
if strings.Contains(r.URL.Path, "foo") {
w.Write([]byte(TEST_HTTP_FOO_STR))
} else if strings.Contains(r.URL.Path, "bar") {
w.Write([]byte(TEST_HTTP_BAR_STR))
} else {
w.Write([]byte(TEST_HTTP_NORMAL_STR))
}
} else {
w.WriteHeader(404)
}
return
} }

View File

@ -3,6 +3,7 @@
./../bin/frps -c ./conf/auto_test_frps.ini & ./../bin/frps -c ./conf/auto_test_frps.ini &
sleep 1 sleep 1
./../bin/frpc -c ./conf/auto_test_frpc.ini & ./../bin/frpc -c ./conf/auto_test_frpc.ini &
./../bin/frpc -c ./conf/auto_test_frpc_visitor.ini &
# wait until proxies are connected # wait until proxies are connected
sleep 2 sleep 2

93
tests/util.go Normal file
View File

@ -0,0 +1,93 @@
package tests
import (
"encoding/base64"
"fmt"
"io/ioutil"
"net"
"net/http"
"time"
frpNet "github.com/fatedier/frp/utils/net"
)
func sendTcpMsg(addr string, msg string) (res string, err error) {
c, err := frpNet.ConnectTcpServer(addr)
if err != nil {
err = fmt.Errorf("connect to tcp server error: %v", err)
return
}
defer c.Close()
timer := time.Now().Add(5 * time.Second)
c.SetDeadline(timer)
c.Write([]byte(msg))
buf := make([]byte, 2048)
n, errRet := c.Read(buf)
if errRet != nil {
err = fmt.Errorf("read from tcp server error: %v", errRet)
return
}
return string(buf[:n]), nil
}
func sendUdpMsg(addr string, msg string) (res string, err error) {
udpAddr, errRet := net.ResolveUDPAddr("udp", addr)
if errRet != nil {
err = fmt.Errorf("resolve udp addr error: %v", err)
return
}
conn, errRet := net.DialUDP("udp", nil, udpAddr)
if errRet != nil {
err = fmt.Errorf("dial udp server error: %v", err)
return
}
defer conn.Close()
_, err = conn.Write([]byte(msg))
if err != nil {
err = fmt.Errorf("write to udp server error: %v", err)
return
}
buf := make([]byte, 2048)
n, errRet := conn.Read(buf)
if errRet != nil {
err = fmt.Errorf("read from udp server error: %v", err)
return
}
return string(buf[:n]), nil
}
func sendHttpMsg(method, url string, host string, header map[string]string) (code int, body string, err error) {
req, errRet := http.NewRequest(method, url, nil)
if errRet != nil {
err = errRet
return
}
if host != "" {
req.Host = host
}
for k, v := range header {
req.Header.Set(k, v)
}
resp, errRet := http.DefaultClient.Do(req)
if errRet != nil {
err = errRet
return
}
code = resp.StatusCode
buf, errRet := ioutil.ReadAll(resp.Body)
if errRet != nil {
err = errRet
return
}
body = string(buf)
return
}
func basicAuth(username, passwd string) string {
auth := username + ":" + passwd
return "Basic " + base64.StdEncoding.EncodeToString([]byte(auth))
}

View File

@ -31,7 +31,7 @@ type KcpListener struct {
log.Logger log.Logger
} }
func ListenKcp(bindAddr string, bindPort int64) (l *KcpListener, err error) { func ListenKcp(bindAddr string, bindPort int) (l *KcpListener, err error) {
listener, err := kcp.ListenWithOptions(fmt.Sprintf("%s:%d", bindAddr, bindPort), nil, 10, 3) listener, err := kcp.ListenWithOptions(fmt.Sprintf("%s:%d", bindAddr, bindPort), nil, 10, 3)
if err != nil { if err != nil {
return l, err return l, err

View File

@ -33,7 +33,7 @@ type TcpListener struct {
log.Logger log.Logger
} }
func ListenTcp(bindAddr string, bindPort int64) (l *TcpListener, err error) { func ListenTcp(bindAddr string, bindPort int) (l *TcpListener, err error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) tcpAddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
if err != nil { if err != nil {
return l, err return l, err

View File

@ -167,7 +167,7 @@ type UdpListener struct {
log.Logger log.Logger
} }
func ListenUDP(bindAddr string, bindPort int64) (l *UdpListener, err error) { func ListenUDP(bindAddr string, bindPort int) (l *UdpListener, err error) {
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
if err != nil { if err != nil {
return l, err return l, err

View File

@ -21,8 +21,8 @@ import (
type Shutdown struct { type Shutdown struct {
doing bool doing bool
ending bool ending bool
start chan struct{} startCh chan struct{}
down chan struct{} doneCh chan struct{}
mu sync.Mutex mu sync.Mutex
} }
@ -30,8 +30,8 @@ func New() *Shutdown {
return &Shutdown{ return &Shutdown{
doing: false, doing: false,
ending: false, ending: false,
start: make(chan struct{}), startCh: make(chan struct{}),
down: make(chan struct{}), doneCh: make(chan struct{}),
} }
} }
@ -40,12 +40,12 @@ func (s *Shutdown) Start() {
defer s.mu.Unlock() defer s.mu.Unlock()
if !s.doing { if !s.doing {
s.doing = true s.doing = true
close(s.start) close(s.startCh)
} }
} }
func (s *Shutdown) WaitStart() { func (s *Shutdown) WaitStart() {
<-s.start <-s.startCh
} }
func (s *Shutdown) Done() { func (s *Shutdown) Done() {
@ -53,10 +53,10 @@ func (s *Shutdown) Done() {
defer s.mu.Unlock() defer s.mu.Unlock()
if !s.ending { if !s.ending {
s.ending = true s.ending = true
close(s.down) close(s.doneCh)
} }
} }
func (s *Shutdown) WaitDown() { func (s *Shutdown) WaitDone() {
<-s.down <-s.doneCh
} }

View File

@ -17,5 +17,5 @@ func TestShutdown(t *testing.T) {
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
s.Done() s.Done()
}() }()
s.WaitDown() s.WaitDone()
} }

View File

@ -19,8 +19,6 @@ import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"strconv"
"strings"
) )
// RandId return a rand string used in frp. // RandId return a rand string used in frp.
@ -48,65 +46,11 @@ func GetAuthKey(token string, timestamp int64) (key string) {
return hex.EncodeToString(data) return hex.EncodeToString(data)
} }
// for example: rangeStr is "1000-2000,2001,2002,3000-4000", return an array as port ranges. func CanonicalAddr(host string, port int) (addr string) {
func GetPortRanges(rangeStr string) (portRanges [][2]int64, err error) { if port == 80 || port == 443 {
// for example: 1000-2000,2001,2002,3000-4000 addr = host
rangeArray := strings.Split(rangeStr, ",")
for _, portRangeStr := range rangeArray {
// 1000-2000 or 2001
portArray := strings.Split(portRangeStr, "-")
// length: only 1 or 2 is correct
rangeType := len(portArray)
if rangeType == 1 {
singlePort, err := strconv.ParseInt(portArray[0], 10, 64)
if err != nil {
return [][2]int64{}, err
}
portRanges = append(portRanges, [2]int64{singlePort, singlePort})
} else if rangeType == 2 {
min, err := strconv.ParseInt(portArray[0], 10, 64)
if err != nil {
return [][2]int64{}, err
}
max, err := strconv.ParseInt(portArray[1], 10, 64)
if err != nil {
return [][2]int64{}, err
}
if max < min {
return [][2]int64{}, fmt.Errorf("range incorrect")
}
portRanges = append(portRanges, [2]int64{min, max})
} else { } else {
return [][2]int64{}, fmt.Errorf("format error") addr = fmt.Sprintf("%s:%d", host, port)
} }
} return
return portRanges, nil
}
func ContainsPort(portRanges [][2]int64, port int64) bool {
for _, pr := range portRanges {
if port >= pr[0] && port <= pr[1] {
return true
}
}
return false
}
func PortRangesCut(portRanges [][2]int64, port int64) [][2]int64 {
var tmpRanges [][2]int64
for _, pr := range portRanges {
if port >= pr[0] && port <= pr[1] {
leftRange := [2]int64{pr[0], port - 1}
rightRange := [2]int64{port + 1, pr[1]}
if leftRange[0] <= leftRange[1] {
tmpRanges = append(tmpRanges, leftRange)
}
if rightRange[0] <= rightRange[1] {
tmpRanges = append(tmpRanges, rightRange)
}
} else {
tmpRanges = append(tmpRanges, pr)
}
}
return tmpRanges
} }

View File

@ -20,67 +20,3 @@ func TestGetAuthKey(t *testing.T) {
t.Log(key) t.Log(key)
assert.Equal("6df41a43725f0c770fd56379e12acf8c", key) assert.Equal("6df41a43725f0c770fd56379e12acf8c", key)
} }
func TestGetPortRanges(t *testing.T) {
assert := assert.New(t)
rangesStr := "2000-3000,3001,4000-50000"
expect := [][2]int64{
[2]int64{2000, 3000},
[2]int64{3001, 3001},
[2]int64{4000, 50000},
}
actual, err := GetPortRanges(rangesStr)
assert.Nil(err)
t.Log(actual)
assert.Equal(expect, actual)
}
func TestContainsPort(t *testing.T) {
assert := assert.New(t)
rangesStr := "2000-3000,3001,4000-50000"
portRanges, err := GetPortRanges(rangesStr)
assert.Nil(err)
type Case struct {
Port int64
Answer bool
}
cases := []Case{
Case{
Port: 3001,
Answer: true,
},
Case{
Port: 3002,
Answer: false,
},
Case{
Port: 44444,
Answer: true,
},
}
for _, elem := range cases {
ok := ContainsPort(portRanges, elem.Port)
assert.Equal(elem.Answer, ok)
}
}
func TestPortRangesCut(t *testing.T) {
assert := assert.New(t)
rangesStr := "2000-3000,3001,4000-50000"
portRanges, err := GetPortRanges(rangesStr)
assert.Nil(err)
expect := [][2]int64{
[2]int64{2000, 3000},
[2]int64{3001, 3001},
[2]int64{4000, 44443},
[2]int64{44445, 50000},
}
actual := PortRangesCut(portRanges, 44444)
t.Log(actual)
assert.Equal(expect, actual)
}

10
vendor/github.com/rodaine/table/.travis.yml generated vendored Normal file
View File

@ -0,0 +1,10 @@
sudo: false
language: go
go: 1.8
branches:
only:
- master
install: go get -t ./... github.com/golang/lint/golint
script: make lint test

9
vendor/github.com/rodaine/table/license generated vendored Normal file
View File

@ -0,0 +1,9 @@
The MIT License (MIT)
Copyright (c) 2015 Chris Roche (rodaine+github@gmail.com)
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

9
vendor/github.com/rodaine/table/makefile generated vendored Normal file
View File

@ -0,0 +1,9 @@
.PHONY: lint
lint:
gofmt -d -s .
golint -set_exit_status ./...
go tool vet -all -shadow -shadowstrict .
.PHONY: test
test:
go test -v -cover -race ./...

61
vendor/github.com/rodaine/table/readme.md generated vendored Normal file
View File

@ -0,0 +1,61 @@
# table <br/> [![GoDoc](https://godoc.org/github.com/rodaine/table?status.svg)](https://godoc.org/github.com/rodaine/table) [![Build Status](https://travis-ci.org/rodaine/table.svg)](https://travis-ci.org/rodaine/table)
![Example Table Output With ANSI Colors](http://res.cloudinary.com/rodaine/image/upload/v1442524799/go-table-example0.png)
Package table provides a convenient way to generate tabular output of any data, primarily useful for CLI tools.
## Features
- Accepts all data types (`string`, `int`, `interface{}`, everything!) and will use the `String() string` method of a type if available.
- Can specify custom formatting for the header and first column cells for better readability.
- Columns are left-aligned and sized to fit the data, with customizable padding.
- The printed output can be sent to any `io.Writer`, defaulting to `os.Stdout`.
- Built to an interface, so you can roll your own `Table` implementation.
- Works well with ANSI colors ([fatih/color](https://github.com/fatih/color) in the example)!
- Can provide a custom `WidthFunc` to accomodate multi- and zero-width characters (such as [runewidth](https://github.com/mattn/go-runewidth))
## Usage
**Download the package:**
```sh
go get -u github.com/rodaine/table
```
**Example:**
```go
package main
import (
"fmt"
"strings"
"github.com/fatih/color"
"github.com/rodaine/table"
)
func main() {
headerFmt := color.New(color.FgGreen, color.Underline).SprintfFunc()
columnFmt := color.New(color.FgYellow).SprintfFunc()
tbl := table.New("ID", "Name", "Score", "Added")
tbl.WithHeaderFormatter(headerFmt).WithFirstColumnFormatter(columnFmt)
for _, widget := range getWidgets() {
tbl.AddRow(widget.ID, widget.Name, widget.Cost, widget.Added)
}
tbl.Print()
}
```
_Consult the [documentation](https://godoc.org/github.com/rodaine/table) for further examples and usage information_
## Contributing
Please feel free to submit an [issue](https://github.com/rodaine/table/issues) or [PR](https://github.com/rodaine/table/pulls) to this repository for features or bugs. All submitted code must pass the scripts specified within [.travis.yml](https://github.com/rodaine/table/blob/master/.travis.yml) and should include tests to back up the changes.
## License
table is released under the MIT License (Expat). See the [full license](https://github.com/rodaine/table/blob/master/license).

267
vendor/github.com/rodaine/table/table.go generated vendored Normal file
View File

@ -0,0 +1,267 @@
// Package table provides a convenient way to generate tabular output of any
// data, primarily useful for CLI tools.
//
// Columns are left-aligned and padded to accomodate the largest cell in that
// column.
//
// Source: https://github.com/rodaine/table
//
// table.DefaultHeaderFormatter = func(format string, vals ...interface{}) string {
// return strings.ToUpper(fmt.Sprintf(format, vals...))
// }
//
// tbl := table.New("ID", "Name", "Cost ($)")
//
// for _, widget := range Widgets {
// tbl.AddRow(widget.ID, widget.Name, widget.Cost)
// }
//
// tbl.Print()
//
// // Output:
// // ID NAME COST ($)
// // 1 Foobar 1.23
// // 2 Fizzbuzz 4.56
// // 3 Gizmo 78.90
package table
import (
"fmt"
"io"
"os"
"strings"
"unicode/utf8"
)
// These are the default properties for all Tables created from this package
// and can be modified.
var (
// DefaultPadding specifies the number of spaces between columns in a table.
DefaultPadding = 2
// DefaultWriter specifies the output io.Writer for the Table.Print method.
DefaultWriter io.Writer = os.Stdout
// DefaultHeaderFormatter specifies the default Formatter for the table header.
DefaultHeaderFormatter Formatter
// DefaultFirstColumnFormatter specifies the default Formatter for the first column cells.
DefaultFirstColumnFormatter Formatter
// DefaultWidthFunc specifies the default WidthFunc for calculating column widths
DefaultWidthFunc WidthFunc = utf8.RuneCountInString
)
// Formatter functions expose a fmt.Sprintf signature that can be used to modify
// the display of the text in either the header or first column of a Table.
// The formatter should not change the width of original text as printed since
// column widths are calculated pre-formatting (though this issue can be mitigated
// with increased padding).
//
// tbl.WithHeaderFormatter(func(format string, vals ...interface{}) string {
// return strings.ToUpper(fmt.Sprintf(format, vals...))
// })
//
// A good use case for formatters is to use ANSI escape codes to color the cells
// for a nicer interface. The package color (https://github.com/fatih/color) makes
// it easy to generate these automatically: http://godoc.org/github.com/fatih/color#Color.SprintfFunc
type Formatter func(string, ...interface{}) string
// A WidthFunc calculates the width of a string. By default, the number of runes
// is used but this may not be appropriate for certain character sets. The
// package runewidth (https://github.com/mattn/go-runewidth) could be used to
// accomodate multi-cell characters (such as emoji or CJK characters).
type WidthFunc func(string) int
// Table describes the interface for building up a tabular representation of data.
// It exposes fluent/chainable methods for convenient table building.
//
// WithHeaderFormatter and WithFirstColumnFormatter sets the Formatter for the
// header and first column, respectively. If nil is passed in (the default), no
// formatting will be applied.
//
// New("foo", "bar").WithFirstColumnFormatter(func(f string, v ...interface{}) string {
// return strings.ToUpper(fmt.Sprintf(f, v...))
// })
//
// WithPadding specifies the minimum padding between cells in a row and defaults
// to DefaultPadding. Padding values less than or equal to zero apply no extra
// padding between the columns.
//
// New("foo", "bar").WithPadding(3)
//
// WithWriter modifies the writer which Print outputs to, defaulting to DefaultWriter
// when instantiated. If nil is passed, os.Stdout will be used.
//
// New("foo", "bar").WithWriter(os.Stderr)
//
// WithWidthFunc sets the function used to calculate the width of the string in
// a column. By default, the number of utf8 runes in the string is used.
//
// AddRow adds another row of data to the table. Any values can be passed in and
// will be output as its string representation as described in the fmt standard
// package. Rows can have less cells than the total number of columns in the table;
// subsequent cells will be rendered empty. Rows with more cells than the total
// number of columns will be truncated. References to the data are not held, so
// the passed in values can be modified without affecting the table's output.
//
// New("foo", "bar").AddRow("fizz", "buzz").AddRow(time.Now()).AddRow(1, 2, 3).Print()
// // Output:
// // foo bar
// // fizz buzz
// // 2006-01-02 15:04:05.0 -0700 MST
// // 1 2
//
// Print writes the string representation of the table to the provided writer.
// Print can be called multiple times, even after subsequent mutations of the
// provided data. The output is always preceded and followed by a new line.
type Table interface {
WithHeaderFormatter(f Formatter) Table
WithFirstColumnFormatter(f Formatter) Table
WithPadding(p int) Table
WithWriter(w io.Writer) Table
WithWidthFunc(f WidthFunc) Table
AddRow(vals ...interface{}) Table
Print()
}
// New creates a Table instance with the specified header(s) provided. The number
// of columns is fixed at this point to len(columnHeaders) and the defined defaults
// are set on the instance.
func New(columnHeaders ...interface{}) Table {
t := table{header: make([]string, len(columnHeaders))}
t.WithPadding(DefaultPadding)
t.WithWriter(DefaultWriter)
t.WithHeaderFormatter(DefaultHeaderFormatter)
t.WithFirstColumnFormatter(DefaultFirstColumnFormatter)
t.WithWidthFunc(DefaultWidthFunc)
for i, col := range columnHeaders {
t.header[i] = fmt.Sprint(col)
}
return &t
}
type table struct {
FirstColumnFormatter Formatter
HeaderFormatter Formatter
Padding int
Writer io.Writer
Width WidthFunc
header []string
rows [][]string
widths []int
}
func (t *table) WithHeaderFormatter(f Formatter) Table {
t.HeaderFormatter = f
return t
}
func (t *table) WithFirstColumnFormatter(f Formatter) Table {
t.FirstColumnFormatter = f
return t
}
func (t *table) WithPadding(p int) Table {
if p < 0 {
p = 0
}
t.Padding = p
return t
}
func (t *table) WithWriter(w io.Writer) Table {
if w == nil {
w = os.Stdout
}
t.Writer = w
return t
}
func (t *table) WithWidthFunc(f WidthFunc) Table {
t.Width = f
return t
}
func (t *table) AddRow(vals ...interface{}) Table {
row := make([]string, len(t.header))
for i, val := range vals {
if i >= len(t.header) {
break
}
row[i] = fmt.Sprint(val)
}
t.rows = append(t.rows, row)
return t
}
func (t *table) Print() {
format := strings.Repeat("%s", len(t.header)) + "\n"
t.calculateWidths()
fmt.Fprintln(t.Writer)
t.printHeader(format)
for _, row := range t.rows {
t.printRow(format, row)
}
}
func (t *table) printHeader(format string) {
vals := t.applyWidths(t.header, t.widths)
if t.HeaderFormatter != nil {
txt := t.HeaderFormatter(format, vals...)
fmt.Fprint(t.Writer, txt)
} else {
fmt.Fprintf(t.Writer, format, vals...)
}
}
func (t *table) printRow(format string, row []string) {
vals := t.applyWidths(row, t.widths)
if t.FirstColumnFormatter != nil {
vals[0] = t.FirstColumnFormatter("%s", vals[0])
}
fmt.Fprintf(t.Writer, format, vals...)
}
func (t *table) calculateWidths() {
t.widths = make([]int, len(t.header))
for _, row := range t.rows {
for i, v := range row {
if w := t.Width(v) + t.Padding; w > t.widths[i] {
t.widths[i] = w
}
}
}
for i, v := range t.header {
if w := t.Width(v) + t.Padding; w > t.widths[i] {
t.widths[i] = w
}
}
}
func (t *table) applyWidths(row []string, widths []int) []interface{} {
out := make([]interface{}, len(row))
for i, s := range row {
out[i] = s + t.lenOffset(s, widths[i])
}
return out
}
func (t *table) lenOffset(s string, w int) string {
l := w - t.Width(s)
if l <= 0 {
return ""
}
return strings.Repeat(" ", l)
}

181
vendor/github.com/rodaine/table/table_test.go generated vendored Normal file
View File

@ -0,0 +1,181 @@
package table
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"strings"
"testing"
"github.com/mattn/go-runewidth"
"github.com/stretchr/testify/assert"
)
func TestFormatter(t *testing.T) {
t.Parallel()
var formatter Formatter
fn := func(a string, b ...interface{}) string { return "" }
f := Formatter(fn)
assert.IsType(t, formatter, f)
}
func TestTable_New(t *testing.T) {
t.Parallel()
buf := bytes.Buffer{}
New("foo", "bar").WithWriter(&buf).Print()
out := buf.String()
assert.Contains(t, out, "foo")
assert.Contains(t, out, "bar")
buf.Reset()
New().WithWriter(&buf).Print()
out = buf.String()
assert.Empty(t, strings.TrimSpace(out))
}
func TestTable_WithHeaderFormatter(t *testing.T) {
t.Parallel()
uppercase := func(f string, v ...interface{}) string {
return strings.ToUpper(fmt.Sprintf(f, v...))
}
buf := bytes.Buffer{}
tbl := New("foo", "bar").WithWriter(&buf).WithHeaderFormatter(uppercase)
tbl.Print()
out := buf.String()
assert.Contains(t, out, "FOO")
assert.Contains(t, out, "BAR")
buf.Reset()
tbl.WithHeaderFormatter(nil).Print()
out = buf.String()
assert.Contains(t, out, "foo")
assert.Contains(t, out, "bar")
}
func TestTable_WithFirstColumnFormatter(t *testing.T) {
t.Parallel()
uppercase := func(f string, v ...interface{}) string {
return strings.ToUpper(fmt.Sprintf(f, v...))
}
buf := bytes.Buffer{}
tbl := New("foo", "bar").WithWriter(&buf).WithFirstColumnFormatter(uppercase).AddRow("fizz", "buzz")
tbl.Print()
out := buf.String()
assert.Contains(t, out, "foo")
assert.Contains(t, out, "bar")
assert.Contains(t, out, "FIZZ")
assert.Contains(t, out, "buzz")
buf.Reset()
tbl.WithFirstColumnFormatter(nil).Print()
out = buf.String()
assert.Contains(t, out, "fizz")
assert.Contains(t, out, "buzz")
}
func TestTable_WithPadding(t *testing.T) {
t.Parallel()
// zero value
buf := bytes.Buffer{}
tbl := New("foo", "bar").WithWriter(&buf).WithPadding(0)
tbl.Print()
out := buf.String()
assert.Contains(t, out, "foobar")
// positive value
buf.Reset()
tbl.WithPadding(4).Print()
out = buf.String()
assert.Contains(t, out, "foo bar ")
// negative value
buf.Reset()
tbl.WithPadding(-1).Print()
out = buf.String()
assert.Contains(t, out, "foobar")
}
func TestTable_WithWriter(t *testing.T) {
t.Parallel()
// not that we haven't been using it in all these tests but:
buf := bytes.Buffer{}
New("foo", "bar").WithWriter(&buf).Print()
assert.NotEmpty(t, buf.String())
stdout := os.Stdout
temp, _ := ioutil.TempFile("", "")
os.Stdout = temp
defer func() {
os.Stdout = stdout
temp.Close()
}()
New("foo", "bar").WithWriter(nil).Print()
temp.Seek(0, 0)
out, _ := ioutil.ReadAll(temp)
assert.NotEmpty(t, out)
}
func TestTable_AddRow(t *testing.T) {
t.Parallel()
buf := bytes.Buffer{}
tbl := New("foo", "bar").WithWriter(&buf).AddRow("fizz", "buzz")
tbl.Print()
out := buf.String()
assert.Contains(t, out, "fizz")
assert.Contains(t, out, "buzz")
lines := strings.Count(out, "\n")
// empty should add empty line
buf.Reset()
tbl.AddRow().Print()
assert.Equal(t, lines+1, strings.Count(buf.String(), "\n"))
// less than one will fill left-to-right
buf.Reset()
tbl.AddRow("cat").Print()
assert.Contains(t, buf.String(), "\ncat")
// more than initial length are truncated
buf.Reset()
tbl.AddRow("bippity", "boppity", "boo").Print()
assert.NotContains(t, buf.String(), "boo")
}
func TestTable_WithWidthFunc(t *testing.T) {
t.Parallel()
buf := bytes.Buffer{}
New("", "").
WithWriter(&buf).
WithPadding(1).
WithWidthFunc(runewidth.StringWidth).
AddRow("请求", "alpha").
AddRow("abc", "beta").
Print()
actual := buf.String()
assert.Contains(t, actual, "请求 alpha")
assert.Contains(t, actual, "abc beta")
}