diff options
| -rw-r--r-- | main.go | 61 | ||||
| -rw-r--r-- | shell.go | 99 |
2 files changed, 68 insertions, 92 deletions
@@ -4,64 +4,47 @@ import ( "fmt" "net" + "github.com/gliderlabs/ssh" "github.com/magiconair/properties" - "golang.org/x/crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) func main() { // 解析 server.properties conf := properties.MustLoadFile("server.properties", properties.UTF8) - var SSH_SERVER struct { - Host string - Port string - User string - Pass string - config *ssh.ServerConfig - } - - SSH_SERVER.Host = conf.MustGetString("server-ip") - SSH_SERVER.Port = conf.MustGetString("server-port") - SSH_SERVER.User = conf.MustGetString("term-user") - SSH_SERVER.Pass = conf.MustGetString("term-pass") - - // 创建 ssh 密码认证 - SSH_SERVER.config = &ssh.ServerConfig{ - PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { - if conn.User() == SSH_SERVER.User && string(password) == SSH_SERVER.Pass { - return nil, nil - } - - return nil, ssh.ErrNoAuth - }, + SSH_SERVER := struct { + Host string + Port string + User string + Pass string + }{ + Host: conf.MustGetString("server-ip"), + Port: conf.MustGetString("server-port"), + User: conf.MustGetString("term-user"), + Pass: conf.MustGetString("term-pass"), } // 创建 ssh 服务器密钥 - privateKeySigner, err := ssh.ParsePrivateKey(privatePEM) + privateKeySigner, err := gossh.ParsePrivateKey(privatePEM) if err != nil { panic(fmt.Errorf("不能解析私钥: %v", err)) } - SSH_SERVER.config.AddHostKey(privateKeySigner) - // 在指定端口开启服务 address := net.JoinHostPort(SSH_SERVER.Host, SSH_SERVER.Port) - listener, err := net.Listen("tcp", address) - if err != nil { - panic(fmt.Errorf("不能在 %s 上创建服务: %v", address, err)) + s := &ssh.Server{ + Addr: address, + Handler: shell, + PasswordHandler: func(ctx ssh.Context, password string) bool { + return ctx.User() == SSH_SERVER.User && password == SSH_SERVER.Pass + }, } + s.AddHostKey(privateKeySigner) fmt.Println("Server Address:", address) - - // 连接到系统 shell - for { - conn, err := listener.Accept() - if err != nil { - fmt.Println("Can not accept connection:", err) - } - - go shell(conn, SSH_SERVER.config) + if err := s.ListenAndServe(); err != nil { + panic(fmt.Errorf("不能启动服务器: %v", err)) } - } @@ -2,69 +2,62 @@ package main import ( "fmt" - "net" + "io" "os" "os/exec" - "golang.org/x/crypto/ssh" + "github.com/creack/pty" + "github.com/gliderlabs/ssh" ) -func shell(conn net.Conn, config *ssh.ServerConfig) { - sshConn, chans, reqs, err := ssh.NewServerConn(conn, config) - if err != nil { - fmt.Println("不能创建连接:", err) +func shell(s ssh.Session) { + ptyReq, winCh, isPty := s.Pty() + + if !isPty { + fmt.Fprintln(s, "Must be PTY") + s.Exit(1) return } - defer sshConn.Close() - - fmt.Println("New connection from", sshConn.RemoteAddr(), "with client version", sshConn.ClientVersion()) - - go ssh.DiscardRequests(reqs) - - for newChannel := range chans { - if newChannel.ChannelType() != "session" { - newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") - continue - } - - channel, requests, err := newChannel.Accept() - if err != nil { - fmt.Println("Can not accept channel:", err) - continue - } - defer channel.Close() - shell := os.Getenv("SHELL") - if shell == "" { - shell = "cmd.exe" - } + shell := os.Getenv("SHELL") + if shell == "" { + shell = "/bin/sh" + } - command := exec.Command(shell) - command.Stdin = channel - command.Stdout = channel - command.Stderr = channel + cmd := exec.Command(shell) + cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) - if err := command.Start(); err != nil { - fmt.Println("Failed to start shell:", err) - return + ptmx, err := pty.Start(cmd) + if err != nil { + fmt.Fprintln(s, "Can not start shell") + fmt.Println(err) + s.Exit(1) + return + } + defer func() { _ = ptmx.Close() }() + + go func() { + for win := range winCh { + pty.Setsize(ptmx, &pty.Winsize{ + Rows: uint16(win.Height), + Cols: uint16(win.Width), + }) } - - go func() { - if err := command.Wait(); err != nil { - fmt.Println("Run shell failed:", err) - } - channel.Close() - }() - - go func() { - for req := range requests { - switch req.Type { - case "shell": - req.Reply(true, nil) - default: - req.Reply(false, nil) - } - } - }() + }() + + go func() { + io.Copy(s, ptmx) + s.Close() + }() + + go func() { + io.Copy(ptmx, s) + ptmx.Close() + }() + + if err := cmd.Wait(); err != nil { + s.Exit(1) + } else { + s.Exit(0) } } |
