diff --git a/utils/io/io.go b/utils/io/io.go index 68d932c7..ace7fffd 100644 --- a/utils/io/io.go +++ b/utils/io/io.go @@ -18,8 +18,6 @@ import ( "io" "sync" - "github.com/golang/snappy" - "github.com/fatedier/frp/utils/crypto" "github.com/fatedier/frp/utils/pool" ) @@ -55,8 +53,13 @@ func WithEncryption(rwc io.ReadWriteCloser, key []byte) (io.ReadWriteCloser, err } func WithCompression(rwc io.ReadWriteCloser) io.ReadWriteCloser { - return WrapReadWriteCloser(snappy.NewReader(rwc), snappy.NewWriter(rwc), func() error { - return rwc.Close() + sr := pool.GetSnappyReader(rwc) + sw := pool.GetSnappyWriter(rwc) + return WrapReadWriteCloser(sr, sw, func() error { + err := rwc.Close() + pool.PutSnappyReader(sr) + pool.PutSnappyWriter(sw) + return err }) } @@ -64,13 +67,18 @@ type ReadWriteCloser struct { r io.Reader w io.Writer closeFn func() error + + closed bool + mu sync.Mutex } +// closeFn will be called only once func WrapReadWriteCloser(r io.Reader, w io.Writer, closeFn func() error) io.ReadWriteCloser { return &ReadWriteCloser{ r: r, w: w, closeFn: closeFn, + closed: false, } } @@ -83,6 +91,14 @@ func (rwc *ReadWriteCloser) Write(p []byte) (n int, err error) { } func (rwc *ReadWriteCloser) Close() (errRet error) { + rwc.mu.Lock() + if rwc.closed { + rwc.mu.Unlock() + return + } + rwc.closed = true + rwc.mu.Unlock() + var err error if rc, ok := rwc.r.(io.Closer); ok { err = rc.Close() diff --git a/utils/pool/snappy.go b/utils/pool/snappy.go new file mode 100644 index 00000000..7eb92266 --- /dev/null +++ b/utils/pool/snappy.go @@ -0,0 +1,57 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pool + +import ( + "io" + "sync" + + "github.com/golang/snappy" +) + +var ( + snappyReaderPool sync.Pool + snappyWriterPool sync.Pool +) + +func GetSnappyReader(r io.Reader) *snappy.Reader { + var x interface{} + x = snappyReaderPool.Get() + if x == nil { + return snappy.NewReader(r) + } + sr := x.(*snappy.Reader) + sr.Reset(r) + return sr +} + +func PutSnappyReader(sr *snappy.Reader) { + snappyReaderPool.Put(sr) +} + +func GetSnappyWriter(w io.Writer) *snappy.Writer { + var x interface{} + x = snappyWriterPool.Get() + if x == nil { + return snappy.NewWriter(w) + } + sw := x.(*snappy.Writer) + sw.Reset(w) + return sw +} + +func PutSnappyWriter(sw *snappy.Writer) { + snappyWriterPool.Put(sw) +}