start refactoring

This commit is contained in:
fatedier
2017-03-09 02:03:47 +08:00
parent b006540141
commit 88083d21e8
115 changed files with 5515 additions and 3491 deletions

View File

@@ -0,0 +1,52 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package crypto
import (
"bytes"
"io"
"testing"
"github.com/stretchr/testify/assert"
)
func TestWriter(t *testing.T) {
// Empty key.
assert := assert.New(t)
key := ""
buffer := bytes.NewBuffer(nil)
_, err := NewWriter(buffer, []byte(key))
assert.NoError(err)
}
func TestCrypto(t *testing.T) {
assert := assert.New(t)
text := "1234567890abcdefghigklmnopqrstuvwxyzeeeeeeeeeeeeeeeeeeeeeewwwwwwwwwwwwwwwwwwwwwwwwwwzzzzzzzzzzzzzzzzzzzzzzzzdddddddddddddddddddddddddddddddddddddrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrllllllllllllllllllllllllllllllllllqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeewwwwwwwwwwwwwwwwwwwwww"
key := "123456"
buffer := bytes.NewBuffer(nil)
encWriter, err := NewWriter(buffer, []byte(key))
assert.NoError(err)
encWriter.Write([]byte(text))
decReader, err := NewReader(buffer, []byte(key))
assert.NoError(err)
c := bytes.NewBuffer(nil)
io.Copy(c, decReader)
assert.Equal(text, string(c.Bytes()))
}

75
utils/crypto/decode.go Normal file
View File

@@ -0,0 +1,75 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/sha1"
"io"
"golang.org/x/crypto/pbkdf2"
)
// NewReader returns a new Reader that decrypts bytes from r
func NewReader(r io.Reader, key []byte) (*Reader, error) {
key = pbkdf2.Key(key, []byte(salt), 64, aes.BlockSize, sha1.New)
return &Reader{
r: r,
key: key,
}, nil
}
// Reader is an io.Reader that can read encrypted bytes.
// Now it only supports aes-128-cfb.
type Reader struct {
r io.Reader
dec *cipher.StreamReader
key []byte
iv []byte
err error
}
// Read satisfies the io.Reader interface.
func (r *Reader) Read(p []byte) (nRet int, errRet error) {
if r.err != nil {
return 0, r.err
}
if r.dec == nil {
iv := make([]byte, aes.BlockSize)
if _, errRet = io.ReadFull(r.r, iv); errRet != nil {
return
}
r.iv = iv
block, err := aes.NewCipher(r.key)
if err != nil {
errRet = err
return
}
r.dec = &cipher.StreamReader{
S: cipher.NewCFBDecrypter(block, iv),
R: r.r,
}
}
nRet, errRet = r.dec.Read(p)
if errRet != nil {
r.err = errRet
}
return
}

93
utils/crypto/encode.go Normal file
View File

@@ -0,0 +1,93 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha1"
"io"
"golang.org/x/crypto/pbkdf2"
)
const (
salt = "frp"
)
// NewWriter returns a new Writer that encrypts bytes to w.
func NewWriter(w io.Writer, key []byte) (*Writer, error) {
key = pbkdf2.Key(key, []byte(salt), 64, aes.BlockSize, sha1.New)
// random iv
iv := make([]byte, aes.BlockSize)
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return &Writer{
w: w,
enc: &cipher.StreamWriter{
S: cipher.NewCFBEncrypter(block, iv),
W: w,
},
key: key,
iv: iv,
}, nil
}
// Writer is an io.Writer that can write encrypted bytes.
// Now it only support aes-128-cfb.
type Writer struct {
w io.Writer
enc *cipher.StreamWriter
key []byte
iv []byte
ivSend bool
err error
}
// Write satisfies the io.Writer interface.
func (w *Writer) Write(p []byte) (nRet int, errRet error) {
return w.write(p)
}
func (w *Writer) write(p []byte) (nRet int, errRet error) {
if w.err != nil {
return 0, w.err
}
// When write is first called, iv will be written to w.w
if !w.ivSend {
w.ivSend = true
_, errRet = w.w.Write(w.iv)
if errRet != nil {
w.err = errRet
return
}
}
nRet, errRet = w.enc.Write(p)
if errRet != nil {
w.err = errRet
}
return
}

35
utils/errors/errors.go Normal file
View File

@@ -0,0 +1,35 @@
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package errors
import (
"errors"
"fmt"
)
var (
ErrMsgType = errors.New("message type error")
)
func PanicToError(fn func()) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("Panic error: %v", r)
}
}()
fn()
return
}

131
utils/log/log.go Normal file
View File

