frp/src/frp/utils/vhost/vhost_https.go

218 lines
6.1 KiB
Go
Raw Normal View History

2016-06-13 14:19:24 +00:00
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package vhost
import (
_ "bufio"
_ "bytes"
_ "crypto/tls"
"errors"
"fmt"
"frp/utils/conn"
"frp/utils/log"
"io"
_ "io/ioutil"
"net"
_ "net/http"
"strings"
_ "sync"
"time"
)
var (
maxHandshake int64 = 65536 // maximum handshake we support (protocol max is 16 MB)
VhostHttpsPort int64 = 443
)
const (
typeClientHello uint8 = 1 // Type client hello
)
// TLS extension numbers
const (
extensionServerName uint16 = 0
extensionStatusRequest uint16 = 5
extensionSupportedCurves uint16 = 10
extensionSupportedPoints uint16 = 11
extensionSignatureAlgorithms uint16 = 13
extensionALPN uint16 = 16
extensionSCT uint16 = 18
extensionSessionTicket uint16 = 35
extensionNextProtoNeg uint16 = 13172 // not IANA assigned
extensionRenegotiationInfo uint16 = 0xff01
)
type HttpsMuxer struct {
*VhostMuxer
}
/*
RFC document: http://tools.ietf.org/html/rfc5246
*/
func errMsgToLog(format string, a ...interface{}) error {
errMsg := fmt.Sprintf(format, a...)
log.Warn(errMsg)
return errors.New(errMsg)
}
func readHandshake(rd io.Reader) (string, error) {
data := make([]byte, 1024)
length, err := rd.Read(data)
if err != nil {
return "", errMsgToLog("read err:%v", err)
} else {
if length < 47 {
return "", errMsgToLog("readHandshake: proto length[%d] is too short", length)
}
}
data = data[:length]
//log.Warn("data: %+v", data)
if uint8(data[5]) != typeClientHello {
return "", errMsgToLog("readHandshake: type[%d] is not clientHello", uint16(data[5]))
}
//version and random
//tlsVersion := uint16(data[9])<<8 | uint16(data[10])
//random := data[11:43]
//session
sessionIdLen := int(data[43])
if sessionIdLen > 32 || len(data) < 44+sessionIdLen {
return "", errMsgToLog("readHandshake: sessionIdLen[%d] is long", sessionIdLen)
}
data = data[44+sessionIdLen:]
if len(data) < 2 {
return "", errMsgToLog("readHandshake: dataLen[%d] after session is short", len(data))
}
// cipher suite numbers
cipherSuiteLen := int(data[0])<<8 | int(data[1])
if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
//return "", errMsgToLog("readHandshake: cipherSuiteLen[%d] is long", sessionIdLen)
return "", errMsgToLog("readHandshake: dataLen[%d] after cipher suite is short", len(data))
}
data = data[2+cipherSuiteLen:]
if len(data) < 1 {
return "", errMsgToLog("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen)
}
//compression method
compressionMethodsLen := int(data[0])
if len(data) < 1+compressionMethodsLen {
return "", errMsgToLog("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen)
//return false
}
data = data[1+compressionMethodsLen:]
if len(data) == 0 {
// ClientHello is optionally followed by extension data
//return true
return "", errMsgToLog("readHandshake: there is no extension data to get servername")
}
if len(data) < 2 {
return "", errMsgToLog("readHandshake: extension dataLen[%d] is too short")
}
extensionsLength := int(data[0])<<8 | int(data[1])
data = data[2:]
if extensionsLength != len(data) {
return "", errMsgToLog("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data))
}
for len(data) != 0 {
if len(data) < 4 {
return "", errMsgToLog("readHandshake: extensionsDataLen[%d] is too short", len(data))
}
extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < length {
return "", errMsgToLog("readHandshake: extensionLen[%d] is long", length)
//return false
}
switch extension {
case extensionRenegotiationInfo:
if length != 1 || data[0] != 0 {
return "", errMsgToLog("readHandshake: extension reNegotiationInfoLen[%d] is short", length)
}
case extensionNextProtoNeg:
case extensionStatusRequest:
case extensionServerName:
d := data[:length]
if len(d) < 2 {
return "", errMsgToLog("readHandshake: remiaining dataLen[%d] is short", len(d))
}
namesLen := int(d[0])<<8 | int(d[1])
d = d[2:]
if len(d) != namesLen {
return "", errMsgToLog("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d))
}
for len(d) > 0 {
if len(d) < 3 {
return "", errMsgToLog("readHandshake: extension serverNameLen[%d] is short", len(d))
}
nameType := d[0]
nameLen := int(d[1])<<8 | int(d[2])
d = d[3:]
if len(d) < nameLen {
return "", errMsgToLog("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d))
}
if nameType == 0 {
suffix := ""
if VhostHttpsPort != 443 {
suffix = fmt.Sprintf(":%d", VhostHttpsPort)
}
serverName := string(d[:nameLen])
domain := strings.ToLower(strings.TrimSpace(serverName)) + suffix
return domain, nil
break
}
d = d[nameLen:]
}
}
data = data[length:]
}
//return "test.codermao.com:8082", nil
return "", errMsgToLog("Unknow error")
}
func GetHttpsHostname(c *conn.Conn) (sc net.Conn, routerName string, err error) {
log.Info("GetHttpsHostname")
sc, rd := newShareConn(c.TcpConn)
host, err := readHandshake(rd)
if err != nil {
return sc, "", err
}
/*
if _, ok := c.TcpConn.(*tls.Conn); ok {
log.Warn("convert to tlsConn success")
} else {
log.Warn("convert to tlsConn error")
}*/
//tcpConn.
log.Debug("GetHttpsHostname: %s", host)
return sc, host, nil
}
func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout)
return &HttpsMuxer{mux}, err
}