frp/src/utils/conn/conn.go
2016-12-19 01:22:21 +08:00

268 lines
5.5 KiB
Go

// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package conn
import (
"bufio"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
type Listener struct {
addr net.Addr
l *net.TCPListener
accept chan *Conn
closeFlag bool
}
func Listen(bindAddr string, bindPort int64) (l *Listener, err error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
if err != nil {
return l, err
}
listener, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
return l, err
}
l = &Listener{
addr: listener.Addr(),
l: listener,
accept: make(chan *Conn),
closeFlag: false,
}
go func() {
for {
conn, err := l.l.AcceptTCP()
if err != nil {
if l.closeFlag {
return
}
continue
}
c := &Conn{
TcpConn: conn,
closeFlag: false,
}
c.Reader = bufio.NewReader(c.TcpConn)
l.accept <- c
}
}()
return l, err
}
// wait util get one new connection or listener is closed
// if listener is closed, err returned
func (l *Listener) Accept() (*Conn, error) {
conn, ok := <-l.accept
if !ok {
return conn, fmt.Errorf("channel close")
}
return conn, nil
}
func (l *Listener) Close() error {
if l.l != nil && l.closeFlag == false {
l.closeFlag = true
l.l.Close()
close(l.accept)
}
return nil
}
// wrap for TCPConn
type Conn struct {
TcpConn net.Conn
Reader *bufio.Reader
closeFlag bool
mutex sync.RWMutex
}
func NewConn(conn net.Conn) (c *Conn) {
c = &Conn{}
c.TcpConn = conn
c.Reader = bufio.NewReader(c.TcpConn)
c.closeFlag = false
return c
}
func ConnectServer(addr string) (c *Conn, err error) {
c = &Conn{}
servertAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return
}
conn, err := net.DialTCP("tcp", nil, servertAddr)
if err != nil {
return
}
c.TcpConn = conn
c.Reader = bufio.NewReader(c.TcpConn)
c.closeFlag = false
return c, nil
}
func ConnectServerByHttpProxy(httpProxy string, serverAddr string) (c *Conn, err error) {
var proxyUrl *url.URL
if proxyUrl, err = url.Parse(httpProxy); err != nil {
return
}
var proxyAuth string
if proxyUrl.User != nil {
proxyAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(proxyUrl.User.String()))
}
if proxyUrl.Scheme != "http" {
err = fmt.Errorf("Proxy URL scheme must be http, not [%s]", proxyUrl.Scheme)
return
}
if c, err = ConnectServer(proxyUrl.Host); err != nil {
return
}
req, err := http.NewRequest("CONNECT", "http://"+serverAddr, nil)
if err != nil {
return
}
if proxyAuth != "" {
req.Header.Set("Proxy-Authorization", proxyAuth)
}
req.Header.Set("User-Agent", "Mozilla/5.0")
req.Write(c.TcpConn)
resp, err := http.ReadResponse(bufio.NewReader(c), req)
if err != nil {
return
}
resp.Body.Close()
if resp.StatusCode != 200 {
err = fmt.Errorf("ConnectServer using proxy error, StatusCode [%d]", resp.StatusCode)
return
}
return
}
// if the tcpConn is different with c.TcpConn
// you should call c.Close() first
func (c *Conn) SetTcpConn(tcpConn net.Conn) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.TcpConn = tcpConn
c.closeFlag = false
c.Reader = bufio.NewReader(c.TcpConn)
}
func (c *Conn) GetRemoteAddr() (addr string) {
return c.TcpConn.RemoteAddr().String()
}
func (c *Conn) GetLocalAddr() (addr string) {
return c.TcpConn.LocalAddr().String()
}
func (c *Conn) Read(p []byte) (n int, err error) {
n, err = c.Reader.Read(p)
return
}
func (c *Conn) ReadLine() (buff string, err error) {
buff, err = c.Reader.ReadString('\n')
if err != nil {
// wsarecv error in windows means connection closed?
if err == io.EOF || strings.Contains(err.Error(), "wsarecv") {
c.mutex.Lock()
c.closeFlag = true
c.mutex.Unlock()
}
}
return buff, err
}
func (c *Conn) Write(content []byte) (n int, err error) {
n, err = c.TcpConn.Write(content)
return
}
func (c *Conn) WriteString(content string) (err error) {
_, err = c.TcpConn.Write([]byte(content))
return err
}
func (c *Conn) SetDeadline(t time.Time) error {
return c.TcpConn.SetDeadline(t)
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.TcpConn.SetReadDeadline(t)
}
func (c *Conn) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.TcpConn != nil && c.closeFlag == false {
c.closeFlag = true
c.TcpConn.Close()
}
return nil
}
func (c *Conn) IsClosed() (closeFlag bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
closeFlag = c.closeFlag
return
}
// when you call this function, you should make sure that
// remote client won't send any bytes to this socket
func (c *Conn) CheckClosed() bool {
c.mutex.RLock()
if c.closeFlag {
return true
}
c.mutex.RUnlock()
err := c.TcpConn.SetReadDeadline(time.Now().Add(time.Millisecond))
if err != nil {
c.Close()
return true
}
var tmp []byte = make([]byte, 1)
_, err = c.TcpConn.Read(tmp)
if err == io.EOF {
return true
}
err = c.TcpConn.SetReadDeadline(time.Time{})
if err != nil {
c.Close()
return true
}
return false
}