@@ -0,0 +1,131 @@
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package log
import (
"fmt"
"github.com/fatedier/beego/logs"
)
var Log *logs.BeeLogger
func init() {
Log = logs.NewLogger(100)
Log.EnableFuncCallDepth(true)
Log.SetLogFuncCallDepth(Log.GetLogFuncCallDepth() + 1)
}
func InitLog(logWay string, logFile string, logLevel string, maxdays int64) {
SetLogFile(logWay, logFile, maxdays)
SetLogLevel(logLevel)
}
// logWay: file or console
func SetLogFile(logWay string, logFile string, maxdays int64) {
if logWay == "console" {
Log.SetLogger("console", "")
} else {
params := fmt.Sprintf(`{"filename": "%s", "maxdays": %d}`, logFile, maxdays)
Log.SetLogger("file", params)
}
}
// value: error, warning, info, debug
func SetLogLevel(logLevel string) {
level := 4 // warning
switch logLevel {
case "error":
level = 3
case "warn":
level = 4
case "info":
level = 6
case "debug":
level = 7
default:
level = 4
}
Log.SetLevel(level)
}
// wrap log
func Error(format string, v ...interface{}) {
Log.Error(format, v...)
}
func Warn(format string, v ...interface{}) {
Log.Warn(format, v...)
}
func Info(format string, v ...interface{}) {
Log.Info(format, v...)
}
func Debug(format string, v ...interface{}) {
Log.Debug(format, v...)
}
// Logger
type Logger interface {
AddLogPrefix(string)
ClearLogPrefix()
Error(string, ...interface{})
Warn(string, ...interface{})
Info(string, ...interface{})
Debug(string, ...interface{})
}
type PrefixLogger struct {
prefix string
}
func NewPrefixLogger(prefix string) *PrefixLogger {
logger := &PrefixLogger{}
logger.AddLogPrefix(prefix)
return logger
}
func (pl *PrefixLogger) AddLogPrefix(prefix string) {
if len(prefix) == 0 {
return
}
if len(pl.prefix) > 0 {
pl.prefix += " "
}
pl.prefix += "[" + prefix + "] "
}
func (pl *PrefixLogger) ClearLogPrefix() {
pl.prefix = ""
}
func (pl *PrefixLogger) Error(format string, v ...interface{}) {
Log.Error(pl.prefix+format, v...)
}
func (pl *PrefixLogger) Warn(format string, v ...interface{}) {
Log.Warn(pl.prefix+format, v...)
}
func (pl *PrefixLogger) Info(format string, v ...interface{}) {
Log.Info(pl.prefix+format, v...)
}
func (pl *PrefixLogger) Debug(format string, v ...interface{}) {
Log.Debug(pl.prefix+format, v...)
}

32
utils/net/conn.go Normal file
View File

@@ -0,0 +1,32 @@
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package net
import (
"net"
"github.com/fatedier/frp/utils/log"
)
// Conn is the interface of connections used in frp.
type Conn interface {
net.Conn
log.Logger
}
type Listener interface {
Accept() (Conn, error)
Close() error
}

162
utils/net/tcp.go Normal file
View File

@@ -0,0 +1,162 @@
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package net
import (
"bufio"
"encoding/base64"
"fmt"
"net"
"net/http"
"net/url"
"github.com/fatedier/frp/utils/log"
)
type TcpListener struct {
net.Addr
listener net.Listener
accept chan Conn
closeFlag bool
}
func ListenTcp(bindAddr string, bindPort int64) (l *TcpListener, err error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
if err != nil {
return l, err
}
listener, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
return l, err
}
l = &TcpListener{
Addr: listener.Addr(),
listener: listener,
accept: make(chan Conn),
closeFlag: false,
}
go func() {
for {
conn, err := listener.AcceptTCP()
if err != nil {
if l.closeFlag {
close(l.accept)
return
}
continue
}
c := NewTcpConn(conn)
l.accept <- c
}
}()
return l, err
}
// Wait util get one new connection or listener is closed
// if listener is closed, err returned.
func (l *TcpListener) Accept() (Conn, error) {
conn, ok := <-l.accept
if !ok {
return conn, fmt.Errorf("channel for tcp listener closed")
}
return conn, nil
}
func (l *TcpListener) Close() error {
if !l.closeFlag {
l.closeFlag = true
l.listener.Close()
}
return nil
}
// Wrap for TCPConn.
type TcpConn struct {
net.Conn
log.Logger
}
func NewTcpConn(conn *net.TCPConn) (c *TcpConn) {
c = &TcpConn{
Conn: conn,
Logger: log.NewPrefixLogger(""),
}
return
}
func ConnectTcpServer(addr string) (c Conn, err error) {
servertAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return
}
conn, err := net.DialTCP("tcp", nil, servertAddr)
if err != nil {
return
}
c = NewTcpConn(conn)
return
}
// ConnectTcpServerByHttpProxy try to connect remote server by http proxy.
// If httpProxy is empty, it will connect server directly.
func ConnectTcpServerByHttpProxy(httpProxy string, serverAddr string) (c Conn, err error) {
if httpProxy == "" {
return ConnectTcpServer(serverAddr)
}
var proxyUrl *url.URL
if proxyUrl, err = url.Parse(httpProxy); err != nil {
return
}
var proxyAuth string
if proxyUrl.User != nil {
proxyAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(proxyUrl.User.String()))
}
if proxyUrl.Scheme != "http" {
err = fmt.Errorf("Proxy URL scheme must be http, not [%s]", proxyUrl.Scheme)
return
}
if c, err = ConnectTcpServer(proxyUrl.Host); err != nil {
return
}
req, err := http.NewRequest("CONNECT", "http://"+serverAddr, nil)
if err != nil {
return
}
if proxyAuth != "" {
req.Header.Set("Proxy-Authorization", proxyAuth)
}
req.Header.Set("User-Agent", "Mozilla/5.0")
req.Write(c)
resp, err := http.ReadResponse(bufio.NewReader(c), req)
if err != nil {
return
}
resp.Body.Close()
if resp.StatusCode != 200 {
err = fmt.Errorf("ConnectTcpServer using proxy error, StatusCode [%d]", resp.StatusCode)
return
}
return
}

