diff --git a/src/frp/models/client/client.go b/src/frp/models/client/client.go index 6165eee1..edb76719 100644 --- a/src/frp/models/client/client.go +++ b/src/frp/models/client/client.go @@ -32,7 +32,7 @@ type ProxyClient struct { LocalIp string LocalPort int64 Type string - UseEncryption bool + UseEncryption int } func (p *ProxyClient) GetLocalConn() (c *conn.Conn, err error) { @@ -89,11 +89,7 @@ func (p *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err erro // l means local, r means remote log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", localConn.GetLocalAddr(), localConn.GetRemoteAddr(), remoteConn.GetLocalAddr(), remoteConn.GetRemoteAddr()) - if p.UseEncryption { - go conn.JoinMore(localConn, remoteConn, p.AuthToken) - } else { - go conn.Join(localConn, remoteConn) - } + go conn.JoinMore(localConn, remoteConn, p.AuthToken, p.UseEncryption) return nil } diff --git a/src/frp/models/client/config.go b/src/frp/models/client/config.go index 1d19c331..a12ffde7 100644 --- a/src/frp/models/client/config.go +++ b/src/frp/models/client/config.go @@ -122,10 +122,15 @@ func LoadConf(confFile string) (err error) { } // use_encryption - proxyClient.UseEncryption = false + proxyClient.UseEncryption = 0 useEncryptionStr, ok := section["use_encryption"] - if ok && useEncryptionStr == "true" { - proxyClient.UseEncryption = true + if ok { + tmpRes, err := strconv.Atoi(useEncryptionStr) + if err != nil { + proxyClient.UseEncryption = 0 + } + + proxyClient.UseEncryption = tmpRes } ProxyClients[proxyClient.Name] = proxyClient diff --git a/src/frp/models/msg/msg.go b/src/frp/models/msg/msg.go index d1b57ad1..d88f8ae9 100644 --- a/src/frp/models/msg/msg.go +++ b/src/frp/models/msg/msg.go @@ -24,7 +24,7 @@ type ControlReq struct { Type int64 `json:"type"` ProxyName string `json:"proxy_name,omitempty"` AuthKey string `json:"auth_key, omitempty"` - UseEncryption bool `json:"use_encryption, omitempty"` + UseEncryption int `json:"use_encryption, omitempty"` Timestamp int64 `json:"timestamp, omitempty"` } diff --git a/src/frp/models/server/server.go b/src/frp/models/server/server.go index b9affdfb..ae9ff024 100644 --- a/src/frp/models/server/server.go +++ b/src/frp/models/server/server.go @@ -38,7 +38,7 @@ type ProxyServer struct { CustomDomains []string // configure in frpc.ini - UseEncryption bool + UseEncryption int Status int64 CtlConn *conn.Conn // control connection with frpc @@ -144,11 +144,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) { log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(), userConn.GetLocalAddr(), userConn.GetRemoteAddr()) - if p.UseEncryption { - go conn.JoinMore(userConn, workConn, p.AuthToken) - } else { - go conn.Join(userConn, workConn) - } + go conn.JoinMore(userConn, workConn, p.AuthToken, p.UseEncryption) }() } }(listener) diff --git a/src/frp/utils/conn/conn.go b/src/frp/utils/conn/conn.go index 1c6eeb1e..c41bc2ac 100644 --- a/src/frp/utils/conn/conn.go +++ b/src/frp/utils/conn/conn.go @@ -16,6 +16,8 @@ package conn import ( "bufio" + "bytes" + "encoding/binary" "fmt" "io" "net" @@ -192,48 +194,70 @@ func Join(c1 *Conn, c2 *Conn) { // messages from c1 to c2 will be encrypted // and from c2 to c1 will be decrypted -func JoinMore(c1 *Conn, c2 *Conn, cryptKey string) { +func JoinMore(c1 *Conn, c2 *Conn, cryptKey string, ptype int) { var wait sync.WaitGroup - encryptPipe := func(from *Conn, to *Conn, key string) { + encryptPipe := func(from *Conn, to *Conn, key string, ttype int) { defer from.Close() defer to.Close() defer wait.Done() // we don't care about errors here - PipeEncrypt(from.TcpConn, to.TcpConn, key) + PipeEncrypt(from.TcpConn, to.TcpConn, key, ttype) } - decryptPipe := func(to *Conn, from *Conn, key string) { + decryptPipe := func(to *Conn, from *Conn, key string, ttype int) { defer from.Close() defer to.Close() defer wait.Done() // we don't care about errors here - PipeDecrypt(to.TcpConn, from.TcpConn, key) + PipeDecrypt(to.TcpConn, from.TcpConn, key, ttype) } wait.Add(2) - go encryptPipe(c1, c2, cryptKey) - go decryptPipe(c2, c1, cryptKey) + go encryptPipe(c1, c2, cryptKey, ptype) + + go decryptPipe(c2, c1, cryptKey, ptype) wait.Wait() log.Debug("One tunnel stopped") return } +func unpkgMsg(data []byte) (int, []byte, []byte) { + if len(data) < 4 { + return -1, nil, nil + } + llen := int(binary.BigEndian.Uint32(data[0:4])) + // no complete + if len(data) < llen+4 { + return -1, nil, nil + } + + return 0, data[4 : llen+4], data[llen+4:] +} + // decrypt msg from reader, then write into writer -func PipeDecrypt(r net.Conn, w net.Conn, key string) error { +func PipeDecrypt(r net.Conn, w net.Conn, key string, ptype int) error { laes := new(pcrypto.Pcrypto) - if err := laes.Init([]byte(key)); err != nil { + if err := laes.Init([]byte(key), ptype); err != nil { log.Error("Pcrypto Init error: %v", err) return fmt.Errorf("Pcrypto Init error: %v", err) } + buf := make([]byte, 10*1024) + var left []byte nreader := bufio.NewReader(r) for { - buf, err := nreader.ReadBytes('\n') + n, err := nreader.Read(buf) if err != nil { return err } + left := append(left, buf[:n]...) + cnt, buf, left := unpkgMsg(left) + + if cnt < 0 { + continue + } res, err := laes.Decrypt(buf) if err != nil { @@ -249,10 +273,18 @@ func PipeDecrypt(r net.Conn, w net.Conn, key string) error { return nil } +func pkgMsg(data []byte) []byte { + llen := uint32(len(data)) + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, llen) + buf.Write(data) + return buf.Bytes() +} + // recvive msg from reader, then encrypt msg into write -func PipeEncrypt(r net.Conn, w net.Conn, key string) error { +func PipeEncrypt(r net.Conn, w net.Conn, key string, ptype int) error { laes := new(pcrypto.Pcrypto) - if err := laes.Init([]byte(key)); err != nil { + if err := laes.Init([]byte(key), ptype); err != nil { log.Error("Pcrypto Init error: %v", err) return fmt.Errorf("Pcrypto Init error: %v", err) } @@ -271,11 +303,12 @@ func PipeEncrypt(r net.Conn, w net.Conn, key string) error { return fmt.Errorf("Encrypt error: %v", err) } - res = append(res, '\n') + res = pkgMsg(res) _, err = w.Write(res) if err != nil { return err } } + return nil } diff --git a/src/frp/utils/pcrypto/pcrypto.go b/src/frp/utils/pcrypto/pcrypto.go index a4772a82..6c646d10 100644 --- a/src/frp/utils/pcrypto/pcrypto.go +++ b/src/frp/utils/pcrypto/pcrypto.go @@ -20,7 +20,6 @@ import ( "crypto/aes" "crypto/cipher" "crypto/md5" - "encoding/base64" "encoding/hex" "errors" "fmt" @@ -30,69 +29,80 @@ import ( type Pcrypto struct { pkey []byte paes cipher.Block + // 0: nono; 1:compress; 2: encrypt; 3: compress and encrypt + ptyp int } -func (pc *Pcrypto) Init(key []byte) error { +func (pc *Pcrypto) Init(key []byte, ptyp int) error { var err error pc.pkey = pKCS7Padding(key, aes.BlockSize) pc.paes, err = aes.NewCipher(pc.pkey) + if ptyp == 1 || ptyp == 2 || ptyp == 3 { + pc.ptyp = ptyp + } else { + pc.ptyp = 0 + } return err } func (pc *Pcrypto) Encrypt(src []byte) ([]byte, error) { - // gzip var zbuf bytes.Buffer - zwr, err := gzip.NewWriterLevel(&zbuf, -1) - if err != nil { - return nil, err + + // gzip + if pc.ptyp == 1 || pc.ptyp == 3 { + zwr, err := gzip.NewWriterLevel(&zbuf, gzip.DefaultCompression) + if err != nil { + return nil, err + } + defer zwr.Close() + zwr.Write(src) + zwr.Flush() + src = zbuf.Bytes() } - defer zwr.Close() - zwr.Write(src) - zwr.Flush() // aes - src = pKCS7Padding(zbuf.Bytes(), aes.BlockSize) - blockMode := cipher.NewCBCEncrypter(pc.paes, pc.pkey) - crypted := make([]byte, len(src)) - blockMode.CryptBlocks(crypted, src) + if pc.ptyp == 2 || pc.ptyp == 3 { + src = pKCS7Padding(src, aes.BlockSize) + blockMode := cipher.NewCBCEncrypter(pc.paes, pc.pkey) + crypted := make([]byte, len(src)) + blockMode.CryptBlocks(crypted, src) + src = crypted + } - // base64 - return []byte(base64.StdEncoding.EncodeToString(crypted)), nil + return src, nil } func (pc *Pcrypto) Decrypt(str []byte) ([]byte, error) { - // base64 - data, err := base64.StdEncoding.DecodeString(string(str)) - if err != nil { - return nil, err - } - // aes - decryptText, err := hex.DecodeString(fmt.Sprintf("%x", data)) - if err != nil { - return nil, err + if pc.ptyp == 2 || pc.ptyp == 3 { + decryptText, err := hex.DecodeString(fmt.Sprintf("%x", str)) + if err != nil { + return nil, err + } + + if len(decryptText)%aes.BlockSize != 0 { + return nil, errors.New("crypto/cipher: ciphertext is not a multiple of the block size") + } + + blockMode := cipher.NewCBCDecrypter(pc.paes, pc.pkey) + + blockMode.CryptBlocks(decryptText, decryptText) + str = pKCS7UnPadding(decryptText) } - if len(decryptText)%aes.BlockSize != 0 { - return nil, errors.New("crypto/cipher: ciphertext is not a multiple of the block size") - } - - blockMode := cipher.NewCBCDecrypter(pc.paes, pc.pkey) - - blockMode.CryptBlocks(decryptText, decryptText) - decryptText = pKCS7UnPadding(decryptText) - // gunzip - zbuf := bytes.NewBuffer(decryptText) - zrd, err := gzip.NewReader(zbuf) - if err != nil { - return nil, err + if pc.ptyp == 1 || pc.ptyp == 3 { + zbuf := bytes.NewBuffer(str) + zrd, err := gzip.NewReader(zbuf) + if err != nil { + return nil, err + } + defer zrd.Close() + str, _ = ioutil.ReadAll(zrd) } - defer zrd.Close() - data, _ = ioutil.ReadAll(zrd) - return data, nil + return str, nil } func pKCS7Padding(ciphertext []byte, blockSize int) []byte { diff --git a/src/frp/utils/pcrypto/pcrypto_test.go b/src/frp/utils/pcrypto/pcrypto_test.go index e86762fe..e240ba32 100644 --- a/src/frp/utils/pcrypto/pcrypto_test.go +++ b/src/frp/utils/pcrypto/pcrypto_test.go @@ -20,28 +20,78 @@ import ( ) func TestEncrypt(t *testing.T) { + return pp := new(Pcrypto) - pp.Init([]byte("Hana")) - res, err := pp.Encrypt([]byte("Just One Test!")) + pp.Init([]byte("Hana"), 1) + res, err := pp.Encrypt([]byte("Test Encrypt!")) if err != nil { t.Fatal(err) } - fmt.Printf("[%x]\n", res) + fmt.Printf("Encrypt: len %d, [%x]\n", len(res), res) } func TestDecrypt(t *testing.T) { - pp := new(Pcrypto) - pp.Init([]byte("Hana")) - res, err := pp.Encrypt([]byte("Just One Test!")) - if err != nil { - t.Fatal(err) + fmt.Println("*****************************************************") + { + pp := new(Pcrypto) + pp.Init([]byte("Hana"), 0) + res, err := pp.Encrypt([]byte("Test Decrypt! 0")) + if err != nil { + t.Fatal(err) + } + + res, err = pp.Decrypt(res) + if err != nil { + t.Fatal(err) + } + + fmt.Printf("[%s]\n", string(res)) + } + { + pp := new(Pcrypto) + pp.Init([]byte("Hana"), 1) + res, err := pp.Encrypt([]byte("Test Decrypt! 1")) + if err != nil { + t.Fatal(err) + } + + res, err = pp.Decrypt(res) + if err != nil { + t.Fatal(err) + } + + fmt.Printf("[%s]\n", string(res)) + } + { + pp := new(Pcrypto) + pp.Init([]byte("Hana"), 2) + res, err := pp.Encrypt([]byte("Test Decrypt! 2")) + if err != nil { + t.Fatal(err) + } + + res, err = pp.Decrypt(res) + if err != nil { + t.Fatal(err) + } + + fmt.Printf("[%s]\n", string(res)) + } + { + pp := new(Pcrypto) + pp.Init([]byte("Hana"), 3) + res, err := pp.Encrypt([]byte("Test Decrypt! 3")) + if err != nil { + t.Fatal(err) + } + + res, err = pp.Decrypt(res) + if err != nil { + t.Fatal(err) + } + + fmt.Printf("[%s]\n", string(res)) } - res, err = pp.Decrypt(res) - if err != nil { - t.Fatal(err) - } - - fmt.Printf("[%s]\n", string(res)) }