mirror of
https://github.com/fatedier/frp.git
synced 2026-01-11 22:23:12 +00:00
rename models to pkg (#2005)
This commit is contained in:
151
pkg/auth/auth.go
Normal file
151
pkg/auth/auth.go
Normal file
@@ -0,0 +1,151 @@
|
||||
// Copyright 2020 guylewin, guy@lewin.co.il
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/fatedier/frp/pkg/consts"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
|
||||
"github.com/vaughan0/go-ini"
|
||||
)
|
||||
|
||||
type baseConfig struct {
|
||||
// AuthenticationMethod specifies what authentication method to use to
|
||||
// authenticate frpc with frps. If "token" is specified - token will be
|
||||
// read into login message. If "oidc" is specified - OIDC (Open ID Connect)
|
||||
// token will be issued using OIDC settings. By default, this value is "token".
|
||||
AuthenticationMethod string `json:"authentication_method"`
|
||||
// AuthenticateHeartBeats specifies whether to include authentication token in
|
||||
// heartbeats sent to frps. By default, this value is false.
|
||||
AuthenticateHeartBeats bool `json:"authenticate_heartbeats"`
|
||||
// AuthenticateNewWorkConns specifies whether to include authentication token in
|
||||
// new work connections sent to frps. By default, this value is false.
|
||||
AuthenticateNewWorkConns bool `json:"authenticate_new_work_conns"`
|
||||
}
|
||||
|
||||
func getDefaultBaseConf() baseConfig {
|
||||
return baseConfig{
|
||||
AuthenticationMethod: "token",
|
||||
AuthenticateHeartBeats: false,
|
||||
AuthenticateNewWorkConns: false,
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalBaseConfFromIni(conf ini.File) baseConfig {
|
||||
var (
|
||||
tmpStr string
|
||||
ok bool
|
||||
)
|
||||
|
||||
cfg := getDefaultBaseConf()
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "authentication_method"); ok {
|
||||
cfg.AuthenticationMethod = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "authenticate_heartbeats"); ok && tmpStr == "true" {
|
||||
cfg.AuthenticateHeartBeats = true
|
||||
} else {
|
||||
cfg.AuthenticateHeartBeats = false
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "authenticate_new_work_conns"); ok && tmpStr == "true" {
|
||||
cfg.AuthenticateNewWorkConns = true
|
||||
} else {
|
||||
cfg.AuthenticateNewWorkConns = false
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
type ClientConfig struct {
|
||||
baseConfig
|
||||
oidcClientConfig
|
||||
tokenConfig
|
||||
}
|
||||
|
||||
func GetDefaultClientConf() ClientConfig {
|
||||
return ClientConfig{
|
||||
baseConfig: getDefaultBaseConf(),
|
||||
oidcClientConfig: getDefaultOidcClientConf(),
|
||||
tokenConfig: getDefaultTokenConf(),
|
||||
}
|
||||
}
|
||||
|
||||
func UnmarshalClientConfFromIni(conf ini.File) (cfg ClientConfig) {
|
||||
cfg.baseConfig = unmarshalBaseConfFromIni(conf)
|
||||
cfg.oidcClientConfig = unmarshalOidcClientConfFromIni(conf)
|
||||
cfg.tokenConfig = unmarshalTokenConfFromIni(conf)
|
||||
return cfg
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
baseConfig
|
||||
oidcServerConfig
|
||||
tokenConfig
|
||||
}
|
||||
|
||||
func GetDefaultServerConf() ServerConfig {
|
||||
return ServerConfig{
|
||||
baseConfig: getDefaultBaseConf(),
|
||||
oidcServerConfig: getDefaultOidcServerConf(),
|
||||
tokenConfig: getDefaultTokenConf(),
|
||||
}
|
||||
}
|
||||
|
||||
func UnmarshalServerConfFromIni(conf ini.File) (cfg ServerConfig) {
|
||||
cfg.baseConfig = unmarshalBaseConfFromIni(conf)
|
||||
cfg.oidcServerConfig = unmarshalOidcServerConfFromIni(conf)
|
||||
cfg.tokenConfig = unmarshalTokenConfFromIni(conf)
|
||||
return cfg
|
||||
}
|
||||
|
||||
type Setter interface {
|
||||
SetLogin(*msg.Login) error
|
||||
SetPing(*msg.Ping) error
|
||||
SetNewWorkConn(*msg.NewWorkConn) error
|
||||
}
|
||||
|
||||
func NewAuthSetter(cfg ClientConfig) (authProvider Setter) {
|
||||
switch cfg.AuthenticationMethod {
|
||||
case consts.TokenAuthMethod:
|
||||
authProvider = NewTokenAuth(cfg.baseConfig, cfg.tokenConfig)
|
||||
case consts.OidcAuthMethod:
|
||||
authProvider = NewOidcAuthSetter(cfg.baseConfig, cfg.oidcClientConfig)
|
||||
default:
|
||||
panic(fmt.Sprintf("wrong authentication method: '%s'", cfg.AuthenticationMethod))
|
||||
}
|
||||
|
||||
return authProvider
|
||||
}
|
||||
|
||||
type Verifier interface {
|
||||
VerifyLogin(*msg.Login) error
|
||||
VerifyPing(*msg.Ping) error
|
||||
VerifyNewWorkConn(*msg.NewWorkConn) error
|
||||
}
|
||||
|
||||
func NewAuthVerifier(cfg ServerConfig) (authVerifier Verifier) {
|
||||
switch cfg.AuthenticationMethod {
|
||||
case consts.TokenAuthMethod:
|
||||
authVerifier = NewTokenAuth(cfg.baseConfig, cfg.tokenConfig)
|
||||
case consts.OidcAuthMethod:
|
||||
authVerifier = NewOidcAuthVerifier(cfg.baseConfig, cfg.oidcServerConfig)
|
||||
}
|
||||
|
||||
return authVerifier
|
||||
}
|
||||
255
pkg/auth/oidc.go
Normal file
255
pkg/auth/oidc.go
Normal file
@@ -0,0 +1,255 @@
|
||||
// Copyright 2020 guylewin, guy@lewin.co.il
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/vaughan0/go-ini"
|
||||
"golang.org/x/oauth2/clientcredentials"
|
||||
)
|
||||
|
||||
type oidcClientConfig struct {
|
||||
// OidcClientID specifies the client ID to use to get a token in OIDC
|
||||
// authentication if AuthenticationMethod == "oidc". By default, this value
|
||||
// is "".
|
||||
OidcClientID string `json:"oidc_client_id"`
|
||||
// OidcClientSecret specifies the client secret to use to get a token in OIDC
|
||||
// authentication if AuthenticationMethod == "oidc". By default, this value
|
||||
// is "".
|
||||
OidcClientSecret string `json:"oidc_client_secret"`
|
||||
// OidcAudience specifies the audience of the token in OIDC authentication
|
||||
//if AuthenticationMethod == "oidc". By default, this value is "".
|
||||
OidcAudience string `json:"oidc_audience"`
|
||||
// OidcTokenEndpointURL specifies the URL which implements OIDC Token Endpoint.
|
||||
// It will be used to get an OIDC token if AuthenticationMethod == "oidc".
|
||||
// By default, this value is "".
|
||||
OidcTokenEndpointURL string `json:"oidc_token_endpoint_url"`
|
||||
}
|
||||
|
||||
func getDefaultOidcClientConf() oidcClientConfig {
|
||||
return oidcClientConfig{
|
||||
OidcClientID: "",
|
||||
OidcClientSecret: "",
|
||||
OidcAudience: "",
|
||||
OidcTokenEndpointURL: "",
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalOidcClientConfFromIni(conf ini.File) oidcClientConfig {
|
||||
var (
|
||||
tmpStr string
|
||||
ok bool
|
||||
)
|
||||
|
||||
cfg := getDefaultOidcClientConf()
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "oidc_client_id"); ok {
|
||||
cfg.OidcClientID = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "oidc_client_secret"); ok {
|
||||
cfg.OidcClientSecret = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "oidc_audience"); ok {
|
||||
cfg.OidcAudience = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "oidc_token_endpoint_url"); ok {
|
||||
cfg.OidcTokenEndpointURL = tmpStr
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
type oidcServerConfig struct {
|
||||
// OidcIssuer specifies the issuer to verify OIDC tokens with. This issuer
|
||||
// will be used to load public keys to verify signature and will be compared
|
||||
// with the issuer claim in the OIDC token. It will be used if
|
||||
// AuthenticationMethod == "oidc". By default, this value is "".
|
||||
OidcIssuer string `json:"oidc_issuer"`
|
||||
// OidcAudience specifies the audience OIDC tokens should contain when validated.
|
||||
// If this value is empty, audience ("client ID") verification will be skipped.
|
||||
// It will be used when AuthenticationMethod == "oidc". By default, this
|
||||
// value is "".
|
||||
OidcAudience string `json:"oidc_audience"`
|
||||
// OidcSkipExpiryCheck specifies whether to skip checking if the OIDC token is
|
||||
// expired. It will be used when AuthenticationMethod == "oidc". By default, this
|
||||
// value is false.
|
||||
OidcSkipExpiryCheck bool `json:"oidc_skip_expiry_check"`
|
||||
// OidcSkipIssuerCheck specifies whether to skip checking if the OIDC token's
|
||||
// issuer claim matches the issuer specified in OidcIssuer. It will be used when
|
||||
// AuthenticationMethod == "oidc". By default, this value is false.
|
||||
OidcSkipIssuerCheck bool `json:"oidc_skip_issuer_check"`
|
||||
}
|
||||
|
||||
func getDefaultOidcServerConf() oidcServerConfig {
|
||||
return oidcServerConfig{
|
||||
OidcIssuer: "",
|
||||
OidcAudience: "",
|
||||
OidcSkipExpiryCheck: false,
|
||||
OidcSkipIssuerCheck: false,
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalOidcServerConfFromIni(conf ini.File) oidcServerConfig {
|
||||
var (
|
||||
tmpStr string
|
||||
ok bool
|
||||
)
|
||||
|
||||
cfg := getDefaultOidcServerConf()
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "oidc_issuer"); ok {
|
||||
cfg.OidcIssuer = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "oidc_audience"); ok {
|
||||
cfg.OidcAudience = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "oidc_skip_expiry_check"); ok && tmpStr == "true" {
|
||||
cfg.OidcSkipExpiryCheck = true
|
||||
} else {
|
||||
cfg.OidcSkipExpiryCheck = false
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "oidc_skip_issuer_check"); ok && tmpStr == "true" {
|
||||
cfg.OidcSkipIssuerCheck = true
|
||||
} else {
|
||||
cfg.OidcSkipIssuerCheck = false
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
type OidcAuthProvider struct {
|
||||
baseConfig
|
||||
|
||||
tokenGenerator *clientcredentials.Config
|
||||
}
|
||||
|
||||
func NewOidcAuthSetter(baseCfg baseConfig, cfg oidcClientConfig) *OidcAuthProvider {
|
||||
tokenGenerator := &clientcredentials.Config{
|
||||
ClientID: cfg.OidcClientID,
|
||||
ClientSecret: cfg.OidcClientSecret,
|
||||
Scopes: []string{cfg.OidcAudience},
|
||||
TokenURL: cfg.OidcTokenEndpointURL,
|
||||
}
|
||||
|
||||
return &OidcAuthProvider{
|
||||
baseConfig: baseCfg,
|
||||
tokenGenerator: tokenGenerator,
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *OidcAuthProvider) generateAccessToken() (accessToken string, err error) {
|
||||
tokenObj, err := auth.tokenGenerator.Token(context.Background())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("couldn't generate OIDC token for login: %v", err)
|
||||
}
|
||||
return tokenObj.AccessToken, nil
|
||||
}
|
||||
|
||||
func (auth *OidcAuthProvider) SetLogin(loginMsg *msg.Login) (err error) {
|
||||
loginMsg.PrivilegeKey, err = auth.generateAccessToken()
|
||||
return err
|
||||
}
|
||||
|
||||
func (auth *OidcAuthProvider) SetPing(pingMsg *msg.Ping) (err error) {
|
||||
if !auth.AuthenticateHeartBeats {
|
||||
return nil
|
||||
}
|
||||
|
||||
pingMsg.PrivilegeKey, err = auth.generateAccessToken()
|
||||
return err
|
||||
}
|
||||
|
||||
func (auth *OidcAuthProvider) SetNewWorkConn(newWorkConnMsg *msg.NewWorkConn) (err error) {
|
||||
if !auth.AuthenticateNewWorkConns {
|
||||
return nil
|
||||
}
|
||||
|
||||
newWorkConnMsg.PrivilegeKey, err = auth.generateAccessToken()
|
||||
return err
|
||||
}
|
||||
|
||||
type OidcAuthConsumer struct {
|
||||
baseConfig
|
||||
|
||||
verifier *oidc.IDTokenVerifier
|
||||
subjectFromLogin string
|
||||
}
|
||||
|
||||
func NewOidcAuthVerifier(baseCfg baseConfig, cfg oidcServerConfig) *OidcAuthConsumer {
|
||||
provider, err := oidc.NewProvider(context.Background(), cfg.OidcIssuer)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
verifierConf := oidc.Config{
|
||||
ClientID: cfg.OidcAudience,
|
||||
SkipClientIDCheck: cfg.OidcAudience == "",
|
||||
SkipExpiryCheck: cfg.OidcSkipExpiryCheck,
|
||||
SkipIssuerCheck: cfg.OidcSkipIssuerCheck,
|
||||
}
|
||||
return &OidcAuthConsumer{
|
||||
baseConfig: baseCfg,
|
||||
verifier: provider.Verifier(&verifierConf),
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) {
|
||||
token, err := auth.verifier.Verify(context.Background(), loginMsg.PrivilegeKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid OIDC token in login: %v", err)
|
||||
}
|
||||
auth.subjectFromLogin = token.Subject
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *OidcAuthConsumer) verifyPostLoginToken(privilegeKey string) (err error) {
|
||||
token, err := auth.verifier.Verify(context.Background(), privilegeKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid OIDC token in ping: %v", err)
|
||||
}
|
||||
if token.Subject != auth.subjectFromLogin {
|
||||
return fmt.Errorf("received different OIDC subject in login and ping. "+
|
||||
"original subject: %s, "+
|
||||
"new subject: %s",
|
||||
auth.subjectFromLogin, token.Subject)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *OidcAuthConsumer) VerifyPing(pingMsg *msg.Ping) (err error) {
|
||||
if !auth.AuthenticateHeartBeats {
|
||||
return nil
|
||||
}
|
||||
|
||||
return auth.verifyPostLoginToken(pingMsg.PrivilegeKey)
|
||||
}
|
||||
|
||||
func (auth *OidcAuthConsumer) VerifyNewWorkConn(newWorkConnMsg *msg.NewWorkConn) (err error) {
|
||||
if !auth.AuthenticateNewWorkConns {
|
||||
return nil
|
||||
}
|
||||
|
||||
return auth.verifyPostLoginToken(newWorkConnMsg.PrivilegeKey)
|
||||
}
|
||||
120
pkg/auth/token.go
Normal file
120
pkg/auth/token.go
Normal file
@@ -0,0 +1,120 @@
|
||||
// Copyright 2020 guylewin, guy@lewin.co.il
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
|
||||
"github.com/vaughan0/go-ini"
|
||||
)
|
||||
|
||||
type tokenConfig struct {
|
||||
// Token specifies the authorization token used to create keys to be sent
|
||||
// to the server. The server must have a matching token for authorization
|
||||
// to succeed. By default, this value is "".
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
func getDefaultTokenConf() tokenConfig {
|
||||
return tokenConfig{
|
||||
Token: "",
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalTokenConfFromIni(conf ini.File) tokenConfig {
|
||||
var (
|
||||
tmpStr string
|
||||
ok bool
|
||||
)
|
||||
|
||||
cfg := getDefaultTokenConf()
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "token"); ok {
|
||||
cfg.Token = tmpStr
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
type TokenAuthSetterVerifier struct {
|
||||
baseConfig
|
||||
|
||||
token string
|
||||
}
|
||||
|
||||
func NewTokenAuth(baseCfg baseConfig, cfg tokenConfig) *TokenAuthSetterVerifier {
|
||||
return &TokenAuthSetterVerifier{
|
||||
baseConfig: baseCfg,
|
||||
token: cfg.Token,
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *TokenAuthSetterVerifier) SetLogin(loginMsg *msg.Login) (err error) {
|
||||
loginMsg.PrivilegeKey = util.GetAuthKey(auth.token, loginMsg.Timestamp)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *TokenAuthSetterVerifier) SetPing(pingMsg *msg.Ping) error {
|
||||
if !auth.AuthenticateHeartBeats {
|
||||
return nil
|
||||
}
|
||||
|
||||
pingMsg.Timestamp = time.Now().Unix()
|
||||
pingMsg.PrivilegeKey = util.GetAuthKey(auth.token, pingMsg.Timestamp)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *TokenAuthSetterVerifier) SetNewWorkConn(newWorkConnMsg *msg.NewWorkConn) error {
|
||||
if !auth.AuthenticateNewWorkConns {
|
||||
return nil
|
||||
}
|
||||
|
||||
newWorkConnMsg.Timestamp = time.Now().Unix()
|
||||
newWorkConnMsg.PrivilegeKey = util.GetAuthKey(auth.token, newWorkConnMsg.Timestamp)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *TokenAuthSetterVerifier) VerifyLogin(loginMsg *msg.Login) error {
|
||||
if util.GetAuthKey(auth.token, loginMsg.Timestamp) != loginMsg.PrivilegeKey {
|
||||
return fmt.Errorf("token in login doesn't match token from configuration")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *TokenAuthSetterVerifier) VerifyPing(pingMsg *msg.Ping) error {
|
||||
if !auth.AuthenticateHeartBeats {
|
||||
return nil
|
||||
}
|
||||
|
||||
if util.GetAuthKey(auth.token, pingMsg.Timestamp) != pingMsg.PrivilegeKey {
|
||||
return fmt.Errorf("token in heartbeat doesn't match token from configuration")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *TokenAuthSetterVerifier) VerifyNewWorkConn(newWorkConnMsg *msg.NewWorkConn) error {
|
||||
if !auth.AuthenticateNewWorkConns {
|
||||
return nil
|
||||
}
|
||||
|
||||
if util.GetAuthKey(auth.token, newWorkConnMsg.Timestamp) != newWorkConnMsg.PrivilegeKey {
|
||||
return fmt.Errorf("token in NewWorkConn doesn't match token from configuration")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
365
pkg/config/client_common.go
Normal file
365
pkg/config/client_common.go
Normal file
@@ -0,0 +1,365 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/fatedier/frp/pkg/auth"
|
||||
|
||||
ini "github.com/vaughan0/go-ini"
|
||||
)
|
||||
|
||||
// ClientCommonConf contains information for a client service. It is
|
||||
// recommended to use GetDefaultClientConf instead of creating this object
|
||||
// directly, so that all unspecified fields have reasonable default values.
|
||||
type ClientCommonConf struct {
|
||||
auth.ClientConfig
|
||||
// ServerAddr specifies the address of the server to connect to. By
|
||||
// default, this value is "0.0.0.0".
|
||||
ServerAddr string `json:"server_addr"`
|
||||
// ServerPort specifies the port to connect to the server on. By default,
|
||||
// this value is 7000.
|
||||
ServerPort int `json:"server_port"`
|
||||
// HTTPProxy specifies a proxy address to connect to the server through. If
|
||||
// this value is "", the server will be connected to directly. By default,
|
||||
// this value is read from the "http_proxy" environment variable.
|
||||
HTTPProxy string `json:"http_proxy"`
|
||||
// LogFile specifies a file where logs will be written to. This value will
|
||||
// only be used if LogWay is set appropriately. By default, this value is
|
||||
// "console".
|
||||
LogFile string `json:"log_file"`
|
||||
// LogWay specifies the way logging is managed. Valid values are "console"
|
||||
// or "file". If "console" is used, logs will be printed to stdout. If
|
||||
// "file" is used, logs will be printed to LogFile. By default, this value
|
||||
// is "console".
|
||||
LogWay string `json:"log_way"`
|
||||
// LogLevel specifies the minimum log level. Valid values are "trace",
|
||||
// "debug", "info", "warn", and "error". By default, this value is "info".
|
||||
LogLevel string `json:"log_level"`
|
||||
// LogMaxDays specifies the maximum number of days to store log information
|
||||
// before deletion. This is only used if LogWay == "file". By default, this
|
||||
// value is 0.
|
||||
LogMaxDays int64 `json:"log_max_days"`
|
||||
// DisableLogColor disables log colors when LogWay == "console" when set to
|
||||
// true. By default, this value is false.
|
||||
DisableLogColor bool `json:"disable_log_color"`
|
||||
// AdminAddr specifies the address that the admin server binds to. By
|
||||
// default, this value is "127.0.0.1".
|
||||
AdminAddr string `json:"admin_addr"`
|
||||
// AdminPort specifies the port for the admin server to listen on. If this
|
||||
// value is 0, the admin server will not be started. By default, this value
|
||||
// is 0.
|
||||
AdminPort int `json:"admin_port"`
|
||||
// AdminUser specifies the username that the admin server will use for
|
||||
// login. By default, this value is "admin".
|
||||
AdminUser string `json:"admin_user"`
|
||||
// AdminPwd specifies the password that the admin server will use for
|
||||
// login. By default, this value is "admin".
|
||||
AdminPwd string `json:"admin_pwd"`
|
||||
// AssetsDir specifies the local directory that the admin server will load
|
||||
// resources from. If this value is "", assets will be loaded from the
|
||||
// bundled executable using statik. By default, this value is "".
|
||||
AssetsDir string `json:"assets_dir"`
|
||||
// PoolCount specifies the number of connections the client will make to
|
||||
// the server in advance. By default, this value is 0.
|
||||
PoolCount int `json:"pool_count"`
|
||||
// TCPMux toggles TCP stream multiplexing. This allows multiple requests
|
||||
// from a client to share a single TCP connection. If this value is true,
|
||||
// the server must have TCP multiplexing enabled as well. By default, this
|
||||
// value is true.
|
||||
TCPMux bool `json:"tcp_mux"`
|
||||
// User specifies a prefix for proxy names to distinguish them from other
|
||||
// clients. If this value is not "", proxy names will automatically be
|
||||
// changed to "{user}.{proxy_name}". By default, this value is "".
|
||||
User string `json:"user"`
|
||||
// DNSServer specifies a DNS server address for FRPC to use. If this value
|
||||
// is "", the default DNS will be used. By default, this value is "".
|
||||
DNSServer string `json:"dns_server"`
|
||||
// LoginFailExit controls whether or not the client should exit after a
|
||||
// failed login attempt. If false, the client will retry until a login
|
||||
// attempt succeeds. By default, this value is true.
|
||||
LoginFailExit bool `json:"login_fail_exit"`
|
||||
// Start specifies a set of enabled proxies by name. If this set is empty,
|
||||
// all supplied proxies are enabled. By default, this value is an empty
|
||||
// set.
|
||||
Start map[string]struct{} `json:"start"`
|
||||
// Protocol specifies the protocol to use when interacting with the server.
|
||||
// Valid values are "tcp", "kcp" and "websocket". By default, this value
|
||||
// is "tcp".
|
||||
Protocol string `json:"protocol"`
|
||||
// TLSEnable specifies whether or not TLS should be used when communicating
|
||||
// with the server. If "tls_cert_file" and "tls_key_file" are valid,
|
||||
// client will load the supplied tls configuration.
|
||||
TLSEnable bool `json:"tls_enable"`
|
||||
// ClientTLSCertPath specifies the path of the cert file that client will
|
||||
// load. It only works when "tls_enable" is true and "tls_key_file" is valid.
|
||||
TLSCertFile string `json:"tls_cert_file"`
|
||||
// ClientTLSKeyPath specifies the path of the secret key file that client
|
||||
// will load. It only works when "tls_enable" is true and "tls_cert_file"
|
||||
// are valid.
|
||||
TLSKeyFile string `json:"tls_key_file"`
|
||||
// TrustedCaFile specifies the path of the trusted ca file that will load.
|
||||
// It only works when "tls_enable" is valid and tls configuration of server
|
||||
// has been specified.
|
||||
TLSTrustedCaFile string `json:"tls_trusted_ca_file"`
|
||||
// HeartBeatInterval specifies at what interval heartbeats are sent to the
|
||||
// server, in seconds. It is not recommended to change this value. By
|
||||
// default, this value is 30.
|
||||
HeartBeatInterval int64 `json:"heartbeat_interval"`
|
||||
// HeartBeatTimeout specifies the maximum allowed heartbeat response delay
|
||||
// before the connection is terminated, in seconds. It is not recommended
|
||||
// to change this value. By default, this value is 90.
|
||||
HeartBeatTimeout int64 `json:"heartbeat_timeout"`
|
||||
// Client meta info
|
||||
Metas map[string]string `json:"metas"`
|
||||
// UDPPacketSize specifies the udp packet size
|
||||
// By default, this value is 1500
|
||||
UDPPacketSize int64 `json:"udp_packet_size"`
|
||||
}
|
||||
|
||||
// GetDefaultClientConf returns a client configuration with default values.
|
||||
func GetDefaultClientConf() ClientCommonConf {
|
||||
return ClientCommonConf{
|
||||
ServerAddr: "0.0.0.0",
|
||||
ServerPort: 7000,
|
||||
HTTPProxy: os.Getenv("http_proxy"),
|
||||
LogFile: "console",
|
||||
LogWay: "console",
|
||||
LogLevel: "info",
|
||||
LogMaxDays: 3,
|
||||
DisableLogColor: false,
|
||||
AdminAddr: "127.0.0.1",
|
||||
AdminPort: 0,
|
||||
AdminUser: "",
|
||||
AdminPwd: "",
|
||||
AssetsDir: "",
|
||||
PoolCount: 1,
|
||||
TCPMux: true,
|
||||
User: "",
|
||||
DNSServer: "",
|
||||
LoginFailExit: true,
|
||||
Start: make(map[string]struct{}),
|
||||
Protocol: "tcp",
|
||||
TLSEnable: false,
|
||||
TLSCertFile: "",
|
||||
TLSKeyFile: "",
|
||||
TLSTrustedCaFile: "",
|
||||
HeartBeatInterval: 30,
|
||||
HeartBeatTimeout: 90,
|
||||
Metas: make(map[string]string),
|
||||
UDPPacketSize: 1500,
|
||||
}
|
||||
}
|
||||
|
||||
func UnmarshalClientConfFromIni(content string) (cfg ClientCommonConf, err error) {
|
||||
cfg = GetDefaultClientConf()
|
||||
|
||||
conf, err := ini.Load(strings.NewReader(content))
|
||||
if err != nil {
|
||||
return ClientCommonConf{}, fmt.Errorf("parse ini conf file error: %v", err)
|
||||
}
|
||||
|
||||
cfg.ClientConfig = auth.UnmarshalClientConfFromIni(conf)
|
||||
|
||||
var (
|
||||
tmpStr string
|
||||
ok bool
|
||||
v int64
|
||||
)
|
||||
if tmpStr, ok = conf.Get("common", "server_addr"); ok {
|
||||
cfg.ServerAddr = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "server_port"); ok {
|
||||
v, err = strconv.ParseInt(tmpStr, 10, 64)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Parse conf error: invalid server_port")
|
||||
return
|
||||
}
|
||||
cfg.ServerPort = int(v)
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "disable_log_color"); ok && tmpStr == "true" {
|
||||
cfg.DisableLogColor = true
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "http_proxy"); ok {
|
||||
cfg.HTTPProxy = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "log_file"); ok {
|
||||
cfg.LogFile = tmpStr
|
||||
if cfg.LogFile == "console" {
|
||||
cfg.LogWay = "console"
|
||||
} else {
|
||||
cfg.LogWay = "file"
|
||||
}
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "log_level"); ok {
|
||||
cfg.LogLevel = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "log_max_days"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil {
|
||||
cfg.LogMaxDays = v
|
||||
}
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "admin_addr"); ok {
|
||||
cfg.AdminAddr = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "admin_port"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil {
|
||||
cfg.AdminPort = int(v)
|
||||
} else {
|
||||
err = fmt.Errorf("Parse conf error: invalid admin_port")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "admin_user"); ok {
|
||||
cfg.AdminUser = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "admin_pwd"); ok {
|
||||
cfg.AdminPwd = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "assets_dir"); ok {
|
||||
cfg.AssetsDir = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "pool_count"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil {
|
||||
cfg.PoolCount = int(v)
|
||||
}
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "tcp_mux"); ok && tmpStr == "false" {
|
||||
cfg.TCPMux = false
|
||||
} else {
|
||||
cfg.TCPMux = true
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "user"); ok {
|
||||
cfg.User = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "dns_server"); ok {
|
||||
cfg.DNSServer = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "start"); ok {
|
||||
proxyNames := strings.Split(tmpStr, ",")
|
||||
for _, name := range proxyNames {
|
||||
cfg.Start[strings.TrimSpace(name)] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "login_fail_exit"); ok && tmpStr == "false" {
|
||||
cfg.LoginFailExit = false
|
||||
} else {
|
||||
cfg.LoginFailExit = true
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "protocol"); ok {
|
||||
// Now it only support tcp and kcp and websocket.
|
||||
if tmpStr != "tcp" && tmpStr != "kcp" && tmpStr != "websocket" {
|
||||
err = fmt.Errorf("Parse conf error: invalid protocol")
|
||||
return
|
||||
}
|
||||
cfg.Protocol = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "tls_enable"); ok && tmpStr == "true" {
|
||||
cfg.TLSEnable = true
|
||||
} else {
|
||||
cfg.TLSEnable = false
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "tls_cert_file"); ok {
|
||||
cfg.TLSCertFile = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok := conf.Get("common", "tls_key_file"); ok {
|
||||
cfg.TLSKeyFile = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok := conf.Get("common", "tls_trusted_ca_file"); ok {
|
||||
cfg.TLSTrustedCaFile = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "heartbeat_timeout"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
|
||||
err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout")
|
||||
return
|
||||
}
|
||||
cfg.HeartBeatTimeout = v
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "heartbeat_interval"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
|
||||
err = fmt.Errorf("Parse conf error: invalid heartbeat_interval")
|
||||
return
|
||||
}
|
||||
cfg.HeartBeatInterval = v
|
||||
}
|
||||
for k, v := range conf.Section("common") {
|
||||
if strings.HasPrefix(k, "meta_") {
|
||||
cfg.Metas[strings.TrimPrefix(k, "meta_")] = v
|
||||
}
|
||||
}
|
||||
if tmpStr, ok = conf.Get("common", "udp_packet_size"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
|
||||
err = fmt.Errorf("Parse conf error: invalid udp_packet_size")
|
||||
return
|
||||
}
|
||||
cfg.UDPPacketSize = v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *ClientCommonConf) Check() (err error) {
|
||||
if cfg.HeartBeatInterval <= 0 {
|
||||
err = fmt.Errorf("Parse conf error: invalid heartbeat_interval")
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.HeartBeatTimeout < cfg.HeartBeatInterval {
|
||||
err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout, heartbeat_timeout is less than heartbeat_interval")
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.TLSEnable == false {
|
||||
if cfg.TLSCertFile != "" {
|
||||
fmt.Println("WARNING! tls_cert_file is invalid when tls_enable is false")
|
||||
}
|
||||
|
||||
if cfg.TLSKeyFile != "" {
|
||||
fmt.Println("WARNING! tls_key_file is invalid when tls_enable is false")
|
||||
}
|
||||
|
||||
if cfg.TLSTrustedCaFile != "" {
|
||||
fmt.Println("WARNING! tls_trusted_ca_file is invalid when tls_enable is false")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
1187
pkg/config/proxy.go
Normal file
1187
pkg/config/proxy.go
Normal file
File diff suppressed because it is too large
Load Diff
477
pkg/config/server_common.go
Normal file
477
pkg/config/server_common.go
Normal file
@@ -0,0 +1,477 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/fatedier/frp/pkg/auth"
|
||||
plugin "github.com/fatedier/frp/pkg/plugin/server"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
|
||||
ini "github.com/vaughan0/go-ini"
|
||||
)
|
||||
|
||||
// ServerCommonConf contains information for a server service. It is
|
||||
// recommended to use GetDefaultServerConf instead of creating this object
|
||||
// directly, so that all unspecified fields have reasonable default values.
|
||||
type ServerCommonConf struct {
|
||||
auth.ServerConfig
|
||||
// BindAddr specifies the address that the server binds to. By default,
|
||||
// this value is "0.0.0.0".
|
||||
BindAddr string `json:"bind_addr"`
|
||||
// BindPort specifies the port that the server listens on. By default, this
|
||||
// value is 7000.
|
||||
BindPort int `json:"bind_port"`
|
||||
// BindUDPPort specifies the UDP port that the server listens on. If this
|
||||
// value is 0, the server will not listen for UDP connections. By default,
|
||||
// this value is 0
|
||||
BindUDPPort int `json:"bind_udp_port"`
|
||||
// KCPBindPort specifies the KCP port that the server listens on. If this
|
||||
// value is 0, the server will not listen for KCP connections. By default,
|
||||
// this value is 0.
|
||||
KCPBindPort int `json:"kcp_bind_port"`
|
||||
// ProxyBindAddr specifies the address that the proxy binds to. This value
|
||||
// may be the same as BindAddr. By default, this value is "0.0.0.0".
|
||||
ProxyBindAddr string `json:"proxy_bind_addr"`
|
||||
// VhostHTTPPort specifies the port that the server listens for HTTP Vhost
|
||||
// requests. If this value is 0, the server will not listen for HTTP
|
||||
// requests. By default, this value is 0.
|
||||
VhostHTTPPort int `json:"vhost_http_port"`
|
||||
// VhostHTTPSPort specifies the port that the server listens for HTTPS
|
||||
// Vhost requests. If this value is 0, the server will not listen for HTTPS
|
||||
// requests. By default, this value is 0.
|
||||
VhostHTTPSPort int `json:"vhost_https_port"`
|
||||
// TCPMuxHTTPConnectPort specifies the port that the server listens for TCP
|
||||
// HTTP CONNECT requests. If the value is 0, the server will not multiplex TCP
|
||||
// requests on one single port. If it's not - it will listen on this value for
|
||||
// HTTP CONNECT requests. By default, this value is 0.
|
||||
TCPMuxHTTPConnectPort int `json:"tcpmux_httpconnect_port"`
|
||||
// VhostHTTPTimeout specifies the response header timeout for the Vhost
|
||||
// HTTP server, in seconds. By default, this value is 60.
|
||||
VhostHTTPTimeout int64 `json:"vhost_http_timeout"`
|
||||
// DashboardAddr specifies the address that the dashboard binds to. By
|
||||
// default, this value is "0.0.0.0".
|
||||
DashboardAddr string `json:"dashboard_addr"`
|
||||
// DashboardPort specifies the port that the dashboard listens on. If this
|
||||
// value is 0, the dashboard will not be started. By default, this value is
|
||||
// 0.
|
||||
DashboardPort int `json:"dashboard_port"`
|
||||
// DashboardUser specifies the username that the dashboard will use for
|
||||
// login. By default, this value is "admin".
|
||||
DashboardUser string `json:"dashboard_user"`
|
||||
// DashboardUser specifies the password that the dashboard will use for
|
||||
// login. By default, this value is "admin".
|
||||
DashboardPwd string `json:"dashboard_pwd"`
|
||||
// EnablePrometheus will export prometheus metrics on {dashboard_addr}:{dashboard_port}
|
||||
// in /metrics api.
|
||||
EnablePrometheus bool `json:"enable_prometheus"`
|
||||
// AssetsDir specifies the local directory that the dashboard will load
|
||||
// resources from. If this value is "", assets will be loaded from the
|
||||
// bundled executable using statik. By default, this value is "".
|
||||
AssetsDir string `json:"asserts_dir"`
|
||||
// LogFile specifies a file where logs will be written to. This value will
|
||||
// only be used if LogWay is set appropriately. By default, this value is
|
||||
// "console".
|
||||
LogFile string `json:"log_file"`
|
||||
// LogWay specifies the way logging is managed. Valid values are "console"
|
||||
// or "file". If "console" is used, logs will be printed to stdout. If
|
||||
// "file" is used, logs will be printed to LogFile. By default, this value
|
||||
// is "console".
|
||||
LogWay string `json:"log_way"`
|
||||
// LogLevel specifies the minimum log level. Valid values are "trace",
|
||||
// "debug", "info", "warn", and "error". By default, this value is "info".
|
||||
LogLevel string `json:"log_level"`
|
||||
// LogMaxDays specifies the maximum number of days to store log information
|
||||
// before deletion. This is only used if LogWay == "file". By default, this
|
||||
// value is 0.
|
||||
LogMaxDays int64 `json:"log_max_days"`
|
||||
// DisableLogColor disables log colors when LogWay == "console" when set to
|
||||
// true. By default, this value is false.
|
||||
DisableLogColor bool `json:"disable_log_color"`
|
||||
// DetailedErrorsToClient defines whether to send the specific error (with
|
||||
// debug info) to frpc. By default, this value is true.
|
||||
DetailedErrorsToClient bool `json:"detailed_errors_to_client"`
|
||||
|
||||
// SubDomainHost specifies the domain that will be attached to sub-domains
|
||||
// requested by the client when using Vhost proxying. For example, if this
|
||||
// value is set to "frps.com" and the client requested the subdomain
|
||||
// "test", the resulting URL would be "test.frps.com". By default, this
|
||||
// value is "".
|
||||
SubDomainHost string `json:"subdomain_host"`
|
||||
// TCPMux toggles TCP stream multiplexing. This allows multiple requests
|
||||
// from a client to share a single TCP connection. By default, this value
|
||||
// is true.
|
||||
TCPMux bool `json:"tcp_mux"`
|
||||
// Custom404Page specifies a path to a custom 404 page to display. If this
|
||||
// value is "", a default page will be displayed. By default, this value is
|
||||
// "".
|
||||
Custom404Page string `json:"custom_404_page"`
|
||||
|
||||
// AllowPorts specifies a set of ports that clients are able to proxy to.
|
||||
// If the length of this value is 0, all ports are allowed. By default,
|
||||
// this value is an empty set.
|
||||
AllowPorts map[int]struct{}
|
||||
// MaxPoolCount specifies the maximum pool size for each proxy. By default,
|
||||
// this value is 5.
|
||||
MaxPoolCount int64 `json:"max_pool_count"`
|
||||
// MaxPortsPerClient specifies the maximum number of ports a single client
|
||||
// may proxy to. If this value is 0, no limit will be applied. By default,
|
||||
// this value is 0.
|
||||
MaxPortsPerClient int64 `json:"max_ports_per_client"`
|
||||
// TLSOnly specifies whether to only accept TLS-encrypted connections.
|
||||
// By default, the value is false.
|
||||
TLSOnly bool `json:"tls_only"`
|
||||
// TLSCertFile specifies the path of the cert file that the server will
|
||||
// load. If "tls_cert_file", "tls_key_file" are valid, the server will use this
|
||||
// supplied tls configuration. Otherwise, the server will use the tls
|
||||
// configuration generated by itself.
|
||||
TLSCertFile string `json:"tls_cert_file"`
|
||||
// TLSKeyFile specifies the path of the secret key that the server will
|
||||
// load. If "tls_cert_file", "tls_key_file" are valid, the server will use this
|
||||
// supplied tls configuration. Otherwise, the server will use the tls
|
||||
// configuration generated by itself.
|
||||
TLSKeyFile string `json:"tls_key_file"`
|
||||
// TLSTrustedCaFile specifies the paths of the client cert files that the
|
||||
// server will load. It only works when "tls_only" is true. If
|
||||
// "tls_trusted_ca_file" is valid, the server will verify each client's
|
||||
// certificate.
|
||||
TLSTrustedCaFile string `json:"tls_trusted_ca_file"`
|
||||
// HeartBeatTimeout specifies the maximum time to wait for a heartbeat
|
||||
// before terminating the connection. It is not recommended to change this
|
||||
// value. By default, this value is 90.
|
||||
HeartBeatTimeout int64 `json:"heart_beat_timeout"`
|
||||
// UserConnTimeout specifies the maximum time to wait for a work
|
||||
// connection. By default, this value is 10.
|
||||
UserConnTimeout int64 `json:"user_conn_timeout"`
|
||||
// HTTPPlugins specify the server plugins support HTTP protocol.
|
||||
HTTPPlugins map[string]plugin.HTTPPluginOptions `json:"http_plugins"`
|
||||
// UDPPacketSize specifies the UDP packet size
|
||||
// By default, this value is 1500
|
||||
UDPPacketSize int64 `json:"udp_packet_size"`
|
||||
}
|
||||
|
||||
// GetDefaultServerConf returns a server configuration with reasonable
|
||||
// defaults.
|
||||
func GetDefaultServerConf() ServerCommonConf {
|
||||
return ServerCommonConf{
|
||||
BindAddr: "0.0.0.0",
|
||||
BindPort: 7000,
|
||||
BindUDPPort: 0,
|
||||
KCPBindPort: 0,
|
||||
ProxyBindAddr: "0.0.0.0",
|
||||
VhostHTTPPort: 0,
|
||||
VhostHTTPSPort: 0,
|
||||
TCPMuxHTTPConnectPort: 0,
|
||||
VhostHTTPTimeout: 60,
|
||||
DashboardAddr: "0.0.0.0",
|
||||
DashboardPort: 0,
|
||||
DashboardUser: "admin",
|
||||
DashboardPwd: "admin",
|
||||
EnablePrometheus: false,
|
||||
AssetsDir: "",
|
||||
LogFile: "console",
|
||||
LogWay: "console",
|
||||
LogLevel: "info",
|
||||
LogMaxDays: 3,
|
||||
DisableLogColor: false,
|
||||
DetailedErrorsToClient: true,
|
||||
SubDomainHost: "",
|
||||
TCPMux: true,
|
||||
AllowPorts: make(map[int]struct{}),
|
||||
MaxPoolCount: 5,
|
||||
MaxPortsPerClient: 0,
|
||||
TLSOnly: false,
|
||||
TLSCertFile: "",
|
||||
TLSKeyFile: "",
|
||||
TLSTrustedCaFile: "",
|
||||
HeartBeatTimeout: 90,
|
||||
UserConnTimeout: 10,
|
||||
Custom404Page: "",
|
||||
HTTPPlugins: make(map[string]plugin.HTTPPluginOptions),
|
||||
UDPPacketSize: 1500,
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalServerConfFromIni parses the contents of a server configuration ini
|
||||
// file and returns the resulting server configuration.
|
||||
func UnmarshalServerConfFromIni(content string) (cfg ServerCommonConf, err error) {
|
||||
cfg = GetDefaultServerConf()
|
||||
|
||||
conf, err := ini.Load(strings.NewReader(content))
|
||||
if err != nil {
|
||||
err = fmt.Errorf("parse ini conf file error: %v", err)
|
||||
return ServerCommonConf{}, err
|
||||
}
|
||||
|
||||
UnmarshalPluginsFromIni(conf, &cfg)
|
||||
|
||||
cfg.ServerConfig = auth.UnmarshalServerConfFromIni(conf)
|
||||
|
||||
var (
|
||||
tmpStr string
|
||||
ok bool
|
||||
v int64
|
||||
)
|
||||
if tmpStr, ok = conf.Get("common", "bind_addr"); ok {
|
||||
cfg.BindAddr = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "bind_port"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
|
||||
err = fmt.Errorf("Parse conf error: invalid bind_port")
|
||||
return
|
||||
}
|
||||
cfg.BindPort = int(v)
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "bind_udp_port"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
|
||||
err = fmt.Errorf("Parse conf error: invalid bind_udp_port")
|
||||
return
|
||||
}
|
||||
cfg.BindUDPPort = int(v)
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "kcp_bind_port"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
|
||||
err = fmt.Errorf("Parse conf error: invalid kcp_bind_port")
|
||||
return
|
||||
}
|
||||
cfg.KCPBindPort = int(v)
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "proxy_bind_addr"); ok {
|
||||
cfg.ProxyBindAddr = tmpStr
|
||||
} else {
|
||||
cfg.ProxyBindAddr = cfg.BindAddr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "vhost_http_port"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
|
||||
err = fmt.Errorf("Parse conf error: invalid vhost_http_port")
|
||||
return
|
||||
}
|
||||
cfg.VhostHTTPPort = int(v)
|
||||
} else {
|
||||
cfg.VhostHTTPPort = 0
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "vhost_https_port"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
|
||||
err = fmt.Errorf("Parse conf error: invalid vhost_https_port")
|
||||
return
|
||||
}
|
||||
cfg.VhostHTTPSPort = int(v)
|
||||
} else {
|
||||
cfg.VhostHTTPSPort = 0
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "tcpmux_httpconnect_port"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
|
||||
err = fmt.Errorf("Parse conf error: invalid tcpmux_httpconnect_port")
|
||||
return
|
||||
}
|
||||
cfg.TCPMuxHTTPConnectPort = int(v)
|
||||
} else {
|
||||
cfg.TCPMuxHTTPConnectPort = 0
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "vhost_http_timeout"); ok {
|
||||
v, errRet := strconv.ParseInt(tmpStr, 10, 64)
|
||||
if errRet != nil || v < 0 {
|
||||
err = fmt.Errorf("Parse conf error: invalid vhost_http_timeout")
|
||||
return
|
||||
}
|
||||
cfg.VhostHTTPTimeout = v
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "dashboard_addr"); ok {
|
||||
cfg.DashboardAddr = tmpStr
|
||||
} else {
|
||||
cfg.DashboardAddr = cfg.BindAddr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "dashboard_port"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
|
||||
err = fmt.Errorf("Parse conf error: invalid dashboard_port")
|
||||
return
|
||||
}
|
||||
cfg.DashboardPort = int(v)
|
||||
} else {
|
||||
cfg.DashboardPort = 0
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "dashboard_user"); ok {
|
||||
cfg.DashboardUser = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "dashboard_pwd"); ok {
|
||||
cfg.DashboardPwd = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "enable_prometheus"); ok && tmpStr == "true" {
|
||||
cfg.EnablePrometheus = true
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "assets_dir"); ok {
|
||||
cfg.AssetsDir = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "log_file"); ok {
|
||||
cfg.LogFile = tmpStr
|
||||
if cfg.LogFile == "console" {
|
||||
cfg.LogWay = "console"
|
||||
} else {
|
||||
cfg.LogWay = "file"
|
||||
}
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "log_level"); ok {
|
||||
cfg.LogLevel = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "log_max_days"); ok {
|
||||
v, err = strconv.ParseInt(tmpStr, 10, 64)
|
||||
if err == nil {
|
||||
cfg.LogMaxDays = v
|
||||
}
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "disable_log_color"); ok && tmpStr == "true" {
|
||||
cfg.DisableLogColor = true
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "detailed_errors_to_client"); ok && tmpStr == "false" {
|
||||
cfg.DetailedErrorsToClient = false
|
||||
} else {
|
||||
cfg.DetailedErrorsToClient = true
|
||||
}
|
||||
|
||||
if allowPortsStr, ok := conf.Get("common", "allow_ports"); ok {
|
||||
// e.g. 1000-2000,2001,2002,3000-4000
|
||||
ports, errRet := util.ParseRangeNumbers(allowPortsStr)
|
||||
if errRet != nil {
|
||||
err = fmt.Errorf("Parse conf error: allow_ports: %v", errRet)
|
||||
return
|
||||
}
|
||||
|
||||
for _, port := range ports {
|
||||
cfg.AllowPorts[int(port)] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "max_pool_count"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
|
||||
err = fmt.Errorf("Parse conf error: invalid max_pool_count")
|
||||
return
|
||||
}
|
||||
|
||||
if v < 0 {
|
||||
err = fmt.Errorf("Parse conf error: invalid max_pool_count")
|
||||
return
|
||||
}
|
||||
cfg.MaxPoolCount = v
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "max_ports_per_client"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
|
||||
err = fmt.Errorf("Parse conf error: invalid max_ports_per_client")
|
||||
return
|
||||
}
|
||||
|
||||
if v < 0 {
|
||||
err = fmt.Errorf("Parse conf error: invalid max_ports_per_client")
|
||||
return
|
||||
}
|
||||
cfg.MaxPortsPerClient = v
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "subdomain_host"); ok {
|
||||
cfg.SubDomainHost = strings.ToLower(strings.TrimSpace(tmpStr))
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "tcp_mux"); ok && tmpStr == "false" {
|
||||
cfg.TCPMux = false
|
||||
} else {
|
||||
cfg.TCPMux = true
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "custom_404_page"); ok {
|
||||
cfg.Custom404Page = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "heartbeat_timeout"); ok {
|
||||
v, errRet := strconv.ParseInt(tmpStr, 10, 64)
|
||||
if errRet != nil {
|
||||
err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect")
|
||||
return
|
||||
}
|
||||
cfg.HeartBeatTimeout = v
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "tls_only"); ok && tmpStr == "true" {
|
||||
cfg.TLSOnly = true
|
||||
} else {
|
||||
cfg.TLSOnly = false
|
||||
}
|
||||
|
||||
if tmpStr, ok = conf.Get("common", "udp_packet_size"); ok {
|
||||
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
|
||||
err = fmt.Errorf("Parse conf error: invalid udp_packet_size")
|
||||
return
|
||||
}
|
||||
cfg.UDPPacketSize = v
|
||||
}
|
||||
|
||||
if tmpStr, ok := conf.Get("common", "tls_cert_file"); ok {
|
||||
cfg.TLSCertFile = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok := conf.Get("common", "tls_key_file"); ok {
|
||||
cfg.TLSKeyFile = tmpStr
|
||||
}
|
||||
|
||||
if tmpStr, ok := conf.Get("common", "tls_trusted_ca_file"); ok {
|
||||
cfg.TLSTrustedCaFile = tmpStr
|
||||
cfg.TLSOnly = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func UnmarshalPluginsFromIni(sections ini.File, cfg *ServerCommonConf) {
|
||||
for name, section := range sections {
|
||||
if strings.HasPrefix(name, "plugin.") {
|
||||
name = strings.TrimSpace(strings.TrimPrefix(name, "plugin."))
|
||||
options := plugin.HTTPPluginOptions{
|
||||
Name: name,
|
||||
Addr: section["addr"],
|
||||
Path: section["path"],
|
||||
Ops: strings.Split(section["ops"], ","),
|
||||
}
|
||||
for i := range options.Ops {
|
||||
options.Ops[i] = strings.TrimSpace(options.Ops[i])
|
||||
}
|
||||
cfg.HTTPPlugins[name] = options
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *ServerCommonConf) Check() error {
|
||||
return nil
|
||||
}
|
||||
112
pkg/config/types.go
Normal file
112
pkg/config/types.go
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
MB = 1024 * 1024
|
||||
KB = 1024
|
||||
)
|
||||
|
||||
type BandwidthQuantity struct {
|
||||
s string // MB or KB
|
||||
|
||||
i int64 // bytes
|
||||
}
|
||||
|
||||
func NewBandwidthQuantity(s string) (BandwidthQuantity, error) {
|
||||
q := BandwidthQuantity{}
|
||||
err := q.UnmarshalString(s)
|
||||
if err != nil {
|
||||
return q, err
|
||||
}
|
||||
return q, nil
|
||||
}
|
||||
|
||||
func (q *BandwidthQuantity) Equal(u *BandwidthQuantity) bool {
|
||||
if q == nil && u == nil {
|
||||
return true
|
||||
}
|
||||
if q != nil && u != nil {
|
||||
return q.i == u.i
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (q *BandwidthQuantity) String() string {
|
||||
return q.s
|
||||
}
|
||||
|
||||
func (q *BandwidthQuantity) UnmarshalString(s string) error {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
base int64
|
||||
f float64
|
||||
err error
|
||||
)
|
||||
if strings.HasSuffix(s, "MB") {
|
||||
base = MB
|
||||
fstr := strings.TrimSuffix(s, "MB")
|
||||
f, err = strconv.ParseFloat(fstr, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if strings.HasSuffix(s, "KB") {
|
||||
base = KB
|
||||
fstr := strings.TrimSuffix(s, "KB")
|
||||
f, err = strconv.ParseFloat(fstr, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return errors.New("unit not support")
|
||||
}
|
||||
|
||||
q.s = s
|
||||
q.i = int64(f * float64(base))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *BandwidthQuantity) UnmarshalJSON(b []byte) error {
|
||||
if len(b) == 4 && string(b) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var str string
|
||||
err := json.Unmarshal(b, &str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return q.UnmarshalString(str)
|
||||
}
|
||||
|
||||
func (q *BandwidthQuantity) MarshalJSON() ([]byte, error) {
|
||||
return []byte("\"" + q.s + "\""), nil
|
||||
}
|
||||
|
||||
func (q *BandwidthQuantity) Bytes() int64 {
|
||||
return q.i
|
||||
}
|
||||
40
pkg/config/types_test.go
Normal file
40
pkg/config/types_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type Wrap struct {
|
||||
B BandwidthQuantity `json:"b"`
|
||||
Int int `json:"int"`
|
||||
}
|
||||
|
||||
func TestBandwidthQuantity(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
var w Wrap
|
||||
err := json.Unmarshal([]byte(`{"b":"1KB","int":5}`), &w)
|
||||
assert.NoError(err)
|
||||
assert.EqualValues(1*KB, w.B.Bytes())
|
||||
|
||||
buf, err := json.Marshal(&w)
|
||||
assert.NoError(err)
|
||||
assert.Equal(`{"b":"1KB","int":5}`, string(buf))
|
||||
}
|
||||
64
pkg/config/value.go
Normal file
64
pkg/config/value.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
var (
|
||||
glbEnvs map[string]string
|
||||
)
|
||||
|
||||
func init() {
|
||||
glbEnvs = make(map[string]string)
|
||||
envs := os.Environ()
|
||||
for _, env := range envs {
|
||||
kv := strings.Split(env, "=")
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
glbEnvs[kv[0]] = kv[1]
|
||||
}
|
||||
}
|
||||
|
||||
type Values struct {
|
||||
Envs map[string]string // environment vars
|
||||
}
|
||||
|
||||
func GetValues() *Values {
|
||||
return &Values{
|
||||
Envs: glbEnvs,
|
||||
}
|
||||
}
|
||||
|
||||
func RenderContent(in string) (out string, err error) {
|
||||
tmpl, errRet := template.New("frp").Parse(in)
|
||||
if errRet != nil {
|
||||
err = errRet
|
||||
return
|
||||
}
|
||||
|
||||
buffer := bytes.NewBufferString("")
|
||||
v := GetValues()
|
||||
err = tmpl.Execute(buffer, v)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
out = buffer.String()
|
||||
return
|
||||
}
|
||||
|
||||
func GetRenderedConfFromFile(path string) (out string, err error) {
|
||||
var b []byte
|
||||
b, err = ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
content := string(b)
|
||||
|
||||
out, err = RenderContent(content)
|
||||
return
|
||||
}
|
||||
244
pkg/config/visitor.go
Normal file
244
pkg/config/visitor.go
Normal file
@@ -0,0 +1,244 @@
|
||||
// Copyright 2018 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
"github.com/fatedier/frp/pkg/consts"
|
||||
|
||||
ini "github.com/vaughan0/go-ini"
|
||||
)
|
||||
|
||||
var (
|
||||
visitorConfTypeMap map[string]reflect.Type
|
||||
)
|
||||
|
||||
func init() {
|
||||
visitorConfTypeMap = make(map[string]reflect.Type)
|
||||
visitorConfTypeMap[consts.STCPProxy] = reflect.TypeOf(STCPVisitorConf{})
|
||||
visitorConfTypeMap[consts.XTCPProxy] = reflect.TypeOf(XTCPVisitorConf{})
|
||||
visitorConfTypeMap[consts.SUDPProxy] = reflect.TypeOf(SUDPVisitorConf{})
|
||||
}
|
||||
|
||||
type VisitorConf interface {
|
||||
GetBaseInfo() *BaseVisitorConf
|
||||
Compare(cmp VisitorConf) bool
|
||||
UnmarshalFromIni(prefix string, name string, section ini.Section) error
|
||||
Check() error
|
||||
}
|
||||
|
||||
func NewVisitorConfByType(cfgType string) VisitorConf {
|
||||
v, ok := visitorConfTypeMap[cfgType]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
cfg := reflect.New(v).Interface().(VisitorConf)
|
||||
return cfg
|
||||
}
|
||||
|
||||
func NewVisitorConfFromIni(prefix string, name string, section ini.Section) (cfg VisitorConf, err error) {
|
||||
cfgType := section["type"]
|
||||
if cfgType == "" {
|
||||
err = fmt.Errorf("visitor [%s] type shouldn't be empty", name)
|
||||
return
|
||||
}
|
||||
cfg = NewVisitorConfByType(cfgType)
|
||||
if cfg == nil {
|
||||
err = fmt.Errorf("visitor [%s] type [%s] error", name, cfgType)
|
||||
return
|
||||
}
|
||||
if err = cfg.UnmarshalFromIni(prefix, name, section); err != nil {
|
||||
return
|
||||
}
|
||||
if err = cfg.Check(); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type BaseVisitorConf struct {
|
||||
ProxyName string `json:"proxy_name"`
|
||||
ProxyType string `json:"proxy_type"`
|
||||
UseEncryption bool `json:"use_encryption"`
|
||||
UseCompression bool `json:"use_compression"`
|
||||
Role string `json:"role"`
|
||||
Sk string `json:"sk"`
|
||||
ServerName string `json:"server_name"`
|
||||
BindAddr string `json:"bind_addr"`
|
||||
BindPort int `json:"bind_port"`
|
||||
}
|
||||
|
||||
func (cfg *BaseVisitorConf) GetBaseInfo() *BaseVisitorConf {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (cfg *BaseVisitorConf) compare(cmp *BaseVisitorConf) bool {
|
||||
if cfg.ProxyName != cmp.ProxyName ||
|
||||
cfg.ProxyType != cmp.ProxyType ||
|
||||
cfg.UseEncryption != cmp.UseEncryption ||
|
||||
cfg.UseCompression != cmp.UseCompression ||
|
||||
cfg.Role != cmp.Role ||
|
||||
cfg.Sk != cmp.Sk ||
|
||||
cfg.ServerName != cmp.ServerName ||
|
||||
cfg.BindAddr != cmp.BindAddr ||
|
||||
cfg.BindPort != cmp.BindPort {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (cfg *BaseVisitorConf) check() (err error) {
|
||||
if cfg.Role != "visitor" {
|
||||
err = fmt.Errorf("invalid role")
|
||||
return
|
||||
}
|
||||
if cfg.BindAddr == "" {
|
||||
err = fmt.Errorf("bind_addr shouldn't be empty")
|
||||
return
|
||||
}
|
||||
if cfg.BindPort <= 0 {
|
||||
err = fmt.Errorf("bind_port is required")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *BaseVisitorConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) {
|
||||
var (
|
||||
tmpStr string
|
||||
ok bool
|
||||
)
|
||||
cfg.ProxyName = prefix + name
|
||||
cfg.ProxyType = section["type"]
|
||||
|
||||
if tmpStr, ok = section["use_encryption"]; ok && tmpStr == "true" {
|
||||
cfg.UseEncryption = true
|
||||
}
|
||||
if tmpStr, ok = section["use_compression"]; ok && tmpStr == "true" {
|
||||
cfg.UseCompression = true
|
||||
}
|
||||
|
||||
cfg.Role = section["role"]
|
||||
if cfg.Role != "visitor" {
|
||||
return fmt.Errorf("Parse conf error: proxy [%s] incorrect role [%s]", name, cfg.Role)
|
||||
}
|
||||
cfg.Sk = section["sk"]
|
||||
cfg.ServerName = prefix + section["server_name"]
|
||||
if cfg.BindAddr = section["bind_addr"]; cfg.BindAddr == "" {
|
||||
cfg.BindAddr = "127.0.0.1"
|
||||
}
|
||||
|
||||
if tmpStr, ok = section["bind_port"]; ok {
|
||||
if cfg.BindPort, err = strconv.Atoi(tmpStr); err != nil {
|
||||
return fmt.Errorf("Parse conf error: proxy [%s] bind_port incorrect", name)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("Parse conf error: proxy [%s] bind_port not found", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type SUDPVisitorConf struct {
|
||||
BaseVisitorConf
|
||||
}
|
||||
|
||||
func (cfg *SUDPVisitorConf) Compare(cmp VisitorConf) bool {
|
||||
cmpConf, ok := cmp.(*SUDPVisitorConf)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if !cfg.BaseVisitorConf.compare(&cmpConf.BaseVisitorConf) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (cfg *SUDPVisitorConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) {
|
||||
if err = cfg.BaseVisitorConf.UnmarshalFromIni(prefix, name, section); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *SUDPVisitorConf) Check() (err error) {
|
||||
if err = cfg.BaseVisitorConf.check(); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type STCPVisitorConf struct {
|
||||
BaseVisitorConf
|
||||
}
|
||||
|
||||
func (cfg *STCPVisitorConf) Compare(cmp VisitorConf) bool {
|
||||
cmpConf, ok := cmp.(*STCPVisitorConf)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if !cfg.BaseVisitorConf.compare(&cmpConf.BaseVisitorConf) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (cfg *STCPVisitorConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) {
|
||||
if err = cfg.BaseVisitorConf.UnmarshalFromIni(prefix, name, section); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *STCPVisitorConf) Check() (err error) {
|
||||
if err = cfg.BaseVisitorConf.check(); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type XTCPVisitorConf struct {
|
||||
BaseVisitorConf
|
||||
}
|
||||
|
||||
func (cfg *XTCPVisitorConf) Compare(cmp VisitorConf) bool {
|
||||
cmpConf, ok := cmp.(*XTCPVisitorConf)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if !cfg.BaseVisitorConf.compare(&cmpConf.BaseVisitorConf) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (cfg *XTCPVisitorConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) {
|
||||
if err = cfg.BaseVisitorConf.UnmarshalFromIni(prefix, name, section); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *XTCPVisitorConf) Check() (err error) {
|
||||
if err = cfg.BaseVisitorConf.check(); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
41
pkg/consts/consts.go
Normal file
41
pkg/consts/consts.go
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package consts
|
||||
|
||||
var (
|
||||
// proxy status
|
||||
Idle string = "idle"
|
||||
Working string = "working"
|
||||
Closed string = "closed"
|
||||
Online string = "online"
|
||||
Offline string = "offline"
|
||||
|
||||
// proxy type
|
||||
TCPProxy string = "tcp"
|
||||
UDPProxy string = "udp"
|
||||
TCPMuxProxy string = "tcpmux"
|
||||
HTTPProxy string = "http"
|
||||
HTTPSProxy string = "https"
|
||||
STCPProxy string = "stcp"
|
||||
XTCPProxy string = "xtcp"
|
||||
SUDPProxy string = "sudp"
|
||||
|
||||
// authentication method
|
||||
TokenAuthMethod string = "token"
|
||||
OidcAuthMethod string = "oidc"
|
||||
|
||||
// TCP multiplexer
|
||||
HTTPConnectTCPMultiplexer string = "httpconnect"
|
||||
)
|
||||
24
pkg/errors/errors.go
Normal file
24
pkg/errors/errors.go
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMsgType = errors.New("message type error")
|
||||
ErrCtlClosed = errors.New("control is closed")
|
||||
)
|
||||
93
pkg/metrics/aggregate/server.go
Normal file
93
pkg/metrics/aggregate/server.go
Normal file
@@ -0,0 +1,93 @@
|
||||
// Copyright 2020 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package aggregate
|
||||
|
||||
import (
|
||||
"github.com/fatedier/frp/pkg/metrics/mem"
|
||||
"github.com/fatedier/frp/pkg/metrics/prometheus"
|
||||
"github.com/fatedier/frp/server/metrics"
|
||||
)
|
||||
|
||||
// EnableMem start to mark metrics to memory monitor system.
|
||||
func EnableMem() {
|
||||
sm.Add(mem.ServerMetrics)
|
||||
}
|
||||
|
||||
// EnablePrometheus start to mark metrics to prometheus.
|
||||
func EnablePrometheus() {
|
||||
sm.Add(prometheus.ServerMetrics)
|
||||
}
|
||||
|
||||
var sm *serverMetrics = &serverMetrics{}
|
||||
|
||||
func init() {
|
||||
metrics.Register(sm)
|
||||
}
|
||||
|
||||
type serverMetrics struct {
|
||||
ms []metrics.ServerMetrics
|
||||
}
|
||||
|
||||
func (m *serverMetrics) Add(sm metrics.ServerMetrics) {
|
||||
m.ms = append(m.ms, sm)
|
||||
}
|
||||
|
||||
func (m *serverMetrics) NewClient() {
|
||||
for _, v := range m.ms {
|
||||
v.NewClient()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverMetrics) CloseClient() {
|
||||
for _, v := range m.ms {
|
||||
v.CloseClient()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverMetrics) NewProxy(name string, proxyType string) {
|
||||
for _, v := range m.ms {
|
||||
v.NewProxy(name, proxyType)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverMetrics) CloseProxy(name string, proxyType string) {
|
||||
for _, v := range m.ms {
|
||||
v.CloseProxy(name, proxyType)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverMetrics) OpenConnection(name string, proxyType string) {
|
||||
for _, v := range m.ms {
|
||||
v.OpenConnection(name, proxyType)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverMetrics) CloseConnection(name string, proxyType string) {
|
||||
for _, v := range m.ms {
|
||||
v.CloseConnection(name, proxyType)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverMetrics) AddTrafficIn(name string, proxyType string, trafficBytes int64) {
|
||||
for _, v := range m.ms {
|
||||
v.AddTrafficIn(name, proxyType, trafficBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverMetrics) AddTrafficOut(name string, proxyType string, trafficBytes int64) {
|
||||
for _, v := range m.ms {
|
||||
v.AddTrafficOut(name, proxyType, trafficBytes)
|
||||
}
|
||||
}
|
||||
262
pkg/metrics/mem/server.go
Normal file
262
pkg/metrics/mem/server.go
Normal file
@@ -0,0 +1,262 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mem
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
"github.com/fatedier/frp/pkg/util/metric"
|
||||
server "github.com/fatedier/frp/server/metrics"
|
||||
)
|
||||
|
||||
var sm *serverMetrics = newServerMetrics()
|
||||
var ServerMetrics server.ServerMetrics
|
||||
var StatsCollector Collector
|
||||
|
||||
func init() {
|
||||
ServerMetrics = sm
|
||||
StatsCollector = sm
|
||||
sm.run()
|
||||
}
|
||||
|
||||
type serverMetrics struct {
|
||||
info *ServerStatistics
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newServerMetrics() *serverMetrics {
|
||||
return &serverMetrics{
|
||||
info: &ServerStatistics{
|
||||
TotalTrafficIn: metric.NewDateCounter(ReserveDays),
|
||||
TotalTrafficOut: metric.NewDateCounter(ReserveDays),
|
||||
CurConns: metric.NewCounter(),
|
||||
|
||||
ClientCounts: metric.NewCounter(),
|
||||
ProxyTypeCounts: make(map[string]metric.Counter),
|
||||
|
||||
ProxyStatistics: make(map[string]*ProxyStatistics),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverMetrics) run() {
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(12 * time.Hour)
|
||||
log.Debug("start to clear useless proxy statistics data...")
|
||||
m.clearUselessInfo()
|
||||
log.Debug("finish to clear useless proxy statistics data")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *serverMetrics) clearUselessInfo() {
|
||||
// To check if there are proxies that closed than 7 days and drop them.
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for name, data := range m.info.ProxyStatistics {
|
||||
if !data.LastCloseTime.IsZero() && time.Since(data.LastCloseTime) > time.Duration(7*24)*time.Hour {
|
||||
delete(m.info.ProxyStatistics, name)
|
||||
log.Trace("clear proxy [%s]'s statistics data, lastCloseTime: [%s]", name, data.LastCloseTime.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverMetrics) NewClient() {
|
||||
m.info.ClientCounts.Inc(1)
|
||||
}
|
||||
|
||||
func (m *serverMetrics) CloseClient() {
|
||||
m.info.ClientCounts.Dec(1)
|
||||
}
|
||||
|
||||
func (m *serverMetrics) NewProxy(name string, proxyType string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
counter, ok := m.info.ProxyTypeCounts[proxyType]
|
||||
if !ok {
|
||||
counter = metric.NewCounter()
|
||||
}
|
||||
counter.Inc(1)
|
||||
m.info.ProxyTypeCounts[proxyType] = counter
|
||||
|
||||
proxyStats, ok := m.info.ProxyStatistics[name]
|
||||
if !(ok && proxyStats.ProxyType == proxyType) {
|
||||
proxyStats = &ProxyStatistics{
|
||||
Name: name,
|
||||
ProxyType: proxyType,
|
||||
CurConns: metric.NewCounter(),
|
||||
TrafficIn: metric.NewDateCounter(ReserveDays),
|
||||
TrafficOut: metric.NewDateCounter(ReserveDays),
|
||||
}
|
||||
m.info.ProxyStatistics[name] = proxyStats
|
||||
}
|
||||
proxyStats.LastStartTime = time.Now()
|
||||
}
|
||||
|
||||
func (m *serverMetrics) CloseProxy(name string, proxyType string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if counter, ok := m.info.ProxyTypeCounts[proxyType]; ok {
|
||||
counter.Dec(1)
|
||||
}
|
||||
if proxyStats, ok := m.info.ProxyStatistics[name]; ok {
|
||||
proxyStats.LastCloseTime = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverMetrics) OpenConnection(name string, proxyType string) {
|
||||
m.info.CurConns.Inc(1)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
proxyStats, ok := m.info.ProxyStatistics[name]
|
||||
if ok {
|
||||
proxyStats.CurConns.Inc(1)
|
||||
m.info.ProxyStatistics[name] = proxyStats
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverMetrics) CloseConnection(name string, proxyType string) {
|
||||
m.info.CurConns.Dec(1)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
proxyStats, ok := m.info.ProxyStatistics[name]
|
||||
if ok {
|
||||
proxyStats.CurConns.Dec(1)
|
||||
m.info.ProxyStatistics[name] = proxyStats
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverMetrics) AddTrafficIn(name string, proxyType string, trafficBytes int64) {
|
||||
m.info.TotalTrafficIn.Inc(trafficBytes)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
proxyStats, ok := m.info.ProxyStatistics[name]
|
||||
if ok {
|
||||
proxyStats.TrafficIn.Inc(trafficBytes)
|
||||
m.info.ProxyStatistics[name] = proxyStats
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverMetrics) AddTrafficOut(name string, proxyType string, trafficBytes int64) {
|
||||
m.info.TotalTrafficOut.Inc(trafficBytes)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
proxyStats, ok := m.info.ProxyStatistics[name]
|
||||
if ok {
|
||||
proxyStats.TrafficOut.Inc(trafficBytes)
|
||||
m.info.ProxyStatistics[name] = proxyStats
|
||||
}
|
||||
}
|
||||
|
||||
// Get stats data api.
|
||||
|
||||
func (m *serverMetrics) GetServer() *ServerStats {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
s := &ServerStats{
|
||||
TotalTrafficIn: m.info.TotalTrafficIn.TodayCount(),
|
||||
TotalTrafficOut: m.info.TotalTrafficOut.TodayCount(),
|
||||
CurConns: int64(m.info.CurConns.Count()),
|
||||
ClientCounts: int64(m.info.ClientCounts.Count()),
|
||||
ProxyTypeCounts: make(map[string]int64),
|
||||
}
|
||||
for k, v := range m.info.ProxyTypeCounts {
|
||||
s.ProxyTypeCounts[k] = int64(v.Count())
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m *serverMetrics) GetProxiesByType(proxyType string) []*ProxyStats {
|
||||
res := make([]*ProxyStats, 0)
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for name, proxyStats := range m.info.ProxyStatistics {
|
||||
if proxyStats.ProxyType != proxyType {
|
||||
continue
|
||||
}
|
||||
|
||||
ps := &ProxyStats{
|
||||
Name: name,
|
||||
Type: proxyStats.ProxyType,
|
||||
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
|
||||
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
|
||||
CurConns: int64(proxyStats.CurConns.Count()),
|
||||
}
|
||||
if !proxyStats.LastStartTime.IsZero() {
|
||||
ps.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
|
||||
}
|
||||
if !proxyStats.LastCloseTime.IsZero() {
|
||||
ps.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
|
||||
}
|
||||
res = append(res, ps)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (m *serverMetrics) GetProxiesByTypeAndName(proxyType string, proxyName string) (res *ProxyStats) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for name, proxyStats := range m.info.ProxyStatistics {
|
||||
if proxyStats.ProxyType != proxyType {
|
||||
continue
|
||||
}
|
||||
|
||||
if name != proxyName {
|
||||
continue
|
||||
}
|
||||
|
||||
res = &ProxyStats{
|
||||
Name: name,
|
||||
Type: proxyStats.ProxyType,
|
||||
TodayTrafficIn: proxyStats.TrafficIn.TodayCount(),
|
||||
TodayTrafficOut: proxyStats.TrafficOut.TodayCount(),
|
||||
CurConns: int64(proxyStats.CurConns.Count()),
|
||||
}
|
||||
if !proxyStats.LastStartTime.IsZero() {
|
||||
res.LastStartTime = proxyStats.LastStartTime.Format("01-02 15:04:05")
|
||||
}
|
||||
if !proxyStats.LastCloseTime.IsZero() {
|
||||
res.LastCloseTime = proxyStats.LastCloseTime.Format("01-02 15:04:05")
|
||||
}
|
||||
break
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *serverMetrics) GetProxyTraffic(name string) (res *ProxyTrafficInfo) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
proxyStats, ok := m.info.ProxyStatistics[name]
|
||||
if ok {
|
||||
res = &ProxyTrafficInfo{
|
||||
Name: name,
|
||||
}
|
||||
res.TrafficIn = proxyStats.TrafficIn.GetLastDaysCount(ReserveDays)
|
||||
res.TrafficOut = proxyStats.TrafficOut.GetLastDaysCount(ReserveDays)
|
||||
}
|
||||
return
|
||||
}
|
||||
82
pkg/metrics/mem/types.go
Normal file
82
pkg/metrics/mem/types.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mem
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/metric"
|
||||
)
|
||||
|
||||
const (
|
||||
ReserveDays = 7
|
||||
)
|
||||
|
||||
type ServerStats struct {
|
||||
TotalTrafficIn int64
|
||||
TotalTrafficOut int64
|
||||
CurConns int64
|
||||
ClientCounts int64
|
||||
ProxyTypeCounts map[string]int64
|
||||
}
|
||||
|
||||
type ProxyStats struct {
|
||||
Name string
|
||||
Type string
|
||||
TodayTrafficIn int64
|
||||
TodayTrafficOut int64
|
||||
LastStartTime string
|
||||
LastCloseTime string
|
||||
CurConns int64
|
||||
}
|
||||
|
||||
type ProxyTrafficInfo struct {
|
||||
Name string
|
||||
TrafficIn []int64
|
||||
TrafficOut []int64
|
||||
}
|
||||
|
||||
type ProxyStatistics struct {
|
||||
Name string
|
||||
ProxyType string
|
||||
TrafficIn metric.DateCounter
|
||||
TrafficOut metric.DateCounter
|
||||
CurConns metric.Counter
|
||||
LastStartTime time.Time
|
||||
LastCloseTime time.Time
|
||||
}
|
||||
|
||||
type ServerStatistics struct {
|
||||
TotalTrafficIn metric.DateCounter
|
||||
TotalTrafficOut metric.DateCounter
|
||||
CurConns metric.Counter
|
||||
|
||||
// counter for clients
|
||||
ClientCounts metric.Counter
|
||||
|
||||
// counter for proxy types
|
||||
ProxyTypeCounts map[string]metric.Counter
|
||||
|
||||
// statistics for different proxies
|
||||
// key is proxy name
|
||||
ProxyStatistics map[string]*ProxyStatistics
|
||||
}
|
||||
|
||||
type Collector interface {
|
||||
GetServer() *ServerStats
|
||||
GetProxiesByType(proxyType string) []*ProxyStats
|
||||
GetProxiesByTypeAndName(proxyType string, proxyName string) *ProxyStats
|
||||
GetProxyTraffic(name string) *ProxyTrafficInfo
|
||||
}
|
||||
8
pkg/metrics/metrics.go
Normal file
8
pkg/metrics/metrics.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"github.com/fatedier/frp/pkg/metrics/aggregate"
|
||||
)
|
||||
|
||||
var EnableMem = aggregate.EnableMem
|
||||
var EnablePrometheus = aggregate.EnablePrometheus
|
||||
95
pkg/metrics/prometheus/server.go
Normal file
95
pkg/metrics/prometheus/server.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package prometheus
|
||||
|
||||
import (
|
||||
"github.com/fatedier/frp/server/metrics"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
const (
|
||||
namespace = "frp"
|
||||
serverSubsystem = "server"
|
||||
)
|
||||
|
||||
var ServerMetrics metrics.ServerMetrics = newServerMetrics()
|
||||
|
||||
type serverMetrics struct {
|
||||
clientCount prometheus.Gauge
|
||||
proxyCount *prometheus.GaugeVec
|
||||
connectionCount *prometheus.GaugeVec
|
||||
trafficIn *prometheus.CounterVec
|
||||
trafficOut *prometheus.CounterVec
|
||||
}
|
||||
|
||||
func (m *serverMetrics) NewClient() {
|
||||
m.clientCount.Inc()
|
||||
}
|
||||
|
||||
func (m *serverMetrics) CloseClient() {
|
||||
m.clientCount.Dec()
|
||||
}
|
||||
|
||||
func (m *serverMetrics) NewProxy(name string, proxyType string) {
|
||||
m.proxyCount.WithLabelValues(proxyType).Inc()
|
||||
}
|
||||
|
||||
func (m *serverMetrics) CloseProxy(name string, proxyType string) {
|
||||
m.proxyCount.WithLabelValues(proxyType).Dec()
|
||||
}
|
||||
|
||||
func (m *serverMetrics) OpenConnection(name string, proxyType string) {
|
||||
m.connectionCount.WithLabelValues(name, proxyType).Inc()
|
||||
}
|
||||
|
||||
func (m *serverMetrics) CloseConnection(name string, proxyType string) {
|
||||
m.connectionCount.WithLabelValues(name, proxyType).Dec()
|
||||
}
|
||||
|
||||
func (m *serverMetrics) AddTrafficIn(name string, proxyType string, trafficBytes int64) {
|
||||
m.trafficIn.WithLabelValues(name, proxyType).Add(float64(trafficBytes))
|
||||
}
|
||||
|
||||
func (m *serverMetrics) AddTrafficOut(name string, proxyType string, trafficBytes int64) {
|
||||
m.trafficOut.WithLabelValues(name, proxyType).Add(float64(trafficBytes))
|
||||
}
|
||||
|
||||
func newServerMetrics() *serverMetrics {
|
||||
m := &serverMetrics{
|
||||
clientCount: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: serverSubsystem,
|
||||
Name: "client_counts",
|
||||
Help: "The current client counts of frps",
|
||||
}),
|
||||
proxyCount: prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: serverSubsystem,
|
||||
Name: "proxy_counts",
|
||||
Help: "The current proxy counts",
|
||||
}, []string{"type"}),
|
||||
connectionCount: prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: serverSubsystem,
|
||||
Name: "connection_counts",
|
||||
Help: "The current connection counts",
|
||||
}, []string{"name", "type"}),
|
||||
trafficIn: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: serverSubsystem,
|
||||
Name: "traffic_in",
|
||||
Help: "The total in traffic",
|
||||
}, []string{"name", "type"}),
|
||||
trafficOut: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: serverSubsystem,
|
||||
Name: "traffic_out",
|
||||
Help: "The total out traffic",
|
||||
}, []string{"name", "type"}),
|
||||
}
|
||||
prometheus.MustRegister(m.clientCount)
|
||||
prometheus.MustRegister(m.proxyCount)
|
||||
prometheus.MustRegister(m.connectionCount)
|
||||
prometheus.MustRegister(m.trafficIn)
|
||||
prometheus.MustRegister(m.trafficOut)
|
||||
return m
|
||||
}
|
||||
46
pkg/msg/ctl.go
Normal file
46
pkg/msg/ctl.go
Normal file
@@ -0,0 +1,46 @@
|
||||
// Copyright 2018 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
jsonMsg "github.com/fatedier/golib/msg/json"
|
||||
)
|
||||
|
||||
type Message = jsonMsg.Message
|
||||
|
||||
var (
|
||||
msgCtl *jsonMsg.MsgCtl
|
||||
)
|
||||
|
||||
func init() {
|
||||
msgCtl = jsonMsg.NewMsgCtl()
|
||||
for typeByte, msg := range msgTypeMap {
|
||||
msgCtl.RegisterMsg(typeByte, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func ReadMsg(c io.Reader) (msg Message, err error) {
|
||||
return msgCtl.ReadMsg(c)
|
||||
}
|
||||
|
||||
func ReadMsgInto(c io.Reader, msg Message) (err error) {
|
||||
return msgCtl.ReadMsgInto(c, msg)
|
||||
}
|
||||
|
||||
func WriteMsg(c io.Writer, msg interface{}) (err error) {
|
||||
return msgCtl.WriteMsg(c, msg)
|
||||
}
|
||||
194
pkg/msg/msg.go
Normal file
194
pkg/msg/msg.go
Normal file
@@ -0,0 +1,194 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg
|
||||
|
||||
import "net"
|
||||
|
||||
const (
|
||||
TypeLogin = 'o'
|
||||
TypeLoginResp = '1'
|
||||
TypeNewProxy = 'p'
|
||||
TypeNewProxyResp = '2'
|
||||
TypeCloseProxy = 'c'
|
||||
TypeNewWorkConn = 'w'
|
||||
TypeReqWorkConn = 'r'
|
||||
TypeStartWorkConn = 's'
|
||||
TypeNewVisitorConn = 'v'
|
||||
TypeNewVisitorConnResp = '3'
|
||||
TypePing = 'h'
|
||||
TypePong = '4'
|
||||
TypeUDPPacket = 'u'
|
||||
TypeNatHoleVisitor = 'i'
|
||||
TypeNatHoleClient = 'n'
|
||||
TypeNatHoleResp = 'm'
|
||||
TypeNatHoleClientDetectOK = 'd'
|
||||
TypeNatHoleSid = '5'
|
||||
)
|
||||
|
||||
var (
|
||||
msgTypeMap = map[byte]interface{}{
|
||||
TypeLogin: Login{},
|
||||
TypeLoginResp: LoginResp{},
|
||||
TypeNewProxy: NewProxy{},
|
||||
TypeNewProxyResp: NewProxyResp{},
|
||||
TypeCloseProxy: CloseProxy{},
|
||||
TypeNewWorkConn: NewWorkConn{},
|
||||
TypeReqWorkConn: ReqWorkConn{},
|
||||
TypeStartWorkConn: StartWorkConn{},
|
||||
TypeNewVisitorConn: NewVisitorConn{},
|
||||
TypeNewVisitorConnResp: NewVisitorConnResp{},
|
||||
TypePing: Ping{},
|
||||
TypePong: Pong{},
|
||||
TypeUDPPacket: UDPPacket{},
|
||||
TypeNatHoleVisitor: NatHoleVisitor{},
|
||||
TypeNatHoleClient: NatHoleClient{},
|
||||
TypeNatHoleResp: NatHoleResp{},
|
||||
TypeNatHoleClientDetectOK: NatHoleClientDetectOK{},
|
||||
TypeNatHoleSid: NatHoleSid{},
|
||||
}
|
||||
)
|
||||
|
||||
// When frpc start, client send this message to login to server.
|
||||
type Login struct {
|
||||
Version string `json:"version"`
|
||||
Hostname string `json:"hostname"`
|
||||
Os string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
User string `json:"user"`
|
||||
PrivilegeKey string `json:"privilege_key"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
RunID string `json:"run_id"`
|
||||
Metas map[string]string `json:"metas"`
|
||||
|
||||
// Some global configures.
|
||||
PoolCount int `json:"pool_count"`
|
||||
}
|
||||
|
||||
type LoginResp struct {
|
||||
Version string `json:"version"`
|
||||
RunID string `json:"run_id"`
|
||||
ServerUDPPort int `json:"server_udp_port"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// When frpc login success, send this message to frps for running a new proxy.
|
||||
type NewProxy struct {
|
||||
ProxyName string `json:"proxy_name"`
|
||||
ProxyType string `json:"proxy_type"`
|
||||
UseEncryption bool `json:"use_encryption"`
|
||||
UseCompression bool `json:"use_compression"`
|
||||
Group string `json:"group"`
|
||||
GroupKey string `json:"group_key"`
|
||||
Metas map[string]string `json:"metas"`
|
||||
|
||||
// tcp and udp only
|
||||
RemotePort int `json:"remote_port"`
|
||||
|
||||
// http and https only
|
||||
CustomDomains []string `json:"custom_domains"`
|
||||
SubDomain string `json:"subdomain"`
|
||||
Locations []string `json:"locations"`
|
||||
HTTPUser string `json:"http_user"`
|
||||
HTTPPwd string `json:"http_pwd"`
|
||||
HostHeaderRewrite string `json:"host_header_rewrite"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
|
||||
// stcp
|
||||
Sk string `json:"sk"`
|
||||
|
||||
// tcpmux
|
||||
Multiplexer string `json:"multiplexer"`
|
||||
}
|
||||
|
||||
type NewProxyResp struct {
|
||||
ProxyName string `json:"proxy_name"`
|
||||
RemoteAddr string `json:"remote_addr"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type CloseProxy struct {
|
||||
ProxyName string `json:"proxy_name"`
|
||||
}
|
||||
|
||||
type NewWorkConn struct {
|
||||
RunID string `json:"run_id"`
|
||||
PrivilegeKey string `json:"privilege_key"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
type ReqWorkConn struct {
|
||||
}
|
||||
|
||||
type StartWorkConn struct {
|
||||
ProxyName string `json:"proxy_name"`
|
||||
SrcAddr string `json:"src_addr"`
|
||||
DstAddr string `json:"dst_addr"`
|
||||
SrcPort uint16 `json:"src_port"`
|
||||
DstPort uint16 `json:"dst_port"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type NewVisitorConn struct {
|
||||
ProxyName string `json:"proxy_name"`
|
||||
SignKey string `json:"sign_key"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
UseEncryption bool `json:"use_encryption"`
|
||||
UseCompression bool `json:"use_compression"`
|
||||
}
|
||||
|
||||
type NewVisitorConnResp struct {
|
||||
ProxyName string `json:"proxy_name"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type Ping struct {
|
||||
PrivilegeKey string `json:"privilege_key"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
type Pong struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type UDPPacket struct {
|
||||
Content string `json:"c"`
|
||||
LocalAddr *net.UDPAddr `json:"l"`
|
||||
RemoteAddr *net.UDPAddr `json:"r"`
|
||||
}
|
||||
|
||||
type NatHoleVisitor struct {
|
||||
ProxyName string `json:"proxy_name"`
|
||||
SignKey string `json:"sign_key"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
type NatHoleClient struct {
|
||||
ProxyName string `json:"proxy_name"`
|
||||
Sid string `json:"sid"`
|
||||
}
|
||||
|
||||
type NatHoleResp struct {
|
||||
Sid string `json:"sid"`
|
||||
VisitorAddr string `json:"visitor_addr"`
|
||||
ClientAddr string `json:"client_addr"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type NatHoleClientDetectOK struct {
|
||||
}
|
||||
|
||||
type NatHoleSid struct {
|
||||
Sid string `json:"sid"`
|
||||
}
|
||||
212
pkg/nathole/nathole.go
Normal file
212
pkg/nathole/nathole.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package nathole
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
"github.com/fatedier/golib/pool"
|
||||
)
|
||||
|
||||
// Timeout seconds.
|
||||
var NatHoleTimeout int64 = 10
|
||||
|
||||
type SidRequest struct {
|
||||
Sid string
|
||||
NotifyCh chan struct{}
|
||||
}
|
||||
|
||||
type Controller struct {
|
||||
listener *net.UDPConn
|
||||
|
||||
clientCfgs map[string]*ClientCfg
|
||||
sessions map[string]*Session
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewController(udpBindAddr string) (nc *Controller, err error) {
|
||||
addr, err := net.ResolveUDPAddr("udp", udpBindAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lconn, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nc = &Controller{
|
||||
listener: lconn,
|
||||
clientCfgs: make(map[string]*ClientCfg),
|
||||
sessions: make(map[string]*Session),
|
||||
}
|
||||
return nc, nil
|
||||
}
|
||||
|
||||
func (nc *Controller) ListenClient(name string, sk string) (sidCh chan *SidRequest) {
|
||||
clientCfg := &ClientCfg{
|
||||
Name: name,
|
||||
Sk: sk,
|
||||
SidCh: make(chan *SidRequest),
|
||||
}
|
||||
nc.mu.Lock()
|
||||
nc.clientCfgs[name] = clientCfg
|
||||
nc.mu.Unlock()
|
||||
return clientCfg.SidCh
|
||||
}
|
||||
|
||||
func (nc *Controller) CloseClient(name string) {
|
||||
nc.mu.Lock()
|
||||
defer nc.mu.Unlock()
|
||||
delete(nc.clientCfgs, name)
|
||||
}
|
||||
|
||||
func (nc *Controller) Run() {
|
||||
for {
|
||||
buf := pool.GetBuf(1024)
|
||||
n, raddr, err := nc.listener.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
log.Trace("nat hole listener read from udp error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
rd := bytes.NewReader(buf[:n])
|
||||
rawMsg, err := msg.ReadMsg(rd)
|
||||
if err != nil {
|
||||
log.Trace("read nat hole message error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
switch m := rawMsg.(type) {
|
||||
case *msg.NatHoleVisitor:
|
||||
go nc.HandleVisitor(m, raddr)
|
||||
case *msg.NatHoleClient:
|
||||
go nc.HandleClient(m, raddr)
|
||||
default:
|
||||
log.Trace("error nat hole message type")
|
||||
continue
|
||||
}
|
||||
pool.PutBuf(buf)
|
||||
}
|
||||
}
|
||||
|
||||
func (nc *Controller) GenSid() string {
|
||||
t := time.Now().Unix()
|
||||
id, _ := util.RandID()
|
||||
return fmt.Sprintf("%d%s", t, id)
|
||||
}
|
||||
|
||||
func (nc *Controller) HandleVisitor(m *msg.NatHoleVisitor, raddr *net.UDPAddr) {
|
||||
sid := nc.GenSid()
|
||||
session := &Session{
|
||||
Sid: sid,
|
||||
VisitorAddr: raddr,
|
||||
NotifyCh: make(chan struct{}, 0),
|
||||
}
|
||||
nc.mu.Lock()
|
||||
clientCfg, ok := nc.clientCfgs[m.ProxyName]
|
||||
if !ok {
|
||||
nc.mu.Unlock()
|
||||
errInfo := fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName)
|
||||
log.Debug(errInfo)
|
||||
nc.listener.WriteToUDP(nc.GenNatHoleResponse(nil, errInfo), raddr)
|
||||
return
|
||||
}
|
||||
if m.SignKey != util.GetAuthKey(clientCfg.Sk, m.Timestamp) {
|
||||
nc.mu.Unlock()
|
||||
errInfo := fmt.Sprintf("xtcp connection of [%s] auth failed", m.ProxyName)
|
||||
log.Debug(errInfo)
|
||||
nc.listener.WriteToUDP(nc.GenNatHoleResponse(nil, errInfo), raddr)
|
||||
return
|
||||
}
|
||||
|
||||
nc.sessions[sid] = session
|
||||
nc.mu.Unlock()
|
||||
log.Trace("handle visitor message, sid [%s]", sid)
|
||||
|
||||
defer func() {
|
||||
nc.mu.Lock()
|
||||
delete(nc.sessions, sid)
|
||||
nc.mu.Unlock()
|
||||
}()
|
||||
|
||||
err := errors.PanicToError(func() {
|
||||
clientCfg.SidCh <- &SidRequest{
|
||||
Sid: sid,
|
||||
NotifyCh: session.NotifyCh,
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Wait client connections.
|
||||
select {
|
||||
case <-session.NotifyCh:
|
||||
resp := nc.GenNatHoleResponse(session, "")
|
||||
log.Trace("send nat hole response to visitor")
|
||||
nc.listener.WriteToUDP(resp, raddr)
|
||||
case <-time.After(time.Duration(NatHoleTimeout) * time.Second):
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (nc *Controller) HandleClient(m *msg.NatHoleClient, raddr *net.UDPAddr) {
|
||||
nc.mu.RLock()
|
||||
session, ok := nc.sessions[m.Sid]
|
||||
nc.mu.RUnlock()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Trace("handle client message, sid [%s]", session.Sid)
|
||||
session.ClientAddr = raddr
|
||||
|
||||
resp := nc.GenNatHoleResponse(session, "")
|
||||
log.Trace("send nat hole response to client")
|
||||
nc.listener.WriteToUDP(resp, raddr)
|
||||
}
|
||||
|
||||
func (nc *Controller) GenNatHoleResponse(session *Session, errInfo string) []byte {
|
||||
var (
|
||||
sid string
|
||||
visitorAddr string
|
||||
clientAddr string
|
||||
)
|
||||
if session != nil {
|
||||
sid = session.Sid
|
||||
visitorAddr = session.VisitorAddr.String()
|
||||
clientAddr = session.ClientAddr.String()
|
||||
}
|
||||
m := &msg.NatHoleResp{
|
||||
Sid: sid,
|
||||
VisitorAddr: visitorAddr,
|
||||
ClientAddr: clientAddr,
|
||||
Error: errInfo,
|
||||
}
|
||||
b := bytes.NewBuffer(nil)
|
||||
err := msg.WriteMsg(b, m)
|
||||
if err != nil {
|
||||
return []byte("")
|
||||
}
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
Sid string
|
||||
VisitorAddr *net.UDPAddr
|
||||
ClientAddr *net.UDPAddr
|
||||
|
||||
NotifyCh chan struct{}
|
||||
}
|
||||
|
||||
type ClientCfg struct {
|
||||
Name string
|
||||
Sk string
|
||||
SidCh chan *SidRequest
|
||||
}
|
||||
111
pkg/plugin/client/http2https.go
Normal file
111
pkg/plugin/client/http2https.go
Normal file
@@ -0,0 +1,111 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
|
||||
frpNet "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
const PluginHTTP2HTTPS = "http2https"
|
||||
|
||||
func init() {
|
||||
Register(PluginHTTP2HTTPS, NewHTTP2HTTPSPlugin)
|
||||
}
|
||||
|
||||
type HTTP2HTTPSPlugin struct {
|
||||
hostHeaderRewrite string
|
||||
localAddr string
|
||||
headers map[string]string
|
||||
|
||||
l *Listener
|
||||
s *http.Server
|
||||
}
|
||||
|
||||
func NewHTTP2HTTPSPlugin(params map[string]string) (Plugin, error) {
|
||||
localAddr := params["plugin_local_addr"]
|
||||
hostHeaderRewrite := params["plugin_host_header_rewrite"]
|
||||
headers := make(map[string]string)
|
||||
for k, v := range params {
|
||||
if !strings.HasPrefix(k, "plugin_header_") {
|
||||
continue
|
||||
}
|
||||
if k = strings.TrimPrefix(k, "plugin_header_"); k != "" {
|
||||
headers[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if localAddr == "" {
|
||||
return nil, fmt.Errorf("plugin_local_addr is required")
|
||||
}
|
||||
|
||||
listener := NewProxyListener()
|
||||
|
||||
p := &HTTPS2HTTPPlugin{
|
||||
localAddr: localAddr,
|
||||
hostHeaderRewrite: hostHeaderRewrite,
|
||||
headers: headers,
|
||||
l: listener,
|
||||
}
|
||||
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
req.URL.Scheme = "https"
|
||||
req.URL.Host = p.localAddr
|
||||
if p.hostHeaderRewrite != "" {
|
||||
req.Host = p.hostHeaderRewrite
|
||||
}
|
||||
for k, v := range p.headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
},
|
||||
Transport: tr,
|
||||
}
|
||||
|
||||
p.s = &http.Server{
|
||||
Handler: rp,
|
||||
}
|
||||
|
||||
go p.s.Serve(listener)
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *HTTP2HTTPSPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) {
|
||||
wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
|
||||
p.l.PutConn(wrapConn)
|
||||
}
|
||||
|
||||
func (p *HTTP2HTTPSPlugin) Name() string {
|
||||
return PluginHTTP2HTTPS
|
||||
}
|
||||
|
||||
func (p *HTTP2HTTPSPlugin) Close() error {
|
||||
if err := p.s.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
244
pkg/plugin/client/http_proxy.go
Normal file
244
pkg/plugin/client/http_proxy.go
Normal file
@@ -0,0 +1,244 @@
|
||||
// Copyright 2017 frp team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
frpNet "github.com/fatedier/frp/pkg/util/net"
|
||||
|
||||
frpIo "github.com/fatedier/golib/io"
|
||||
gnet "github.com/fatedier/golib/net"
|
||||
)
|
||||
|
||||
const PluginHTTPProxy = "http_proxy"
|
||||
|
||||
func init() {
|
||||
Register(PluginHTTPProxy, NewHTTPProxyPlugin)
|
||||
}
|
||||
|
||||
type HTTPProxy struct {
|
||||
l *Listener
|
||||
s *http.Server
|
||||
AuthUser string
|
||||
AuthPasswd string
|
||||
}
|
||||
|
||||
func NewHTTPProxyPlugin(params map[string]string) (Plugin, error) {
|
||||
user := params["plugin_http_user"]
|
||||
passwd := params["plugin_http_passwd"]
|
||||
listener := NewProxyListener()
|
||||
|
||||
hp := &HTTPProxy{
|
||||
l: listener,
|
||||
AuthUser: user,
|
||||
AuthPasswd: passwd,
|
||||
}
|
||||
|
||||
hp.s = &http.Server{
|
||||
Handler: hp,
|
||||
}
|
||||
|
||||
go hp.s.Serve(listener)
|
||||
return hp, nil
|
||||
}
|
||||
|
||||
func (hp *HTTPProxy) Name() string {
|
||||
return PluginHTTPProxy
|
||||
}
|
||||
|
||||
func (hp *HTTPProxy) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) {
|
||||
wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
|
||||
|
||||
sc, rd := gnet.NewSharedConn(wrapConn)
|
||||
firstBytes := make([]byte, 7)
|
||||
_, err := rd.Read(firstBytes)
|
||||
if err != nil {
|
||||
wrapConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if strings.ToUpper(string(firstBytes)) == "CONNECT" {
|
||||
bufRd := bufio.NewReader(sc)
|
||||
request, err := http.ReadRequest(bufRd)
|
||||
if err != nil {
|
||||
wrapConn.Close()
|
||||
return
|
||||
}
|
||||
hp.handleConnectReq(request, frpIo.WrapReadWriteCloser(bufRd, wrapConn, wrapConn.Close))
|
||||
return
|
||||
}
|
||||
|
||||
hp.l.PutConn(sc)
|
||||
return
|
||||
}
|
||||
|
||||
func (hp *HTTPProxy) Close() error {
|
||||
hp.s.Close()
|
||||
hp.l.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hp *HTTPProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if ok := hp.Auth(req); !ok {
|
||||
rw.Header().Set("Proxy-Authenticate", "Basic")
|
||||
rw.WriteHeader(http.StatusProxyAuthRequired)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Method == http.MethodConnect {
|
||||
// deprecated
|
||||
// Connect request is handled in Handle function.
|
||||
hp.ConnectHandler(rw, req)
|
||||
} else {
|
||||
hp.HTTPHandler(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (hp *HTTPProxy) HTTPHandler(rw http.ResponseWriter, req *http.Request) {
|
||||
removeProxyHeaders(req)
|
||||
|
||||
resp, err := http.DefaultTransport.RoundTrip(req)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
copyHeaders(rw.Header(), resp.Header)
|
||||
rw.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(rw, resp.Body)
|
||||
if err != nil && err != io.EOF {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// deprecated
|
||||
// Hijack needs to SetReadDeadline on the Conn of the request, but if we use stream compression here,
|
||||
// we may always get i/o timeout error.
|
||||
func (hp *HTTPProxy) ConnectHandler(rw http.ResponseWriter, req *http.Request) {
|
||||
hj, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
client, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
remote, err := net.Dial("tcp", req.URL.Host)
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed", http.StatusBadRequest)
|
||||
client.Close()
|
||||
return
|
||||
}
|
||||
client.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
|
||||
|
||||
go frpIo.Join(remote, client)
|
||||
}
|
||||
|
||||
func (hp *HTTPProxy) Auth(req *http.Request) bool {
|
||||
if hp.AuthUser == "" && hp.AuthPasswd == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
s := strings.SplitN(req.Header.Get("Proxy-Authorization"), " ", 2)
|
||||
if len(s) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
b, err := base64.StdEncoding.DecodeString(s[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pair := strings.SplitN(string(b), ":", 2)
|
||||
if len(pair) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
if pair[0] != hp.AuthUser || pair[1] != hp.AuthPasswd {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (hp *HTTPProxy) handleConnectReq(req *http.Request, rwc io.ReadWriteCloser) {
|
||||
defer rwc.Close()
|
||||
if ok := hp.Auth(req); !ok {
|
||||
res := getBadResponse()
|
||||
res.Write(rwc)
|
||||
return
|
||||
}
|
||||
|
||||
remote, err := net.Dial("tcp", req.URL.Host)
|
||||
if err != nil {
|
||||
res := &http.Response{
|
||||
StatusCode: 400,
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
}
|
||||
res.Write(rwc)
|
||||
return
|
||||
}
|
||||
rwc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
|
||||
|
||||
frpIo.Join(remote, rwc)
|
||||
}
|
||||
|
||||
func copyHeaders(dst, src http.Header) {
|
||||
for key, values := range src {
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func removeProxyHeaders(req *http.Request) {
|
||||
req.RequestURI = ""
|
||||
req.Header.Del("Proxy-Connection")
|
||||
req.Header.Del("Connection")
|
||||
req.Header.Del("Proxy-Authenticate")
|
||||
req.Header.Del("Proxy-Authorization")
|
||||
req.Header.Del("TE")
|
||||
req.Header.Del("Trailers")
|
||||
req.Header.Del("Transfer-Encoding")
|
||||
req.Header.Del("Upgrade")
|
||||
}
|
||||
|
||||
func getBadResponse() *http.Response {
|
||||
header := make(map[string][]string)
|
||||
header["Proxy-Authenticate"] = []string{"Basic"}
|
||||
header["Connection"] = []string{"close"}
|
||||
res := &http.Response{
|
||||
Status: "407 Not authorized",
|
||||
StatusCode: 407,
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: header,
|
||||
}
|
||||
return res
|
||||
}
|
||||
133
pkg/plugin/client/https2http.go
Normal file
133
pkg/plugin/client/https2http.go
Normal file
@@ -0,0 +1,133 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
|
||||
frpNet "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
const PluginHTTPS2HTTP = "https2http"
|
||||
|
||||
func init() {
|
||||
Register(PluginHTTPS2HTTP, NewHTTPS2HTTPPlugin)
|
||||
}
|
||||
|
||||
type HTTPS2HTTPPlugin struct {
|
||||
crtPath string
|
||||
keyPath string
|
||||
hostHeaderRewrite string
|
||||
localAddr string
|
||||
headers map[string]string
|
||||
|
||||
l *Listener
|
||||
s *http.Server
|
||||
}
|
||||
|
||||
func NewHTTPS2HTTPPlugin(params map[string]string) (Plugin, error) {
|
||||
crtPath := params["plugin_crt_path"]
|
||||
keyPath := params["plugin_key_path"]
|
||||
localAddr := params["plugin_local_addr"]
|
||||
hostHeaderRewrite := params["plugin_host_header_rewrite"]
|
||||
headers := make(map[string]string)
|
||||
for k, v := range params {
|
||||
if !strings.HasPrefix(k, "plugin_header_") {
|
||||
continue
|
||||
}
|
||||
if k = strings.TrimPrefix(k, "plugin_header_"); k != "" {
|
||||
headers[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if crtPath == "" {
|
||||
return nil, fmt.Errorf("plugin_crt_path is required")
|
||||
}
|
||||
if keyPath == "" {
|
||||
return nil, fmt.Errorf("plugin_key_path is required")
|
||||
}
|
||||
if localAddr == "" {
|
||||
return nil, fmt.Errorf("plugin_local_addr is required")
|
||||
}
|
||||
|
||||
listener := NewProxyListener()
|
||||
|
||||
p := &HTTPS2HTTPPlugin{
|
||||
crtPath: crtPath,
|
||||
keyPath: keyPath,
|
||||
localAddr: localAddr,
|
||||
hostHeaderRewrite: hostHeaderRewrite,
|
||||
headers: headers,
|
||||
l: listener,
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = p.localAddr
|
||||
if p.hostHeaderRewrite != "" {
|
||||
req.Host = p.hostHeaderRewrite
|
||||
}
|
||||
for k, v := range p.headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
p.s = &http.Server{
|
||||
Handler: rp,
|
||||
}
|
||||
|
||||
tlsConfig, err := p.genTLSConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gen TLS config error: %v", err)
|
||||
}
|
||||
ln := tls.NewListener(listener, tlsConfig)
|
||||
|
||||
go p.s.Serve(ln)
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPPlugin) genTLSConfig() (*tls.Config, error) {
|
||||
cert, err := tls.LoadX509KeyPair(p.crtPath, p.keyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config := &tls.Config{Certificates: []tls.Certificate{cert}}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) {
|
||||
wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
|
||||
p.l.PutConn(wrapConn)
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPPlugin) Name() string {
|
||||
return PluginHTTPS2HTTP
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPPlugin) Close() error {
|
||||
if err := p.s.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
92
pkg/plugin/client/plugin.go
Normal file
92
pkg/plugin/client/plugin.go
Normal file
@@ -0,0 +1,92 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
)
|
||||
|
||||
// Creators is used for create plugins to handle connections.
|
||||
var creators = make(map[string]CreatorFn)
|
||||
|
||||
// params has prefix "plugin_"
|
||||
type CreatorFn func(params map[string]string) (Plugin, error)
|
||||
|
||||
func Register(name string, fn CreatorFn) {
|
||||
creators[name] = fn
|
||||
}
|
||||
|
||||
func Create(name string, params map[string]string) (p Plugin, err error) {
|
||||
if fn, ok := creators[name]; ok {
|
||||
p, err = fn(params)
|
||||
} else {
|
||||
err = fmt.Errorf("plugin [%s] is not registered", name)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type Plugin interface {
|
||||
Name() string
|
||||
|
||||
// extraBufToLocal will send to local connection first, then join conn with local connection
|
||||
Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
conns chan net.Conn
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewProxyListener() *Listener {
|
||||
return &Listener{
|
||||
conns: make(chan net.Conn, 64),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (net.Conn, error) {
|
||||
conn, ok := <-l.conns
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("listener closed")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (l *Listener) PutConn(conn net.Conn) error {
|
||||
err := errors.PanicToError(func() {
|
||||
l.conns <- conn
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if !l.closed {
|
||||
close(l.conns)
|
||||
l.closed = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) Addr() net.Addr {
|
||||
return (*net.TCPAddr)(nil)
|
||||
}
|
||||
69
pkg/plugin/client/socks5.go
Normal file
69
pkg/plugin/client/socks5.go
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
|
||||
frpNet "github.com/fatedier/frp/pkg/util/net"
|
||||
|
||||
gosocks5 "github.com/armon/go-socks5"
|
||||
)
|
||||
|
||||
const PluginSocks5 = "socks5"
|
||||
|
||||
func init() {
|
||||
Register(PluginSocks5, NewSocks5Plugin)
|
||||
}
|
||||
|
||||
type Socks5Plugin struct {
|
||||
Server *gosocks5.Server
|
||||
|
||||
user string
|
||||
passwd string
|
||||
}
|
||||
|
||||
func NewSocks5Plugin(params map[string]string) (p Plugin, err error) {
|
||||
user := params["plugin_user"]
|
||||
passwd := params["plugin_passwd"]
|
||||
|
||||
cfg := &gosocks5.Config{
|
||||
Logger: log.New(ioutil.Discard, "", log.LstdFlags),
|
||||
}
|
||||
if user != "" || passwd != "" {
|
||||
cfg.Credentials = gosocks5.StaticCredentials(map[string]string{user: passwd})
|
||||
}
|
||||
sp := &Socks5Plugin{}
|
||||
sp.Server, err = gosocks5.New(cfg)
|
||||
p = sp
|
||||
return
|
||||
}
|
||||
|
||||
func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) {
|
||||
defer conn.Close()
|
||||
wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
|
||||
sp.Server.ServeConn(wrapConn)
|
||||
}
|
||||
|
||||
func (sp *Socks5Plugin) Name() string {
|
||||
return PluginSocks5
|
||||
}
|
||||
|
||||
func (sp *Socks5Plugin) Close() error {
|
||||
return nil
|
||||
}
|
||||
89
pkg/plugin/client/static_file.go
Normal file
89
pkg/plugin/client/static_file.go
Normal file
@@ -0,0 +1,89 @@
|
||||
// Copyright 2018 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
frpNet "github.com/fatedier/frp/pkg/util/net"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
const PluginStaticFile = "static_file"
|
||||
|
||||
func init() {
|
||||
Register(PluginStaticFile, NewStaticFilePlugin)
|
||||
}
|
||||
|
||||
type StaticFilePlugin struct {
|
||||
localPath string
|
||||
stripPrefix string
|
||||
httpUser string
|
||||
httpPasswd string
|
||||
|
||||
l *Listener
|
||||
s *http.Server
|
||||
}
|
||||
|
||||
func NewStaticFilePlugin(params map[string]string) (Plugin, error) {
|
||||
localPath := params["plugin_local_path"]
|
||||
stripPrefix := params["plugin_strip_prefix"]
|
||||
httpUser := params["plugin_http_user"]
|
||||
httpPasswd := params["plugin_http_passwd"]
|
||||
|
||||
listener := NewProxyListener()
|
||||
|
||||
sp := &StaticFilePlugin{
|
||||
localPath: localPath,
|
||||
stripPrefix: stripPrefix,
|
||||
httpUser: httpUser,
|
||||
httpPasswd: httpPasswd,
|
||||
|
||||
l: listener,
|
||||
}
|
||||
var prefix string
|
||||
if stripPrefix != "" {
|
||||
prefix = "/" + stripPrefix + "/"
|
||||
} else {
|
||||
prefix = "/"
|
||||
}
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.Use(frpNet.NewHTTPAuthMiddleware(httpUser, httpPasswd).Middleware)
|
||||
router.PathPrefix(prefix).Handler(frpNet.MakeHTTPGzipHandler(http.StripPrefix(prefix, http.FileServer(http.Dir(localPath))))).Methods("GET")
|
||||
sp.s = &http.Server{
|
||||
Handler: router,
|
||||
}
|
||||
go sp.s.Serve(listener)
|
||||
return sp, nil
|
||||
}
|
||||
|
||||
func (sp *StaticFilePlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) {
|
||||
wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
|
||||
sp.l.PutConn(wrapConn)
|
||||
}
|
||||
|
||||
func (sp *StaticFilePlugin) Name() string {
|
||||
return PluginStaticFile
|
||||
}
|
||||
|
||||
func (sp *StaticFilePlugin) Close() error {
|
||||
sp.s.Close()
|
||||
sp.l.Close()
|
||||
return nil
|
||||
}
|
||||
72
pkg/plugin/client/unix_domain_socket.go
Normal file
72
pkg/plugin/client/unix_domain_socket.go
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
frpIo "github.com/fatedier/golib/io"
|
||||
)
|
||||
|
||||
const PluginUnixDomainSocket = "unix_domain_socket"
|
||||
|
||||
func init() {
|
||||
Register(PluginUnixDomainSocket, NewUnixDomainSocketPlugin)
|
||||
}
|
||||
|
||||
type UnixDomainSocketPlugin struct {
|
||||
UnixAddr *net.UnixAddr
|
||||
}
|
||||
|
||||
func NewUnixDomainSocketPlugin(params map[string]string) (p Plugin, err error) {
|
||||
unixPath, ok := params["plugin_unix_path"]
|
||||
if !ok {
|
||||
err = fmt.Errorf("plugin_unix_path not found")
|
||||
return
|
||||
}
|
||||
|
||||
unixAddr, errRet := net.ResolveUnixAddr("unix", unixPath)
|
||||
if errRet != nil {
|
||||
err = errRet
|
||||
return
|
||||
}
|
||||
|
||||
p = &UnixDomainSocketPlugin{
|
||||
UnixAddr: unixAddr,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (uds *UnixDomainSocketPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) {
|
||||
localConn, err := net.DialUnix("unix", nil, uds.UnixAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(extraBufToLocal) > 0 {
|
||||
localConn.Write(extraBufToLocal)
|
||||
}
|
||||
|
||||
frpIo.Join(localConn, conn)
|
||||
}
|
||||
|
||||
func (uds *UnixDomainSocketPlugin) Name() string {
|
||||
return PluginUnixDomainSocket
|
||||
}
|
||||
|
||||
func (uds *UnixDomainSocketPlugin) Close() error {
|
||||
return nil
|
||||
}
|
||||
109
pkg/plugin/server/http.go
Normal file
109
pkg/plugin/server/http.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type HTTPPluginOptions struct {
|
||||
Name string
|
||||
Addr string
|
||||
Path string
|
||||
Ops []string
|
||||
}
|
||||
|
||||
type httpPlugin struct {
|
||||
options HTTPPluginOptions
|
||||
|
||||
url string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewHTTPPluginOptions(options HTTPPluginOptions) Plugin {
|
||||
return &httpPlugin{
|
||||
options: options,
|
||||
url: fmt.Sprintf("http://%s%s", options.Addr, options.Path),
|
||||
client: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *httpPlugin) Name() string {
|
||||
return p.options.Name
|
||||
}
|
||||
|
||||
func (p *httpPlugin) IsSupport(op string) bool {
|
||||
for _, v := range p.options.Ops {
|
||||
if v == op {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *httpPlugin) Handle(ctx context.Context, op string, content interface{}) (*Response, interface{}, error) {
|
||||
r := &Request{
|
||||
Version: APIVersion,
|
||||
Op: op,
|
||||
Content: content,
|
||||
}
|
||||
var res Response
|
||||
res.Content = reflect.New(reflect.TypeOf(content)).Interface()
|
||||
if err := p.do(ctx, r, &res); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return &res, res.Content, nil
|
||||
}
|
||||
|
||||
func (p *httpPlugin) do(ctx context.Context, r *Request, res *Response) error {
|
||||
buf, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v := url.Values{}
|
||||
v.Set("version", r.Version)
|
||||
v.Set("op", r.Op)
|
||||
req, err := http.NewRequest("POST", p.url+"?"+v.Encode(), bytes.NewReader(buf))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("X-Frp-Reqid", GetReqidFromContext(ctx))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("do http request error code: %d", resp.StatusCode)
|
||||
}
|
||||
buf, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = json.Unmarshal(buf, res); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
230
pkg/plugin/server/manager.go
Normal file
230
pkg/plugin/server/manager.go
Normal file
@@ -0,0 +1,230 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
loginPlugins []Plugin
|
||||
newProxyPlugins []Plugin
|
||||
pingPlugins []Plugin
|
||||
newWorkConnPlugins []Plugin
|
||||
newUserConnPlugins []Plugin
|
||||
}
|
||||
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
loginPlugins: make([]Plugin, 0),
|
||||
newProxyPlugins: make([]Plugin, 0),
|
||||
pingPlugins: make([]Plugin, 0),
|
||||
newWorkConnPlugins: make([]Plugin, 0),
|
||||
newUserConnPlugins: make([]Plugin, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Register(p Plugin) {
|
||||
if p.IsSupport(OpLogin) {
|
||||
m.loginPlugins = append(m.loginPlugins, p)
|
||||
}
|
||||
if p.IsSupport(OpNewProxy) {
|
||||
m.newProxyPlugins = append(m.newProxyPlugins, p)
|
||||
}
|
||||
if p.IsSupport(OpPing) {
|
||||
m.pingPlugins = append(m.pingPlugins, p)
|
||||
}
|
||||
if p.IsSupport(OpNewWorkConn) {
|
||||
m.newWorkConnPlugins = append(m.newWorkConnPlugins, p)
|
||||
}
|
||||
if p.IsSupport(OpNewUserConn) {
|
||||
m.newUserConnPlugins = append(m.newUserConnPlugins, p)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Login(content *LoginContent) (*LoginContent, error) {
|
||||
if len(m.loginPlugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent interface{}
|
||||
err error
|
||||
)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
|
||||
for _, p := range m.loginPlugins {
|
||||
res, retContent, err = p.Handle(ctx, OpLogin, *content)
|
||||
if err != nil {
|
||||
xl.Warn("send Login request to plugin [%s] error: %v", p.Name(), err)
|
||||
return nil, errors.New("send Login request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
content = retContent.(*LoginContent)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (m *Manager) NewProxy(content *NewProxyContent) (*NewProxyContent, error) {
|
||||
if len(m.newProxyPlugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent interface{}
|
||||
err error
|
||||
)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
|
||||
for _, p := range m.newProxyPlugins {
|
||||
res, retContent, err = p.Handle(ctx, OpNewProxy, *content)
|
||||
if err != nil {
|
||||
xl.Warn("send NewProxy request to plugin [%s] error: %v", p.Name(), err)
|
||||
return nil, errors.New("send NewProxy request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
content = retContent.(*NewProxyContent)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (m *Manager) Ping(content *PingContent) (*PingContent, error) {
|
||||
if len(m.pingPlugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent interface{}
|
||||
err error
|
||||
)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
|
||||
for _, p := range m.pingPlugins {
|
||||
res, retContent, err = p.Handle(ctx, OpPing, *content)
|
||||
if err != nil {
|
||||
xl.Warn("send Ping request to plugin [%s] error: %v", p.Name(), err)
|
||||
return nil, errors.New("send Ping request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
content = retContent.(*PingContent)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (m *Manager) NewWorkConn(content *NewWorkConnContent) (*NewWorkConnContent, error) {
|
||||
if len(m.newWorkConnPlugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent interface{}
|
||||
err error
|
||||
)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
|
||||
for _, p := range m.pingPlugins {
|
||||
res, retContent, err = p.Handle(ctx, OpPing, *content)
|
||||
if err != nil {
|
||||
xl.Warn("send NewWorkConn request to plugin [%s] error: %v", p.Name(), err)
|
||||
return nil, errors.New("send NewWorkConn request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
content = retContent.(*NewWorkConnContent)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (m *Manager) NewUserConn(content *NewUserConnContent) (*NewUserConnContent, error) {
|
||||
if len(m.newUserConnPlugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent interface{}
|
||||
err error
|
||||
)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
|
||||
for _, p := range m.newUserConnPlugins {
|
||||
res, retContent, err = p.Handle(ctx, OpNewUserConn, *content)
|
||||
if err != nil {
|
||||
xl.Info("send NewUserConn request to plugin [%s] error: %v", p.Name(), err)
|
||||
return nil, errors.New("send NewUserConn request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
content = retContent.(*NewUserConnContent)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
35
pkg/plugin/server/plugin.go
Normal file
35
pkg/plugin/server/plugin.go
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const (
|
||||
APIVersion = "0.1.0"
|
||||
|
||||
OpLogin = "Login"
|
||||
OpNewProxy = "NewProxy"
|
||||
OpPing = "Ping"
|
||||
OpNewWorkConn = "NewWorkConn"
|
||||
OpNewUserConn = "NewUserConn"
|
||||
)
|
||||
|
||||
type Plugin interface {
|
||||
Name() string
|
||||
IsSupport(op string) bool
|
||||
Handle(ctx context.Context, op string, content interface{}) (res *Response, retContent interface{}, err error)
|
||||
}
|
||||
34
pkg/plugin/server/tracer.go
Normal file
34
pkg/plugin/server/tracer.go
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type key int
|
||||
|
||||
const (
|
||||
reqidKey key = 0
|
||||
)
|
||||
|
||||
func NewReqidContext(ctx context.Context, reqid string) context.Context {
|
||||
return context.WithValue(ctx, reqidKey, reqid)
|
||||
}
|
||||
|
||||
func GetReqidFromContext(ctx context.Context) string {
|
||||
ret, _ := ctx.Value(reqidKey).(string)
|
||||
return ret
|
||||
}
|
||||
64
pkg/plugin/server/types.go
Normal file
64
pkg/plugin/server/types.go
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
Version string `json:"version"`
|
||||
Op string `json:"op"`
|
||||
Content interface{} `json:"content"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Reject bool `json:"reject"`
|
||||
RejectReason string `json:"reject_reason"`
|
||||
Unchange bool `json:"unchange"`
|
||||
Content interface{} `json:"content"`
|
||||
}
|
||||
|
||||
type LoginContent struct {
|
||||
msg.Login
|
||||
}
|
||||
|
||||
type UserInfo struct {
|
||||
User string `json:"user"`
|
||||
Metas map[string]string `json:"metas"`
|
||||
RunID string `json:"run_id"`
|
||||
}
|
||||
|
||||
type NewProxyContent struct {
|
||||
User UserInfo `json:"user"`
|
||||
msg.NewProxy
|
||||
}
|
||||
|
||||
type PingContent struct {
|
||||
User UserInfo `json:"user"`
|
||||
msg.Ping
|
||||
}
|
||||
|
||||
type NewWorkConnContent struct {
|
||||
User UserInfo `json:"user"`
|
||||
msg.NewWorkConn
|
||||
}
|
||||
|
||||
type NewUserConnContent struct {
|
||||
User UserInfo `json:"user"`
|
||||
ProxyName string `json:"proxy_name"`
|
||||
ProxyType string `json:"proxy_type"`
|
||||
RemoteAddr string `json:"remote_addr"`
|
||||
}
|
||||
137
pkg/proto/udp/udp.go
Normal file
137
pkg/proto/udp/udp.go
Normal file
@@ -0,0 +1,137 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
"github.com/fatedier/golib/pool"
|
||||
)
|
||||
|
||||
func NewUDPPacket(buf []byte, laddr, raddr *net.UDPAddr) *msg.UDPPacket {
|
||||
return &msg.UDPPacket{
|
||||
Content: base64.StdEncoding.EncodeToString(buf),
|
||||
LocalAddr: laddr,
|
||||
RemoteAddr: raddr,
|
||||
}
|
||||
}
|
||||
|
||||
func GetContent(m *msg.UDPPacket) (buf []byte, err error) {
|
||||
buf, err = base64.StdEncoding.DecodeString(m.Content)
|
||||
return
|
||||
}
|
||||
|
||||
func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UDPPacket, sendCh chan<- *msg.UDPPacket, bufSize int) {
|
||||
// read
|
||||
go func() {
|
||||
for udpMsg := range readCh {
|
||||
buf, err := GetContent(udpMsg)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
udpConn.WriteToUDP(buf, udpMsg.RemoteAddr)
|
||||
}
|
||||
}()
|
||||
|
||||
// write
|
||||
buf := pool.GetBuf(bufSize)
|
||||
defer pool.PutBuf(buf)
|
||||
for {
|
||||
n, remoteAddr, err := udpConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// buf[:n] will be encoded to string, so the bytes can be reused
|
||||
udpMsg := NewUDPPacket(buf[:n], nil, remoteAddr)
|
||||
|
||||
select {
|
||||
case sendCh <- udpMsg:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<- msg.Message, bufSize int) {
|
||||
var (
|
||||
mu sync.RWMutex
|
||||
)
|
||||
udpConnMap := make(map[string]*net.UDPConn)
|
||||
|
||||
// read from dstAddr and write to sendCh
|
||||
writerFn := func(raddr *net.UDPAddr, udpConn *net.UDPConn) {
|
||||
addr := raddr.String()
|
||||
defer func() {
|
||||
mu.Lock()
|
||||
delete(udpConnMap, addr)
|
||||
mu.Unlock()
|
||||
udpConn.Close()
|
||||
}()
|
||||
|
||||
buf := pool.GetBuf(bufSize)
|
||||
for {
|
||||
udpConn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
n, _, err := udpConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
udpMsg := NewUDPPacket(buf[:n], nil, raddr)
|
||||
if err = errors.PanicToError(func() {
|
||||
select {
|
||||
case sendCh <- udpMsg:
|
||||
default:
|
||||
}
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// read from readCh
|
||||
go func() {
|
||||
for udpMsg := range readCh {
|
||||
buf, err := GetContent(udpMsg)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
mu.Lock()
|
||||
udpConn, ok := udpConnMap[udpMsg.RemoteAddr.String()]
|
||||
if !ok {
|
||||
udpConn, err = net.DialUDP("udp", nil, dstAddr)
|
||||
if err != nil {
|
||||
mu.Unlock()
|
||||
continue
|
||||
}
|
||||
udpConnMap[udpMsg.RemoteAddr.String()] = udpConn
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
_, err = udpConn.Write(buf)
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
}
|
||||
|
||||
if !ok {
|
||||
go writerFn(udpMsg.RemoteAddr, udpConn)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
18
pkg/proto/udp/udp_test.go
Normal file
18
pkg/proto/udp/udp_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUdpPacket(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
buf := []byte("hello world")
|
||||
udpMsg := NewUDPPacket(buf, nil, nil)
|
||||
|
||||
newBuf, err := GetContent(udpMsg)
|
||||
assert.NoError(err)
|
||||
assert.EqualValues(buf, newBuf)
|
||||
}
|
||||
117
pkg/transport/tls.go
Normal file
117
pkg/transport/tls.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
func newCustomTLSKeyPair(certfile, keyfile string) (*tls.Certificate, error) {
|
||||
tlsCert, err := tls.LoadX509KeyPair(certfile, keyfile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tlsCert, nil
|
||||
}
|
||||
|
||||
func newRandomTLSKeyPair() *tls.Certificate {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
template := x509.Certificate{SerialNumber: big.NewInt(1)}
|
||||
certDER, err := x509.CreateCertificate(
|
||||
rand.Reader,
|
||||
&template,
|
||||
&template,
|
||||
&key.PublicKey,
|
||||
key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
|
||||
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &tlsCert
|
||||
}
|
||||
|
||||
// Only supprt one ca file to add
|
||||
func newCertPool(caPath string) (*x509.CertPool, error) {
|
||||
pool := x509.NewCertPool()
|
||||
|
||||
caCrt, err := ioutil.ReadFile(caPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pool.AppendCertsFromPEM(caCrt)
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func NewServerTLSConfig(certPath, keyPath, caPath string) (*tls.Config, error) {
|
||||
var base = &tls.Config{}
|
||||
|
||||
if certPath == "" || keyPath == "" {
|
||||
// server will generate tls conf by itself
|
||||
cert := newRandomTLSKeyPair()
|
||||
base.Certificates = []tls.Certificate{*cert}
|
||||
} else {
|
||||
cert, err := newCustomTLSKeyPair(certPath, keyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base.Certificates = []tls.Certificate{*cert}
|
||||
}
|
||||
|
||||
if caPath != "" {
|
||||
pool, err := newCertPool(caPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
base.ClientCAs = pool
|
||||
}
|
||||
|
||||
return base, nil
|
||||
}
|
||||
|
||||
func NewClientTLSConfig(certPath, keyPath, caPath, servearName string) (*tls.Config, error) {
|
||||
var base = &tls.Config{}
|
||||
|
||||
if certPath == "" || keyPath == "" {
|
||||
// client will not generate tls conf by itself
|
||||
} else {
|
||||
cert, err := newCustomTLSKeyPair(certPath, keyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base.Certificates = []tls.Certificate{*cert}
|
||||
}
|
||||
|
||||
if caPath != "" {
|
||||
pool, err := newCertPool(caPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base.RootCAs = pool
|
||||
base.ServerName = servearName
|
||||
base.InsecureSkipVerify = false
|
||||
} else {
|
||||
base.InsecureSkipVerify = true
|
||||
}
|
||||
|
||||
return base, nil
|
||||
}
|
||||
51
pkg/util/limit/reader.go
Normal file
51
pkg/util/limit/reader.go
Normal file
@@ -0,0 +1,51 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package limit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type Reader struct {
|
||||
r io.Reader
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
func NewReader(r io.Reader, limiter *rate.Limiter) *Reader {
|
||||
return &Reader{
|
||||
r: r,
|
||||
limiter: limiter,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Reader) Read(p []byte) (n int, err error) {
|
||||
b := r.limiter.Burst()
|
||||
if b < len(p) {
|
||||
p = p[:b]
|
||||
}
|
||||
n, err = r.r.Read(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = r.limiter.WaitN(context.Background(), n)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
60
pkg/util/limit/writer.go
Normal file
60
pkg/util/limit/writer.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package limit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type Writer struct {
|
||||
w io.Writer
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
func NewWriter(w io.Writer, limiter *rate.Limiter) *Writer {
|
||||
return &Writer{
|
||||
w: w,
|
||||
limiter: limiter,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Writer) Write(p []byte) (n int, err error) {
|
||||
var nn int
|
||||
b := w.limiter.Burst()
|
||||
for {
|
||||
end := len(p)
|
||||
if end == 0 {
|
||||
break
|
||||
}
|
||||
if b < len(p) {
|
||||
end = b
|
||||
}
|
||||
err = w.limiter.WaitN(context.Background(), end)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
nn, err = w.w.Write(p[:end])
|
||||
n += nn
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p = p[end:]
|
||||
}
|
||||
return
|
||||
}
|
||||
93
pkg/util/log/log.go
Normal file
93
pkg/util/log/log.go
Normal file
@@ -0,0 +1,93 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package log
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/fatedier/beego/logs"
|
||||
)
|
||||
|
||||
// Log is the under log object
|
||||
var Log *logs.BeeLogger
|
||||
|
||||
func init() {
|
||||
Log = logs.NewLogger(200)
|
||||
Log.EnableFuncCallDepth(true)
|
||||
Log.SetLogFuncCallDepth(Log.GetLogFuncCallDepth() + 1)
|
||||
}
|
||||
|
||||
func InitLog(logWay string, logFile string, logLevel string, maxdays int64, disableLogColor bool) {
|
||||
SetLogFile(logWay, logFile, maxdays, disableLogColor)
|
||||
SetLogLevel(logLevel)
|
||||
}
|
||||
|
||||
// SetLogFile to configure log params
|
||||
// logWay: file or console
|
||||
func SetLogFile(logWay string, logFile string, maxdays int64, disableLogColor bool) {
|
||||
if logWay == "console" {
|
||||
params := ""
|
||||
if disableLogColor {
|
||||
params = fmt.Sprintf(`{"color": false}`)
|
||||
}
|
||||
Log.SetLogger("console", params)
|
||||
} else {
|
||||
params := fmt.Sprintf(`{"filename": "%s", "maxdays": %d}`, logFile, maxdays)
|
||||
Log.SetLogger("file", params)
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogLevel set log level, default is warning
|
||||
// value: error, warning, info, debug, trace
|
||||
func SetLogLevel(logLevel string) {
|
||||
level := 4 // warning
|
||||
switch logLevel {
|
||||
case "error":
|
||||
level = 3
|
||||
case "warn":
|
||||
level = 4
|
||||
case "info":
|
||||
level = 6
|
||||
case "debug":
|
||||
level = 7
|
||||
case "trace":
|
||||
level = 8
|
||||
default:
|
||||
level = 4
|
||||
}
|
||||
Log.SetLevel(level)
|
||||
}
|
||||
|
||||
// wrap log
|
||||
|
||||
func Error(format string, v ...interface{}) {
|
||||
Log.Error(format, v...)
|
||||
}
|
||||
|
||||
func Warn(format string, v ...interface{}) {
|
||||
Log.Warn(format, v...)
|
||||
}
|
||||
|
||||
func Info(format string, v ...interface{}) {
|
||||
Log.Info(format, v...)
|
||||
}
|
||||
|
||||
func Debug(format string, v ...interface{}) {
|
||||
Log.Debug(format, v...)
|
||||
}
|
||||
|
||||
func Trace(format string, v ...interface{}) {
|
||||
Log.Trace(format, v...)
|
||||
}
|
||||
60
pkg/util/metric/counter.go
Normal file
60
pkg/util/metric/counter.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package metric
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type Counter interface {
|
||||
Count() int32
|
||||
Inc(int32)
|
||||
Dec(int32)
|
||||
Snapshot() Counter
|
||||
Clear()
|
||||
}
|
||||
|
||||
func NewCounter() Counter {
|
||||
return &StandardCounter{
|
||||
count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
type StandardCounter struct {
|
||||
count int32
|
||||
}
|
||||
|
||||
func (c *StandardCounter) Count() int32 {
|
||||
return atomic.LoadInt32(&c.count)
|
||||
}
|
||||
|
||||
func (c *StandardCounter) Inc(count int32) {
|
||||
atomic.AddInt32(&c.count, count)
|
||||
}
|
||||
|
||||
func (c *StandardCounter) Dec(count int32) {
|
||||
atomic.AddInt32(&c.count, -count)
|
||||
}
|
||||
|
||||
func (c *StandardCounter) Snapshot() Counter {
|
||||
tmp := &StandardCounter{
|
||||
count: atomic.LoadInt32(&c.count),
|
||||
}
|
||||
return tmp
|
||||
}
|
||||
|
||||
func (c *StandardCounter) Clear() {
|
||||
atomic.StoreInt32(&c.count, 0)
|
||||
}
|
||||
23
pkg/util/metric/counter_test.go
Normal file
23
pkg/util/metric/counter_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package metric
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCounter(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
c := NewCounter()
|
||||
c.Inc(10)
|
||||
assert.EqualValues(10, c.Count())
|
||||
|
||||
c.Dec(5)
|
||||
assert.EqualValues(5, c.Count())
|
||||
|
||||
cTmp := c.Snapshot()
|
||||
assert.EqualValues(5, cTmp.Count())
|
||||
|
||||
c.Clear()
|
||||
assert.EqualValues(0, c.Count())
|
||||
}
|
||||
134
pkg/util/metric/date_counter.go
Normal file
134
pkg/util/metric/date_counter.go
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package metric
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DateCounter interface {
|
||||
TodayCount() int64
|
||||
GetLastDaysCount(lastdays int64) []int64
|
||||
Inc(int64)
|
||||
Dec(int64)
|
||||
Snapshot() DateCounter
|
||||
Clear()
|
||||
}
|
||||
|
||||
func NewDateCounter(reserveDays int64) DateCounter {
|
||||
if reserveDays <= 0 {
|
||||
reserveDays = 1
|
||||
}
|
||||
return newStandardDateCounter(reserveDays)
|
||||
}
|
||||
|
||||
type StandardDateCounter struct {
|
||||
reserveDays int64
|
||||
counts []int64
|
||||
|
||||
lastUpdateDate time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newStandardDateCounter(reserveDays int64) *StandardDateCounter {
|
||||
now := time.Now()
|
||||
now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
s := &StandardDateCounter{
|
||||
reserveDays: reserveDays,
|
||||
counts: make([]int64, reserveDays),
|
||||
lastUpdateDate: now,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) TodayCount() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.rotate(time.Now())
|
||||
return c.counts[0]
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) GetLastDaysCount(lastdays int64) []int64 {
|
||||
if lastdays > c.reserveDays {
|
||||
lastdays = c.reserveDays
|
||||
}
|
||||
counts := make([]int64, lastdays)
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.rotate(time.Now())
|
||||
for i := 0; i < int(lastdays); i++ {
|
||||
counts[i] = c.counts[i]
|
||||
}
|
||||
return counts
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) Inc(count int64) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.rotate(time.Now())
|
||||
c.counts[0] += count
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) Dec(count int64) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.rotate(time.Now())
|
||||
c.counts[0] -= count
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) Snapshot() DateCounter {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
tmp := newStandardDateCounter(c.reserveDays)
|
||||
for i := 0; i < int(c.reserveDays); i++ {
|
||||
tmp.counts[i] = c.counts[i]
|
||||
}
|
||||
return tmp
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for i := 0; i < int(c.reserveDays); i++ {
|
||||
c.counts[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
// rotate
|
||||
// Must hold the lock before calling this function.
|
||||
func (c *StandardDateCounter) rotate(now time.Time) {
|
||||
now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
days := int(now.Sub(c.lastUpdateDate).Hours() / 24)
|
||||
|
||||
defer func() {
|
||||
c.lastUpdateDate = now
|
||||
}()
|
||||
|
||||
if days <= 0 {
|
||||
return
|
||||
} else if days >= int(c.reserveDays) {
|
||||
c.counts = make([]int64, c.reserveDays)
|
||||
return
|
||||
}
|
||||
newCounts := make([]int64, c.reserveDays)
|
||||
|
||||
for i := days; i < int(c.reserveDays); i++ {
|
||||
newCounts[i] = c.counts[i-days]
|
||||
}
|
||||
c.counts = newCounts
|
||||
}
|
||||
27
pkg/util/metric/date_counter_test.go
Normal file
27
pkg/util/metric/date_counter_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package metric
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDateCounter(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
dc := NewDateCounter(3)
|
||||
dc.Inc(10)
|
||||
assert.EqualValues(10, dc.TodayCount())
|
||||
|
||||
dc.Dec(5)
|
||||
assert.EqualValues(5, dc.TodayCount())
|
||||
|
||||
counts := dc.GetLastDaysCount(3)
|
||||
assert.EqualValues(3, len(counts))
|
||||
assert.EqualValues(5, counts[0])
|
||||
assert.EqualValues(0, counts[1])
|
||||
assert.EqualValues(0, counts[2])
|
||||
|
||||
dcTmp := dc.Snapshot()
|
||||
assert.EqualValues(5, dcTmp.TodayCount())
|
||||
}
|
||||
34
pkg/util/metric/metrics.go
Normal file
34
pkg/util/metric/metrics.go
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright 2020 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package metric
|
||||
|
||||
// GaugeMetric represents a single numerical value that can arbitrarily go up
|
||||
// and down.
|
||||
type GaugeMetric interface {
|
||||
Inc()
|
||||
Dec()
|
||||
Set(float64)
|
||||
}
|
||||
|
||||
// CounterMetric represents a single numerical value that only ever
|
||||
// goes up.
|
||||
type CounterMetric interface {
|
||||
Inc()
|
||||
}
|
||||
|
||||
// HistogramMetric counts individual observations.
|
||||
type HistogramMetric interface {
|
||||
Observe(float64)
|
||||
}
|
||||
243
pkg/util/net/conn.go
Normal file
243
pkg/util/net/conn.go
Normal file
@@ -0,0 +1,243 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
|
||||
gnet "github.com/fatedier/golib/net"
|
||||
kcp "github.com/fatedier/kcp-go"
|
||||
)
|
||||
|
||||
type ContextGetter interface {
|
||||
Context() context.Context
|
||||
}
|
||||
|
||||
type ContextSetter interface {
|
||||
WithContext(ctx context.Context)
|
||||
}
|
||||
|
||||
func NewLogFromConn(conn net.Conn) *xlog.Logger {
|
||||
if c, ok := conn.(ContextGetter); ok {
|
||||
return xlog.FromContextSafe(c.Context())
|
||||
}
|
||||
return xlog.New()
|
||||
}
|
||||
|
||||
func NewContextFromConn(conn net.Conn) context.Context {
|
||||
if c, ok := conn.(ContextGetter); ok {
|
||||
return c.Context()
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// ContextConn is the connection with context
|
||||
type ContextConn struct {
|
||||
net.Conn
|
||||
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewContextConn(ctx context.Context, c net.Conn) *ContextConn {
|
||||
return &ContextConn{
|
||||
Conn: c,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ContextConn) WithContext(ctx context.Context) {
|
||||
c.ctx = ctx
|
||||
}
|
||||
|
||||
func (c *ContextConn) Context() context.Context {
|
||||
return c.ctx
|
||||
}
|
||||
|
||||
type WrapReadWriteCloserConn struct {
|
||||
io.ReadWriteCloser
|
||||
|
||||
underConn net.Conn
|
||||
}
|
||||
|
||||
func WrapReadWriteCloserToConn(rwc io.ReadWriteCloser, underConn net.Conn) net.Conn {
|
||||
return &WrapReadWriteCloserConn{
|
||||
ReadWriteCloser: rwc,
|
||||
underConn: underConn,
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *WrapReadWriteCloserConn) LocalAddr() net.Addr {
|
||||
if conn.underConn != nil {
|
||||
return conn.underConn.LocalAddr()
|
||||
}
|
||||
return (*net.TCPAddr)(nil)
|
||||
}
|
||||
|
||||
func (conn *WrapReadWriteCloserConn) RemoteAddr() net.Addr {
|
||||
if conn.underConn != nil {
|
||||
return conn.underConn.RemoteAddr()
|
||||
}
|
||||
return (*net.TCPAddr)(nil)
|
||||
}
|
||||
|
||||
func (conn *WrapReadWriteCloserConn) SetDeadline(t time.Time) error {
|
||||
if conn.underConn != nil {
|
||||
return conn.underConn.SetDeadline(t)
|
||||
}
|
||||
return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (conn *WrapReadWriteCloserConn) SetReadDeadline(t time.Time) error {
|
||||
if conn.underConn != nil {
|
||||
return conn.underConn.SetReadDeadline(t)
|
||||
}
|
||||
return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (conn *WrapReadWriteCloserConn) SetWriteDeadline(t time.Time) error {
|
||||
if conn.underConn != nil {
|
||||
return conn.underConn.SetWriteDeadline(t)
|
||||
}
|
||||
return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
type CloseNotifyConn struct {
|
||||
net.Conn
|
||||
|
||||
// 1 means closed
|
||||
closeFlag int32
|
||||
|
||||
closeFn func()
|
||||
}
|
||||
|
||||
// closeFn will be only called once
|
||||
func WrapCloseNotifyConn(c net.Conn, closeFn func()) net.Conn {
|
||||
return &CloseNotifyConn{
|
||||
Conn: c,
|
||||
closeFn: closeFn,
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *CloseNotifyConn) Close() (err error) {
|
||||
pflag := atomic.SwapInt32(&cc.closeFlag, 1)
|
||||
if pflag == 0 {
|
||||
err = cc.Close()
|
||||
if cc.closeFn != nil {
|
||||
cc.closeFn()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type StatsConn struct {
|
||||
net.Conn
|
||||
|
||||
closed int64 // 1 means closed
|
||||
totalRead int64
|
||||
totalWrite int64
|
||||
statsFunc func(totalRead, totalWrite int64)
|
||||
}
|
||||
|
||||
func WrapStatsConn(conn net.Conn, statsFunc func(total, totalWrite int64)) *StatsConn {
|
||||
return &StatsConn{
|
||||
Conn: conn,
|
||||
statsFunc: statsFunc,
|
||||
}
|
||||
}
|
||||
|
||||
func (statsConn *StatsConn) Read(p []byte) (n int, err error) {
|
||||
n, err = statsConn.Conn.Read(p)
|
||||
statsConn.totalRead += int64(n)
|
||||
return
|
||||
}
|
||||
|
||||
func (statsConn *StatsConn) Write(p []byte) (n int, err error) {
|
||||
n, err = statsConn.Conn.Write(p)
|
||||
statsConn.totalWrite += int64(n)
|
||||
return
|
||||
}
|
||||
|
||||
func (statsConn *StatsConn) Close() (err error) {
|
||||
old := atomic.SwapInt64(&statsConn.closed, 1)
|
||||
if old != 1 {
|
||||
err = statsConn.Conn.Close()
|
||||
if statsConn.statsFunc != nil {
|
||||
statsConn.statsFunc(statsConn.totalRead, statsConn.totalWrite)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ConnectServer(protocol string, addr string) (c net.Conn, err error) {
|
||||
switch protocol {
|
||||
case "tcp":
|
||||
return net.Dial("tcp", addr)
|
||||
case "kcp":
|
||||
kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3)
|
||||
if errRet != nil {
|
||||
err = errRet
|
||||
return
|
||||
}
|
||||
kcpConn.SetStreamMode(true)
|
||||
kcpConn.SetWriteDelay(true)
|
||||
kcpConn.SetNoDelay(1, 20, 2, 1)
|
||||
kcpConn.SetWindowSize(128, 512)
|
||||
kcpConn.SetMtu(1350)
|
||||
kcpConn.SetACKNoDelay(false)
|
||||
kcpConn.SetReadBuffer(4194304)
|
||||
kcpConn.SetWriteBuffer(4194304)
|
||||
c = kcpConn
|
||||
return
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupport protocol: %s", protocol)
|
||||
}
|
||||
}
|
||||
|
||||
func ConnectServerByProxy(proxyURL string, protocol string, addr string) (c net.Conn, err error) {
|
||||
switch protocol {
|
||||
case "tcp":
|
||||
return gnet.DialTcpByProxy(proxyURL, addr)
|
||||
case "kcp":
|
||||
// http proxy is not supported for kcp
|
||||
return ConnectServer(protocol, addr)
|
||||
case "websocket":
|
||||
return ConnectWebsocketServer(addr)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupport protocol: %s", protocol)
|
||||
}
|
||||
}
|
||||
|
||||
func ConnectServerByProxyWithTLS(proxyURL string, protocol string, addr string, tlsConfig *tls.Config) (c net.Conn, err error) {
|
||||
c, err = ConnectServerByProxy(proxyURL, protocol, addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if tlsConfig == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c = WrapTLSClientConn(c, tlsConfig)
|
||||
return
|
||||
}
|
||||
115
pkg/util/net/http.go
Normal file
115
pkg/util/net/http.go
Normal file
@@ -0,0 +1,115 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type HTTPAuthWraper struct {
|
||||
h http.Handler
|
||||
user string
|
||||
passwd string
|
||||
}
|
||||
|
||||
func NewHTTPBasicAuthWraper(h http.Handler, user, passwd string) http.Handler {
|
||||
return &HTTPAuthWraper{
|
||||
h: h,
|
||||
user: user,
|
||||
passwd: passwd,
|
||||
}
|
||||
}
|
||||
|
||||
func (aw *HTTPAuthWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
user, passwd, hasAuth := r.BasicAuth()
|
||||
if (aw.user == "" && aw.passwd == "") || (hasAuth && user == aw.user && passwd == aw.passwd) {
|
||||
aw.h.ServeHTTP(w, r)
|
||||
} else {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
|
||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
type HTTPAuthMiddleware struct {
|
||||
user string
|
||||
passwd string
|
||||
}
|
||||
|
||||
func NewHTTPAuthMiddleware(user, passwd string) *HTTPAuthMiddleware {
|
||||
return &HTTPAuthMiddleware{
|
||||
user: user,
|
||||
passwd: passwd,
|
||||
}
|
||||
}
|
||||
|
||||
func (authMid *HTTPAuthMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqUser, reqPasswd, hasAuth := r.BasicAuth()
|
||||
if (authMid.user == "" && authMid.passwd == "") ||
|
||||
(hasAuth && reqUser == authMid.user && reqPasswd == authMid.passwd) {
|
||||
next.ServeHTTP(w, r)
|
||||
} else {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
|
||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func HTTPBasicAuth(h http.HandlerFunc, user, passwd string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
reqUser, reqPasswd, hasAuth := r.BasicAuth()
|
||||
if (user == "" && passwd == "") ||
|
||||
(hasAuth && reqUser == user && reqPasswd == passwd) {
|
||||
h.ServeHTTP(w, r)
|
||||
} else {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
|
||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type HTTPGzipWraper struct {
|
||||
h http.Handler
|
||||
}
|
||||
|
||||
func (gw *HTTPGzipWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
|
||||
gw.h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
gz := gzip.NewWriter(w)
|
||||
defer gz.Close()
|
||||
gzr := gzipResponseWriter{Writer: gz, ResponseWriter: w}
|
||||
gw.h.ServeHTTP(gzr, r)
|
||||
}
|
||||
|
||||
func MakeHTTPGzipHandler(h http.Handler) http.Handler {
|
||||
return &HTTPGzipWraper{
|
||||
h: h,
|
||||
}
|
||||
}
|
||||
|
||||
type gzipResponseWriter struct {
|
||||
io.Writer
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
func (w gzipResponseWriter) Write(b []byte) (int, error) {
|
||||
return w.Writer.Write(b)
|
||||
}
|
||||
99
pkg/util/net/kcp.go
Normal file
99
pkg/util/net/kcp.go
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
kcp "github.com/fatedier/kcp-go"
|
||||
)
|
||||
|
||||
type KCPListener struct {
|
||||
listener net.Listener
|
||||
acceptCh chan net.Conn
|
||||
closeFlag bool
|
||||
}
|
||||
|
||||
func ListenKcp(bindAddr string, bindPort int) (l *KCPListener, err error) {
|
||||
listener, err := kcp.ListenWithOptions(fmt.Sprintf("%s:%d", bindAddr, bindPort), nil, 10, 3)
|
||||
if err != nil {
|
||||
return l, err
|
||||
}
|
||||
listener.SetReadBuffer(4194304)
|
||||
listener.SetWriteBuffer(4194304)
|
||||
|
||||
l = &KCPListener{
|
||||
listener: listener,
|
||||
acceptCh: make(chan net.Conn),
|
||||
closeFlag: false,
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.AcceptKCP()
|
||||
if err != nil {
|
||||
if l.closeFlag {
|
||||
close(l.acceptCh)
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
conn.SetStreamMode(true)
|
||||
conn.SetWriteDelay(true)
|
||||
conn.SetNoDelay(1, 20, 2, 1)
|
||||
conn.SetMtu(1350)
|
||||
conn.SetWindowSize(1024, 1024)
|
||||
conn.SetACKNoDelay(false)
|
||||
|
||||
l.acceptCh <- conn
|
||||
}
|
||||
}()
|
||||
return l, err
|
||||
}
|
||||
|
||||
func (l *KCPListener) Accept() (net.Conn, error) {
|
||||
conn, ok := <-l.acceptCh
|
||||
if !ok {
|
||||
return conn, fmt.Errorf("channel for kcp listener closed")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (l *KCPListener) Close() error {
|
||||
if !l.closeFlag {
|
||||
l.closeFlag = true
|
||||
l.listener.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *KCPListener) Addr() net.Addr {
|
||||
return l.listener.Addr()
|
||||
}
|
||||
|
||||
func NewKCPConnFromUDP(conn *net.UDPConn, connected bool, raddr string) (net.Conn, error) {
|
||||
kcpConn, err := kcp.NewConnEx(1, connected, raddr, nil, 10, 3, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kcpConn.SetStreamMode(true)
|
||||
kcpConn.SetWriteDelay(true)
|
||||
kcpConn.SetNoDelay(1, 20, 2, 1)
|
||||
kcpConn.SetMtu(1350)
|
||||
kcpConn.SetWindowSize(1024, 1024)
|
||||
kcpConn.SetACKNoDelay(false)
|
||||
return kcpConn, nil
|
||||
}
|
||||
69
pkg/util/net/listener.go
Normal file
69
pkg/util/net/listener.go
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
)
|
||||
|
||||
// Custom listener
|
||||
type CustomListener struct {
|
||||
acceptCh chan net.Conn
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewCustomListener() *CustomListener {
|
||||
return &CustomListener{
|
||||
acceptCh: make(chan net.Conn, 64),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *CustomListener) Accept() (net.Conn, error) {
|
||||
conn, ok := <-l.acceptCh
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("listener closed")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (l *CustomListener) PutConn(conn net.Conn) error {
|
||||
err := errors.PanicToError(func() {
|
||||
select {
|
||||
case l.acceptCh <- conn:
|
||||
default:
|
||||
conn.Close()
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *CustomListener) Close() error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if !l.closed {
|
||||
close(l.acceptCh)
|
||||
l.closed = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *CustomListener) Addr() net.Addr {
|
||||
return (*net.TCPAddr)(nil)
|
||||
}
|
||||
57
pkg/util/net/tls.go
Normal file
57
pkg/util/net/tls.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
gnet "github.com/fatedier/golib/net"
|
||||
)
|
||||
|
||||
var (
|
||||
FRPTLSHeadByte = 0x17
|
||||
)
|
||||
|
||||
func WrapTLSClientConn(c net.Conn, tlsConfig *tls.Config) (out net.Conn) {
|
||||
c.Write([]byte{byte(FRPTLSHeadByte)})
|
||||
out = tls.Client(c, tlsConfig)
|
||||
return
|
||||
}
|
||||
|
||||
func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, tlsOnly bool, timeout time.Duration) (out net.Conn, err error) {
|
||||
sc, r := gnet.NewSharedConnSize(c, 2)
|
||||
buf := make([]byte, 1)
|
||||
var n int
|
||||
c.SetReadDeadline(time.Now().Add(timeout))
|
||||
n, err = r.Read(buf)
|
||||
c.SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if n == 1 && int(buf[0]) == FRPTLSHeadByte {
|
||||
out = tls.Server(c, tlsConfig)
|
||||
} else {
|
||||
if tlsOnly {
|
||||
err = fmt.Errorf("non-TLS connection received on a TlsOnly server")
|
||||
return
|
||||
}
|
||||
out = sc
|
||||
}
|
||||
return
|
||||
}
|
||||
258
pkg/util/net/udp.go
Normal file
258
pkg/util/net/udp.go
Normal file
@@ -0,0 +1,258 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
)
|
||||
|
||||
type UDPPacket struct {
|
||||
Buf []byte
|
||||
LocalAddr net.Addr
|
||||
RemoteAddr net.Addr
|
||||
}
|
||||
|
||||
type FakeUDPConn struct {
|
||||
l *UDPListener
|
||||
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
packets chan []byte
|
||||
closeFlag bool
|
||||
|
||||
lastActive time.Time
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewFakeUDPConn(l *UDPListener, laddr, raddr net.Addr) *FakeUDPConn {
|
||||
fc := &FakeUDPConn{
|
||||
l: l,
|
||||
localAddr: laddr,
|
||||
remoteAddr: raddr,
|
||||
packets: make(chan []byte, 20),
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(5 * time.Second)
|
||||
fc.mu.RLock()
|
||||
if time.Now().Sub(fc.lastActive) > 10*time.Second {
|
||||
fc.mu.RUnlock()
|
||||
fc.Close()
|
||||
break
|
||||
}
|
||||
fc.mu.RUnlock()
|
||||
}
|
||||
}()
|
||||
return fc
|
||||
}
|
||||
|
||||
func (c *FakeUDPConn) putPacket(content []byte) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case c.packets <- content:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *FakeUDPConn) Read(b []byte) (n int, err error) {
|
||||
content, ok := <-c.packets
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.lastActive = time.Now()
|
||||
c.mu.Unlock()
|
||||
|
||||
if len(b) < len(content) {
|
||||
n = len(b)
|
||||
} else {
|
||||
n = len(content)
|
||||
}
|
||||
copy(b, content)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *FakeUDPConn) Write(b []byte) (n int, err error) {
|
||||
c.mu.RLock()
|
||||
if c.closeFlag {
|
||||
c.mu.RUnlock()
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
packet := &UDPPacket{
|
||||
Buf: b,
|
||||
LocalAddr: c.localAddr,
|
||||
RemoteAddr: c.remoteAddr,
|
||||
}
|
||||
c.l.writeUDPPacket(packet)
|
||||
|
||||
c.mu.Lock()
|
||||
c.lastActive = time.Now()
|
||||
c.mu.Unlock()
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *FakeUDPConn) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if !c.closeFlag {
|
||||
c.closeFlag = true
|
||||
close(c.packets)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *FakeUDPConn) IsClosed() bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.closeFlag
|
||||
}
|
||||
|
||||
func (c *FakeUDPConn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
func (c *FakeUDPConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *FakeUDPConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *FakeUDPConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *FakeUDPConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type UDPListener struct {
|
||||
addr net.Addr
|
||||
acceptCh chan net.Conn
|
||||
writeCh chan *UDPPacket
|
||||
readConn net.Conn
|
||||
closeFlag bool
|
||||
|
||||
fakeConns map[string]*FakeUDPConn
|
||||
}
|
||||
|
||||
func ListenUDP(bindAddr string, bindPort int) (l *UDPListener, err error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
|
||||
if err != nil {
|
||||
return l, err
|
||||
}
|
||||
readConn, err := net.ListenUDP("udp", udpAddr)
|
||||
|
||||
l = &UDPListener{
|
||||
addr: udpAddr,
|
||||
acceptCh: make(chan net.Conn),
|
||||
writeCh: make(chan *UDPPacket, 1000),
|
||||
fakeConns: make(map[string]*FakeUDPConn),
|
||||
}
|
||||
|
||||
// for reading
|
||||
go func() {
|
||||
for {
|
||||
buf := pool.GetBuf(1450)
|
||||
n, remoteAddr, err := readConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
close(l.acceptCh)
|
||||
close(l.writeCh)
|
||||
return
|
||||
}
|
||||
|
||||
fakeConn, exist := l.fakeConns[remoteAddr.String()]
|
||||
if !exist || fakeConn.IsClosed() {
|
||||
fakeConn = NewFakeUDPConn(l, l.Addr(), remoteAddr)
|
||||
l.fakeConns[remoteAddr.String()] = fakeConn
|
||||
}
|
||||
fakeConn.putPacket(buf[:n])
|
||||
|
||||
l.acceptCh <- fakeConn
|
||||
}
|
||||
}()
|
||||
|
||||
// for writing
|
||||
go func() {
|
||||
for {
|
||||
packet, ok := <-l.writeCh
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if addr, ok := packet.RemoteAddr.(*net.UDPAddr); ok {
|
||||
readConn.WriteToUDP(packet.Buf, addr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *UDPListener) writeUDPPacket(packet *UDPPacket) (err error) {
|
||||
defer func() {
|
||||
if errRet := recover(); errRet != nil {
|
||||
err = fmt.Errorf("udp write closed listener")
|
||||
}
|
||||
}()
|
||||
l.writeCh <- packet
|
||||
return
|
||||
}
|
||||
|
||||
func (l *UDPListener) WriteMsg(buf []byte, remoteAddr *net.UDPAddr) (err error) {
|
||||
// only set remote addr here
|
||||
packet := &UDPPacket{
|
||||
Buf: buf,
|
||||
RemoteAddr: remoteAddr,
|
||||
}
|
||||
err = l.writeUDPPacket(packet)
|
||||
return
|
||||
}
|
||||
|
||||
func (l *UDPListener) Accept() (net.Conn, error) {
|
||||
conn, ok := <-l.acceptCh
|
||||
if !ok {
|
||||
return conn, fmt.Errorf("channel for udp listener closed")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (l *UDPListener) Close() error {
|
||||
if !l.closeFlag {
|
||||
l.closeFlag = true
|
||||
if l.readConn != nil {
|
||||
l.readConn.Close()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *UDPListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
103
pkg/util/net/websocket.go
Normal file
103
pkg/util/net/websocket.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrWebsocketListenerClosed = errors.New("websocket listener closed")
|
||||
)
|
||||
|
||||
const (
|
||||
FrpWebsocketPath = "/~!frp"
|
||||
)
|
||||
|
||||
type WebsocketListener struct {
|
||||
ln net.Listener
|
||||
acceptCh chan net.Conn
|
||||
|
||||
server *http.Server
|
||||
httpMutex *http.ServeMux
|
||||
}
|
||||
|
||||
// NewWebsocketListener to handle websocket connections
|
||||
// ln: tcp listener for websocket connections
|
||||
func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
|
||||
wl = &WebsocketListener{
|
||||
acceptCh: make(chan net.Conn),
|
||||
}
|
||||
|
||||
muxer := http.NewServeMux()
|
||||
muxer.Handle(FrpWebsocketPath, websocket.Handler(func(c *websocket.Conn) {
|
||||
notifyCh := make(chan struct{})
|
||||
conn := WrapCloseNotifyConn(c, func() {
|
||||
close(notifyCh)
|
||||
})
|
||||
wl.acceptCh <- conn
|
||||
<-notifyCh
|
||||
}))
|
||||
|
||||
wl.server = &http.Server{
|
||||
Addr: ln.Addr().String(),
|
||||
Handler: muxer,
|
||||
}
|
||||
|
||||
go wl.server.Serve(ln)
|
||||
return
|
||||
}
|
||||
|
||||
func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) {
|
||||
tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l := NewWebsocketListener(tcpLn)
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (p *WebsocketListener) Accept() (net.Conn, error) {
|
||||
c, ok := <-p.acceptCh
|
||||
if !ok {
|
||||
return nil, ErrWebsocketListenerClosed
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (p *WebsocketListener) Close() error {
|
||||
return p.server.Close()
|
||||
}
|
||||
|
||||
func (p *WebsocketListener) Addr() net.Addr {
|
||||
return p.ln.Addr()
|
||||
}
|
||||
|
||||
// addr: domain:port
|
||||
func ConnectWebsocketServer(addr string) (net.Conn, error) {
|
||||
addr = "ws://" + addr + FrpWebsocketPath
|
||||
uri, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
origin := "http://" + uri.Host
|
||||
cfg, err := websocket.NewConfig(addr, origin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.Dialer = &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := websocket.DialConfig(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
68
pkg/util/tcpmux/httpconnect.go
Normal file
68
pkg/util/tcpmux/httpconnect.go
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright 2020 guylewin, guy@lewin.co.il
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tcpmux
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
"github.com/fatedier/frp/pkg/util/vhost"
|
||||
)
|
||||
|
||||
type HTTPConnectTCPMuxer struct {
|
||||
*vhost.Muxer
|
||||
}
|
||||
|
||||
func NewHTTPConnectTCPMuxer(listener net.Listener, timeout time.Duration) (*HTTPConnectTCPMuxer, error) {
|
||||
mux, err := vhost.NewMuxer(listener, getHostFromHTTPConnect, nil, sendHTTPOk, nil, timeout)
|
||||
return &HTTPConnectTCPMuxer{mux}, err
|
||||
}
|
||||
|
||||
func readHTTPConnectRequest(rd io.Reader) (host string, err error) {
|
||||
bufioReader := bufio.NewReader(rd)
|
||||
|
||||
req, err := http.ReadRequest(bufioReader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Method != "CONNECT" {
|
||||
err = fmt.Errorf("connections to tcp vhost must be of method CONNECT")
|
||||
return
|
||||
}
|
||||
|
||||
host = util.GetHostFromAddr(req.Host)
|
||||
return
|
||||
}
|
||||
|
||||
func sendHTTPOk(c net.Conn) error {
|
||||
return util.OkResponse().Write(c)
|
||||
}
|
||||
|
||||
func getHostFromHTTPConnect(c net.Conn) (_ net.Conn, _ map[string]string, err error) {
|
||||
reqInfoMap := make(map[string]string, 0)
|
||||
host, err := readHTTPConnectRequest(c)
|
||||
if err != nil {
|
||||
return nil, reqInfoMap, err
|
||||
}
|
||||
reqInfoMap["Host"] = host
|
||||
reqInfoMap["Scheme"] = "tcp"
|
||||
return c, reqInfoMap, nil
|
||||
}
|
||||
44
pkg/util/util/http.go
Normal file
44
pkg/util/util/http.go
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright 2020 guylewin, guy@lewin.co.il
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package util
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func OkResponse() *http.Response {
|
||||
header := make(http.Header)
|
||||
|
||||
res := &http.Response{
|
||||
Status: "OK",
|
||||
StatusCode: 200,
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: header,
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func GetHostFromAddr(addr string) (host string) {
|
||||
strs := strings.Split(addr, ":")
|
||||
if len(strs) > 1 {
|
||||
host = strs[0]
|
||||
} else {
|
||||
host = addr
|
||||
}
|
||||
return
|
||||
}
|
||||
110
pkg/util/util/util.go
Normal file
110
pkg/util/util/util.go
Normal file
@@ -0,0 +1,110 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package util
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// RandID return a rand string used in frp.
|
||||
func RandID() (id string, err error) {
|
||||
return RandIDWithLen(8)
|
||||
}
|
||||
|
||||
// RandIDWithLen return a rand string with idLen length.
|
||||
func RandIDWithLen(idLen int) (id string, err error) {
|
||||
b := make([]byte, idLen)
|
||||
_, err = rand.Read(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
id = fmt.Sprintf("%x", b)
|
||||
return
|
||||
}
|
||||
|
||||
func GetAuthKey(token string, timestamp int64) (key string) {
|
||||
token = token + fmt.Sprintf("%d", timestamp)
|
||||
md5Ctx := md5.New()
|
||||
md5Ctx.Write([]byte(token))
|
||||
data := md5Ctx.Sum(nil)
|
||||
return hex.EncodeToString(data)
|
||||
}
|
||||
|
||||
func CanonicalAddr(host string, port int) (addr string) {
|
||||
if port == 80 || port == 443 {
|
||||
addr = host
|
||||
} else {
|
||||
addr = fmt.Sprintf("%s:%d", host, port)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ParseRangeNumbers(rangeStr string) (numbers []int64, err error) {
|
||||
rangeStr = strings.TrimSpace(rangeStr)
|
||||
numbers = make([]int64, 0)
|
||||
// e.g. 1000-2000,2001,2002,3000-4000
|
||||
numRanges := strings.Split(rangeStr, ",")
|
||||
for _, numRangeStr := range numRanges {
|
||||
// 1000-2000 or 2001
|
||||
numArray := strings.Split(numRangeStr, "-")
|
||||
// length: only 1 or 2 is correct
|
||||
rangeType := len(numArray)
|
||||
if rangeType == 1 {
|
||||
// single number
|
||||
singleNum, errRet := strconv.ParseInt(strings.TrimSpace(numArray[0]), 10, 64)
|
||||
if errRet != nil {
|
||||
err = fmt.Errorf("range number is invalid, %v", errRet)
|
||||
return
|
||||
}
|
||||
numbers = append(numbers, singleNum)
|
||||
} else if rangeType == 2 {
|
||||
// range numbers
|
||||
min, errRet := strconv.ParseInt(strings.TrimSpace(numArray[0]), 10, 64)
|
||||
if errRet != nil {
|
||||
err = fmt.Errorf("range number is invalid, %v", errRet)
|
||||
return
|
||||
}
|
||||
max, errRet := strconv.ParseInt(strings.TrimSpace(numArray[1]), 10, 64)
|
||||
if errRet != nil {
|
||||
err = fmt.Errorf("range number is invalid, %v", errRet)
|
||||
return
|
||||
}
|
||||
if max < min {
|
||||
err = fmt.Errorf("range number is invalid")
|
||||
return
|
||||
}
|
||||
for i := min; i <= max; i++ {
|
||||
numbers = append(numbers, i)
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("range number is invalid")
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func GenerateResponseErrorString(summary string, err error, detailed bool) string {
|
||||
if detailed {
|
||||
return err.Error()
|
||||
}
|
||||
return summary
|
||||
}
|
||||
48
pkg/util/util/util_test.go
Normal file
48
pkg/util/util/util_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRandId(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
id, err := RandID()
|
||||
assert.NoError(err)
|
||||
t.Log(id)
|
||||
assert.Equal(16, len(id))
|
||||
}
|
||||
|
||||
func TestGetAuthKey(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
key := GetAuthKey("1234", 1488720000)
|
||||
t.Log(key)
|
||||
assert.Equal("6df41a43725f0c770fd56379e12acf8c", key)
|
||||
}
|
||||
|
||||
func TestParseRangeNumbers(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
numbers, err := ParseRangeNumbers("2-5")
|
||||
if assert.NoError(err) {
|
||||
assert.Equal([]int64{2, 3, 4, 5}, numbers)
|
||||
}
|
||||
|
||||
numbers, err = ParseRangeNumbers("1")
|
||||
if assert.NoError(err) {
|
||||
assert.Equal([]int64{1}, numbers)
|
||||
}
|
||||
|
||||
numbers, err = ParseRangeNumbers("3-5,8")
|
||||
if assert.NoError(err) {
|
||||
assert.Equal([]int64{3, 4, 5, 8}, numbers)
|
||||
}
|
||||
|
||||
numbers, err = ParseRangeNumbers(" 3-5,8, 10-12 ")
|
||||
if assert.NoError(err) {
|
||||
assert.Equal([]int64{3, 4, 5, 8, 10, 11, 12}, numbers)
|
||||
}
|
||||
|
||||
_, err = ParseRangeNumbers("3-a")
|
||||
assert.Error(err)
|
||||
}
|
||||
82
pkg/util/version/version.go
Normal file
82
pkg/util/version/version.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package version
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var version string = "0.34.0"
|
||||
|
||||
func Full() string {
|
||||
return version
|
||||
}
|
||||
|
||||
func getSubVersion(v string, position int) int64 {
|
||||
arr := strings.Split(v, ".")
|
||||
if len(arr) < 3 {
|
||||
return 0
|
||||
}
|
||||
res, _ := strconv.ParseInt(arr[position], 10, 64)
|
||||
return res
|
||||
}
|
||||
|
||||
func Proto(v string) int64 {
|
||||
return getSubVersion(v, 0)
|
||||
}
|
||||
|
||||
func Major(v string) int64 {
|
||||
return getSubVersion(v, 1)
|
||||
}
|
||||
|
||||
func Minor(v string) int64 {
|
||||
return getSubVersion(v, 2)
|
||||
}
|
||||
|
||||
// add every case there if server will not accept client's protocol and return false
|
||||
func Compat(client string) (ok bool, msg string) {
|
||||
if LessThan(client, "0.18.0") {
|
||||
return false, "Please upgrade your frpc version to at least 0.18.0"
|
||||
}
|
||||
return true, ""
|
||||
}
|
||||
|
||||
func LessThan(client string, server string) bool {
|
||||
vc := Proto(client)
|
||||
vs := Proto(server)
|
||||
if vc > vs {
|
||||
return false
|
||||
} else if vc < vs {
|
||||
return true
|
||||
}
|
||||
|
||||
vc = Major(client)
|
||||
vs = Major(server)
|
||||
if vc > vs {
|
||||
return false
|
||||
} else if vc < vs {
|
||||
return true
|
||||
}
|
||||
|
||||
vc = Minor(client)
|
||||
vs = Minor(server)
|
||||
if vc > vs {
|
||||
return false
|
||||
} else if vc < vs {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
65
pkg/util/version/version_test.go
Normal file
65
pkg/util/version/version_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package version
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFull(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
version := Full()
|
||||
arr := strings.Split(version, ".")
|
||||
assert.Equal(3, len(arr))
|
||||
|
||||
proto, err := strconv.ParseInt(arr[0], 10, 64)
|
||||
assert.NoError(err)
|
||||
assert.True(proto >= 0)
|
||||
|
||||
major, err := strconv.ParseInt(arr[1], 10, 64)
|
||||
assert.NoError(err)
|
||||
assert.True(major >= 0)
|
||||
|
||||
minor, err := strconv.ParseInt(arr[2], 10, 64)
|
||||
assert.NoError(err)
|
||||
assert.True(minor >= 0)
|
||||
}
|
||||
|
||||
func TestVersion(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
proto := Proto(Full())
|
||||
major := Major(Full())
|
||||
minor := Minor(Full())
|
||||
parseVerion := fmt.Sprintf("%d.%d.%d", proto, major, minor)
|
||||
version := Full()
|
||||
assert.Equal(parseVerion, version)
|
||||
}
|
||||
|
||||
func TestCompact(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
ok, _ := Compat("0.9.0")
|
||||
assert.False(ok)
|
||||
|
||||
ok, _ = Compat("10.0.0")
|
||||
assert.True(ok)
|
||||
|
||||
ok, _ = Compat("0.10.0")
|
||||
assert.False(ok)
|
||||
}
|
||||
206
pkg/util/vhost/http.go
Normal file
206
pkg/util/vhost/http.go
Normal file
@@ -0,0 +1,206 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package vhost
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
frpLog "github.com/fatedier/frp/pkg/util/log"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoDomain = errors.New("no such domain")
|
||||
)
|
||||
|
||||
type HTTPReverseProxyOptions struct {
|
||||
ResponseHeaderTimeoutS int64
|
||||
}
|
||||
|
||||
type HTTPReverseProxy struct {
|
||||
proxy *ReverseProxy
|
||||
vhostRouter *Routers
|
||||
|
||||
responseHeaderTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *HTTPReverseProxy {
|
||||
if option.ResponseHeaderTimeoutS <= 0 {
|
||||
option.ResponseHeaderTimeoutS = 60
|
||||
}
|
||||
rp := &HTTPReverseProxy{
|
||||
responseHeaderTimeout: time.Duration(option.ResponseHeaderTimeoutS) * time.Second,
|
||||
vhostRouter: vhostRouter,
|
||||
}
|
||||
proxy := &ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
req.URL.Scheme = "http"
|
||||
url := req.Context().Value(RouteInfoURL).(string)
|
||||
oldHost := util.GetHostFromAddr(req.Context().Value(RouteInfoHost).(string))
|
||||
host := rp.GetRealHost(oldHost, url)
|
||||
if host != "" {
|
||||
req.Host = host
|
||||
}
|
||||
req.URL.Host = req.Host
|
||||
|
||||
headers := rp.GetHeaders(oldHost, url)
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
},
|
||||
Transport: &http.Transport{
|
||||
ResponseHeaderTimeout: rp.responseHeaderTimeout,
|
||||
DisableKeepAlives: true,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
url := ctx.Value(RouteInfoURL).(string)
|
||||
host := util.GetHostFromAddr(ctx.Value(RouteInfoHost).(string))
|
||||
remote := ctx.Value(RouteInfoRemote).(string)
|
||||
return rp.CreateConnection(host, url, remote)
|
||||
},
|
||||
},
|
||||
BufferPool: newWrapPool(),
|
||||
ErrorLog: log.New(newWrapLogger(), "", 0),
|
||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
frpLog.Warn("do http proxy request error: %v", err)
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
rw.Write(getNotFoundPageContent())
|
||||
},
|
||||
}
|
||||
rp.proxy = proxy
|
||||
return rp
|
||||
}
|
||||
|
||||
// Register register the route config to reverse proxy
|
||||
// reverse proxy will use CreateConnFn from routeCfg to create a connection to the remote service
|
||||
func (rp *HTTPReverseProxy) Register(routeCfg RouteConfig) error {
|
||||
err := rp.vhostRouter.Add(routeCfg.Domain, routeCfg.Location, &routeCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnRegister unregister route config by domain and location
|
||||
func (rp *HTTPReverseProxy) UnRegister(domain string, location string) {
|
||||
rp.vhostRouter.Del(domain, location)
|
||||
}
|
||||
|
||||
func (rp *HTTPReverseProxy) GetRealHost(domain string, location string) (host string) {
|
||||
vr, ok := rp.getVhost(domain, location)
|
||||
if ok {
|
||||
host = vr.payload.(*RouteConfig).RewriteHost
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (rp *HTTPReverseProxy) GetHeaders(domain string, location string) (headers map[string]string) {
|
||||
vr, ok := rp.getVhost(domain, location)
|
||||
if ok {
|
||||
headers = vr.payload.(*RouteConfig).Headers
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CreateConnection create a new connection by route config
|
||||
func (rp *HTTPReverseProxy) CreateConnection(domain string, location string, remoteAddr string) (net.Conn, error) {
|
||||
vr, ok := rp.getVhost(domain, location)
|
||||
if ok {
|
||||
fn := vr.payload.(*RouteConfig).CreateConnFn
|
||||
if fn != nil {
|
||||
return fn(remoteAddr)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("%v: %s %s", ErrNoDomain, domain, location)
|
||||
}
|
||||
|
||||
func (rp *HTTPReverseProxy) CheckAuth(domain, location, user, passwd string) bool {
|
||||
vr, ok := rp.getVhost(domain, location)
|
||||
if ok {
|
||||
checkUser := vr.payload.(*RouteConfig).Username
|
||||
checkPasswd := vr.payload.(*RouteConfig).Password
|
||||
if (checkUser != "" || checkPasswd != "") && (checkUser != user || checkPasswd != passwd) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// getVhost get vhost router by domain and location
|
||||
func (rp *HTTPReverseProxy) getVhost(domain string, location string) (vr *Router, ok bool) {
|
||||
// first we check the full hostname
|
||||
// if not exist, then check the wildcard_domain such as *.example.com
|
||||
vr, ok = rp.vhostRouter.Get(domain, location)
|
||||
if ok {
|
||||
return
|
||||
}
|
||||
|
||||
domainSplit := strings.Split(domain, ".")
|
||||
if len(domainSplit) < 3 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for {
|
||||
if len(domainSplit) < 3 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
domainSplit[0] = "*"
|
||||
domain = strings.Join(domainSplit, ".")
|
||||
vr, ok = rp.vhostRouter.Get(domain, location)
|
||||
if ok {
|
||||
return vr, true
|
||||
}
|
||||
domainSplit = domainSplit[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func (rp *HTTPReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
domain := util.GetHostFromAddr(req.Host)
|
||||
location := req.URL.Path
|
||||
user, passwd, _ := req.BasicAuth()
|
||||
if !rp.CheckAuth(domain, location, user, passwd) {
|
||||
rw.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
|
||||
http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
rp.proxy.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
type wrapPool struct{}
|
||||
|
||||
func newWrapPool() *wrapPool { return &wrapPool{} }
|
||||
|
||||
func (p *wrapPool) Get() []byte { return pool.GetBuf(32 * 1024) }
|
||||
|
||||
func (p *wrapPool) Put(buf []byte) { pool.PutBuf(buf) }
|
||||
|
||||
type wrapLogger struct{}
|
||||
|
||||
func newWrapLogger() *wrapLogger { return &wrapLogger{} }
|
||||
|
||||
func (l *wrapLogger) Write(p []byte) (n int, err error) {
|
||||
frpLog.Warn("%s", string(bytes.TrimRight(p, "\n")))
|
||||
return len(p), nil
|
||||
}
|
||||
193
pkg/util/vhost/https.go
Normal file
193
pkg/util/vhost/https.go
Normal file
@@ -0,0 +1,193 @@
|
||||
// Copyright 2016 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package vhost
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gnet "github.com/fatedier/golib/net"
|
||||
"github.com/fatedier/golib/pool"
|
||||
)
|
||||
|
||||
const (
|
||||
typeClientHello uint8 = 1 // Type client hello
|
||||
)
|
||||
|
||||
// TLS extension numbers
|
||||
const (
|
||||
extensionServerName uint16 = 0
|
||||
extensionStatusRequest uint16 = 5
|
||||
extensionSupportedCurves uint16 = 10
|
||||
extensionSupportedPoints uint16 = 11
|
||||
extensionSignatureAlgorithms uint16 = 13
|
||||
extensionALPN uint16 = 16
|
||||
extensionSCT uint16 = 18
|
||||
extensionSessionTicket uint16 = 35
|
||||
extensionNextProtoNeg uint16 = 13172 // not IANA assigned
|
||||
extensionRenegotiationInfo uint16 = 0xff01
|
||||
)
|
||||
|
||||
type HTTPSMuxer struct {
|
||||
*Muxer
|
||||
}
|
||||
|
||||
func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, error) {
|
||||
mux, err := NewMuxer(listener, GetHTTPSHostname, nil, nil, nil, timeout)
|
||||
return &HTTPSMuxer{mux}, err
|
||||
}
|
||||
|
||||
func readHandshake(rd io.Reader) (host string, err error) {
|
||||
data := pool.GetBuf(1024)
|
||||
origin := data
|
||||
defer pool.PutBuf(origin)
|
||||
|
||||
_, err = io.ReadFull(rd, data[:47])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
length, err := rd.Read(data[47:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
length += 47
|
||||
data = data[:length]
|
||||
if uint8(data[5]) != typeClientHello {
|
||||
err = fmt.Errorf("readHandshake: type[%d] is not clientHello", uint16(data[5]))
|
||||
return
|
||||
}
|
||||
|
||||
// session
|
||||
sessionIDLen := int(data[43])
|
||||
if sessionIDLen > 32 || len(data) < 44+sessionIDLen {
|
||||
err = fmt.Errorf("readHandshake: sessionIdLen[%d] is long", sessionIDLen)
|
||||
return
|
||||
}
|
||||
data = data[44+sessionIDLen:]
|
||||
if len(data) < 2 {
|
||||
err = fmt.Errorf("readHandshake: dataLen[%d] after session is short", len(data))
|
||||
return
|
||||
}
|
||||
|
||||
// cipher suite numbers
|
||||
cipherSuiteLen := int(data[0])<<8 | int(data[1])
|
||||
if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
|
||||
err = fmt.Errorf("readHandshake: dataLen[%d] after cipher suite is short", len(data))
|
||||
return
|
||||
}
|
||||
data = data[2+cipherSuiteLen:]
|
||||
if len(data) < 1 {
|
||||
err = fmt.Errorf("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen)
|
||||
return
|
||||
}
|
||||
|
||||
// compression method
|
||||
compressionMethodsLen := int(data[0])
|
||||
if len(data) < 1+compressionMethodsLen {
|
||||
err = fmt.Errorf("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen)
|
||||
return
|
||||
}
|
||||
|
||||
data = data[1+compressionMethodsLen:]
|
||||
if len(data) == 0 {
|
||||
// ClientHello is optionally followed by extension data
|
||||
err = fmt.Errorf("readHandshake: there is no extension data to get servername")
|
||||
return
|
||||
}
|
||||
if len(data) < 2 {
|
||||
err = fmt.Errorf("readHandshake: extension dataLen[%d] is too short", len(data))
|
||||
return
|
||||
}
|
||||
|
||||
extensionsLength := int(data[0])<<8 | int(data[1])
|
||||
data = data[2:]
|
||||
if extensionsLength != len(data) {
|
||||
err = fmt.Errorf("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data))
|
||||
return
|
||||
}
|
||||
for len(data) != 0 {
|
||||
if len(data) < 4 {
|
||||
err = fmt.Errorf("readHandshake: extensionsDataLen[%d] is too short", len(data))
|
||||
return
|
||||
}
|
||||
extension := uint16(data[0])<<8 | uint16(data[1])
|
||||
length := int(data[2])<<8 | int(data[3])
|
||||
data = data[4:]
|
||||
if len(data) < length {
|
||||
err = fmt.Errorf("readHandshake: extensionLen[%d] is long", length)
|
||||
return
|
||||
}
|
||||
|
||||
switch extension {
|
||||
case extensionRenegotiationInfo:
|
||||
if length != 1 || data[0] != 0 {
|
||||
err = fmt.Errorf("readHandshake: extension reNegotiationInfoLen[%d] is short", length)
|
||||
return
|
||||
}
|
||||
case extensionNextProtoNeg:
|
||||
case extensionStatusRequest:
|
||||
case extensionServerName:
|
||||
d := data[:length]
|
||||
if len(d) < 2 {
|
||||
err = fmt.Errorf("readHandshake: remiaining dataLen[%d] is short", len(d))
|
||||
return
|
||||
}
|
||||
namesLen := int(d[0])<<8 | int(d[1])
|
||||
d = d[2:]
|
||||
if len(d) != namesLen {
|
||||
err = fmt.Errorf("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d))
|
||||
return
|
||||
}
|
||||
for len(d) > 0 {
|
||||
if len(d) < 3 {
|
||||
err = fmt.Errorf("readHandshake: extension serverNameLen[%d] is short", len(d))
|
||||
return
|
||||
}
|
||||
nameType := d[0]
|
||||
nameLen := int(d[1])<<8 | int(d[2])
|
||||
d = d[3:]
|
||||
if len(d) < nameLen {
|
||||
err = fmt.Errorf("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d))
|
||||
return
|
||||
}
|
||||
if nameType == 0 {
|
||||
serverName := string(d[:nameLen])
|
||||
host = strings.TrimSpace(serverName)
|
||||
return host, nil
|
||||
}
|
||||
d = d[nameLen:]
|
||||
}
|
||||
}
|
||||
data = data[length:]
|
||||
}
|
||||
err = fmt.Errorf("Unknown error")
|
||||
return
|
||||
}
|
||||
|
||||
func GetHTTPSHostname(c net.Conn) (_ net.Conn, _ map[string]string, err error) {
|
||||
reqInfoMap := make(map[string]string, 0)
|
||||
sc, rd := gnet.NewSharedConn(c)
|
||||
host, err := readHandshake(rd)
|
||||
if err != nil {
|
||||
return nil, reqInfoMap, err
|
||||
}
|
||||
reqInfoMap["Host"] = host
|
||||
reqInfoMap["Scheme"] = "https"
|
||||
return sc, reqInfoMap, nil
|
||||
}
|
||||
100
pkg/util/vhost/resource.go
Normal file
100
pkg/util/vhost/resource.go
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package vhost
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
frpLog "github.com/fatedier/frp/pkg/util/log"
|
||||
"github.com/fatedier/frp/pkg/util/version"
|
||||
)
|
||||
|
||||
var (
|
||||
NotFoundPagePath = ""
|
||||
)
|
||||
|
||||
const (
|
||||
NotFound = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Not Found</title>
|
||||
<style>
|
||||
body {
|
||||
width: 35em;
|
||||
margin: 0 auto;
|
||||
font-family: Tahoma, Verdana, Arial, sans-serif;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>The page you requested was not found.</h1>
|
||||
<p>Sorry, the page you are looking for is currently unavailable.<br/>
|
||||
Please try again later.</p>
|
||||
<p>The server is powered by <a href="https://github.com/fatedier/frp">frp</a>.</p>
|
||||
<p><em>Faithfully yours, frp.</em></p>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
)
|
||||
|
||||
func getNotFoundPageContent() []byte {
|
||||
var (
|
||||
buf []byte
|
||||
err error
|
||||
)
|
||||
if NotFoundPagePath != "" {
|
||||
buf, err = ioutil.ReadFile(NotFoundPagePath)
|
||||
if err != nil {
|
||||
frpLog.Warn("read custom 404 page error: %v", err)
|
||||
buf = []byte(NotFound)
|
||||
}
|
||||
} else {
|
||||
buf = []byte(NotFound)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
func notFoundResponse() *http.Response {
|
||||
header := make(http.Header)
|
||||
header.Set("server", "frp/"+version.Full())
|
||||
header.Set("Content-Type", "text/html")
|
||||
|
||||
res := &http.Response{
|
||||
Status: "Not Found",
|
||||
StatusCode: 404,
|
||||
Proto: "HTTP/1.0",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 0,
|
||||
Header: header,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(getNotFoundPageContent())),
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func noAuthResponse() *http.Response {
|
||||
header := make(map[string][]string)
|
||||
header["WWW-Authenticate"] = []string{`Basic realm="Restricted"`}
|
||||
res := &http.Response{
|
||||
Status: "401 Not authorized",
|
||||
StatusCode: 401,
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: header,
|
||||
}
|
||||
return res
|
||||
}
|
||||
563
pkg/util/vhost/reverseproxy.go
Normal file
563
pkg/util/vhost/reverseproxy.go
Normal file
@@ -0,0 +1,563 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// HTTP reverse proxy handler
|
||||
|
||||
package vhost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
// ReverseProxy is an HTTP Handler that takes an incoming request and
|
||||
// sends it to another server, proxying the response back to the
|
||||
// client.
|
||||
type ReverseProxy struct {
|
||||
// Director must be a function which modifies
|
||||
// the request into a new request to be sent
|
||||
// using Transport. Its response is then copied
|
||||
// back to the original client unmodified.
|
||||
// Director must not access the provided Request
|
||||
// after returning.
|
||||
Director func(*http.Request)
|
||||
|
||||
// The transport used to perform proxy requests.
|
||||
// If nil, http.DefaultTransport is used.
|
||||
Transport http.RoundTripper
|
||||
|
||||
// FlushInterval specifies the flush interval
|
||||
// to flush to the client while copying the
|
||||
// response body.
|
||||
// If zero, no periodic flushing is done.
|
||||
// A negative value means to flush immediately
|
||||
// after each write to the client.
|
||||
// The FlushInterval is ignored when ReverseProxy
|
||||
// recognizes a response as a streaming response;
|
||||
// for such responses, writes are flushed to the client
|
||||
// immediately.
|
||||
FlushInterval time.Duration
|
||||
|
||||
// ErrorLog specifies an optional logger for errors
|
||||
// that occur when attempting to proxy the request.
|
||||
// If nil, logging is done via the log package's standard logger.
|
||||
ErrorLog *log.Logger
|
||||
|
||||
// BufferPool optionally specifies a buffer pool to
|
||||
// get byte slices for use by io.CopyBuffer when
|
||||
// copying HTTP response bodies.
|
||||
BufferPool BufferPool
|
||||
|
||||
// ModifyResponse is an optional function that modifies the
|
||||
// Response from the backend. It is called if the backend
|
||||
// returns a response at all, with any HTTP status code.
|
||||
// If the backend is unreachable, the optional ErrorHandler is
|
||||
// called without any call to ModifyResponse.
|
||||
//
|
||||
// If ModifyResponse returns an error, ErrorHandler is called
|
||||
// with its error value. If ErrorHandler is nil, its default
|
||||
// implementation is used.
|
||||
ModifyResponse func(*http.Response) error
|
||||
|
||||
// ErrorHandler is an optional function that handles errors
|
||||
// reaching the backend or errors from ModifyResponse.
|
||||
//
|
||||
// If nil, the default is to log the provided error and return
|
||||
// a 502 Status Bad Gateway response.
|
||||
ErrorHandler func(http.ResponseWriter, *http.Request, error)
|
||||
}
|
||||
|
||||
// A BufferPool is an interface for getting and returning temporary
|
||||
// byte slices for use by io.CopyBuffer.
|
||||
type BufferPool interface {
|
||||
Get() []byte
|
||||
Put([]byte)
|
||||
}
|
||||
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
aslash := strings.HasSuffix(a, "/")
|
||||
bslash := strings.HasPrefix(b, "/")
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a + b[1:]
|
||||
case !aslash && !bslash:
|
||||
return a + "/" + b
|
||||
}
|
||||
return a + b
|
||||
}
|
||||
|
||||
// NewSingleHostReverseProxy returns a new ReverseProxy that routes
|
||||
// URLs to the scheme, host, and base path provided in target. If the
|
||||
// target's path is "/base" and the incoming request was for "/dir",
|
||||
// the target request will be for /base/dir.
|
||||
// NewSingleHostReverseProxy does not rewrite the Host header.
|
||||
// To rewrite Host headers, use ReverseProxy directly with a custom
|
||||
// Director policy.
|
||||
func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
|
||||
targetQuery := target.RawQuery
|
||||
director := func(req *http.Request) {
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
|
||||
if targetQuery == "" || req.URL.RawQuery == "" {
|
||||
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
||||
} else {
|
||||
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
|
||||
}
|
||||
if _, ok := req.Header["User-Agent"]; !ok {
|
||||
// explicitly disable User-Agent so it's not set to default value
|
||||
req.Header.Set("User-Agent", "")
|
||||
}
|
||||
}
|
||||
return &ReverseProxy{Director: director}
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// As of RFC 7230, hop-by-hop headers are required to appear in the
|
||||
// Connection header field. These are the headers defined by the
|
||||
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
|
||||
// compatibility.
|
||||
var hopHeaders = []string{
|
||||
"Connection",
|
||||
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
|
||||
if p.ErrorHandler != nil {
|
||||
return p.ErrorHandler
|
||||
}
|
||||
return p.defaultErrorHandler
|
||||
}
|
||||
|
||||
// modifyResponse conditionally runs the optional ModifyResponse hook
|
||||
// and reports whether the request should proceed.
|
||||
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
|
||||
if p.ModifyResponse == nil {
|
||||
return true
|
||||
}
|
||||
if err := p.ModifyResponse(res); err != nil {
|
||||
res.Body.Close()
|
||||
p.getErrorHandler()(rw, req, err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
transport := p.Transport
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
|
||||
ctx := req.Context()
|
||||
if cn, ok := rw.(http.CloseNotifier); ok {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
notifyChan := cn.CloseNotify()
|
||||
go func() {
|
||||
select {
|
||||
case <-notifyChan:
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
outreq := req.WithContext(ctx)
|
||||
if req.ContentLength == 0 {
|
||||
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
|
||||
}
|
||||
|
||||
// =============================
|
||||
// Modified for frp
|
||||
outreq = outreq.WithContext(context.WithValue(outreq.Context(), RouteInfoURL, req.URL.Path))
|
||||
outreq = outreq.WithContext(context.WithValue(outreq.Context(), RouteInfoHost, req.Host))
|
||||
outreq = outreq.WithContext(context.WithValue(outreq.Context(), RouteInfoRemote, req.RemoteAddr))
|
||||
// =============================
|
||||
|
||||
p.Director(outreq)
|
||||
outreq.Close = false
|
||||
|
||||
reqUpType := upgradeType(outreq.Header)
|
||||
removeConnectionHeaders(outreq.Header)
|
||||
|
||||
// Remove hop-by-hop headers to the backend. Especially
|
||||
// important is "Connection" because we want a persistent
|
||||
// connection, regardless of what the client sent to us.
|
||||
for _, h := range hopHeaders {
|
||||
hv := outreq.Header.Get(h)
|
||||
if hv == "" {
|
||||
continue
|
||||
}
|
||||
if h == "Te" && hv == "trailers" {
|
||||
// Issue 21096: tell backend applications that
|
||||
// care about trailer support that we support
|
||||
// trailers. (We do, but we don't go out of
|
||||
// our way to advertise that unless the
|
||||
// incoming client request thought it was
|
||||
// worth mentioning)
|
||||
continue
|
||||
}
|
||||
outreq.Header.Del(h)
|
||||
}
|
||||
|
||||
// After stripping all the hop-by-hop connection headers above, add back any
|
||||
// necessary for protocol upgrades, such as for websockets.
|
||||
if reqUpType != "" {
|
||||
outreq.Header.Set("Connection", "Upgrade")
|
||||
outreq.Header.Set("Upgrade", reqUpType)
|
||||
}
|
||||
|
||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
// If we aren't the first proxy retain prior
|
||||
// X-Forwarded-For information as a comma+space
|
||||
// separated list and fold multiple headers into one.
|
||||
if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
outreq.Header.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
|
||||
res, err := transport.RoundTrip(outreq)
|
||||
if err != nil {
|
||||
p.getErrorHandler()(rw, outreq, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||||
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||
if !p.modifyResponse(rw, res, outreq) {
|
||||
return
|
||||
}
|
||||
p.handleUpgradeResponse(rw, outreq, res)
|
||||
return
|
||||
}
|
||||
|
||||
removeConnectionHeaders(res.Header)
|
||||
|
||||
for _, h := range hopHeaders {
|
||||
res.Header.Del(h)
|
||||
}
|
||||
|
||||
if !p.modifyResponse(rw, res, outreq) {
|
||||
return
|
||||
}
|
||||
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
|
||||
// The "Trailer" header isn't included in the Transport's response,
|
||||
// at least for *http.Transport. Build it up from Trailer.
|
||||
announcedTrailers := len(res.Trailer)
|
||||
if announcedTrailers > 0 {
|
||||
trailerKeys := make([]string, 0, len(res.Trailer))
|
||||
for k := range res.Trailer {
|
||||
trailerKeys = append(trailerKeys, k)
|
||||
}
|
||||
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
|
||||
}
|
||||
|
||||
rw.WriteHeader(res.StatusCode)
|
||||
|
||||
err = p.copyResponse(rw, res.Body, p.flushInterval(req, res))
|
||||
if err != nil {
|
||||
defer res.Body.Close()
|
||||
// Since we're streaming the response, if we run into an error all we can do
|
||||
// is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler
|
||||
// on read error while copying body.
|
||||
if !shouldPanicOnCopyError(req) {
|
||||
p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
|
||||
return
|
||||
}
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
res.Body.Close() // close now, instead of defer, to populate res.Trailer
|
||||
|
||||
if len(res.Trailer) > 0 {
|
||||
// Force chunking if we saw a response trailer.
|
||||
// This prevents net/http from calculating the length for short
|
||||
// bodies and adding a Content-Length.
|
||||
if fl, ok := rw.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if len(res.Trailer) == announcedTrailers {
|
||||
copyHeader(rw.Header(), res.Trailer)
|
||||
return
|
||||
}
|
||||
|
||||
for k, vv := range res.Trailer {
|
||||
k = http.TrailerPrefix + k
|
||||
for _, v := range vv {
|
||||
rw.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var inOurTests bool // whether we're in our own tests
|
||||
|
||||
// shouldPanicOnCopyError reports whether the reverse proxy should
|
||||
// panic with http.ErrAbortHandler. This is the right thing to do by
|
||||
// default, but Go 1.10 and earlier did not, so existing unit tests
|
||||
// weren't expecting panics. Only panic in our own tests, or when
|
||||
// running under the HTTP server.
|
||||
func shouldPanicOnCopyError(req *http.Request) bool {
|
||||
if inOurTests {
|
||||
// Our tests know to handle this panic.
|
||||
return true
|
||||
}
|
||||
if req.Context().Value(http.ServerContextKey) != nil {
|
||||
// We seem to be running under an HTTP server, so
|
||||
// it'll recover the panic.
|
||||
return true
|
||||
}
|
||||
// Otherwise act like Go 1.10 and earlier to not break
|
||||
// existing tests.
|
||||
return false
|
||||
}
|
||||
|
||||
// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h.
|
||||
// See RFC 7230, section 6.1
|
||||
func removeConnectionHeaders(h http.Header) {
|
||||
for _, f := range h["Connection"] {
|
||||
for _, sf := range strings.Split(f, ",") {
|
||||
if sf = strings.TrimSpace(sf); sf != "" {
|
||||
h.Del(sf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flushInterval returns the p.FlushInterval value, conditionally
|
||||
// overriding its value for a specific request/response.
|
||||
func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time.Duration {
|
||||
resCT := res.Header.Get("Content-Type")
|
||||
|
||||
// For Server-Sent Events responses, flush immediately.
|
||||
// The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream
|
||||
if resCT == "text/event-stream" {
|
||||
return -1 // negative means immediately
|
||||
}
|
||||
|
||||
// TODO: more specific cases? e.g. res.ContentLength == -1?
|
||||
return p.FlushInterval
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
|
||||
if flushInterval != 0 {
|
||||
if wf, ok := dst.(writeFlusher); ok {
|
||||
mlw := &maxLatencyWriter{
|
||||
dst: wf,
|
||||
latency: flushInterval,
|
||||
}
|
||||
defer mlw.stop()
|
||||
|
||||
// set up initial timer so headers get flushed even if body writes are delayed
|
||||
mlw.flushPending = true
|
||||
mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
|
||||
|
||||
dst = mlw
|
||||
}
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
if p.BufferPool != nil {
|
||||
buf = p.BufferPool.Get()
|
||||
defer p.BufferPool.Put(buf)
|
||||
}
|
||||
_, err := p.copyBuffer(dst, src, buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// copyBuffer returns any write errors or non-EOF read errors, and the amount
|
||||
// of bytes written.
|
||||
func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
|
||||
if len(buf) == 0 {
|
||||
buf = make([]byte, 32*1024)
|
||||
}
|
||||
var written int64
|
||||
for {
|
||||
nr, rerr := src.Read(buf)
|
||||
if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
|
||||
p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
|
||||
}
|
||||
if nr > 0 {
|
||||
nw, werr := dst.Write(buf[:nr])
|
||||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
}
|
||||
if werr != nil {
|
||||
return written, werr
|
||||
}
|
||||
if nr != nw {
|
||||
return written, io.ErrShortWrite
|
||||
}
|
||||
}
|
||||
if rerr != nil {
|
||||
if rerr == io.EOF {
|
||||
rerr = nil
|
||||
}
|
||||
return written, rerr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) logf(format string, args ...interface{}) {
|
||||
if p.ErrorLog != nil {
|
||||
p.ErrorLog.Printf(format, args...)
|
||||
} else {
|
||||
log.Printf(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
type writeFlusher interface {
|
||||
io.Writer
|
||||
http.Flusher
|
||||
}
|
||||
|
||||
type maxLatencyWriter struct {
|
||||
dst writeFlusher
|
||||
latency time.Duration // non-zero; negative means to flush immediately
|
||||
|
||||
mu sync.Mutex // protects t, flushPending, and dst.Flush
|
||||
t *time.Timer
|
||||
flushPending bool
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
n, err = m.dst.Write(p)
|
||||
if m.latency < 0 {
|
||||
m.dst.Flush()
|
||||
return
|
||||
}
|
||||
if m.flushPending {
|
||||
return
|
||||
}
|
||||
if m.t == nil {
|
||||
m.t = time.AfterFunc(m.latency, m.delayedFlush)
|
||||
} else {
|
||||
m.t.Reset(m.latency)
|
||||
}
|
||||
m.flushPending = true
|
||||
return
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) delayedFlush() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
|
||||
return
|
||||
}
|
||||
m.dst.Flush()
|
||||
m.flushPending = false
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) stop() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.flushPending = false
|
||||
if m.t != nil {
|
||||
m.t.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func upgradeType(h http.Header) string {
|
||||
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(h.Get("Upgrade"))
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
|
||||
reqUpType := upgradeType(req.Header)
|
||||
resUpType := upgradeType(res.Header)
|
||||
if reqUpType != resUpType {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
|
||||
return
|
||||
}
|
||||
|
||||
copyHeader(res.Header, rw.Header())
|
||||
|
||||
hj, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
|
||||
return
|
||||
}
|
||||
backConn, ok := res.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
|
||||
return
|
||||
}
|
||||
defer backConn.Close()
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
|
||||
if err := res.Write(brw); err != nil {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
|
||||
return
|
||||
}
|
||||
if err := brw.Flush(); err != nil {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
|
||||
return
|
||||
}
|
||||
errc := make(chan error, 1)
|
||||
spc := switchProtocolCopier{user: conn, backend: backConn}
|
||||
go spc.copyToBackend(errc)
|
||||
go spc.copyFromBackend(errc)
|
||||
<-errc
|
||||
return
|
||||
}
|
||||
|
||||
// switchProtocolCopier exists so goroutines proxying data back and
|
||||
// forth have nice names in stacks.
|
||||
type switchProtocolCopier struct {
|
||||
user, backend io.ReadWriter
|
||||
}
|
||||
|
||||
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
|
||||
_, err := io.Copy(c.user, c.backend)
|
||||
errc <- err
|
||||
}
|
||||
|
||||
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
|
||||
_, err := io.Copy(c.backend, c.user)
|
||||
errc <- err
|
||||
}
|
||||
119
pkg/util/vhost/router.go
Normal file
119
pkg/util/vhost/router.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package vhost
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRouterConfigConflict = errors.New("router config conflict")
|
||||
)
|
||||
|
||||
type Routers struct {
|
||||
RouterByDomain map[string][]*Router
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type Router struct {
|
||||
domain string
|
||||
location string
|
||||
|
||||
payload interface{}
|
||||
}
|
||||
|
||||
func NewRouters() *Routers {
|
||||
return &Routers{
|
||||
RouterByDomain: make(map[string][]*Router),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Routers) Add(domain, location string, payload interface{}) error {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
if _, exist := r.exist(domain, location); exist {
|
||||
return ErrRouterConfigConflict
|
||||
}
|
||||
|
||||
vrs, found := r.RouterByDomain[domain]
|
||||
if !found {
|
||||
vrs = make([]*Router, 0, 1)
|
||||
}
|
||||
|
||||
vr := &Router{
|
||||
domain: domain,
|
||||
location: location,
|
||||
payload: payload,
|
||||
}
|
||||
vrs = append(vrs, vr)
|
||||
|
||||
sort.Sort(sort.Reverse(ByLocation(vrs)))
|
||||
r.RouterByDomain[domain] = vrs
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Routers) Del(domain, location string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
vrs, found := r.RouterByDomain[domain]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
newVrs := make([]*Router, 0)
|
||||
for _, vr := range vrs {
|
||||
if vr.location != location {
|
||||
newVrs = append(newVrs, vr)
|
||||
}
|
||||
}
|
||||
r.RouterByDomain[domain] = newVrs
|
||||
}
|
||||
|
||||
func (r *Routers) Get(host, path string) (vr *Router, exist bool) {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
vrs, found := r.RouterByDomain[host]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
// can't support load balance, will to do
|
||||
for _, vr = range vrs {
|
||||
if strings.HasPrefix(path, vr.location) {
|
||||
return vr, true
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *Routers) exist(host, path string) (vr *Router, exist bool) {
|
||||
vrs, found := r.RouterByDomain[host]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
for _, vr = range vrs {
|
||||
if path == vr.location {
|
||||
return vr, true
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// sort by location
|
||||
type ByLocation []*Router
|
||||
|
||||
func (a ByLocation) Len() int {
|
||||
return len(a)
|
||||
}
|
||||
func (a ByLocation) Swap(i, j int) {
|
||||
a[i], a[j] = a[j], a[i]
|
||||
}
|
||||
func (a ByLocation) Less(i, j int) bool {
|
||||
return strings.Compare(a[i].location, a[j].location) < 0
|
||||
}
|
||||
245
pkg/util/vhost/vhost.go
Normal file
245
pkg/util/vhost/vhost.go
Normal file
@@ -0,0 +1,245 @@
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package vhost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
frpNet "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
)
|
||||
|
||||
type RouteInfo string
|
||||
|
||||
const (
|
||||
RouteInfoURL RouteInfo = "url"
|
||||
RouteInfoHost RouteInfo = "host"
|
||||
RouteInfoRemote RouteInfo = "remote"
|
||||
)
|
||||
|
||||
type muxFunc func(net.Conn) (net.Conn, map[string]string, error)
|
||||
type httpAuthFunc func(net.Conn, string, string, string) (bool, error)
|
||||
type hostRewriteFunc func(net.Conn, string) (net.Conn, error)
|
||||
type successFunc func(net.Conn) error
|
||||
|
||||
type Muxer struct {
|
||||
listener net.Listener
|
||||
timeout time.Duration
|
||||
vhostFunc muxFunc
|
||||
authFunc httpAuthFunc
|
||||
successFunc successFunc
|
||||
rewriteFunc hostRewriteFunc
|
||||
registryRouter *Routers
|
||||
}
|
||||
|
||||
func NewMuxer(listener net.Listener, vhostFunc muxFunc, authFunc httpAuthFunc, successFunc successFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *Muxer, err error) {
|
||||
mux = &Muxer{
|
||||
listener: listener,
|
||||
timeout: timeout,
|
||||
vhostFunc: vhostFunc,
|
||||
authFunc: authFunc,
|
||||
successFunc: successFunc,
|
||||
rewriteFunc: rewriteFunc,
|
||||
registryRouter: NewRouters(),
|
||||
}
|
||||
go mux.run()
|
||||
return mux, nil
|
||||
}
|
||||
|
||||
type CreateConnFunc func(remoteAddr string) (net.Conn, error)
|
||||
|
||||
// RouteConfig is the params used to match HTTP requests
|
||||
type RouteConfig struct {
|
||||
Domain string
|
||||
Location string
|
||||
RewriteHost string
|
||||
Username string
|
||||
Password string
|
||||
Headers map[string]string
|
||||
|
||||
CreateConnFn CreateConnFunc
|
||||
}
|
||||
|
||||
// listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil
|
||||
// then rewrite the host header to rewriteHost
|
||||
func (v *Muxer) Listen(ctx context.Context, cfg *RouteConfig) (l *Listener, err error) {
|
||||
l = &Listener{
|
||||
name: cfg.Domain,
|
||||
location: cfg.Location,
|
||||
rewriteHost: cfg.RewriteHost,
|
||||
userName: cfg.Username,
|
||||
passWord: cfg.Password,
|
||||
mux: v,
|
||||
accept: make(chan net.Conn),
|
||||
ctx: ctx,
|
||||
}
|
||||
err = v.registryRouter.Add(cfg.Domain, cfg.Location, l)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (v *Muxer) getListener(name, path string) (l *Listener, exist bool) {
|
||||
// first we check the full hostname
|
||||
// if not exist, then check the wildcard_domain such as *.example.com
|
||||
vr, found := v.registryRouter.Get(name, path)
|
||||
if found {
|
||||
return vr.payload.(*Listener), true
|
||||
}
|
||||
|
||||
domainSplit := strings.Split(name, ".")
|
||||
if len(domainSplit) < 3 {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if len(domainSplit) < 3 {
|
||||
return
|
||||
}
|
||||
|
||||
domainSplit[0] = "*"
|
||||
name = strings.Join(domainSplit, ".")
|
||||
|
||||
vr, found = v.registryRouter.Get(name, path)
|
||||
if found {
|
||||
return vr.payload.(*Listener), true
|
||||
}
|
||||
domainSplit = domainSplit[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Muxer) run() {
|
||||
for {
|
||||
conn, err := v.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go v.handle(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Muxer) handle(c net.Conn) {
|
||||
if err := c.SetDeadline(time.Now().Add(v.timeout)); err != nil {
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
sConn, reqInfoMap, err := v.vhostFunc(c)
|
||||
if err != nil {
|
||||
log.Warn("get hostname from http/https request error: %v", err)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
name := strings.ToLower(reqInfoMap["Host"])
|
||||
path := strings.ToLower(reqInfoMap["Path"])
|
||||
l, ok := v.getListener(name, path)
|
||||
if !ok {
|
||||
res := notFoundResponse()
|
||||
res.Write(c)
|
||||
log.Debug("http request for host [%s] path [%s] not found", name, path)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
xl := xlog.FromContextSafe(l.ctx)
|
||||
if v.successFunc != nil {
|
||||
if err := v.successFunc(c); err != nil {
|
||||
xl.Info("success func failure on vhost connection: %v", err)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// if authFunc is exist and userName/password is set
|
||||
// then verify user access
|
||||
if l.mux.authFunc != nil && l.userName != "" && l.passWord != "" {
|
||||
bAccess, err := l.mux.authFunc(c, l.userName, l.passWord, reqInfoMap["Authorization"])
|
||||
if bAccess == false || err != nil {
|
||||
xl.Debug("check http Authorization failed")
|
||||
res := noAuthResponse()
|
||||
res.Write(c)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err = sConn.SetDeadline(time.Time{}); err != nil {
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
c = sConn
|
||||
|
||||
xl.Debug("get new http request host [%s] path [%s]", name, path)
|
||||
err = errors.PanicToError(func() {
|
||||
l.accept <- c
|
||||
})
|
||||
if err != nil {
|
||||
xl.Warn("listener is already closed, ignore this request")
|
||||
}
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
name string
|
||||
location string
|
||||
rewriteHost string
|
||||
userName string
|
||||
passWord string
|
||||
mux *Muxer // for closing Muxer
|
||||
accept chan net.Conn
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (net.Conn, error) {
|
||||
xl := xlog.FromContextSafe(l.ctx)
|
||||
conn, ok := <-l.accept
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Listener closed")
|
||||
}
|
||||
|
||||
// if rewriteFunc is exist
|
||||
// rewrite http requests with a modified host header
|
||||
// if l.rewriteHost is empty, nothing to do
|
||||
if l.mux.rewriteFunc != nil {
|
||||
sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost)
|
||||
if err != nil {
|
||||
xl.Warn("host header rewrite failed: %v", err)
|
||||
return nil, fmt.Errorf("host header rewrite failed")
|
||||
}
|
||||
xl.Debug("rewrite host to [%s] success", l.rewriteHost)
|
||||
conn = sConn
|
||||
}
|
||||
return frpNet.NewContextConn(l.ctx, conn), nil
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
l.mux.registryRouter.Del(l.name, l.location)
|
||||
close(l.accept)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) Name() string {
|
||||
return l.name
|
||||
}
|
||||
|
||||
func (l *Listener) Addr() net.Addr {
|
||||
return (*net.TCPAddr)(nil)
|
||||
}
|
||||
42
pkg/util/xlog/ctx.go
Normal file
42
pkg/util/xlog/ctx.go
Normal file
@@ -0,0 +1,42 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package xlog
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type key int
|
||||
|
||||
const (
|
||||
xlogKey key = 0
|
||||
)
|
||||
|
||||
func NewContext(ctx context.Context, xl *Logger) context.Context {
|
||||
return context.WithValue(ctx, xlogKey, xl)
|
||||
}
|
||||
|
||||
func FromContext(ctx context.Context) (xl *Logger, ok bool) {
|
||||
xl, ok = ctx.Value(xlogKey).(*Logger)
|
||||
return
|
||||
}
|
||||
|
||||
func FromContextSafe(ctx context.Context) *Logger {
|
||||
xl, ok := ctx.Value(xlogKey).(*Logger)
|
||||
if !ok {
|
||||
xl = New()
|
||||
}
|
||||
return xl
|
||||
}
|
||||
73
pkg/util/xlog/xlog.go
Normal file
73
pkg/util/xlog/xlog.go
Normal file
@@ -0,0 +1,73 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package xlog
|
||||
|
||||
import (
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
)
|
||||
|
||||
// Logger is not thread safety for operations on prefix
|
||||
type Logger struct {
|
||||
prefixes []string
|
||||
|
||||
prefixString string
|
||||
}
|
||||
|
||||
func New() *Logger {
|
||||
return &Logger{
|
||||
prefixes: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) ResetPrefixes() (old []string) {
|
||||
old = l.prefixes
|
||||
l.prefixes = make([]string, 0)
|
||||
l.prefixString = ""
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Logger) AppendPrefix(prefix string) *Logger {
|
||||
l.prefixes = append(l.prefixes, prefix)
|
||||
l.prefixString += "[" + prefix + "] "
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *Logger) Spawn() *Logger {
|
||||
nl := New()
|
||||
for _, v := range l.prefixes {
|
||||
nl.AppendPrefix(v)
|
||||
}
|
||||
return nl
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string, v ...interface{}) {
|
||||
log.Log.Error(l.prefixString+format, v...)
|
||||
}
|
||||
|
||||
func (l *Logger) Warn(format string, v ...interface{}) {
|
||||
log.Log.Warn(l.prefixString+format, v...)
|
||||
}
|
||||
|
||||
func (l *Logger) Info(format string, v ...interface{}) {
|
||||
log.Log.Info(l.prefixString+format, v...)
|
||||
}
|
||||
|
||||
func (l *Logger) Debug(format string, v ...interface{}) {
|
||||
log.Log.Debug(l.prefixString+format, v...)
|
||||
}
|
||||
|
||||
func (l *Logger) Trace(format string, v ...interface{}) {
|
||||
log.Log.Trace(l.prefixString+format, v...)
|
||||
}
|
||||
Reference in New Issue
Block a user