243
utils/net/udp.go Normal file
View File

@@ -0,0 +1,243 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package net
import (
"fmt"
"io"
"net"
"sync"
"time"
flog "github.com/fatedier/frp/utils/log"
"github.com/fatedier/frp/utils/pool"
)
type UdpPacket struct {
Buf []byte
LocalAddr net.Addr
RemoteAddr net.Addr
}
type FakeUdpConn struct {
flog.Logger
l *UdpListener
localAddr net.Addr
remoteAddr net.Addr
packets chan []byte
closeFlag bool
lastActive time.Time
mu sync.RWMutex
}
func NewFakeUdpConn(l *UdpListener, laddr, raddr net.Addr) *FakeUdpConn {
fc := &FakeUdpConn{
Logger: flog.NewPrefixLogger(""),
l: l,
localAddr: laddr,
remoteAddr: raddr,
packets: make(chan []byte, 20),
}
go func() {
for {
time.Sleep(5 * time.Second)
fc.mu.RLock()
if time.Now().Sub(fc.lastActive) > 10*time.Second {
fc.mu.RUnlock()
fc.Close()
break
}
fc.mu.RUnlock()
}
}()
return fc
}
func (c *FakeUdpConn) putPacket(content []byte) {
defer func() {
if err := recover(); err != nil {
}
}()
select {
case c.packets <- content:
default:
}
}
func (c *FakeUdpConn) Read(b []byte) (n int, err error) {
content, ok := <-c.packets
if !ok {
return 0, io.EOF
}
c.mu.Lock()
c.lastActive = time.Now()
c.mu.Unlock()
if len(b) < len(content) {
n = len(b)
} else {
n = len(content)
}
copy(b, content)
return n, nil
}
func (c *FakeUdpConn) Write(b []byte) (n int, err error) {
c.mu.RLock()
if c.closeFlag {
c.mu.RUnlock()
return 0, io.ErrClosedPipe
}
c.mu.RUnlock()
packet := &UdpPacket{
Buf: b,
LocalAddr: c.localAddr,
RemoteAddr: c.remoteAddr,
}
c.l.writeUdpPacket(packet)
c.mu.Lock()
c.lastActive = time.Now()
c.mu.Unlock()
return len(b), nil
}
func (c *FakeUdpConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.closeFlag {
c.closeFlag = true
close(c.packets)
}
return nil
}
func (c *FakeUdpConn) IsClosed() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.closeFlag
}
func (c *FakeUdpConn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *FakeUdpConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *FakeUdpConn) SetDeadline(t time.Time) error {
return nil
}
func (c *FakeUdpConn) SetReadDeadline(t time.Time) error {
return nil
}
func (c *FakeUdpConn) SetWriteDeadline(t time.Time) error {
return nil
}
type UdpListener struct {
net.Addr
accept chan Conn
writeCh chan *UdpPacket
readConn net.Conn
closeFlag bool
fakeConns map[string]*FakeUdpConn
}
func ListenUDP(bindAddr string, bindPort int64) (l *UdpListener, err error) {
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
if err != nil {
return l, err
}
readConn, err := net.ListenUDP("udp", udpAddr)
l = &UdpListener{
Addr: udpAddr,
accept: make(chan Conn),
writeCh: make(chan *UdpPacket, 1000),
fakeConns: make(map[string]*FakeUdpConn),
}
// for reading
go func() {
for {
buf := pool.GetBuf(1450)
n, remoteAddr, err := readConn.ReadFromUDP(buf)
if err != nil {
close(l.accept)
close(l.writeCh)
return
}
fakeConn, exist := l.fakeConns[remoteAddr.String()]
if !exist || fakeConn.IsClosed() {
fakeConn = NewFakeUdpConn(l, l.Addr, remoteAddr)
l.fakeConns[remoteAddr.String()] = fakeConn
}
fakeConn.putPacket(buf[:n])
l.accept <- fakeConn
}
}()
// for writing
go func() {
for {
packet, ok := <-l.writeCh
if !ok {
return
}
if addr, ok := packet.RemoteAddr.(*net.UDPAddr); ok {
readConn.WriteToUDP(packet.Buf, addr)
}
}
}()
return
}
func (l *UdpListener) writeUdpPacket(packet *UdpPacket) {
defer func() {
if err := recover(); err != nil {
}
}()
l.writeCh <- packet
}
func (l *UdpListener) Accept() (Conn, error) {
conn, ok := <-l.accept
if !ok {
return conn, fmt.Errorf("channel for udp listener closed")
}
return conn, nil
}
func (l *UdpListener) Close() error {
if !l.closeFlag {
l.closeFlag = true
l.readConn.Close()
}
return nil
}

27
utils/net/udp_test.go Normal file
View File

