mirror of
https://github.com/fatedier/frp.git
synced 2025-01-23 09:54:20 +00:00
666 lines
14 KiB
Go
666 lines
14 KiB
Go
// Copyright 2009 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package websocket
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"reflect"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
var serverAddr string
|
|
var once sync.Once
|
|
|
|
func echoServer(ws *Conn) {
|
|
defer ws.Close()
|
|
io.Copy(ws, ws)
|
|
}
|
|
|
|
type Count struct {
|
|
S string
|
|
N int
|
|
}
|
|
|
|
func countServer(ws *Conn) {
|
|
defer ws.Close()
|
|
for {
|
|
var count Count
|
|
err := JSON.Receive(ws, &count)
|
|
if err != nil {
|
|
return
|
|
}
|
|
count.N++
|
|
count.S = strings.Repeat(count.S, count.N)
|
|
err = JSON.Send(ws, count)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
type testCtrlAndDataHandler struct {
|
|
hybiFrameHandler
|
|
}
|
|
|
|
func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) {
|
|
h.hybiFrameHandler.conn.wio.Lock()
|
|
defer h.hybiFrameHandler.conn.wio.Unlock()
|
|
w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
n, err := w.Write(b)
|
|
w.Close()
|
|
return n, err
|
|
}
|
|
|
|
func ctrlAndDataServer(ws *Conn) {
|
|
defer ws.Close()
|
|
h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
|
|
ws.frameHandler = h
|
|
|
|
go func() {
|
|
for i := 0; ; i++ {
|
|
var b []byte
|
|
if i%2 != 0 { // with or without payload
|
|
b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i))
|
|
}
|
|
if _, err := h.WritePing(b); err != nil {
|
|
break
|
|
}
|
|
if _, err := h.WritePong(b); err != nil { // unsolicited pong
|
|
break
|
|
}
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
}()
|
|
|
|
b := make([]byte, 128)
|
|
for {
|
|
n, err := ws.Read(b)
|
|
if err != nil {
|
|
break
|
|
}
|
|
if _, err := ws.Write(b[:n]); err != nil {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func subProtocolHandshake(config *Config, req *http.Request) error {
|
|
for _, proto := range config.Protocol {
|
|
if proto == "chat" {
|
|
config.Protocol = []string{proto}
|
|
return nil
|
|
}
|
|
}
|
|
return ErrBadWebSocketProtocol
|
|
}
|
|
|
|
func subProtoServer(ws *Conn) {
|
|
for _, proto := range ws.Config().Protocol {
|
|
io.WriteString(ws, proto)
|
|
}
|
|
}
|
|
|
|
func startServer() {
|
|
http.Handle("/echo", Handler(echoServer))
|
|
http.Handle("/count", Handler(countServer))
|
|
http.Handle("/ctrldata", Handler(ctrlAndDataServer))
|
|
subproto := Server{
|
|
Handshake: subProtocolHandshake,
|
|
Handler: Handler(subProtoServer),
|
|
}
|
|
http.Handle("/subproto", subproto)
|
|
server := httptest.NewServer(nil)
|
|
serverAddr = server.Listener.Addr().String()
|
|
log.Print("Test WebSocket server listening on ", serverAddr)
|
|
}
|
|
|
|
func newConfig(t *testing.T, path string) *Config {
|
|
config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
|
|
return config
|
|
}
|
|
|
|
func TestEcho(t *testing.T) {
|
|
once.Do(startServer)
|
|
|
|
// websocket.Dial()
|
|
client, err := net.Dial("tcp", serverAddr)
|
|
if err != nil {
|
|
t.Fatal("dialing", err)
|
|
}
|
|
conn, err := NewClient(newConfig(t, "/echo"), client)
|
|
if err != nil {
|
|
t.Errorf("WebSocket handshake error: %v", err)
|
|
return
|
|
}
|
|
|
|
msg := []byte("hello, world\n")
|
|
if _, err := conn.Write(msg); err != nil {
|
|
t.Errorf("Write: %v", err)
|
|
}
|
|
var actual_msg = make([]byte, 512)
|
|
n, err := conn.Read(actual_msg)
|
|
if err != nil {
|
|
t.Errorf("Read: %v", err)
|
|
}
|
|
actual_msg = actual_msg[0:n]
|
|
if !bytes.Equal(msg, actual_msg) {
|
|
t.Errorf("Echo: expected %q got %q", msg, actual_msg)
|
|
}
|
|
conn.Close()
|
|
}
|
|
|
|
func TestAddr(t *testing.T) {
|
|
once.Do(startServer)
|
|
|
|
// websocket.Dial()
|
|
client, err := net.Dial("tcp", serverAddr)
|
|
if err != nil {
|
|
t.Fatal("dialing", err)
|
|
}
|
|
conn, err := NewClient(newConfig(t, "/echo"), client)
|
|
if err != nil {
|
|
t.Errorf("WebSocket handshake error: %v", err)
|
|
return
|
|
}
|
|
|
|
ra := conn.RemoteAddr().String()
|
|
if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
|
|
t.Errorf("Bad remote addr: %v", ra)
|
|
}
|
|
la := conn.LocalAddr().String()
|
|
if !strings.HasPrefix(la, "http://") {
|
|
t.Errorf("Bad local addr: %v", la)
|
|
}
|
|
conn.Close()
|
|
}
|
|
|
|
func TestCount(t *testing.T) {
|
|
once.Do(startServer)
|
|
|
|
// websocket.Dial()
|
|
client, err := net.Dial("tcp", serverAddr)
|
|
if err != nil {
|
|
t.Fatal("dialing", err)
|
|
}
|
|
conn, err := NewClient(newConfig(t, "/count"), client)
|
|
if err != nil {
|
|
t.Errorf("WebSocket handshake error: %v", err)
|
|
return
|
|
}
|
|
|
|
var count Count
|
|
count.S = "hello"
|
|
if err := JSON.Send(conn, count); err != nil {
|
|
t.Errorf("Write: %v", err)
|
|
}
|
|
if err := JSON.Receive(conn, &count); err != nil {
|
|
t.Errorf("Read: %v", err)
|
|
}
|
|
if count.N != 1 {
|
|
t.Errorf("count: expected %d got %d", 1, count.N)
|
|
}
|
|
if count.S != "hello" {
|
|
t.Errorf("count: expected %q got %q", "hello", count.S)
|
|
}
|
|
if err := JSON.Send(conn, count); err != nil {
|
|
t.Errorf("Write: %v", err)
|
|
}
|
|
if err := JSON.Receive(conn, &count); err != nil {
|
|
t.Errorf("Read: %v", err)
|
|
}
|
|
if count.N != 2 {
|
|
t.Errorf("count: expected %d got %d", 2, count.N)
|
|
}
|
|
if count.S != "hellohello" {
|
|
t.Errorf("count: expected %q got %q", "hellohello", count.S)
|
|
}
|
|
conn.Close()
|
|
}
|
|
|
|
func TestWithQuery(t *testing.T) {
|
|
once.Do(startServer)
|
|
|
|
client, err := net.Dial("tcp", serverAddr)
|
|
if err != nil {
|
|
t.Fatal("dialing", err)
|
|
}
|
|
|
|
config := newConfig(t, "/echo")
|
|
config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
|
|
if err != nil {
|
|
t.Fatal("location url", err)
|
|
}
|
|
|
|
ws, err := NewClient(config, client)
|
|
if err != nil {
|
|
t.Errorf("WebSocket handshake: %v", err)
|
|
return
|
|
}
|
|
ws.Close()
|
|
}
|
|
|
|
func testWithProtocol(t *testing.T, subproto []string) (string, error) {
|
|
once.Do(startServer)
|
|
|
|
client, err := net.Dial("tcp", serverAddr)
|
|
if err != nil {
|
|
t.Fatal("dialing", err)
|
|
}
|
|
|
|
config := newConfig(t, "/subproto")
|
|
config.Protocol = subproto
|
|
|
|
ws, err := NewClient(config, client)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
msg := make([]byte, 16)
|
|
n, err := ws.Read(msg)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
ws.Close()
|
|
return string(msg[:n]), nil
|
|
}
|
|
|
|
func TestWithProtocol(t *testing.T) {
|
|
proto, err := testWithProtocol(t, []string{"chat"})
|
|
if err != nil {
|
|
t.Errorf("SubProto: unexpected error: %v", err)
|
|
}
|
|
if proto != "chat" {
|
|
t.Errorf("SubProto: expected %q, got %q", "chat", proto)
|
|
}
|
|
}
|
|
|
|
func TestWithTwoProtocol(t *testing.T) {
|
|
proto, err := testWithProtocol(t, []string{"test", "chat"})
|
|
if err != nil {
|
|
t.Errorf("SubProto: unexpected error: %v", err)
|
|
}
|
|
if proto != "chat" {
|
|
t.Errorf("SubProto: expected %q, got %q", "chat", proto)
|
|
}
|
|
}
|
|
|
|
func TestWithBadProtocol(t *testing.T) {
|
|
_, err := testWithProtocol(t, []string{"test"})
|
|
if err != ErrBadStatus {
|
|
t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
|
|
}
|
|
}
|
|
|
|
func TestHTTP(t *testing.T) {
|
|
once.Do(startServer)
|
|
|
|
// If the client did not send a handshake that matches the protocol
|
|
// specification, the server MUST return an HTTP response with an
|
|
// appropriate error code (such as 400 Bad Request)
|
|
resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
|
|
if err != nil {
|
|
t.Errorf("Get: error %#v", err)
|
|
return
|
|
}
|
|
if resp == nil {
|
|
t.Error("Get: resp is null")
|
|
return
|
|
}
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestTrailingSpaces(t *testing.T) {
|
|
// http://code.google.com/p/go/issues/detail?id=955
|
|
// The last runs of this create keys with trailing spaces that should not be
|
|
// generated by the client.
|
|
once.Do(startServer)
|
|
config := newConfig(t, "/echo")
|
|
for i := 0; i < 30; i++ {
|
|
// body
|
|
ws, err := DialConfig(config)
|
|
if err != nil {
|
|
t.Errorf("Dial #%d failed: %v", i, err)
|
|
break
|
|
}
|
|
ws.Close()
|
|
}
|
|
}
|
|
|
|
func TestDialConfigBadVersion(t *testing.T) {
|
|
once.Do(startServer)
|
|
config := newConfig(t, "/echo")
|
|
config.Version = 1234
|
|
|
|
_, err := DialConfig(config)
|
|
|
|
if dialerr, ok := err.(*DialError); ok {
|
|
if dialerr.Err != ErrBadProtocolVersion {
|
|
t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDialConfigWithDialer(t *testing.T) {
|
|
once.Do(startServer)
|
|
config := newConfig(t, "/echo")
|
|
config.Dialer = &net.Dialer{
|
|
Deadline: time.Now().Add(-time.Minute),
|
|
}
|
|
_, err := DialConfig(config)
|
|
dialerr, ok := err.(*DialError)
|
|
if !ok {
|
|
t.Fatalf("DialError expected, got %#v", err)
|
|
}
|
|
neterr, ok := dialerr.Err.(*net.OpError)
|
|
if !ok {
|
|
t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
|
|
}
|
|
if !neterr.Timeout() {
|
|
t.Fatalf("expected timeout error, got %#v", neterr)
|
|
}
|
|
}
|
|
|
|
func TestSmallBuffer(t *testing.T) {
|
|
// http://code.google.com/p/go/issues/detail?id=1145
|
|
// Read should be able to handle reading a fragment of a frame.
|
|
once.Do(startServer)
|
|
|
|
// websocket.Dial()
|
|
client, err := net.Dial("tcp", serverAddr)
|
|
if err != nil {
|
|
t.Fatal("dialing", err)
|
|
}
|
|
conn, err := NewClient(newConfig(t, "/echo"), client)
|
|
if err != nil {
|
|
t.Errorf("WebSocket handshake error: %v", err)
|
|
return
|
|
}
|
|
|
|
msg := []byte("hello, world\n")
|
|
if _, err := conn.Write(msg); err != nil {
|
|
t.Errorf("Write: %v", err)
|
|
}
|
|
var small_msg = make([]byte, 8)
|
|
n, err := conn.Read(small_msg)
|
|
if err != nil {
|
|
t.Errorf("Read: %v", err)
|
|
}
|
|
if !bytes.Equal(msg[:len(small_msg)], small_msg) {
|
|
t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
|
|
}
|
|
var second_msg = make([]byte, len(msg))
|
|
n, err = conn.Read(second_msg)
|
|
if err != nil {
|
|
t.Errorf("Read: %v", err)
|
|
}
|
|
second_msg = second_msg[0:n]
|
|
if !bytes.Equal(msg[len(small_msg):], second_msg) {
|
|
t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
|
|
}
|
|
conn.Close()
|
|
}
|
|
|
|
var parseAuthorityTests = []struct {
|
|
in *url.URL
|
|
out string
|
|
}{
|
|
{
|
|
&url.URL{
|
|
Scheme: "ws",
|
|
Host: "www.google.com",
|
|
},
|
|
"www.google.com:80",
|
|
},
|
|
{
|
|
&url.URL{
|
|
Scheme: "wss",
|
|
Host: "www.google.com",
|
|
},
|
|
"www.google.com:443",
|
|
},
|
|
{
|
|
&url.URL{
|
|
Scheme: "ws",
|
|
Host: "www.google.com:80",
|
|
},
|
|
"www.google.com:80",
|
|
},
|
|
{
|
|
&url.URL{
|
|
Scheme: "wss",
|
|
Host: "www.google.com:443",
|
|
},
|
|
"www.google.com:443",
|
|
},
|
|
// some invalid ones for parseAuthority. parseAuthority doesn't
|
|
// concern itself with the scheme unless it actually knows about it
|
|
{
|
|
&url.URL{
|
|
Scheme: "http",
|
|
Host: "www.google.com",
|
|
},
|
|
"www.google.com",
|
|
},
|
|
{
|
|
&url.URL{
|
|
Scheme: "http",
|
|
Host: "www.google.com:80",
|
|
},
|
|
"www.google.com:80",
|
|
},
|
|
{
|
|
&url.URL{
|
|
Scheme: "asdf",
|
|
Host: "127.0.0.1",
|
|
},
|
|
"127.0.0.1",
|
|
},
|
|
{
|
|
&url.URL{
|
|
Scheme: "asdf",
|
|
Host: "www.google.com",
|
|
},
|
|
"www.google.com",
|
|
},
|
|
}
|
|
|
|
func TestParseAuthority(t *testing.T) {
|
|
for _, tt := range parseAuthorityTests {
|
|
out := parseAuthority(tt.in)
|
|
if out != tt.out {
|
|
t.Errorf("got %v; want %v", out, tt.out)
|
|
}
|
|
}
|
|
}
|
|
|
|
type closerConn struct {
|
|
net.Conn
|
|
closed int // count of the number of times Close was called
|
|
}
|
|
|
|
func (c *closerConn) Close() error {
|
|
c.closed++
|
|
return c.Conn.Close()
|
|
}
|
|
|
|
func TestClose(t *testing.T) {
|
|
if runtime.GOOS == "plan9" {
|
|
t.Skip("see golang.org/issue/11454")
|
|
}
|
|
|
|
once.Do(startServer)
|
|
|
|
conn, err := net.Dial("tcp", serverAddr)
|
|
if err != nil {
|
|
t.Fatal("dialing", err)
|
|
}
|
|
|
|
cc := closerConn{Conn: conn}
|
|
|
|
client, err := NewClient(newConfig(t, "/echo"), &cc)
|
|
if err != nil {
|
|
t.Fatalf("WebSocket handshake: %v", err)
|
|
}
|
|
|
|
// set the deadline to ten minutes ago, which will have expired by the time
|
|
// client.Close sends the close status frame.
|
|
conn.SetDeadline(time.Now().Add(-10 * time.Minute))
|
|
|
|
if err := client.Close(); err == nil {
|
|
t.Errorf("ws.Close(): expected error, got %v", err)
|
|
}
|
|
if cc.closed < 1 {
|
|
t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
|
|
}
|
|
}
|
|
|
|
var originTests = []struct {
|
|
req *http.Request
|
|
origin *url.URL
|
|
}{
|
|
{
|
|
req: &http.Request{
|
|
Header: http.Header{
|
|
"Origin": []string{"http://www.example.com"},
|
|
},
|
|
},
|
|
origin: &url.URL{
|
|
Scheme: "http",
|
|
Host: "www.example.com",
|
|
},
|
|
},
|
|
{
|
|
req: &http.Request{},
|
|
},
|
|
}
|
|
|
|
func TestOrigin(t *testing.T) {
|
|
conf := newConfig(t, "/echo")
|
|
conf.Version = ProtocolVersionHybi13
|
|
for i, tt := range originTests {
|
|
origin, err := Origin(conf, tt.req)
|
|
if err != nil {
|
|
t.Error(err)
|
|
continue
|
|
}
|
|
if !reflect.DeepEqual(origin, tt.origin) {
|
|
t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCtrlAndData(t *testing.T) {
|
|
once.Do(startServer)
|
|
|
|
c, err := net.Dial("tcp", serverAddr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
ws, err := NewClient(newConfig(t, "/ctrldata"), c)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer ws.Close()
|
|
|
|
h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
|
|
ws.frameHandler = h
|
|
|
|
b := make([]byte, 128)
|
|
for i := 0; i < 2; i++ {
|
|
data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i))
|
|
if _, err := ws.Write(data); err != nil {
|
|
t.Fatalf("#%d: %v", i, err)
|
|
}
|
|
var ctrl []byte
|
|
if i%2 != 0 { // with or without payload
|
|
ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i))
|
|
}
|
|
if _, err := h.WritePing(ctrl); err != nil {
|
|
t.Fatalf("#%d: %v", i, err)
|
|
}
|
|
n, err := ws.Read(b)
|
|
if err != nil {
|
|
t.Fatalf("#%d: %v", i, err)
|
|
}
|
|
if !bytes.Equal(b[:n], data) {
|
|
t.Fatalf("#%d: got %v; want %v", i, b[:n], data)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCodec_ReceiveLimited(t *testing.T) {
|
|
const limit = 2048
|
|
var payloads [][]byte
|
|
for _, size := range []int{
|
|
1024,
|
|
2048,
|
|
4096, // receive of this message would be interrupted due to limit
|
|
2048, // this one is to make sure next receive recovers discarding leftovers
|
|
} {
|
|
b := make([]byte, size)
|
|
rand.Read(b)
|
|
payloads = append(payloads, b)
|
|
}
|
|
handlerDone := make(chan struct{})
|
|
limitedHandler := func(ws *Conn) {
|
|
defer close(handlerDone)
|
|
ws.MaxPayloadBytes = limit
|
|
defer ws.Close()
|
|
for i, p := range payloads {
|
|
t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit)
|
|
var recv []byte
|
|
err := Message.Receive(ws, &recv)
|
|
switch err {
|
|
case nil:
|
|
case ErrFrameTooLarge:
|
|
if len(p) <= limit {
|
|
t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit)
|
|
}
|
|
continue
|
|
default:
|
|
t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err)
|
|
}
|
|
if len(recv) > limit {
|
|
t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit)
|
|
}
|
|
if !bytes.Equal(p, recv) {
|
|
t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p)
|
|
}
|
|
}
|
|
}
|
|
server := httptest.NewServer(Handler(limitedHandler))
|
|
defer server.CloseClientConnections()
|
|
defer server.Close()
|
|
addr := server.Listener.Addr().String()
|
|
ws, err := Dial("ws://"+addr+"/", "", "http://localhost/")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer ws.Close()
|
|
for i, p := range payloads {
|
|
if err := Message.Send(ws, p); err != nil {
|
|
t.Fatalf("payload #%d (size %d): %v", i, len(p), err)
|
|
}
|
|
}
|
|
<-handlerDone
|
|
}
|