diff --git a/client/admin.go b/client/admin.go new file mode 100644 index 00000000..f728483e --- /dev/null +++ b/client/admin.go @@ -0,0 +1,60 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "fmt" + "net" + "net/http" + "time" + + "github.com/fatedier/frp/models/config" + frpNet "github.com/fatedier/frp/utils/net" + + "github.com/julienschmidt/httprouter" +) + +var ( + httpServerReadTimeout = 10 * time.Second + httpServerWriteTimeout = 10 * time.Second +) + +func (svr *Service) RunAdminServer(addr string, port int64) (err error) { + // url router + router := httprouter.New() + + user, passwd := config.ClientCommonCfg.AdminUser, config.ClientCommonCfg.AdminPwd + + // api, see dashboard_api.go + router.GET("/api/reload", frpNet.HttprouterBasicAuth(svr.apiReload, user, passwd)) + + address := fmt.Sprintf("%s:%d", addr, port) + server := &http.Server{ + Addr: address, + Handler: router, + ReadTimeout: httpServerReadTimeout, + WriteTimeout: httpServerWriteTimeout, + } + if address == "" { + address = ":http" + } + ln, err := net.Listen("tcp", address) + if err != nil { + return err + } + + go server.Serve(ln) + return +} diff --git a/client/admin_api.go b/client/admin_api.go new file mode 100644 index 00000000..72fae04e --- /dev/null +++ b/client/admin_api.go @@ -0,0 +1,78 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "encoding/json" + "net/http" + + "github.com/julienschmidt/httprouter" + ini "github.com/vaughan0/go-ini" + + "github.com/fatedier/frp/models/config" + "github.com/fatedier/frp/utils/log" +) + +type GeneralResponse struct { + Code int64 `json:"code"` + Msg string `json:"msg"` +} + +// api/reload +type ReloadResp struct { + GeneralResponse +} + +func (svr *Service) apiReload(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + var ( + buf []byte + res ReloadResp + ) + defer func() { + log.Info("Http response [/api/reload]: code [%d]", res.Code) + buf, _ = json.Marshal(&res) + w.Write(buf) + }() + + log.Info("Http request: [/api/reload]") + + conf, err := ini.LoadFile(config.ClientCommonCfg.ConfigFile) + if err != nil { + res.Code = 1 + res.Msg = err.Error() + log.Error("reload frpc config file error: %v", err) + return + } + + newCommonCfg, err := config.LoadClientCommonConf(conf) + if err != nil { + res.Code = 2 + res.Msg = err.Error() + log.Error("reload frpc common section error: %v", err) + return + } + + pxyCfgs, vistorCfgs, err := config.LoadProxyConfFromFile(newCommonCfg.User, conf, newCommonCfg.Start) + if err != nil { + res.Code = 3 + res.Msg = err.Error() + log.Error("reload frpc proxy config error: %v", err) + return + } + + svr.ctl.reloadConf(pxyCfgs, vistorCfgs) + log.Info("success reload conf") + return +} diff --git a/client/control.go b/client/control.go index fb1ba716..29dca60c 100644 --- a/client/control.go +++ b/client/control.go @@ -388,7 +388,7 @@ func (ctl *Control) manager() { ctl.Warn("[%s] start error: %s", m.ProxyName, m.Error) continue } - cfg, ok := ctl.pxyCfgs[m.ProxyName] + cfg, ok := ctl.getProxyConf(m.ProxyName) if !ok { // it will never go to this branch now ctl.Warn("[%s] no proxy conf found", m.ProxyName) @@ -424,20 +424,36 @@ func (ctl *Control) controler() { maxDelayTime := 30 * time.Second delayTime := time.Second - checkInterval := 30 * time.Second + checkInterval := 10 * time.Second checkProxyTicker := time.NewTicker(checkInterval) for { select { case <-checkProxyTicker.C: - // Every 30 seconds, check which proxy registered failed and reregister it to server. + // Every 10 seconds, check which proxy registered failed and reregister it to server. + ctl.mu.RLock() for _, cfg := range ctl.pxyCfgs { - if _, exist := ctl.getProxy(cfg.GetName()); !exist { - ctl.Info("try to reregister proxy [%s]", cfg.GetName()) + if _, exist := ctl.proxies[cfg.GetName()]; !exist { + ctl.Info("try to register proxy [%s]", cfg.GetName()) var newProxyMsg msg.NewProxy cfg.UnMarshalToMsg(&newProxyMsg) ctl.sendCh <- &newProxyMsg } } + + for _, cfg := range ctl.vistorCfgs { + if _, exist := ctl.vistors[cfg.GetName()]; !exist { + ctl.Info("try to start vistor [%s]", cfg.GetName()) + vistor := NewVistor(ctl, cfg) + err = vistor.Run() + if err != nil { + vistor.Warn("start error: %v", err) + continue + } + ctl.vistors[cfg.GetName()] = vistor + vistor.Info("start vistor success") + } + } + ctl.mu.RUnlock() case _, ok := <-ctl.closedCh: // we won't get any variable from this channel if !ok { @@ -485,11 +501,13 @@ func (ctl *Control) controler() { go ctl.reader() // send NewProxy message for all configured proxies + ctl.mu.RLock() for _, cfg := range ctl.pxyCfgs { var newProxyMsg msg.NewProxy cfg.UnMarshalToMsg(&newProxyMsg) ctl.sendCh <- &newProxyMsg } + ctl.mu.RUnlock() checkProxyTicker.Stop() checkProxyTicker = time.NewTicker(checkInterval) @@ -522,3 +540,82 @@ func (ctl *Control) addProxy(name string, pxy Proxy) { defer ctl.mu.Unlock() ctl.proxies[name] = pxy } + +func (ctl *Control) getProxyConf(name string) (conf config.ProxyConf, ok bool) { + ctl.mu.RLock() + defer ctl.mu.RUnlock() + conf, ok = ctl.pxyCfgs[name] + return +} + +func (ctl *Control) reloadConf(pxyCfgs map[string]config.ProxyConf, vistorCfgs map[string]config.ProxyConf) { + ctl.mu.Lock() + defer ctl.mu.Unlock() + + removedPxyNames := make([]string, 0) + for name, oldCfg := range ctl.pxyCfgs { + del := false + cfg, ok := pxyCfgs[name] + if !ok { + del = true + } else { + if !oldCfg.Compare(cfg) { + del = true + } + } + + if del { + removedPxyNames = append(removedPxyNames, name) + delete(ctl.pxyCfgs, name) + if pxy, ok := ctl.proxies[name]; ok { + pxy.Close() + } + delete(ctl.proxies, name) + ctl.sendCh <- &msg.CloseProxy{ + ProxyName: name, + } + } + } + ctl.Info("proxy removed: %v", removedPxyNames) + + addedPxyNames := make([]string, 0) + for name, cfg := range pxyCfgs { + if _, ok := ctl.pxyCfgs[name]; !ok { + ctl.pxyCfgs[name] = cfg + addedPxyNames = append(addedPxyNames, name) + } + } + ctl.Info("proxy added: %v", addedPxyNames) + + removedVistorName := make([]string, 0) + for name, oldVistorCfg := range ctl.vistorCfgs { + del := false + cfg, ok := vistorCfgs[name] + if !ok { + del = true + } else { + if !oldVistorCfg.Compare(cfg) { + del = true + } + } + + if del { + removedVistorName = append(removedVistorName, name) + delete(ctl.vistorCfgs, name) + if vistor, ok := ctl.vistors[name]; ok { + vistor.Close() + } + delete(ctl.vistors, name) + } + } + ctl.Info("vistor removed: %v", removedVistorName) + + addedVistorName := make([]string, 0) + for name, vistorCfg := range vistorCfgs { + if _, ok := ctl.vistorCfgs[name]; !ok { + ctl.vistorCfgs[name] = vistorCfg + addedVistorName = append(addedVistorName, name) + } + } + ctl.Info("vistor added: %v", addedVistorName) +} diff --git a/client/service.go b/client/service.go index ff28cc91..241a435c 100644 --- a/client/service.go +++ b/client/service.go @@ -14,7 +14,10 @@ package client -import "github.com/fatedier/frp/models/config" +import ( + "github.com/fatedier/frp/models/config" + "github.com/fatedier/frp/utils/log" +) type Service struct { // manager control connection with server @@ -38,6 +41,14 @@ func (svr *Service) Run() error { return err } + if config.ClientCommonCfg.AdminPort != 0 { + err = svr.RunAdminServer(config.ClientCommonCfg.AdminAddr, config.ClientCommonCfg.AdminPort) + if err != nil { + log.Warn("run admin server error: %v", err) + } + log.Info("admin server listen on %s:%d", config.ClientCommonCfg.AdminAddr, config.ClientCommonCfg.AdminPort) + } + <-svr.closedCh return nil } diff --git a/cmd/frpc/main.go b/cmd/frpc/main.go index 88fc59b9..986592ff 100644 --- a/cmd/frpc/main.go +++ b/cmd/frpc/main.go @@ -54,7 +54,7 @@ Options: func main() { var err error - confFile := "./frpc.ini" + confFile := "./frps.ini" // the configures parsed from file will be replaced by those from command line if exist args, err := docopt.Parse(usage, nil, true, version.Full(), false) @@ -73,6 +73,7 @@ func main() { fmt.Println(err) os.Exit(1) } + config.ClientCommonCfg.ConfigFile = confFile if args["-L"] != nil { if args["-L"].(string) == "console" { diff --git a/conf/frpc_full.ini b/conf/frpc_full.ini index 4cc6dc17..3a73e43f 100644 --- a/conf/frpc_full.ini +++ b/conf/frpc_full.ini @@ -20,6 +20,12 @@ log_max_days = 3 # for authentication privilege_token = 12345678 +# set admin address for control frpc's action by http api such as reload +admin_addr = 127.0.0.1 +admin_port = 7400 +admin_user = admin +admin_pwd = admin + # connections will be established in advance, default value is zero pool_count = 5 diff --git a/conf/frps_full.ini b/conf/frps_full.ini index 446d86aa..3b4740ed 100644 --- a/conf/frps_full.ini +++ b/conf/frps_full.ini @@ -16,7 +16,7 @@ kcp_bind_port = 7000 vhost_http_port = 80 vhost_https_port = 443 -# if you want to configure or reload frps by dashboard, dashboard_port must be set +# set dashboard_port to view dashboard of frps dashboard_port = 7500 # dashboard user and pwd for basic auth protect, if not set, both default value is admin diff --git a/models/config/client_common.go b/models/config/client_common.go index 8ec6cd89..749b6b13 100644 --- a/models/config/client_common.go +++ b/models/config/client_common.go @@ -36,6 +36,10 @@ type ClientCommonConf struct { LogLevel string LogMaxDays int64 PrivilegeToken string + AdminAddr string + AdminPort int64 + AdminUser string + AdminPwd string PoolCount int TcpMux bool User string @@ -57,6 +61,10 @@ func GetDeaultClientCommonConf() *ClientCommonConf { LogLevel: "info", LogMaxDays: 3, PrivilegeToken: "", + AdminAddr: "127.0.0.1", + AdminPort: 0, + AdminUser: "", + AdminPwd: "", PoolCount: 1, TcpMux: true, User: "", @@ -111,7 +119,9 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { tmpStr, ok = conf.Get("common", "log_max_days") if ok { - cfg.LogMaxDays, _ = strconv.ParseInt(tmpStr, 10, 64) + if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil { + cfg.LogMaxDays = v + } } tmpStr, ok = conf.Get("common", "privilege_token") @@ -119,6 +129,28 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { cfg.PrivilegeToken = tmpStr } + tmpStr, ok = conf.Get("common", "admin_addr") + if ok { + cfg.AdminAddr = tmpStr + } + + tmpStr, ok = conf.Get("common", "admin_port") + if ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil { + cfg.AdminPort = v + } + } + + tmpStr, ok = conf.Get("common", "admin_user") + if ok { + cfg.AdminUser = tmpStr + } + + tmpStr, ok = conf.Get("common", "admin_pwd") + if ok { + cfg.AdminPwd = tmpStr + } + tmpStr, ok = conf.Get("common", "pool_count") if ok { v, err = strconv.ParseInt(tmpStr, 10, 64) @@ -145,7 +177,7 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { if ok { proxyNames := strings.Split(tmpStr, ",") for _, name := range proxyNames { - cfg.Start[name] = struct{}{} + cfg.Start[strings.TrimSpace(name)] = struct{}{} } } diff --git a/models/config/proxy.go b/models/config/proxy.go index 90c982bf..b42f416c 100644 --- a/models/config/proxy.go +++ b/models/config/proxy.go @@ -56,6 +56,7 @@ type ProxyConf interface { LoadFromFile(name string, conf ini.Section) error UnMarshalToMsg(pMsg *msg.NewProxy) Check() error + Compare(conf ProxyConf) bool } func NewProxyConf(pMsg *msg.NewProxy) (cfg ProxyConf, err error) { @@ -105,6 +106,16 @@ func (cfg *BaseProxyConf) GetBaseInfo() *BaseProxyConf { return cfg } +func (cfg *BaseProxyConf) compare(cmp *BaseProxyConf) bool { + if cfg.ProxyName != cmp.ProxyName || + cfg.ProxyType != cmp.ProxyType || + cfg.UseEncryption != cmp.UseEncryption || + cfg.UseCompression != cmp.UseCompression { + return false + } + return true +} + func (cfg *BaseProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { cfg.ProxyName = pMsg.ProxyName cfg.ProxyType = pMsg.ProxyType @@ -149,8 +160,16 @@ type BindInfoConf struct { RemotePort int64 `json:"remote_port"` } +func (cfg *BindInfoConf) compare(cmp *BindInfoConf) bool { + if cfg.BindAddr != cmp.BindAddr || + cfg.RemotePort != cmp.RemotePort { + return false + } + return true +} + func (cfg *BindInfoConf) LoadFromMsg(pMsg *msg.NewProxy) { - cfg.BindAddr = ServerCommonCfg.BindAddr + cfg.BindAddr = ServerCommonCfg.ProxyBindAddr cfg.RemotePort = pMsg.RemotePort } @@ -188,6 +207,14 @@ type DomainConf struct { SubDomain string `json:"sub_domain"` } +func (cfg *DomainConf) compare(cmp *DomainConf) bool { + if strings.Join(cfg.CustomDomains, " ") != strings.Join(cmp.CustomDomains, " ") || + cfg.SubDomain != cmp.SubDomain { + return false + } + return true +} + func (cfg *DomainConf) LoadFromMsg(pMsg *msg.NewProxy) { cfg.CustomDomains = pMsg.CustomDomains cfg.SubDomain = pMsg.SubDomain @@ -246,6 +273,14 @@ type LocalSvrConf struct { LocalPort int `json:"-"` } +func (cfg *LocalSvrConf) compare(cmp *LocalSvrConf) bool { + if cfg.LocalIp != cmp.LocalIp || + cfg.LocalPort != cmp.LocalPort { + return false + } + return true +} + func (cfg *LocalSvrConf) LoadFromFile(name string, section ini.Section) (err error) { if cfg.LocalIp = section["local_ip"]; cfg.LocalIp == "" { cfg.LocalIp = "127.0.0.1" @@ -266,6 +301,20 @@ type PluginConf struct { PluginParams map[string]string `json:"-"` } +func (cfg *PluginConf) compare(cmp *PluginConf) bool { + if cfg.Plugin != cmp.Plugin || + len(cfg.PluginParams) != len(cmp.PluginParams) { + return false + } + for k, v := range cfg.PluginParams { + value, ok := cmp.PluginParams[k] + if !ok || v != value { + return false + } + } + return true +} + func (cfg *PluginConf) LoadFromFile(name string, section ini.Section) (err error) { cfg.Plugin = section["plugin"] cfg.PluginParams = make(map[string]string) @@ -291,6 +340,21 @@ type TcpProxyConf struct { PluginConf } +func (cfg *TcpProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*TcpProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + !cfg.BindInfoConf.compare(&cmpConf.BindInfoConf) || + !cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) || + !cfg.PluginConf.compare(&cmpConf.PluginConf) { + return false + } + return true +} + func (cfg *TcpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { cfg.BaseProxyConf.LoadFromMsg(pMsg) cfg.BindInfoConf.LoadFromMsg(pMsg) @@ -330,6 +394,20 @@ type UdpProxyConf struct { LocalSvrConf } +func (cfg *UdpProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*UdpProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + !cfg.BindInfoConf.compare(&cmpConf.BindInfoConf) || + !cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) { + return false + } + return true +} + func (cfg *UdpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { cfg.BaseProxyConf.LoadFromMsg(pMsg) cfg.BindInfoConf.LoadFromMsg(pMsg) @@ -372,6 +450,25 @@ type HttpProxyConf struct { HttpPwd string `json:"-"` } +func (cfg *HttpProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*HttpProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + !cfg.DomainConf.compare(&cmpConf.DomainConf) || + !cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) || + !cfg.PluginConf.compare(&cmpConf.PluginConf) || + strings.Join(cfg.Locations, " ") != strings.Join(cmpConf.Locations, " ") || + cfg.HostHeaderRewrite != cmpConf.HostHeaderRewrite || + cfg.HttpUser != cmpConf.HttpUser || + cfg.HttpPwd != cmpConf.HttpPwd { + return false + } + return true +} + func (cfg *HttpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { cfg.BaseProxyConf.LoadFromMsg(pMsg) cfg.DomainConf.LoadFromMsg(pMsg) @@ -438,6 +535,21 @@ type HttpsProxyConf struct { PluginConf } +func (cfg *HttpsProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*HttpsProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + !cfg.DomainConf.compare(&cmpConf.DomainConf) || + !cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) || + !cfg.PluginConf.compare(&cmpConf.PluginConf) { + return false + } + return true +} + func (cfg *HttpsProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { cfg.BaseProxyConf.LoadFromMsg(pMsg) cfg.DomainConf.LoadFromMsg(pMsg) @@ -488,6 +600,25 @@ type StcpProxyConf struct { BindPort int `json:"bind_port"` } +func (cfg *StcpProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*StcpProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + !cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) || + !cfg.PluginConf.compare(&cmpConf.PluginConf) || + cfg.Role != cmpConf.Role || + cfg.Sk != cmpConf.Sk || + cfg.ServerName != cmpConf.ServerName || + cfg.BindAddr != cmpConf.BindAddr || + cfg.BindPort != cmpConf.BindPort { + return false + } + return true +} + // Only for role server. func (cfg *StcpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { cfg.BaseProxyConf.LoadFromMsg(pMsg) diff --git a/server/dashboard.go b/server/dashboard.go index 84eac81d..01f71591 100644 --- a/server/dashboard.go +++ b/server/dashboard.go @@ -15,16 +15,14 @@ package server import ( - "compress/gzip" "fmt" - "io" "net" "net/http" - "strings" "time" "github.com/fatedier/frp/assets" "github.com/fatedier/frp/models/config" + frpNet "github.com/fatedier/frp/utils/net" "github.com/julienschmidt/httprouter" ) @@ -38,20 +36,24 @@ func RunDashboardServer(addr string, port int64) (err error) { // url router router := httprouter.New() + user, passwd := config.ServerCommonCfg.DashboardUser, config.ServerCommonCfg.DashboardPwd + // api, see dashboard_api.go - router.GET("/api/serverinfo", httprouterBasicAuth(apiServerInfo)) - router.GET("/api/proxy/tcp", httprouterBasicAuth(apiProxyTcp)) - router.GET("/api/proxy/udp", httprouterBasicAuth(apiProxyUdp)) - router.GET("/api/proxy/http", httprouterBasicAuth(apiProxyHttp)) - router.GET("/api/proxy/https", httprouterBasicAuth(apiProxyHttps)) - router.GET("/api/proxy/traffic/:name", httprouterBasicAuth(apiProxyTraffic)) + router.GET("/api/serverinfo", frpNet.HttprouterBasicAuth(apiServerInfo, user, passwd)) + router.GET("/api/proxy/tcp", frpNet.HttprouterBasicAuth(apiProxyTcp, user, passwd)) + router.GET("/api/proxy/udp", frpNet.HttprouterBasicAuth(apiProxyUdp, user, passwd)) + router.GET("/api/proxy/http", frpNet.HttprouterBasicAuth(apiProxyHttp, user, passwd)) + router.GET("/api/proxy/https", frpNet.HttprouterBasicAuth(apiProxyHttps, user, passwd)) + router.GET("/api/proxy/traffic/:name", frpNet.HttprouterBasicAuth(apiProxyTraffic, user, passwd)) // view router.Handler("GET", "/favicon.ico", http.FileServer(assets.FileSystem)) - router.Handler("GET", "/static/*filepath", MakeGzipHandler(basicAuthWraper(http.StripPrefix("/static/", http.FileServer(assets.FileSystem))))) - router.HandlerFunc("GET", "/", basicAuth(func(w http.ResponseWriter, r *http.Request) { + router.Handler("GET", "/static/*filepath", frpNet.MakeHttpGzipHandler( + frpNet.NewHttpBasicAuthWraper(http.StripPrefix("/static/", http.FileServer(assets.FileSystem)), user, passwd))) + + router.HandlerFunc("GET", "/", frpNet.HttpBasicAuth(func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/static/", http.StatusMovedPermanently) - })) + }, user, passwd)) address := fmt.Sprintf("%s:%d", addr, port) server := &http.Server{ @@ -71,91 +73,3 @@ func RunDashboardServer(addr string, port int64) (err error) { go server.Serve(ln) return } - -func use(h http.HandlerFunc, middleware ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc { - for _, m := range middleware { - h = m(h) - } - return h -} - -type AuthWraper struct { - h http.Handler - user string - passwd string -} - -func (aw *AuthWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) { - user, passwd, hasAuth := r.BasicAuth() - if (aw.user == "" && aw.passwd == "") || (hasAuth && user == aw.user && passwd == aw.passwd) { - aw.h.ServeHTTP(w, r) - } else { - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - } -} - -func basicAuthWraper(h http.Handler) http.Handler { - return &AuthWraper{ - h: h, - user: config.ServerCommonCfg.DashboardUser, - passwd: config.ServerCommonCfg.DashboardPwd, - } -} - -func basicAuth(h http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - user, passwd, hasAuth := r.BasicAuth() - if (config.ServerCommonCfg.DashboardUser == "" && config.ServerCommonCfg.DashboardPwd == "") || - (hasAuth && user == config.ServerCommonCfg.DashboardUser && passwd == config.ServerCommonCfg.DashboardPwd) { - h.ServeHTTP(w, r) - } else { - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - } - } -} - -func httprouterBasicAuth(h httprouter.Handle) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - user, passwd, hasAuth := r.BasicAuth() - if (config.ServerCommonCfg.DashboardUser == "" && config.ServerCommonCfg.DashboardPwd == "") || - (hasAuth && user == config.ServerCommonCfg.DashboardUser && passwd == config.ServerCommonCfg.DashboardPwd) { - h(w, r, ps) - } else { - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - } - } -} - -type GzipWraper struct { - h http.Handler -} - -func (gw *GzipWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { - gw.h.ServeHTTP(w, r) - return - } - w.Header().Set("Content-Encoding", "gzip") - gz := gzip.NewWriter(w) - defer gz.Close() - gzr := gzipResponseWriter{Writer: gz, ResponseWriter: w} - gw.h.ServeHTTP(gzr, r) -} - -func MakeGzipHandler(h http.Handler) http.Handler { - return &GzipWraper{ - h: h, - } -} - -type gzipResponseWriter struct { - io.Writer - http.ResponseWriter -} - -func (w gzipResponseWriter) Write(b []byte) (int, error) { - return w.Writer.Write(b) -} diff --git a/utils/net/http.go b/utils/net/http.go new file mode 100644 index 00000000..acc0f43e --- /dev/null +++ b/utils/net/http.go @@ -0,0 +1,105 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "compress/gzip" + "io" + "net/http" + "strings" + + "github.com/julienschmidt/httprouter" +) + +type HttpAuthWraper struct { + h http.Handler + user string + passwd string +} + +func NewHttpBasicAuthWraper(h http.Handler, user, passwd string) http.Handler { + return &HttpAuthWraper{ + h: h, + user: user, + passwd: passwd, + } +} + +func (aw *HttpAuthWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) { + user, passwd, hasAuth := r.BasicAuth() + if (aw.user == "" && aw.passwd == "") || (hasAuth && user == aw.user && passwd == aw.passwd) { + aw.h.ServeHTTP(w, r) + } else { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + } +} + +func HttpBasicAuth(h http.HandlerFunc, user, passwd string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + reqUser, reqPasswd, hasAuth := r.BasicAuth() + if (user == "" && passwd == "") || + (hasAuth && reqUser == user && reqPasswd == passwd) { + h.ServeHTTP(w, r) + } else { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + } + } +} + +func HttprouterBasicAuth(h httprouter.Handle, user, passwd string) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + reqUser, reqPasswd, hasAuth := r.BasicAuth() + if (user == "" && passwd == "") || + (hasAuth && reqUser == user && reqPasswd == passwd) { + h(w, r, ps) + } else { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + } + } +} + +type HttpGzipWraper struct { + h http.Handler +} + +func (gw *HttpGzipWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + gw.h.ServeHTTP(w, r) + return + } + w.Header().Set("Content-Encoding", "gzip") + gz := gzip.NewWriter(w) + defer gz.Close() + gzr := gzipResponseWriter{Writer: gz, ResponseWriter: w} + gw.h.ServeHTTP(gzr, r) +} + +func MakeHttpGzipHandler(h http.Handler) http.Handler { + return &HttpGzipWraper{ + h: h, + } +} + +type gzipResponseWriter struct { + io.Writer + http.ResponseWriter +} + +func (w gzipResponseWriter) Write(b []byte) (int, error) { + return w.Writer.Write(b) +}