mirror of
https://github.com/fatedier/frp.git
synced 2025-07-27 07:35:07 +00:00
vendor: add package golib/net
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -133,49 +132,6 @@ func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn
|
||||
}
|
||||
}
|
||||
|
||||
type SharedConn struct {
|
||||
Conn
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
// the bytes you read in io.Reader, will be reserved in SharedConn
|
||||
func NewShareConn(conn Conn) (*SharedConn, io.Reader) {
|
||||
sc := &SharedConn{
|
||||
Conn: conn,
|
||||
buf: bytes.NewBuffer(make([]byte, 0, 1024)),
|
||||
}
|
||||
return sc, io.TeeReader(conn, sc.buf)
|
||||
}
|
||||
|
||||
func NewShareConnSize(conn Conn, bufSize int) (*SharedConn, io.Reader) {
|
||||
sc := &SharedConn{
|
||||
Conn: conn,
|
||||
buf: bytes.NewBuffer(make([]byte, 0, bufSize)),
|
||||
}
|
||||
return sc, io.TeeReader(conn, sc.buf)
|
||||
}
|
||||
|
||||
// Not thread safety.
|
||||
func (sc *SharedConn) Read(p []byte) (n int, err error) {
|
||||
if sc.buf == nil {
|
||||
return sc.Conn.Read(p)
|
||||
}
|
||||
n, err = sc.buf.Read(p)
|
||||
if err == io.EOF {
|
||||
sc.buf = nil
|
||||
var n2 int
|
||||
n2, err = sc.Conn.Read(p[n:])
|
||||
n += n2
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (sc *SharedConn) WriteBuff(buffer []byte) (err error) {
|
||||
sc.buf.Reset()
|
||||
_, err = sc.buf.Write(buffer)
|
||||
return err
|
||||
}
|
||||
|
||||
type StatsConn struct {
|
||||
Conn
|
||||
|
||||
|
@@ -1,211 +0,0 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
frpNet "github.com/fatedier/frp/utils/net"
|
||||
|
||||
"github.com/fatedier/golib/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultTimeout is the default length of time to wait for bytes we need.
|
||||
DefaultTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type Mux struct {
|
||||
ln net.Listener
|
||||
|
||||
defaultLn *listener
|
||||
lns []*listener
|
||||
maxNeedBytesNum uint32
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMux() (mux *Mux) {
|
||||
mux = &Mux{
|
||||
lns: make([]*listener, 0),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (mux *Mux) Listen(priority int, needBytesNum uint32, fn MatchFunc) net.Listener {
|
||||
ln := &listener{
|
||||
c: make(chan net.Conn),
|
||||
mux: mux,
|
||||
needBytesNum: needBytesNum,
|
||||
matchFn: fn,
|
||||
}
|
||||
|
||||
mux.mu.Lock()
|
||||
defer mux.mu.Unlock()
|
||||
if needBytesNum > mux.maxNeedBytesNum {
|
||||
mux.maxNeedBytesNum = needBytesNum
|
||||
}
|
||||
|
||||
newlns := append(mux.copyLns(), ln)
|
||||
sort.Slice(newlns, func(i, j int) bool {
|
||||
return newlns[i].needBytesNum < newlns[j].needBytesNum
|
||||
})
|
||||
mux.lns = newlns
|
||||
return ln
|
||||
}
|
||||
|
||||
func (mux *Mux) ListenHttp(priority int) net.Listener {
|
||||
return mux.Listen(priority, HttpNeedBytesNum, HttpMatchFunc)
|
||||
}
|
||||
|
||||
func (mux *Mux) ListenHttps(priority int) net.Listener {
|
||||
return mux.Listen(priority, HttpsNeedBytesNum, HttpsMatchFunc)
|
||||
}
|
||||
|
||||
func (mux *Mux) DefaultListener() net.Listener {
|
||||
mux.mu.Lock()
|
||||
defer mux.mu.Unlock()
|
||||
if mux.defaultLn == nil {
|
||||
mux.defaultLn = &listener{
|
||||
c: make(chan net.Conn),
|
||||
mux: mux,
|
||||
}
|
||||
}
|
||||
return mux.defaultLn
|
||||
}
|
||||
|
||||
func (mux *Mux) release(ln *listener) bool {
|
||||
result := false
|
||||
mux.mu.Lock()
|
||||
defer mux.mu.Unlock()
|
||||
lns := mux.copyLns()
|
||||
|
||||
for i, l := range lns {
|
||||
if l == ln {
|
||||
lns = append(lns[:i], lns[i+1:]...)
|
||||
result = true
|
||||
}
|
||||
}
|
||||
mux.lns = lns
|
||||
return result
|
||||
}
|
||||
|
||||
func (mux *Mux) copyLns() []*listener {
|
||||
lns := make([]*listener, 0, len(mux.lns))
|
||||
for _, l := range mux.lns {
|
||||
lns = append(lns, l)
|
||||
}
|
||||
return lns
|
||||
}
|
||||
|
||||
// Serve handles connections from ln and multiplexes then across registered listeners.
|
||||
func (mux *Mux) Serve(ln net.Listener) error {
|
||||
mux.mu.Lock()
|
||||
mux.ln = ln
|
||||
mux.mu.Unlock()
|
||||
for {
|
||||
// Wait for the next connection.
|
||||
// If it returns a temporary error then simply retry.
|
||||
// If it returns any other error then exit immediately.
|
||||
conn, err := ln.Accept()
|
||||
if err, ok := err.(interface {
|
||||
Temporary() bool
|
||||
}); ok && err.Temporary() {
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go mux.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (mux *Mux) handleConn(conn net.Conn) {
|
||||
mux.mu.RLock()
|
||||
maxNeedBytesNum := mux.maxNeedBytesNum
|
||||
lns := mux.lns
|
||||
defaultLn := mux.defaultLn
|
||||
mux.mu.RUnlock()
|
||||
|
||||
shareConn, rd := frpNet.NewShareConnSize(frpNet.WrapConn(conn), int(maxNeedBytesNum))
|
||||
data := make([]byte, maxNeedBytesNum)
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(DefaultTimeout))
|
||||
_, err := io.ReadFull(rd, data)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
for _, ln := range lns {
|
||||
if match := ln.matchFn(data); match {
|
||||
err = errors.PanicToError(func() {
|
||||
ln.c <- shareConn
|
||||
})
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// No match listeners
|
||||
if defaultLn != nil {
|
||||
err = errors.PanicToError(func() {
|
||||
defaultLn.c <- shareConn
|
||||
})
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No listeners for this connection, close it.
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
type listener struct {
|
||||
mux *Mux
|
||||
|
||||
needBytesNum uint32
|
||||
matchFn MatchFunc
|
||||
|
||||
c chan net.Conn
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (ln *listener) Accept() (net.Conn, error) {
|
||||
conn, ok := <-ln.c
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("network connection closed")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Close removes this listener from the parent mux and closes the channel.
|
||||
func (ln *listener) Close() error {
|
||||
if ok := ln.mux.release(ln); ok {
|
||||
// Close done to signal to any RLock holders to release their lock.
|
||||
close(ln.c)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ln *listener) Addr() net.Addr {
|
||||
if ln.mux == nil {
|
||||
return nil
|
||||
}
|
||||
ln.mux.mu.RLock()
|
||||
defer ln.mux.mu.RUnlock()
|
||||
if ln.mux.ln == nil {
|
||||
return nil
|
||||
}
|
||||
return ln.mux.ln.Addr()
|
||||
}
|
@@ -1,95 +0,0 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func runHttpSvr(ln net.Listener) *httptest.Server {
|
||||
svr := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("http service"))
|
||||
}))
|
||||
svr.Listener = ln
|
||||
svr.Start()
|
||||
return svr
|
||||
}
|
||||
|
||||
func runHttpsSvr(ln net.Listener) *httptest.Server {
|
||||
svr := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("https service"))
|
||||
}))
|
||||
svr.Listener = ln
|
||||
svr.StartTLS()
|
||||
return svr
|
||||
}
|
||||
|
||||
func runEchoSvr(ln net.Listener) {
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rd := bufio.NewReader(conn)
|
||||
data, err := rd.ReadString('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn.Write([]byte(data))
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func TestMux(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||
assert.NoError(err)
|
||||
|
||||
mux := NewMux()
|
||||
httpLn := mux.ListenHttp(0)
|
||||
httpsLn := mux.ListenHttps(0)
|
||||
defaultLn := mux.DefaultListener()
|
||||
go mux.Serve(ln)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
httpSvr := runHttpSvr(httpLn)
|
||||
defer httpSvr.Close()
|
||||
httpsSvr := runHttpsSvr(httpsLn)
|
||||
defer httpsSvr.Close()
|
||||
runEchoSvr(defaultLn)
|
||||
defer ln.Close()
|
||||
|
||||
// test http service
|
||||
resp, err := http.Get(httpSvr.URL)
|
||||
assert.NoError(err)
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
assert.NoError(err)
|
||||
assert.Equal("http service", string(data))
|
||||
|
||||
// test https service
|
||||
client := httpsSvr.Client()
|
||||
resp, err = client.Get(httpsSvr.URL)
|
||||
assert.NoError(err)
|
||||
data, err = ioutil.ReadAll(resp.Body)
|
||||
assert.NoError(err)
|
||||
assert.Equal("https service", string(data))
|
||||
|
||||
// test echo service
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
assert.NoError(err)
|
||||
_, err = conn.Write([]byte("test echo\n"))
|
||||
assert.NoError(err)
|
||||
data = make([]byte, 1024)
|
||||
n, err := conn.Read(data)
|
||||
assert.NoError(err)
|
||||
assert.Equal("test echo\n", string(data[:n]))
|
||||
}
|
@@ -1,55 +0,0 @@
|
||||
package mux
|
||||
|
||||
type MatchFunc func(data []byte) (match bool)
|
||||
|
||||
var (
|
||||
HttpsNeedBytesNum uint32 = 1
|
||||
HttpNeedBytesNum uint32 = 3
|
||||
YamuxNeedBytesNum uint32 = 2
|
||||
)
|
||||
|
||||
var HttpsMatchFunc MatchFunc = func(data []byte) bool {
|
||||
if len(data) < int(HttpsNeedBytesNum) {
|
||||
return false
|
||||
}
|
||||
|
||||
if data[0] == 0x16 {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// From https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
|
||||
var httpHeadBytes = map[string]struct{}{
|
||||
"GET": struct{}{},
|
||||
"HEA": struct{}{},
|
||||
"POS": struct{}{},
|
||||
"PUT": struct{}{},
|
||||
"DEL": struct{}{},
|
||||
"CON": struct{}{},
|
||||
"OPT": struct{}{},
|
||||
"TRA": struct{}{},
|
||||
"PAT": struct{}{},
|
||||
}
|
||||
|
||||
var HttpMatchFunc MatchFunc = func(data []byte) bool {
|
||||
if len(data) < int(HttpNeedBytesNum) {
|
||||
return false
|
||||
}
|
||||
|
||||
_, ok := httpHeadBytes[string(data[:3])]
|
||||
return ok
|
||||
}
|
||||
|
||||
// From https://github.com/hashicorp/yamux/blob/master/spec.md
|
||||
var YamuxMatchFunc MatchFunc = func(data []byte) bool {
|
||||
if len(data) < int(YamuxNeedBytesNum) {
|
||||
return false
|
||||
}
|
||||
|
||||
if data[0] == 0 && data[1] >= 0x0 && data[1] <= 0x3 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
@@ -27,6 +27,7 @@ import (
|
||||
|
||||
frpNet "github.com/fatedier/frp/utils/net"
|
||||
|
||||
gnet "github.com/fatedier/golib/net"
|
||||
"github.com/fatedier/golib/pool"
|
||||
)
|
||||
|
||||
@@ -36,11 +37,11 @@ type HttpMuxer struct {
|
||||
|
||||
func GetHttpRequestInfo(c frpNet.Conn) (_ frpNet.Conn, _ map[string]string, err error) {
|
||||
reqInfoMap := make(map[string]string, 0)
|
||||
sc, rd := frpNet.NewShareConn(c)
|
||||
sc, rd := gnet.NewSharedConn(c)
|
||||
|
||||
request, err := http.ReadRequest(bufio.NewReader(rd))
|
||||
if err != nil {
|
||||
return sc, reqInfoMap, err
|
||||
return nil, reqInfoMap, err
|
||||
}
|
||||
// hostName
|
||||
tmpArr := strings.Split(request.Host, ":")
|
||||
@@ -54,7 +55,7 @@ func GetHttpRequestInfo(c frpNet.Conn) (_ frpNet.Conn, _ map[string]string, err
|
||||
reqInfoMap["Authorization"] = authStr
|
||||
}
|
||||
request.Body.Close()
|
||||
return sc, reqInfoMap, nil
|
||||
return frpNet.WrapConn(sc), reqInfoMap, nil
|
||||
}
|
||||
|
||||
func NewHttpMuxer(listener frpNet.Listener, timeout time.Duration) (*HttpMuxer, error) {
|
||||
@@ -63,14 +64,14 @@ func NewHttpMuxer(listener frpNet.Listener, timeout time.Duration) (*HttpMuxer,
|
||||
}
|
||||
|
||||
func ModifyHttpRequest(c frpNet.Conn, rewriteHost string) (_ frpNet.Conn, err error) {
|
||||
sc, rd := frpNet.NewShareConn(c)
|
||||
sc, rd := gnet.NewSharedConn(c)
|
||||
var buff []byte
|
||||
remoteIP := strings.Split(c.RemoteAddr().String(), ":")[0]
|
||||
if buff, err = hostNameRewrite(rd, rewriteHost, remoteIP); err != nil {
|
||||
return sc, err
|
||||
return nil, err
|
||||
}
|
||||
err = sc.WriteBuff(buff)
|
||||
return sc, err
|
||||
err = sc.ResetBuf(buff)
|
||||
return frpNet.WrapConn(sc), err
|
||||
}
|
||||
|
||||
func hostNameRewrite(request io.Reader, rewriteHost string, remoteIP string) (_ []byte, err error) {
|
||||
|
@@ -21,6 +21,8 @@ import (
|
||||
"time"
|
||||
|
||||
frpNet "github.com/fatedier/frp/utils/net"
|
||||
|
||||
gnet "github.com/fatedier/golib/net"
|
||||
"github.com/fatedier/golib/pool"
|
||||
)
|
||||
|
||||
@@ -180,14 +182,14 @@ func readHandshake(rd io.Reader) (host string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func GetHttpsHostname(c frpNet.Conn) (sc frpNet.Conn, _ map[string]string, err error) {
|
||||
func GetHttpsHostname(c frpNet.Conn) (_ frpNet.Conn, _ map[string]string, err error) {
|
||||
reqInfoMap := make(map[string]string, 0)
|
||||
sc, rd := frpNet.NewShareConn(c)
|
||||
sc, rd := gnet.NewSharedConn(c)
|
||||
host, err := readHandshake(rd)
|
||||
if err != nil {
|
||||
return sc, reqInfoMap, err
|
||||
return nil, reqInfoMap, err
|
||||
}
|
||||
reqInfoMap["Host"] = host
|
||||
reqInfoMap["Scheme"] = "https"
|
||||
return sc, reqInfoMap, nil
|
||||
return frpNet.WrapConn(sc), reqInfoMap, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user