Files
mengyaconnect/mengyaconnect-backend/main.go
2026-03-12 15:01:48 +08:00

716 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"golang.org/x/crypto/ssh"
)
// ─── 持久化数据类型 ───────────────────────────────────────────────
type SSHProfile struct {
Name string `json:"name,omitempty"` // 文件名(不含 .json
Alias string `json:"alias"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password,omitempty"`
PrivateKey string `json:"privateKey,omitempty"`
Passphrase string `json:"passphrase,omitempty"`
}
type Command struct {
Alias string `json:"alias"`
Command string `json:"command"`
}
type ScriptInfo struct {
Name string `json:"name"`
Content string `json:"content,omitempty"`
}
// 配置与数据目录辅助函数见 config.go
type wsMessage struct {
Type string `json:"type"`
Host string `json:"host,omitempty"`
Port int `json:"port,omitempty"`
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
PrivateKey string `json:"privateKey,omitempty"`
Passphrase string `json:"passphrase,omitempty"`
Data string `json:"data,omitempty"`
Cols int `json:"cols,omitempty"`
Rows int `json:"rows,omitempty"`
Status string `json:"status,omitempty"`
Message string `json:"message,omitempty"`
}
type wsWriter struct {
conn *websocket.Conn
mu sync.Mutex
}
func (w *wsWriter) send(msg wsMessage) {
w.mu.Lock()
defer w.mu.Unlock()
_ = w.conn.WriteJSON(msg)
}
func main() {
if mode := os.Getenv("GIN_MODE"); mode != "" {
gin.SetMode(mode)
}
router := gin.New()
router.Use(gin.Logger(), gin.Recovery(), corsMiddleware())
allowedOrigins := parseListEnv("ALLOWED_ORIGINS")
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return isOriginAllowed(r.Header.Get("Origin"), allowedOrigins)
},
}
// ─── 基本配置 CRUD ──────────────────────────────────────────
router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "ok",
"time": time.Now().Format(time.RFC3339),
})
})
router.GET("/api/ws/ssh", func(c *gin.Context) {
handleSSHWebSocket(c, upgrader)
})
// ─── SSH 配置 CRUD ──────────────────────────────────────────
router.GET("/api/ssh", handleListSSH)
router.POST("/api/ssh", handleCreateSSH)
router.PUT("/api/ssh/:name", handleUpdateSSH)
router.DELETE("/api/ssh/:name", handleDeleteSSH)
// ─── 快捷命令 CRUD ─────────────────────────────────────────
router.GET("/api/commands", handleListCommands)
router.POST("/api/commands", handleCreateCommand)
router.PUT("/api/commands/:index", handleUpdateCommand)
router.DELETE("/api/commands/:index", handleDeleteCommand)
// ─── 脚本 CRUD ─────────────────────────────────────────────
router.GET("/api/scripts", handleListScripts)
router.GET("/api/scripts/:name", handleGetScript)
router.POST("/api/scripts", handleCreateScript)
router.PUT("/api/scripts/:name", handleUpdateScript)
router.DELETE("/api/scripts/:name", handleDeleteScript)
addr := getEnv("ADDR", ":"+getEnv("PORT", "8080"))
server := &http.Server{
Addr: addr,
Handler: router,
ReadHeaderTimeout: 10 * time.Second,
}
log.Printf("SSH WebSocket server listening on %s", addr)
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("server error: %v", err)
}
}
func handleSSHWebSocket(c *gin.Context, upgrader websocket.Upgrader) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return
}
defer conn.Close()
conn.SetReadLimit(1 << 20)
writer := &wsWriter{conn: conn}
writer.send(wsMessage{Type: "status", Status: "connected", Message: "WebSocket connected"})
var (
sshClient *ssh.Client
sshSession *ssh.Session
sshStdin io.WriteCloser
stdout io.Reader
stderr io.Reader
cancelFn context.CancelFunc
)
cleanup := func() {
if cancelFn != nil {
cancelFn()
}
if sshSession != nil {
_ = sshSession.Close()
}
if sshClient != nil {
_ = sshClient.Close()
}
}
defer cleanup()
for {
var msg wsMessage
if err := conn.ReadJSON(&msg); err != nil {
writer.send(wsMessage{Type: "status", Status: "closed", Message: "WebSocket closed"})
return
}
switch msg.Type {
case "connect":
if sshSession != nil {
writer.send(wsMessage{Type: "error", Message: "SSH session already exists"})
continue
}
client, session, stdin, out, errOut, err := startSSHSession(msg)
if err != nil {
writer.send(wsMessage{Type: "error", Message: err.Error()})
continue
}
sshClient = client
sshSession = session
sshStdin = stdin
stdout = out
stderr = errOut
ctx, cancel := context.WithCancel(context.Background())
cancelFn = cancel
go streamToWebSocket(ctx, writer, stdout)
go streamToWebSocket(ctx, writer, stderr)
go func() {
_ = session.Wait()
writer.send(wsMessage{Type: "status", Status: "closed", Message: "SSH session closed"})
cleanup()
}()
writer.send(wsMessage{Type: "status", Status: "ready", Message: "SSH connected"})
case "input":
if sshStdin == nil {
writer.send(wsMessage{Type: "error", Message: "SSH session not ready"})
continue
}
if msg.Data != "" {
_, _ = sshStdin.Write([]byte(msg.Data))
}
case "resize":
if sshSession == nil {
continue
}
rows := msg.Rows
cols := msg.Cols
if rows > 0 && cols > 0 {
_ = sshSession.WindowChange(rows, cols)
}
case "ping":
writer.send(wsMessage{Type: "pong"})
case "close":
writer.send(wsMessage{Type: "status", Status: "closing", Message: "Closing SSH session"})
return
}
}
}
func startSSHSession(msg wsMessage) (*ssh.Client, *ssh.Session, io.WriteCloser, io.Reader, io.Reader, error) {
host := strings.TrimSpace(msg.Host)
if host == "" {
return nil, nil, nil, nil, nil, errors.New("host is required")
}
port := msg.Port
if port == 0 {
port = 22
}
user := strings.TrimSpace(msg.Username)
if user == "" {
return nil, nil, nil, nil, nil, errors.New("username is required")
}
auths, err := buildAuthMethods(msg)
if err != nil {
return nil, nil, nil, nil, nil, err
}
cfg := &ssh.ClientConfig{
User: user,
Auth: auths,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 12 * time.Second,
}
addr := fmt.Sprintf("%s:%d", host, port)
client, err := ssh.Dial("tcp", addr, cfg)
if err != nil {
return nil, nil, nil, nil, nil, fmt.Errorf("ssh dial failed: %w", err)
}
session, err := client.NewSession()
if err != nil {
_ = client.Close()
return nil, nil, nil, nil, nil, fmt.Errorf("ssh session failed: %w", err)
}
rows := msg.Rows
cols := msg.Cols
if rows == 0 {
rows = 24
}
if cols == 0 {
cols = 80
}
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
_ = session.Close()
_ = client.Close()
return nil, nil, nil, nil, nil, fmt.Errorf("request pty failed: %w", err)
}
stdin, err := session.StdinPipe()
if err != nil {
_ = session.Close()
_ = client.Close()
return nil, nil, nil, nil, nil, fmt.Errorf("stdin pipe failed: %w", err)
}
stdout, err := session.StdoutPipe()
if err != nil {
_ = session.Close()
_ = client.Close()
return nil, nil, nil, nil, nil, fmt.Errorf("stdout pipe failed: %w", err)
}
stderr, err := session.StderrPipe()
if err != nil {
_ = session.Close()
_ = client.Close()
return nil, nil, nil, nil, nil, fmt.Errorf("stderr pipe failed: %w", err)
}
if err := session.Shell(); err != nil {
_ = session.Close()
_ = client.Close()
return nil, nil, nil, nil, nil, fmt.Errorf("shell start failed: %w", err)
}
return client, session, stdin, stdout, stderr, nil
}
func buildAuthMethods(msg wsMessage) ([]ssh.AuthMethod, error) {
var methods []ssh.AuthMethod
if strings.TrimSpace(msg.PrivateKey) != "" {
signer, err := parsePrivateKey(msg.PrivateKey, msg.Passphrase)
if err != nil {
return nil, fmt.Errorf("private key error: %w", err)
}
methods = append(methods, ssh.PublicKeys(signer))
}
if msg.Password != "" {
methods = append(methods, ssh.Password(msg.Password))
}
if len(methods) == 0 {
return nil, errors.New("no auth method provided")
}
return methods, nil
}
func parsePrivateKey(key, passphrase string) (ssh.Signer, error) {
key = strings.TrimSpace(key)
if passphrase != "" {
return ssh.ParsePrivateKeyWithPassphrase([]byte(key), []byte(passphrase))
}
return ssh.ParsePrivateKey([]byte(key))
}
func streamToWebSocket(ctx context.Context, writer *wsWriter, reader io.Reader) {
buf := make([]byte, 8192)
for {
select {
case <-ctx.Done():
return
default:
}
n, err := reader.Read(buf)
if n > 0 {
writer.send(wsMessage{Type: "output", Data: string(buf[:n])})
}
if err != nil {
return
}
}
}
// CORS、中间件与环境变量工具函数见 config.go
// ═══════════════════════════════════════════════════════════════════
// SSH 配置 CRUD
// ═══════════════════════════════════════════════════════════════════
// GET /api/ssh — 列出所有 SSH 配置
func handleListSSH(c *gin.Context) {
entries, err := os.ReadDir(sshDir())
if err != nil {
c.JSON(http.StatusOK, gin.H{"data": []SSHProfile{}})
return
}
var profiles []SSHProfile
for _, e := range entries {
if e.IsDir() || !strings.HasSuffix(e.Name(), ".json") {
continue
}
raw, err := os.ReadFile(filepath.Join(sshDir(), e.Name()))
if err != nil {
continue
}
var p SSHProfile
if err := json.Unmarshal(raw, &p); err != nil {
continue
}
p.Name = strings.TrimSuffix(e.Name(), ".json")
profiles = append(profiles, p)
}
if profiles == nil {
profiles = []SSHProfile{}
}
c.JSON(http.StatusOK, gin.H{"data": profiles})
}
// POST /api/ssh — 新建 SSH 配置
func handleCreateSSH(c *gin.Context) {
var p SSHProfile
if err := c.ShouldBindJSON(&p); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
if p.Alias == "" || p.Host == "" || p.Username == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "alias、host 和 username 为必填项"})
return
}
name := p.Name
if name == "" {
name = p.Alias
}
safe, err := sanitizeName(strings.ReplaceAll(name, " ", "-"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid name"})
return
}
p.Name = ""
raw, _ := json.MarshalIndent(p, "", " ")
if err := os.MkdirAll(sshDir(), 0o750); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create dir"})
return
}
if err := os.WriteFile(filepath.Join(sshDir(), safe+".json"), raw, 0o600); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to write file"})
return
}
p.Name = safe
c.JSON(http.StatusOK, gin.H{"data": p})
}
// PUT /api/ssh/:name — 更新 SSH 配置
func handleUpdateSSH(c *gin.Context) {
name, err := sanitizeName(c.Param("name"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid name"})
return
}
var p SSHProfile
if err := c.ShouldBindJSON(&p); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
if p.Alias == "" || p.Host == "" || p.Username == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "alias、host 和 username 为必填项"})
return
}
p.Name = ""
raw, _ := json.MarshalIndent(p, "", " ")
filePath := filepath.Join(sshDir(), name+".json")
if _, err := os.Stat(filePath); os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
return
}
if err := os.WriteFile(filePath, raw, 0o600); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to write file"})
return
}
p.Name = name
c.JSON(http.StatusOK, gin.H{"data": p})
}
// DELETE /api/ssh/:name — 删除 SSH 配置
func handleDeleteSSH(c *gin.Context) {
name, err := sanitizeName(c.Param("name"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid name"})
return
}
if err := os.Remove(filepath.Join(sshDir(), name+".json")); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to delete"})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "deleted"})
}
// ═══════════════════════════════════════════════════════════════════
// 快捷命令 CRUD
// ═══════════════════════════════════════════════════════════════════
func readCommands() ([]Command, error) {
raw, err := os.ReadFile(cmdFilePath())
if err != nil {
if os.IsNotExist(err) {
return []Command{}, nil
}
return nil, err
}
var cmds []Command
if err := json.Unmarshal(raw, &cmds); err != nil {
return nil, err
}
return cmds, nil
}
func writeCommands(cmds []Command) error {
raw, err := json.MarshalIndent(cmds, "", " ")
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(cmdFilePath()), 0o750); err != nil {
return err
}
return os.WriteFile(cmdFilePath(), raw, 0o600)
}
// GET /api/commands
func handleListCommands(c *gin.Context) {
cmds, err := readCommands()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read commands"})
return
}
c.JSON(http.StatusOK, gin.H{"data": cmds})
}
// POST /api/commands
func handleCreateCommand(c *gin.Context) {
var cmd Command
if err := c.ShouldBindJSON(&cmd); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
if cmd.Alias == "" || cmd.Command == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "alias 和 command 为必填项"})
return
}
cmds, err := readCommands()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read commands"})
return
}
cmds = append(cmds, cmd)
if err := writeCommands(cmds); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save commands"})
return
}
c.JSON(http.StatusOK, gin.H{"data": cmds})
}
// PUT /api/commands/:index
func handleUpdateCommand(c *gin.Context) {
idx, err := strconv.Atoi(c.Param("index"))
if err != nil || idx < 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid index"})
return
}
var cmd Command
if err := c.ShouldBindJSON(&cmd); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
if cmd.Alias == "" || cmd.Command == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "alias 和 command 为必填项"})
return
}
cmds, err := readCommands()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read commands"})
return
}
if idx >= len(cmds) {
c.JSON(http.StatusNotFound, gin.H{"error": "index out of range"})
return
}
cmds[idx] = cmd
if err := writeCommands(cmds); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save commands"})
return
}
c.JSON(http.StatusOK, gin.H{"data": cmds})
}
// DELETE /api/commands/:index
func handleDeleteCommand(c *gin.Context) {
idx, err := strconv.Atoi(c.Param("index"))
if err != nil || idx < 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid index"})
return
}
cmds, err := readCommands()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read commands"})
return
}
if idx >= len(cmds) {
c.JSON(http.StatusNotFound, gin.H{"error": "index out of range"})
return
}
cmds = append(cmds[:idx], cmds[idx+1:]...)
if err := writeCommands(cmds); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save commands"})
return
}
c.JSON(http.StatusOK, gin.H{"data": cmds})
}
// ═══════════════════════════════════════════════════════════════════
// 脚本 CRUD
// ═══════════════════════════════════════════════════════════════════
// GET /api/scripts — 列出所有脚本名称
func handleListScripts(c *gin.Context) {
entries, err := os.ReadDir(scriptDir())
if err != nil {
c.JSON(http.StatusOK, gin.H{"data": []ScriptInfo{}})
return
}
var scripts []ScriptInfo
for _, e := range entries {
if !e.IsDir() {
scripts = append(scripts, ScriptInfo{Name: e.Name()})
}
}
if scripts == nil {
scripts = []ScriptInfo{}
}
c.JSON(http.StatusOK, gin.H{"data": scripts})
}
// GET /api/scripts/:name — 获取脚本内容
func handleGetScript(c *gin.Context) {
name, err := sanitizeName(c.Param("name"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid name"})
return
}
raw, err := os.ReadFile(filepath.Join(scriptDir(), name))
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read"})
}
return
}
c.JSON(http.StatusOK, gin.H{"data": ScriptInfo{Name: name, Content: string(raw)}})
}
// POST /api/scripts — 新建脚本
func handleCreateScript(c *gin.Context) {
var s ScriptInfo
if err := c.ShouldBindJSON(&s); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
if s.Name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "name 为必填项"})
return
}
name, err := sanitizeName(s.Name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid name"})
return
}
if err := os.MkdirAll(scriptDir(), 0o750); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create dir"})
return
}
if err := os.WriteFile(filepath.Join(scriptDir(), name), []byte(s.Content), 0o640); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to write"})
return
}
c.JSON(http.StatusOK, gin.H{"data": ScriptInfo{Name: name, Content: s.Content}})
}
// PUT /api/scripts/:name — 更新脚本内容
func handleUpdateScript(c *gin.Context) {
name, err := sanitizeName(c.Param("name"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid name"})
return
}
filePath := filepath.Join(scriptDir(), name)
if _, err := os.Stat(filePath); os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
return
}
var s ScriptInfo
if err := c.ShouldBindJSON(&s); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
if err := os.WriteFile(filePath, []byte(s.Content), 0o640); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to write"})
return
}
c.JSON(http.StatusOK, gin.H{"data": ScriptInfo{Name: name, Content: s.Content}})
}
// DELETE /api/scripts/:name — 删除脚本
func handleDeleteScript(c *gin.Context) {
name, err := sanitizeName(c.Param("name"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid name"})
return
}
if err := os.Remove(filepath.Join(scriptDir(), name)); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to delete"})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "deleted"})
}