@@ -0,0 +1,27 @@
package net
import (
"fmt"
"testing"
)
func TestA(t *testing.T) {
l, err := ListenUDP("0.0.0.0", 9000)
if err != nil {
fmt.Println(err)
}
for {
c, _ := l.Accept()
go func() {
for {
buf := make([]byte, 1450)
n, err := c.Read(buf)
if err != nil {
fmt.Println(buf[:n])
}
c.Write(buf[:n])
}
}()
}
}

58
utils/pool/pool.go Normal file
View File

@@ -0,0 +1,58 @@
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pool
import "sync"
var (
bufPool5k sync.Pool
bufPool2k sync.Pool
bufPool1k sync.Pool
bufPool sync.Pool
)
func GetBuf(size int) []byte {
var x interface{}
if size >= 5*1024 {
x = bufPool5k.Get()
} else if size >= 2*1024 {
x = bufPool2k.Get()
} else if size >= 1*1024 {
x = bufPool1k.Get()
} else {
x = bufPool.Get()
}
if x == nil {
return make([]byte, size)
}
buf := x.([]byte)
if cap(buf) < size {
return make([]byte, size)
}
return buf[:size]
}
func PutBuf(buf []byte) {
size := cap(buf)
if size >= 5*1024 {
bufPool5k.Put(buf)
} else if size >= 2*1024 {
bufPool2k.Put(buf)
} else if size >= 1*1024 {
bufPool1k.Put(buf)
} else {
bufPool.Put(buf)
}
}

51
utils/pool/pool_test.go Normal file
View File

@@ -0,0 +1,51 @@
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pool
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestPutBuf(t *testing.T) {
buf := make([]byte, 512)
PutBuf(buf)
buf = make([]byte, 1025)
PutBuf(buf)
buf = make([]byte, 2*1025)
PutBuf(buf)
buf = make([]byte, 5*1025)
PutBuf(buf)
}
func TestGetBuf(t *testing.T) {
assert := assert.New(t)
buf := GetBuf(200)
assert.Len(buf, 200)
buf = GetBuf(1025)
assert.Len(buf, 1025)
buf = GetBuf(2 * 1024)
assert.Len(buf, 2*1024)
buf = GetBuf(5 * 2000)
assert.Len(buf, 5*2000)
}

View File

@@ -0,0 +1,62 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shutdown
import (
"sync"
)
type Shutdown struct {
doing bool
ending bool
start chan struct{}
down chan struct{}
mu sync.Mutex
}
func New() *Shutdown {
return &Shutdown{
doing: false,
ending: false,
start: make(chan struct{}),
down: make(chan struct{}),
}
}
func (s *Shutdown) Start() {
s.mu.Lock()
defer s.mu.Unlock()
if !s.doing {
s.doing = true
close(s.start)
}
}
func (s *Shutdown) WaitStart() {
<-s.start
}
func (s *Shutdown) Done() {
s.mu.Lock()
defer s.mu.Unlock()
if !s.ending {
s.ending = true
close(s.down)
}
}
func (s *Shutdown) WaitDown() {
<-s.down
}

View File

@@ -0,0 +1,21 @@
package shutdown
import (
"testing"
"time"
)
func TestShutdown(t *testing.T) {
s := New()
go func() {
time.Sleep(time.Millisecond)
s.Start()
}()
s.WaitStart()
go func() {
time.Sleep(time.Millisecond)
s.Done()
}()
s.WaitDown()
}

47
utils/util/util.go Normal file
View File

@@ -0,0 +1,47 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package util
import (
"crypto/md5"
"crypto/rand"
"encoding/hex"
"fmt"
)
// RandId return a rand string used in frp.
func RandId() (id string, err error) {
return RandIdWithLen(8)
}
// RandIdWithLen return a rand string with idLen length.
func RandIdWithLen(idLen int) (id string, err error) {
b := make([]byte, idLen)
_, err = rand.Read(b)
if err != nil {
return
}
id = fmt.Sprintf("%x", b)
return
}
func GetAuthKey(token string, timestamp int64) (key string) {
token = token + fmt.Sprintf("%d", timestamp)
md5Ctx := md5.New()
md5Ctx.Write([]byte(token))
data := md5Ctx.Sum(nil)
return hex.EncodeToString(data)
}

22
utils/util/util_test.go Normal file
View File

@@ -0,0 +1,22 @@
package util
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestRandId(t *testing.T) {
assert := assert.New(t)
id, err := RandId()
assert.NoError(err)
t.Log(id)
assert.Equal(16, len(id))
}
func TestGetAuthKey(t *testing.T) {
assert := assert.New(t)
key := GetAuthKey("1234", 1488720000)
t.Log(key)
assert.Equal("6df41a43725f0c770fd56379e12acf8c", key)
}

88
utils/version/version.go Normal file
View File

