diff --git a/pkg/util/vhost/http.go b/pkg/util/vhost/http.go index bc458a5c..8b449eda 100644 --- a/pkg/util/vhost/http.go +++ b/pkg/util/vhost/http.go @@ -16,6 +16,7 @@ package vhost import ( "context" + "crypto/tls" "encoding/base64" "errors" "fmt" @@ -49,6 +50,21 @@ type HTTPReverseProxy struct { responseHeaderTimeout time.Duration } +type grpcSwitchH2Transport struct { + h2t *http2.Transport +} + +var _ http.RoundTripper = (*grpcSwitchH2Transport)(nil) + +// RoundTrip implements http.RoundTripper. +func (d *grpcSwitchH2Transport) RoundTrip(req *http.Request) (*http.Response, error) { + if req.Header.Get("Content-Type") != "application/grpc" { + return nil, http.ErrSkipAltProtocol + } + + return d.h2t.RoundTrip(req) +} + func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *HTTPReverseProxy { if option.ResponseHeaderTimeoutS <= 0 { option.ResponseHeaderTimeoutS = 60 @@ -57,6 +73,40 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) * responseHeaderTimeout: time.Duration(option.ResponseHeaderTimeoutS) * time.Second, vhostRouter: vhostRouter, } + + h2t := &http2.Transport{ + AllowHTTP: true, + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + return rp.CreateConnection(ctx.Value(RouteInfoKey).(*RequestRouteInfo), true) + }, + } + h1t := &http.Transport{ + ResponseHeaderTimeout: rp.responseHeaderTimeout, + IdleConnTimeout: 60 * time.Second, + MaxIdleConnsPerHost: 5, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return rp.CreateConnection(ctx.Value(RouteInfoKey).(*RequestRouteInfo), true) + }, + ForceAttemptHTTP2: true, + Proxy: func(req *http.Request) (*url.URL, error) { + // Use proxy mode if there is host in HTTP first request line. + // GET http://example.com/ HTTP/1.1 + // Host: example.com + // + // Normal: + // GET / HTTP/1.1 + // Host: example.com + urlHost := req.Context().Value(RouteInfoKey).(*RequestRouteInfo).URLHost + if urlHost != "" { + return req.URL, nil + } + return nil, nil + }, + } + // alrough register http protocol with a h2t, but it's only used for grpc. + // for normal http request, it still uses h1t. + h1t.RegisterProtocol("http", &grpcSwitchH2Transport{h2t: h2t}) + proxy := &httputil.ReverseProxy{ // Modify incoming requests by route policies. Rewrite: func(r *httputil.ProxyRequest) { @@ -103,29 +153,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) * } return nil }, - // Create a connection to one proxy routed by route policy. - Transport: &http.Transport{ - ResponseHeaderTimeout: rp.responseHeaderTimeout, - IdleConnTimeout: 60 * time.Second, - MaxIdleConnsPerHost: 5, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return rp.CreateConnection(ctx.Value(RouteInfoKey).(*RequestRouteInfo), true) - }, - Proxy: func(req *http.Request) (*url.URL, error) { - // Use proxy mode if there is host in HTTP first request line. - // GET http://example.com/ HTTP/1.1 - // Host: example.com - // - // Normal: - // GET / HTTP/1.1 - // Host: example.com - urlHost := req.Context().Value(RouteInfoKey).(*RequestRouteInfo).URLHost - if urlHost != "" { - return req.URL, nil - } - return nil, nil - }, - }, + Transport: h1t, BufferPool: pool.NewBuffer(32 * 1024), ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0), ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {