summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--main.go61
-rw-r--r--shell.go99
2 files changed, 68 insertions, 92 deletions
diff --git a/main.go b/main.go
index 1bbe285..c407351 100644
--- a/main.go
+++ b/main.go
@@ -4,64 +4,47 @@ import (
"fmt"
"net"
+ "github.com/gliderlabs/ssh"
"github.com/magiconair/properties"
- "golang.org/x/crypto/ssh"
+ gossh "golang.org/x/crypto/ssh"
)
func main() {
// 解析 server.properties
conf := properties.MustLoadFile("server.properties", properties.UTF8)
- var SSH_SERVER struct {
- Host string
- Port string
- User string
- Pass string
- config *ssh.ServerConfig
- }
-
- SSH_SERVER.Host = conf.MustGetString("server-ip")
- SSH_SERVER.Port = conf.MustGetString("server-port")
- SSH_SERVER.User = conf.MustGetString("term-user")
- SSH_SERVER.Pass = conf.MustGetString("term-pass")
-
- // 创建 ssh 密码认证
- SSH_SERVER.config = &ssh.ServerConfig{
- PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
- if conn.User() == SSH_SERVER.User && string(password) == SSH_SERVER.Pass {
- return nil, nil
- }
-
- return nil, ssh.ErrNoAuth
- },
+ SSH_SERVER := struct {
+ Host string
+ Port string
+ User string
+ Pass string
+ }{
+ Host: conf.MustGetString("server-ip"),
+ Port: conf.MustGetString("server-port"),
+ User: conf.MustGetString("term-user"),
+ Pass: conf.MustGetString("term-pass"),
}
// 创建 ssh 服务器密钥
- privateKeySigner, err := ssh.ParsePrivateKey(privatePEM)
+ privateKeySigner, err := gossh.ParsePrivateKey(privatePEM)
if err != nil {
panic(fmt.Errorf("不能解析私钥: %v", err))
}
- SSH_SERVER.config.AddHostKey(privateKeySigner)
-
// 在指定端口开启服务
address := net.JoinHostPort(SSH_SERVER.Host, SSH_SERVER.Port)
- listener, err := net.Listen("tcp", address)
- if err != nil {
- panic(fmt.Errorf("不能在 %s 上创建服务: %v", address, err))
+ s := &ssh.Server{
+ Addr: address,
+ Handler: shell,
+ PasswordHandler: func(ctx ssh.Context, password string) bool {
+ return ctx.User() == SSH_SERVER.User && password == SSH_SERVER.Pass
+ },
}
+ s.AddHostKey(privateKeySigner)
fmt.Println("Server Address:", address)
-
- // 连接到系统 shell
- for {
- conn, err := listener.Accept()
- if err != nil {
- fmt.Println("Can not accept connection:", err)
- }
-
- go shell(conn, SSH_SERVER.config)
+ if err := s.ListenAndServe(); err != nil {
+ panic(fmt.Errorf("不能启动服务器: %v", err))
}
-
}
diff --git a/shell.go b/shell.go
index 3492b33..d750136 100644
--- a/shell.go
+++ b/shell.go
@@ -2,69 +2,62 @@ package main
import (
"fmt"
- "net"
+ "io"
"os"
"os/exec"
- "golang.org/x/crypto/ssh"
+ "github.com/creack/pty"
+ "github.com/gliderlabs/ssh"
)
-func shell(conn net.Conn, config *ssh.ServerConfig) {
- sshConn, chans, reqs, err := ssh.NewServerConn(conn, config)
- if err != nil {
- fmt.Println("不能创建连接:", err)
+func shell(s ssh.Session) {
+ ptyReq, winCh, isPty := s.Pty()
+
+ if !isPty {
+ fmt.Fprintln(s, "Must be PTY")
+ s.Exit(1)
return
}
- defer sshConn.Close()
-
- fmt.Println("New connection from", sshConn.RemoteAddr(), "with client version", sshConn.ClientVersion())
-
- go ssh.DiscardRequests(reqs)
-
- for newChannel := range chans {
- if newChannel.ChannelType() != "session" {
- newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
- continue
- }
-
- channel, requests, err := newChannel.Accept()
- if err != nil {
- fmt.Println("Can not accept channel:", err)
- continue
- }
- defer channel.Close()
- shell := os.Getenv("SHELL")
- if shell == "" {
- shell = "cmd.exe"
- }
+ shell := os.Getenv("SHELL")
+ if shell == "" {
+ shell = "/bin/sh"
+ }
- command := exec.Command(shell)
- command.Stdin = channel
- command.Stdout = channel
- command.Stderr = channel
+ cmd := exec.Command(shell)
+ cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
- if err := command.Start(); err != nil {
- fmt.Println("Failed to start shell:", err)
- return
+ ptmx, err := pty.Start(cmd)
+ if err != nil {
+ fmt.Fprintln(s, "Can not start shell")
+ fmt.Println(err)
+ s.Exit(1)
+ return
+ }
+ defer func() { _ = ptmx.Close() }()
+
+ go func() {
+ for win := range winCh {
+ pty.Setsize(ptmx, &pty.Winsize{
+ Rows: uint16(win.Height),
+ Cols: uint16(win.Width),
+ })
}
-
- go func() {
- if err := command.Wait(); err != nil {
- fmt.Println("Run shell failed:", err)
- }
- channel.Close()
- }()
-
- go func() {
- for req := range requests {
- switch req.Type {
- case "shell":
- req.Reply(true, nil)
- default:
- req.Reply(false, nil)
- }
- }
- }()
+ }()
+
+ go func() {
+ io.Copy(s, ptmx)
+ s.Close()
+ }()
+
+ go func() {
+ io.Copy(ptmx, s)
+ ptmx.Close()
+ }()
+
+ if err := cmd.Wait(); err != nil {
+ s.Exit(1)
+ } else {
+ s.Exit(0)
}
}