@@ -0,0 +1,88 @@
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package version
import (
"strconv"
"strings"
)
var version string = "0.10.0"
func Full() string {
return version
}
func Proto(v string) int64 {
arr := strings.Split(v, ".")
if len(arr) < 3 {
return 0
}
res, _ := strconv.ParseInt(arr[0], 10, 64)
return res
}
func Major(v string) int64 {
arr := strings.Split(v, ".")
if len(arr) < 3 {
return 0
}
res, _ := strconv.ParseInt(arr[1], 10, 64)
return res
}
func Minor(v string) int64 {
arr := strings.Split(v, ".")
if len(arr) < 3 {
return 0
}
res, _ := strconv.ParseInt(arr[2], 10, 64)
return res
}
// add every case there if server will not accept client's protocol and return false
func Compat(client string) (ok bool, msg string) {
if LessThan(client, version) {
return false, "Please upgrade your frpc version to 0.10.0"
}
return true, ""
}
func LessThan(client string, server string) bool {
vc := Proto(client)
vs := Proto(server)
if vc > vs {
return false
} else if vc < vs {
return true
}
vc = Major(client)
vs = Major(server)
if vc > vs {
return false
} else if vc < vs {
return true
}
vc = Minor(client)
vs = Minor(server)
if vc > vs {
return false
} else if vc < vs {
return true
}
return false
}

View File

@@ -0,0 +1,65 @@
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package version
import (
"fmt"
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestFull(t *testing.T) {
assert := assert.New(t)
version := Full()
arr := strings.Split(version, ".")
assert.Equal(3, len(arr))
proto, err := strconv.ParseInt(arr[0], 10, 64)
assert.NoError(err)
assert.True(proto >= 0)
major, err := strconv.ParseInt(arr[1], 10, 64)
assert.NoError(err)
assert.True(major >= 0)
minor, err := strconv.ParseInt(arr[2], 10, 64)
assert.NoError(err)
assert.True(minor >= 0)
}
func TestVersion(t *testing.T) {
assert := assert.New(t)
proto := Proto(Full())
major := Major(Full())
minor := Minor(Full())
parseVerion := fmt.Sprintf("%d.%d.%d", proto, major, minor)
version := Full()
assert.Equal(parseVerion, version)
}
func TestCompact(t *testing.T) {
assert := assert.New(t)
ok, _ := Compat("0.9.0")
assert.False(ok)
ok, _ = Compat("10.0.0")
assert.True(ok)
ok, _ = Compat("0.10.0")
assert.False(ok)
}

216
utils/vhost/http.go Normal file
View File

@@ -0,0 +1,216 @@
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package vhost
import (
"bufio"
"bytes"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
frpNet "github.com/fatedier/frp/utils/net"
"github.com/fatedier/frp/utils/pool"
)
type HttpMuxer struct {
*VhostMuxer
}
func GetHttpRequestInfo(c frpNet.Conn) (_ frpNet.Conn, _ map[string]string, err error) {
reqInfoMap := make(map[string]string, 0)
sc, rd := newShareConn(c)
request, err := http.ReadRequest(bufio.NewReader(rd))
if err != nil {
return sc, reqInfoMap, err
}
// hostName
tmpArr := strings.Split(request.Host, ":")
reqInfoMap["Host"] = tmpArr[0]
reqInfoMap["Path"] = request.URL.Path
reqInfoMap["Scheme"] = request.URL.Scheme
// Authorization
authStr := request.Header.Get("Authorization")
if authStr != "" {
reqInfoMap["Authorization"] = authStr
}
request.Body.Close()
return sc, reqInfoMap, nil
}
func NewHttpMuxer(listener frpNet.Listener, timeout time.Duration) (*HttpMuxer, error) {
mux, err := NewVhostMuxer(listener, GetHttpRequestInfo, HttpAuthFunc, HttpHostNameRewrite, timeout)
return &HttpMuxer{mux}, err
}
func HttpHostNameRewrite(c frpNet.Conn, rewriteHost string) (_ frpNet.Conn, err error) {
sc, rd := newShareConn(c)
var buff []byte
if buff, err = hostNameRewrite(rd, rewriteHost); err != nil {
return sc, err
}
err = sc.WriteBuff(buff)
return sc, err
}
func hostNameRewrite(request io.Reader, rewriteHost string) (_ []byte, err error) {
buf := pool.GetBuf(1024)
defer pool.PutBuf(buf)
request.Read(buf)
retBuffer, err := parseRequest(buf, rewriteHost)
return retBuffer, err
}
func parseRequest(org []byte, rewriteHost string) (ret []byte, err error) {
tp := bytes.NewBuffer(org)
// First line: GET /index.html HTTP/1.0
var b []byte
if b, err = tp.ReadBytes('\n'); err != nil {
return nil, err
}
req := new(http.Request)
// we invoked ReadRequest in GetHttpHostname before, so we ignore error
req.Method, req.RequestURI, req.Proto, _ = parseRequestLine(string(b))
rawurl := req.RequestURI
// CONNECT www.google.com:443 HTTP/1.1
justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/")
if justAuthority {
rawurl = "http://" + rawurl
}
req.URL, _ = url.ParseRequestURI(rawurl)
if justAuthority {
// Strip the bogus "http://" back off.
req.URL.Scheme = ""
}
// RFC2616: first case
// GET /index.html HTTP/1.1
// Host: www.google.com
if req.URL.Host == "" {
changedBuf, err := changeHostName(tp, rewriteHost)
buf := new(bytes.Buffer)
buf.Write(b)
buf.Write(changedBuf)
return buf.Bytes(), err
}
// RFC2616: second case
// GET http://www.google.com/index.html HTTP/1.1
// Host: doesntmatter
// In this case, any Host line is ignored.
hostPort := strings.Split(req.URL.Host, ":")
if len(hostPort) == 1 {
req.URL.Host = rewriteHost
} else if len(hostPort) == 2 {
req.URL.Host = fmt.Sprintf("%s:%s", rewriteHost, hostPort[1])
}
firstLine := req.Method + " " + req.URL.String() + " " + req.Proto
buf := new(bytes.Buffer)
buf.WriteString(firstLine)
tp.WriteTo(buf)
return buf.Bytes(), err
}
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
s1 := strings.Index(line, " ")
s2 := strings.Index(line[s1+1:], " ")
if s1 < 0 || s2 < 0 {
return
}
s2 += s1 + 1
return line[:s1], line[s1+1 : s2], line[s2+1:], true
}
func changeHostName(buff *bytes.Buffer, rewriteHost string) (_ []byte, err error) {
retBuf := new(bytes.Buffer)
peek := buff.Bytes()
for len(peek) > 0 {
i := bytes.IndexByte(peek, '\n')
if i < 3 {
// Not present (-1) or found within the next few bytes,
// implying we're at the end ("\r\n\r\n" or "\n\n")
return nil, err
}
kv := peek[:i]
j := bytes.IndexByte(kv, ':')
if j < 0 {
return nil, fmt.Errorf("malformed MIME header line: " + string(kv))
}
if strings.Contains(strings.ToLower(string(kv[:j])), "host") {
var hostHeader string
portPos := bytes.IndexByte(kv[j+1:], ':')
if portPos == -1 {
hostHeader = fmt.Sprintf("Host: %s\n", rewriteHost)
} else {
hostHeader = fmt.Sprintf("Host: %s:%s\n", rewriteHost, kv[portPos+1:])
}
retBuf.WriteString(hostHeader)
peek = peek[i+1:]
break
} else {
retBuf.Write(peek[:i])
retBuf.WriteByte('\n')
}
peek = peek[i+1:]
}
retBuf.Write(peek)
return retBuf.Bytes(), err
}
func HttpAuthFunc(c frpNet.Conn, userName, passWord, authorization string) (bAccess bool, err error) {
s := strings.SplitN(authorization, " ", 2)
if len(s) != 2 {
res := noAuthResponse()
res.Write(c)
return
}
b, err := base64.StdEncoding.DecodeString(s[1])
if err != nil {
return
}
pair := strings.SplitN(string(b), ":", 2)
if len(pair) != 2 {
return
}
if pair[0] != userName || pair[1] != passWord {
return
}
return true, nil
}
func noAuthResponse() *http.Response {
header := make(map[string][]string)
header["WWW-Authenticate"] = []string{`Basic realm="Restricted"`}
res := &http.Response{
Status: "401 Not authorized",
StatusCode: 401,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: header,
}
return res
}

