package ssh import ( "encoding/binary" "errors" "flag" "fmt" "io" "net" "strconv" "strings" "sync" "sync/atomic" "time" gerror "github.com/fatedier/golib/errors" "golang.org/x/crypto/ssh" v1 "github.com/fatedier/frp/pkg/config/v1" "github.com/fatedier/frp/pkg/util/log" ) const ( // ssh protocol define // https://datatracker.ietf.org/doc/html/rfc4254#page-16 ChannelTypeServerOpenChannel = "forwarded-tcpip" RequestTypeForward = "tcpip-forward" // golang ssh package define. // https://pkg.go.dev/golang.org/x/crypto/ssh RequestTypeHeartbeat = "keepalive@openssh.com" ) // 当 proxy 失败会返回该错误 type VProxyError struct{} // ssh protocol define // https://datatracker.ietf.org/doc/html/rfc4254#page-16 // parse ssh client cmds input type forwardedTCPPayload struct { Addr string Port uint32 // can be default empty value but do not delete it // because ssh protocol shoule be reserved OriginAddr string OriginPort uint32 } // custom define // parse ssh client cmds input type CmdPayload struct { Address string Port uint32 } // custom define // with frp control cmds type ExtraPayload struct { Type string // TODO port can be set by extra message and priority to ssh raw cmd Address string Port uint32 } type Service struct { tcpConn net.Conn cfg *ssh.ServerConfig sshConn *ssh.ServerConn gChannel <-chan ssh.NewChannel gReq <-chan *ssh.Request addrPayloadCh chan CmdPayload extraPayloadCh chan ExtraPayload proxyPayloadCh chan v1.ProxyConfigurer replyCh chan interface{} closeCh chan struct{} exit int32 } func NewSSHService( tcpConn net.Conn, cfg *ssh.ServerConfig, proxyPayloadCh chan v1.ProxyConfigurer, replyCh chan interface{}, ) (ss *Service, err error) { ss = &Service{ tcpConn: tcpConn, cfg: cfg, addrPayloadCh: make(chan CmdPayload), extraPayloadCh: make(chan ExtraPayload), proxyPayloadCh: proxyPayloadCh, replyCh: replyCh, closeCh: make(chan struct{}), exit: 0, } ss.sshConn, ss.gChannel, ss.gReq, err = ssh.NewServerConn(tcpConn, cfg) if err != nil { log.Error("ssh handshake error: %v", err) return nil, err } log.Info("ssh connection success") return ss, nil } func (ss *Service) Run() { go ss.loopGenerateProxy() go ss.loopParseCmdPayload() go ss.loopParseExtraPayload() go ss.loopReply() } func (ss *Service) Exit() <-chan struct{} { return ss.closeCh } func (ss *Service) Close() { if atomic.LoadInt32(&ss.exit) == 1 { return } select { case <-ss.closeCh: return default: } close(ss.closeCh) close(ss.addrPayloadCh) close(ss.extraPayloadCh) _ = ss.sshConn.Wait() ss.sshConn.Close() ss.tcpConn.Close() atomic.StoreInt32(&ss.exit, 1) log.Info("ssh service close") } func (ss *Service) loopParseCmdPayload() { for { select { case req, ok := <-ss.gReq: if !ok { log.Info("global request is close") ss.Close() return } switch req.Type { case RequestTypeForward: var addrPayload CmdPayload if err := ssh.Unmarshal(req.Payload, &addrPayload); err != nil { log.Error("ssh unmarshal error: %v", err) return } _ = gerror.PanicToError(func() { ss.addrPayloadCh <- addrPayload }) default: if req.Type == RequestTypeHeartbeat { log.Debug("ssh heartbeat data") } else { log.Info("default req, data: %v", req) } } if req.WantReply { err := req.Reply(true, nil) if err != nil { log.Error("reply to ssh client error: %v", err) } } case <-ss.closeCh: log.Info("loop parse cmd payload close") return } } } func (ss *Service) loopSendHeartbeat(ch ssh.Channel) { tk := time.NewTicker(time.Second * 60) defer tk.Stop() for { select { case <-tk.C: ok, err := ch.SendRequest("heartbeat", false, nil) if err != nil { log.Error("channel send req error: %v", err) if err == io.EOF { ss.Close() return } continue } log.Debug("heartbeat send success, ok: %v", ok) case <-ss.closeCh: return } } } func (ss *Service) loopParseExtraPayload() { log.Info("loop parse extra payload start") for newChannel := range ss.gChannel { ch, req, err := newChannel.Accept() if err != nil { log.Error("channel accept error: %v", err) return } go ss.loopSendHeartbeat(ch) go func(req <-chan *ssh.Request) { for r := range req { if len(r.Payload) <= 4 { log.Info("r.payload is less than 4") continue } if !strings.Contains(string(r.Payload), "tcp") && !strings.Contains(string(r.Payload), "http") { log.Info("ssh protocol exchange data") continue } // [4byte data_len|data] end := 4 + binary.BigEndian.Uint32(r.Payload[:4]) if end > uint32(len(r.Payload)) { end = uint32(len(r.Payload)) } p := string(r.Payload[4:end]) msg, err := parseSSHExtraMessage(p) if err != nil { log.Error("parse ssh extra message error: %v, payload: %v", err, r.Payload) continue } _ = gerror.PanicToError(func() { ss.extraPayloadCh <- msg }) return } }(req) } } func (ss *Service) SSHConn() *ssh.ServerConn { return ss.sshConn } func (ss *Service) TCPConn() net.Conn { return ss.tcpConn } func (ss *Service) loopReply() { for { select { case <-ss.closeCh: log.Info("loop reply close") return case req := <-ss.replyCh: switch req.(type) { case *VProxyError: log.Error("run frp proxy error, close ssh service") ss.Close() default: // TODO } } } } func (ss *Service) loopGenerateProxy() { log.Info("loop generate proxy start") for { if atomic.LoadInt32(&ss.exit) == 1 { return } wg := new(sync.WaitGroup) wg.Add(2) var p1 CmdPayload var p2 ExtraPayload go func() { defer wg.Done() for { select { case <-ss.closeCh: return case p1 = <-ss.addrPayloadCh: return } } }() go func() { defer wg.Done() for { select { case <-ss.closeCh: return case p2 = <-ss.extraPayloadCh: return } } }() wg.Wait() if atomic.LoadInt32(&ss.exit) == 1 { return } switch p2.Type { case "http": case "tcp": ss.proxyPayloadCh <- &v1.TCPProxyConfig{ ProxyBaseConfig: v1.ProxyBaseConfig{ Name: fmt.Sprintf("ssh-proxy-%v-%v", ss.tcpConn.RemoteAddr().String(), time.Now().UnixNano()), Type: p2.Type, ProxyBackend: v1.ProxyBackend{ LocalIP: p1.Address, }, }, RemotePort: int(p1.Port), } default: log.Warn("invalid frp proxy type: %v", p2.Type) } } } func parseSSHExtraMessage(s string) (p ExtraPayload, err error) { sn := len(s) log.Info("parse ssh extra message: %v", s) ss := strings.Fields(s) if len(ss) == 0 { if sn != 0 { ss = append(ss, s) } else { return p, fmt.Errorf("invalid ssh input, args: %v", ss) } } for i, v := range ss { ss[i] = strings.TrimSpace(v) } if ss[0] != "tcp" && ss[0] != "http" { return p, fmt.Errorf("only support tcp/http now") } switch ss[0] { case "tcp": tcpCmd, err := ParseTCPCommand(ss) if err != nil { return ExtraPayload{}, fmt.Errorf("invalid ssh input: %v", err) } port, _ := strconv.Atoi(tcpCmd.Port) p = ExtraPayload{ Type: "tcp", Address: tcpCmd.Address, Port: uint32(port), } case "http": httpCmd, err := ParseHTTPCommand(ss) if err != nil { return ExtraPayload{}, fmt.Errorf("invalid ssh input: %v", err) } _ = httpCmd p = ExtraPayload{ Type: "http", } } return p, nil } type HTTPCommand struct { Domain string BasicAuthUser string BasicAuthPass string } func ParseHTTPCommand(params []string) (*HTTPCommand, error) { if len(params) < 2 { return nil, errors.New("invalid HTTP command") } var ( basicAuth string domainURL string basicAuthUser string basicAuthPass string ) fs := flag.NewFlagSet("http", flag.ContinueOnError) fs.StringVar(&basicAuth, "basic-auth", "", "") fs.StringVar(&domainURL, "domain", "", "") fs.SetOutput(&nullWriter{}) // Disables usage output err := fs.Parse(params[2:]) if err != nil { if !errors.Is(err, flag.ErrHelp) { return nil, err } } if basicAuth != "" { authParts := strings.SplitN(basicAuth, ":", 2) basicAuthUser = authParts[0] if len(authParts) > 1 { basicAuthPass = authParts[1] } } httpCmd := &HTTPCommand{ Domain: domainURL, BasicAuthUser: basicAuthUser, BasicAuthPass: basicAuthPass, } return httpCmd, nil } type TCPCommand struct { Address string Port string } func ParseTCPCommand(params []string) (*TCPCommand, error) { if len(params) == 0 || params[0] != "tcp" { return nil, errors.New("invalid TCP command") } if len(params) == 1 { return &TCPCommand{}, nil } var ( address string port string ) fs := flag.NewFlagSet("tcp", flag.ContinueOnError) fs.StringVar(&address, "address", "", "The IP address to listen on") fs.StringVar(&port, "port", "", "The port to listen on") fs.SetOutput(&nullWriter{}) // Disables usage output args := params[1:] err := fs.Parse(args) if err != nil { if !errors.Is(err, flag.ErrHelp) { return nil, err } } parsedAddr, err := net.ResolveIPAddr("ip", address) if err != nil { return nil, err } if _, err := net.LookupPort("tcp", port); err != nil { return nil, err } tcpCmd := &TCPCommand{ Address: parsedAddr.String(), Port: port, } return tcpCmd, nil } type nullWriter struct{} func (w *nullWriter) Write(p []byte) (n int, err error) { return len(p), nil }