190
utils/vhost/https.go Normal file
View File

@@ -0,0 +1,190 @@
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package vhost
import (
"fmt"
"io"
"strings"
"time"
frpNet "github.com/fatedier/frp/utils/net"
"github.com/fatedier/frp/utils/pool"
)
const (
typeClientHello uint8 = 1 // Type client hello
)
// TLS extension numbers
const (
extensionServerName uint16 = 0
extensionStatusRequest uint16 = 5
extensionSupportedCurves uint16 = 10
extensionSupportedPoints uint16 = 11
extensionSignatureAlgorithms uint16 = 13
extensionALPN uint16 = 16
extensionSCT uint16 = 18
extensionSessionTicket uint16 = 35
extensionNextProtoNeg uint16 = 13172 // not IANA assigned
extensionRenegotiationInfo uint16 = 0xff01
)
type HttpsMuxer struct {
*VhostMuxer
}
func NewHttpsMuxer(listener frpNet.Listener, timeout time.Duration) (*HttpsMuxer, error) {
mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, nil, timeout)
return &HttpsMuxer{mux}, err
}
func readHandshake(rd io.Reader) (host string, err error) {
data := pool.GetBuf(1024)
origin := data
defer pool.PutBuf(origin)
length, err := rd.Read(data)
if err != nil {
return
} else {
if length < 47 {
err = fmt.Errorf("readHandshake: proto length[%d] is too short", length)
return
}
}
data = data[:length]
if uint8(data[5]) != typeClientHello {
err = fmt.Errorf("readHandshake: type[%d] is not clientHello", uint16(data[5]))
return
}
// session
sessionIdLen := int(data[43])
if sessionIdLen > 32 || len(data) < 44+sessionIdLen {
err = fmt.Errorf("readHandshake: sessionIdLen[%d] is long", sessionIdLen)
return
}
data = data[44+sessionIdLen:]
if len(data) < 2 {
err = fmt.Errorf("readHandshake: dataLen[%d] after session is short", len(data))
return
}
// cipher suite numbers
cipherSuiteLen := int(data[0])<<8 | int(data[1])
if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
err = fmt.Errorf("readHandshake: dataLen[%d] after cipher suite is short", len(data))
return
}
data = data[2+cipherSuiteLen:]
if len(data) < 1 {
err = fmt.Errorf("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen)
return
}
// compression method
compressionMethodsLen := int(data[0])
if len(data) < 1+compressionMethodsLen {
err = fmt.Errorf("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen)
return
}
data = data[1+compressionMethodsLen:]
if len(data) == 0 {
// ClientHello is optionally followed by extension data
err = fmt.Errorf("readHandshake: there is no extension data to get servername")
return
}
if len(data) < 2 {
err = fmt.Errorf("readHandshake: extension dataLen[%d] is too short")
return
}
extensionsLength := int(data[0])<<8 | int(data[1])
data = data[2:]
if extensionsLength != len(data) {
err = fmt.Errorf("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data))
return
}
for len(data) != 0 {
if len(data) < 4 {
err = fmt.Errorf("readHandshake: extensionsDataLen[%d] is too short", len(data))
return
}
extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < length {
err = fmt.Errorf("readHandshake: extensionLen[%d] is long", length)
return
}
switch extension {
case extensionRenegotiationInfo:
if length != 1 || data[0] != 0 {
err = fmt.Errorf("readHandshake: extension reNegotiationInfoLen[%d] is short", length)
return
}
case extensionNextProtoNeg:
case extensionStatusRequest:
case extensionServerName:
d := data[:length]
if len(d) < 2 {
err = fmt.Errorf("readHandshake: remiaining dataLen[%d] is short", len(d))
return
}
namesLen := int(d[0])<<8 | int(d[1])
d = d[2:]
if len(d) != namesLen {
err = fmt.Errorf("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d))
return
}
for len(d) > 0 {
if len(d) < 3 {
err = fmt.Errorf("readHandshake: extension serverNameLen[%d] is short", len(d))
return
}
nameType := d[0]
nameLen := int(d[1])<<8 | int(d[2])
d = d[3:]
if len(d) < nameLen {
err = fmt.Errorf("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d))
return
}
if nameType == 0 {
serverName := string(d[:nameLen])
host = strings.TrimSpace(serverName)
return host, nil
}
d = d[nameLen:]
}
}
data = data[length:]
}
err = fmt.Errorf("Unknow error")
return
}
func GetHttpsHostname(c frpNet.Conn) (sc frpNet.Conn, _ map[string]string, err error) {
reqInfoMap := make(map[string]string, 0)
sc, rd := newShareConn(c)
host, err := readHandshake(rd)
if err != nil {
return sc, reqInfoMap, err
}
reqInfoMap["Host"] = host
reqInfoMap["Scheme"] = "https"
return sc, reqInfoMap, nil
}

114
utils/vhost/router.go Normal file
View File

@@ -0,0 +1,114 @@
package vhost
import (
"sort"
"strings"
"sync"
)
type VhostRouters struct {
RouterByDomain map[string][]*VhostRouter
mutex sync.RWMutex
}
type VhostRouter struct {
domain string
location string
listener *Listener
}
func NewVhostRouters() *VhostRouters {
return &VhostRouters{
RouterByDomain: make(map[string][]*VhostRouter),
}
}
func (r *VhostRouters) Add(domain, location string, l *Listener) {
r.mutex.Lock()
defer r.mutex.Unlock()
vrs, found := r.RouterByDomain[domain]
if !found {
vrs = make([]*VhostRouter, 0, 1)
}
vr := &VhostRouter{
domain: domain,
location: location,
listener: l,
}
vrs = append(vrs, vr)
sort.Sort(sort.Reverse(ByLocation(vrs)))
r.RouterByDomain[domain] = vrs
}
func (r *VhostRouters) Del(domain, location string) {
r.mutex.Lock()
defer r.mutex.Unlock()
vrs, found := r.RouterByDomain[domain]
if !found {
return
}
for i, vr := range vrs {
if vr.location == location {
if len(vrs) > i+1 {
r.RouterByDomain[domain] = append(vrs[:i], vrs[i+1:]...)
} else {
r.RouterByDomain[domain] = vrs[:i]
}
}
}
}
func (r *VhostRouters) Get(host, path string) (vr *VhostRouter, exist bool) {
r.mutex.RLock()
defer r.mutex.RUnlock()
vrs, found := r.RouterByDomain[host]
if !found {
return
}
// can't support load balance, will to do
for _, vr = range vrs {
if strings.HasPrefix(path, vr.location) {
return vr, true
}
}
return
}
func (r *VhostRouters) Exist(host, path string) (vr *VhostRouter, exist bool) {
r.mutex.RLock()
defer r.mutex.RUnlock()
vrs, found := r.RouterByDomain[host]
if !found {
return
}
for _, vr = range vrs {
if path == vr.location {
return vr, true
}
}
return
}
// sort by location
type ByLocation []*VhostRouter
func (a ByLocation) Len() int {
return len(a)
}
func (a ByLocation) Swap(i, j int) {
a[i], a[j] = a[j], a[i]
}
func (a ByLocation) Less(i, j int) bool {
return strings.Compare(a[i].location, a[j].location) < 0
}

233
utils/vhost/vhost.go Normal file
View File

@@ -0,0 +1,233 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package vhost
import (
"bytes"
"fmt"
"io"
"strings"
"sync"
"time"
frpNet "github.com/fatedier/frp/utils/net"
)
type muxFunc func(frpNet.Conn) (frpNet.Conn, map[string]string, error)
type httpAuthFunc func(frpNet.Conn, string, string, string) (bool, error)
type hostRewriteFunc func(frpNet.Conn, string) (frpNet.Conn, error)
type VhostMuxer struct {
listener frpNet.Listener
timeout time.Duration
vhostFunc muxFunc
authFunc httpAuthFunc
rewriteFunc hostRewriteFunc
registryRouter *VhostRouters
mutex sync.RWMutex
}
func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAuthFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
mux = &VhostMuxer{
listener: listener,
timeout: timeout,
vhostFunc: vhostFunc,
authFunc: authFunc,
rewriteFunc: rewriteFunc,
registryRouter: NewVhostRouters(),
}
go mux.run()
return mux, nil
}
// listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil, then rewrite the host header to rewriteHost
func (v *VhostMuxer) Listen(name, location, rewriteHost, userName, passWord string) (l *Listener, err error) {
v.mutex.Lock()
defer v.mutex.Unlock()
_, ok := v.registryRouter.Exist(name, location)
if ok {
return nil, fmt.Errorf("hostname [%s] location [%s] is already registered", name, location)
}
l = &Listener{
name: name,
location: location,
rewriteHost: rewriteHost,
userName: userName,
passWord: passWord,
mux: v,
accept: make(chan frpNet.Conn),
}
v.registryRouter.Add(name, location, l)
return l, nil
}
func (v *VhostMuxer) getListener(name, path string) (l *Listener, exist bool) {
v.mutex.RLock()
defer v.mutex.RUnlock()
// first we check the full hostname
// if not exist, then check the wildcard_domain such as *.example.com
vr, found := v.registryRouter.Get(name, path)
if found {
return vr.listener, true
}
domainSplit := strings.Split(name, ".")
if len(domainSplit) < 3 {
return l, false
}
domainSplit[0] = "*"
name = strings.Join(domainSplit, ".")
vr, found = v.registryRouter.Get(name, path)
if !found {
return
}
return vr.listener, true
}
func (v *VhostMuxer) run() {
for {
conn, err := v.listener.Accept()
if err != nil {
return
}
go v.handle(conn)
}
}
func (v *VhostMuxer) handle(c frpNet.Conn) {
if err := c.SetDeadline(time.Now().Add(v.timeout)); err != nil {
c.Close()
return
}
sConn, reqInfoMap, err := v.vhostFunc(c)
if err != nil {
c.Close()
return
}
name := strings.ToLower(reqInfoMap["Host"])
path := strings.ToLower(reqInfoMap["Path"])
l, ok := v.getListener(name, path)
if !ok {
c.Close()
return
}
// if authFunc is exist and userName/password is set
// verify user access
if l.mux.authFunc != nil &&
l.userName != "" && l.passWord != "" {
bAccess, err := l.mux.authFunc(c, l.userName, l.passWord, reqInfoMap["Authorization"])
if bAccess == false || err != nil {
res := noAuthResponse()
res.Write(c)
c.Close()
return
}
}
if err = sConn.SetDeadline(time.Time{}); err != nil {
c.Close()
return
}
c = sConn
l.accept <- c
}
type Listener struct {
name string
location string
rewriteHost string
userName string
passWord string
mux *VhostMuxer // for closing VhostMuxer
accept chan frpNet.Conn
}
func (l *Listener) Accept() (frpNet.Conn, error) {
conn, ok := <-l.accept
if !ok {
return nil, fmt.Errorf("Listener closed")
}
// if rewriteFunc is exist and rewriteHost is set
// rewrite http requests with a modified host header
if l.mux.rewriteFunc != nil && l.rewriteHost != "" {
sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost)
if err != nil {
return nil, fmt.Errorf("http host header rewrite failed")
}
conn = sConn
}
return conn, nil
}
func (l *Listener) Close() error {
l.mux.registryRouter.Del(l.name, l.location)
close(l.accept)
return nil
}
func (l *Listener) Name() string {
return l.name
}
type sharedConn struct {
frpNet.Conn
sync.Mutex
buff *bytes.Buffer
}
// the bytes you read in io.Reader, will be reserved in sharedConn
func newShareConn(conn frpNet.Conn) (*sharedConn, io.Reader) {
sc := &sharedConn{
Conn: conn,
buff: bytes.NewBuffer(make([]byte, 0, 1024)),
}
return sc, io.TeeReader(conn, sc.buff)
}
func (sc *sharedConn) Read(p []byte) (n int, err error) {
sc.Lock()
if sc.buff == nil {
sc.Unlock()
return sc.Conn.Read(p)
}
sc.Unlock()
n, err = sc.buff.Read(p)
if err == io.EOF {
sc.Lock()
sc.buff = nil
sc.Unlock()
var n2 int
n2, err = sc.Conn.Read(p[n:])
n += n2
}
return
}
func (sc *sharedConn) WriteBuff(buffer []byte) (err error) {
sc.buff.Reset()
_, err = sc.buff.Write(buffer)
return err
}