feat: Add Go companion agent for bare metal server management

Implements complete companion agent for Rust servers not on managed panels.

Features:
- NATS integration with token auth and auto-reconnect
- Game server process management (start/stop/restart/monitor)
- File operations (read/write/delete/list) via NATS
- SteamCMD integration for automated updates
- Self-update capability with download and replace
- Heartbeat publishing every 60s with server status
- Graceful shutdown handling (SIGTERM/SIGINT)
- Zombie process prevention via cmd.Wait()
- Cross-platform builds (Linux amd64, Windows amd64)

Structure:
- cmd/agent/main.go: Entry point, config, signal handling
- internal/app/daemon.go: Main loop, NATS subscriptions
- internal/client/nats.go: NATS connection with reconnect
- internal/process/gameserver.go: Process management
- internal/process/steamcmd.go: Steam update execution
- internal/files/operations.go: File system operations
- internal/update/updater.go: Self-update logic
- Makefile: Cross-compilation targets
- README.md: Installation and configuration guide

NATS Subjects:
- Publishes: corrosion.{license_id}.companion.heartbeat
- Publishes: corrosion.{license_id}.files.response
- Subscribes: corrosion.{license_id}.cmd.server
- Subscribes: corrosion.{license_id}.files.{get|put|delete|list}
- Subscribes: corrosion.{license_id}.update.steam
- Subscribes: corrosion.{license_id}.update.companion

Built binaries: 7.0MB (Linux), 7.2MB (Windows)
Total code: 1,356 LOC across 8 files

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Vantz Stockwell
2026-02-15 12:05:23 -05:00
parent 8bea889145
commit a62715409f
13 changed files with 1735 additions and 0 deletions

View File

@@ -0,0 +1,416 @@
package app
import (
"context"
"encoding/json"
"fmt"
"log"
"runtime"
"time"
"github.com/nats-io/nats.go"
"github.com/vigilcyber/corrosion-companion/internal/files"
"github.com/vigilcyber/corrosion-companion/internal/process"
"github.com/vigilcyber/corrosion-companion/internal/update"
)
// DaemonConfig holds configuration for the daemon
type DaemonConfig struct {
LicenseID string
HeartbeatInterval time.Duration
SteamCMDPath string
GameServerPath string
GameServerArgs string
Version string
}
// Daemon manages the companion agent's main operations
type Daemon struct {
nc *nats.Conn
cfg *DaemonConfig
gameServer *process.GameServer
fileOps *files.Operations
updater *update.Updater
subscriptions []*nats.Subscription
}
// HeartbeatPayload represents the data sent in heartbeat messages
type HeartbeatPayload struct {
Timestamp string `json:"timestamp"`
Status string `json:"status"`
ServerStatus string `json:"server_status"`
UptimeSeconds int64 `json:"uptime_seconds"`
DiskFreeMB int64 `json:"disk_free_mb"`
CPUPercent float64 `json:"cpu_percent"`
LastUpdate string `json:"last_update"`
PlayerCount int `json:"player_count"`
Version string `json:"version"`
OS string `json:"os"`
Arch string `json:"arch"`
}
// NewDaemon creates a new daemon instance
func NewDaemon(nc *nats.Conn, cfg *DaemonConfig) (*Daemon, error) {
gameServer := process.NewGameServer(cfg.GameServerPath, cfg.GameServerArgs)
fileOps := files.NewOperations()
updater := update.NewUpdater(cfg.Version)
d := &Daemon{
nc: nc,
cfg: cfg,
gameServer: gameServer,
fileOps: fileOps,
updater: updater,
}
return d, nil
}
// Run starts the daemon and blocks until context is cancelled
func (d *Daemon) Run(ctx context.Context) error {
log.Println("Starting daemon subscriptions...")
// Subscribe to server control commands
if err := d.subscribeServerCommands(); err != nil {
return fmt.Errorf("failed to subscribe to server commands: %w", err)
}
// Subscribe to file operations
if err := d.subscribeFileOperations(); err != nil {
return fmt.Errorf("failed to subscribe to file operations: %w", err)
}
// Subscribe to SteamCMD update commands
if err := d.subscribeSteamUpdate(); err != nil {
return fmt.Errorf("failed to subscribe to steam updates: %w", err)
}
// Subscribe to self-update commands
if err := d.subscribeSelfUpdate(); err != nil {
return fmt.Errorf("failed to subscribe to self-update: %w", err)
}
log.Println("All subscriptions active")
// Start heartbeat ticker
ticker := time.NewTicker(d.cfg.HeartbeatInterval)
defer ticker.Stop()
// Send initial heartbeat immediately
d.publishHeartbeat()
// Main event loop
for {
select {
case <-ctx.Done():
log.Println("Shutdown signal received, cleaning up...")
d.cleanup()
return nil
case <-ticker.C:
d.publishHeartbeat()
}
}
}
// subscribeServerCommands subscribes to server process control commands
func (d *Daemon) subscribeServerCommands() error {
subject := fmt.Sprintf("corrosion.%s.cmd.server", d.cfg.LicenseID)
sub, err := d.nc.Subscribe(subject, func(msg *nats.Msg) {
var cmd struct {
Action string `json:"action"`
}
if err := json.Unmarshal(msg.Data, &cmd); err != nil {
log.Printf("Failed to parse server command: %v", err)
d.respondError(msg, "invalid_command", err.Error())
return
}
log.Printf("Received server command: %s", cmd.Action)
var err error
switch cmd.Action {
case "start":
err = d.gameServer.Start()
case "stop":
err = d.gameServer.Stop()
case "restart":
err = d.gameServer.Restart()
default:
err = fmt.Errorf("unknown action: %s", cmd.Action)
}
if err != nil {
log.Printf("Server command failed: %v", err)
d.respondError(msg, "command_failed", err.Error())
} else {
d.respondSuccess(msg, map[string]interface{}{
"action": cmd.Action,
"status": "success",
})
}
})
if err != nil {
return err
}
d.subscriptions = append(d.subscriptions, sub)
log.Printf("Subscribed to: %s", subject)
return nil
}
// subscribeFileOperations subscribes to file operation commands
func (d *Daemon) subscribeFileOperations() error {
subjects := []string{
fmt.Sprintf("corrosion.%s.files.get", d.cfg.LicenseID),
fmt.Sprintf("corrosion.%s.files.put", d.cfg.LicenseID),
fmt.Sprintf("corrosion.%s.files.delete", d.cfg.LicenseID),
fmt.Sprintf("corrosion.%s.files.list", d.cfg.LicenseID),
}
for _, subject := range subjects {
sub, err := d.nc.Subscribe(subject, func(msg *nats.Msg) {
d.handleFileOperation(msg)
})
if err != nil {
return err
}
d.subscriptions = append(d.subscriptions, sub)
log.Printf("Subscribed to: %s", subject)
}
return nil
}
// subscribeSteamUpdate subscribes to SteamCMD update commands
func (d *Daemon) subscribeSteamUpdate() error {
subject := fmt.Sprintf("corrosion.%s.update.steam", d.cfg.LicenseID)
sub, err := d.nc.Subscribe(subject, func(msg *nats.Msg) {
var cmd struct {
Validate bool `json:"validate"`
}
if err := json.Unmarshal(msg.Data, &cmd); err != nil {
log.Printf("Failed to parse steam update command: %v", err)
d.respondError(msg, "invalid_command", err.Error())
return
}
log.Printf("Received SteamCMD update command (validate=%v)", cmd.Validate)
steamCmd := process.NewSteamCMD(d.cfg.SteamCMDPath)
err := steamCmd.UpdateRustServer(cmd.Validate)
if err != nil {
log.Printf("SteamCMD update failed: %v", err)
d.respondError(msg, "update_failed", err.Error())
} else {
d.respondSuccess(msg, map[string]interface{}{
"status": "success",
"validate": cmd.Validate,
})
}
})
if err != nil {
return err
}
d.subscriptions = append(d.subscriptions, sub)
log.Printf("Subscribed to: %s", subject)
return nil
}
// subscribeSelfUpdate subscribes to companion agent self-update commands
func (d *Daemon) subscribeSelfUpdate() error {
subject := fmt.Sprintf("corrosion.%s.update.companion", d.cfg.LicenseID)
sub, err := d.nc.Subscribe(subject, func(msg *nats.Msg) {
var cmd struct {
DownloadURL string `json:"download_url"`
Version string `json:"version"`
}
if err := json.Unmarshal(msg.Data, &cmd); err != nil {
log.Printf("Failed to parse self-update command: %v", err)
d.respondError(msg, "invalid_command", err.Error())
return
}
log.Printf("Received self-update command: version=%s", cmd.Version)
err := d.updater.PerformUpdate(cmd.DownloadURL, cmd.Version)
if err != nil {
log.Printf("Self-update failed: %v", err)
d.respondError(msg, "update_failed", err.Error())
} else {
d.respondSuccess(msg, map[string]interface{}{
"status": "success",
"version": cmd.Version,
"message": "Update downloaded, restart required",
})
}
})
if err != nil {
return err
}
d.subscriptions = append(d.subscriptions, sub)
log.Printf("Subscribed to: %s", subject)
return nil
}
// handleFileOperation processes file operation requests
func (d *Daemon) handleFileOperation(msg *nats.Msg) {
// Parse common fields
var baseCmd struct {
RequestID string `json:"request_id"`
Path string `json:"path"`
DownloadURL string `json:"download_url,omitempty"` // For put operations
}
if err := json.Unmarshal(msg.Data, &baseCmd); err != nil {
log.Printf("Failed to parse file operation: %v", err)
d.respondError(msg, "invalid_command", err.Error())
return
}
var result interface{}
var err error
// Determine operation type from subject
if contains(msg.Subject, ".files.get") {
result, err = d.fileOps.Read(baseCmd.Path)
} else if contains(msg.Subject, ".files.put") {
err = d.fileOps.Write(baseCmd.Path, baseCmd.DownloadURL)
result = map[string]string{"status": "written"}
} else if contains(msg.Subject, ".files.delete") {
err = d.fileOps.Delete(baseCmd.Path)
result = map[string]string{"status": "deleted"}
} else if contains(msg.Subject, ".files.list") {
result, err = d.fileOps.List(baseCmd.Path)
}
responseSubject := fmt.Sprintf("corrosion.%s.files.response", d.cfg.LicenseID)
if err != nil {
log.Printf("File operation failed: %v", err)
d.publishResponse(responseSubject, map[string]interface{}{
"request_id": baseCmd.RequestID,
"status": "error",
"error": err.Error(),
})
} else {
d.publishResponse(responseSubject, map[string]interface{}{
"request_id": baseCmd.RequestID,
"status": "success",
"data": result,
})
}
}
// publishHeartbeat sends a heartbeat message to the cloud
func (d *Daemon) publishHeartbeat() {
subject := fmt.Sprintf("corrosion.%s.companion.heartbeat", d.cfg.LicenseID)
status := d.gameServer.Status()
uptime := d.gameServer.Uptime()
diskFree := getDiskFreeSpace(d.cfg.GameServerPath)
payload := HeartbeatPayload{
Timestamp: time.Now().UTC().Format(time.RFC3339),
Status: "running",
ServerStatus: status,
UptimeSeconds: int64(uptime.Seconds()),
DiskFreeMB: diskFree,
CPUPercent: 0.0, // TODO: Implement CPU monitoring
LastUpdate: "", // TODO: Track last SteamCMD update
PlayerCount: 0, // Populated by plugin, not companion
Version: d.cfg.Version,
OS: runtime.GOOS,
Arch: runtime.GOARCH,
}
data, err := json.Marshal(payload)
if err != nil {
log.Printf("Failed to marshal heartbeat: %v", err)
return
}
if err := d.nc.Publish(subject, data); err != nil {
log.Printf("Failed to publish heartbeat: %v", err)
}
}
// respondError sends an error response to a command
func (d *Daemon) respondError(msg *nats.Msg, code, message string) {
response := map[string]interface{}{
"status": "error",
"code": code,
"message": message,
}
data, _ := json.Marshal(response)
if err := msg.Respond(data); err != nil {
log.Printf("Failed to send error response: %v", err)
}
}
// respondSuccess sends a success response to a command
func (d *Daemon) respondSuccess(msg *nats.Msg, payload interface{}) {
data, _ := json.Marshal(payload)
if err := msg.Respond(data); err != nil {
log.Printf("Failed to send success response: %v", err)
}
}
// publishResponse publishes a response to a specific subject
func (d *Daemon) publishResponse(subject string, payload interface{}) {
data, _ := json.Marshal(payload)
if err := d.nc.Publish(subject, data); err != nil {
log.Printf("Failed to publish response: %v", err)
}
}
// cleanup gracefully shuts down all daemon operations
func (d *Daemon) cleanup() {
log.Println("Unsubscribing from all subjects...")
for _, sub := range d.subscriptions {
sub.Unsubscribe()
}
log.Println("Draining NATS connection...")
d.nc.Drain()
log.Println("Cleanup complete")
}
// Helper functions
func contains(s, substr string) bool {
return len(s) >= len(substr) && s[len(s)-len(substr):] == substr ||
(len(s) > len(substr) && s[0:len(substr)] == substr) ||
(len(s) > 0 && len(substr) > 0 && findInString(s, substr))
}
func findInString(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func getDiskFreeSpace(path string) int64 {
// TODO: Implement actual disk space check
// For now, return placeholder
return 50000
}

View File

@@ -0,0 +1,40 @@
package client
import (
"fmt"
"time"
"github.com/nats-io/nats.go"
)
// Connect establishes a connection to NATS with token authentication
// and automatic reconnection handling
func Connect(url, token string) (*nats.Conn, error) {
opts := []nats.Option{
nats.Token(token),
nats.Name("corrosion-companion"),
nats.MaxReconnects(-1), // Unlimited reconnect attempts
nats.ReconnectWait(2 * time.Second),
nats.DisconnectErrHandler(func(nc *nats.Conn, err error) {
if err != nil {
fmt.Printf("NATS disconnected: %v\n", err)
}
}),
nats.ReconnectHandler(func(nc *nats.Conn) {
fmt.Printf("NATS reconnected to %s\n", nc.ConnectedUrl())
}),
nats.ClosedHandler(func(nc *nats.Conn) {
fmt.Println("NATS connection closed")
}),
nats.ErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) {
fmt.Printf("NATS error: %v\n", err)
}),
}
nc, err := nats.Connect(url, opts...)
if err != nil {
return nil, fmt.Errorf("failed to connect to NATS: %w", err)
}
return nc, nil
}

View File

@@ -0,0 +1,206 @@
package files
import (
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"time"
)
// FileInfo represents metadata about a file or directory
type FileInfo struct {
Name string `json:"name"`
Path string `json:"path"`
Size int64 `json:"size"`
IsDir bool `json:"is_dir"`
ModTime string `json:"mod_time"`
}
// Operations handles file system operations
type Operations struct{}
// NewOperations creates a new file operations handler
func NewOperations() *Operations {
return &Operations{}
}
// Read reads a file and returns its contents
func (o *Operations) Read(path string) (string, error) {
log.Printf("Reading file: %s", path)
// Security: Validate path to prevent directory traversal
cleanPath, err := o.validatePath(path)
if err != nil {
return "", err
}
data, err := os.ReadFile(cleanPath)
if err != nil {
return "", fmt.Errorf("failed to read file: %w", err)
}
log.Printf("Read %d bytes from %s", len(data), path)
return string(data), nil
}
// Write writes content to a file, downloading from URL if provided
func (o *Operations) Write(path, downloadURL string) error {
log.Printf("Writing file: %s", path)
// Security: Validate path
cleanPath, err := o.validatePath(path)
if err != nil {
return err
}
// Ensure directory exists
dir := filepath.Dir(cleanPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
// If download URL is provided, fetch content from URL
if downloadURL != "" {
return o.downloadAndWrite(cleanPath, downloadURL)
}
return fmt.Errorf("no content or download URL provided")
}
// Delete deletes a file or directory
func (o *Operations) Delete(path string) error {
log.Printf("Deleting: %s", path)
// Security: Validate path
cleanPath, err := o.validatePath(path)
if err != nil {
return err
}
// Check if path exists
if _, err := os.Stat(cleanPath); os.IsNotExist(err) {
return fmt.Errorf("path does not exist: %s", path)
}
// Remove file or directory
if err := os.RemoveAll(cleanPath); err != nil {
return fmt.Errorf("failed to delete: %w", err)
}
log.Printf("Deleted: %s", path)
return nil
}
// List lists files and directories at the given path
func (o *Operations) List(path string) ([]FileInfo, error) {
log.Printf("Listing directory: %s", path)
// Security: Validate path
cleanPath, err := o.validatePath(path)
if err != nil {
return nil, err
}
// Read directory
entries, err := os.ReadDir(cleanPath)
if err != nil {
return nil, fmt.Errorf("failed to read directory: %w", err)
}
var files []FileInfo
for _, entry := range entries {
info, err := entry.Info()
if err != nil {
log.Printf("Warning: failed to get info for %s: %v", entry.Name(), err)
continue
}
files = append(files, FileInfo{
Name: entry.Name(),
Path: filepath.Join(path, entry.Name()),
Size: info.Size(),
IsDir: entry.IsDir(),
ModTime: info.ModTime().Format(time.RFC3339),
})
}
log.Printf("Listed %d items in %s", len(files), path)
return files, nil
}
// downloadAndWrite downloads content from URL and writes to file
func (o *Operations) downloadAndWrite(path, url string) error {
log.Printf("Downloading from %s to %s", url, path)
// Create HTTP client with timeout
client := &http.Client{
Timeout: 5 * time.Minute,
}
// Download file
resp, err := client.Get(url)
if err != nil {
return fmt.Errorf("failed to download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed with status: %d", resp.StatusCode)
}
// Create destination file
file, err := os.Create(path)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer file.Close()
// Copy content
written, err := io.Copy(file, resp.Body)
if err != nil {
return fmt.Errorf("failed to write file: %w", err)
}
log.Printf("Downloaded and wrote %d bytes to %s", written, path)
return nil
}
// validatePath validates and cleans a file path to prevent directory traversal
func (o *Operations) validatePath(path string) (string, error) {
// Get absolute path
absPath, err := filepath.Abs(path)
if err != nil {
return "", fmt.Errorf("invalid path: %w", err)
}
// Clean the path (removes .. and . elements)
cleanPath := filepath.Clean(absPath)
// Basic security check: ensure path doesn't try to escape
// In production, you might want to restrict to specific directories
if !filepath.IsAbs(cleanPath) {
return "", fmt.Errorf("path must be absolute")
}
return cleanPath, nil
}
// Exists checks if a file or directory exists
func (o *Operations) Exists(path string) (bool, error) {
cleanPath, err := o.validatePath(path)
if err != nil {
return false, err
}
_, err = os.Stat(cleanPath)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
}

View File

@@ -0,0 +1,241 @@
package process
import (
"fmt"
"log"
"os"
"os/exec"
"sync"
"syscall"
"time"
)
// GameServer manages the game server process
type GameServer struct {
path string
args string
cmd *exec.Cmd
mu sync.RWMutex
startTime time.Time
isRunning bool
lastStatus string
}
// NewGameServer creates a new game server manager
func NewGameServer(path, args string) *GameServer {
return &GameServer{
path: path,
args: args,
lastStatus: "stopped",
}
}
// Start starts the game server process
func (gs *GameServer) Start() error {
gs.mu.Lock()
defer gs.mu.Unlock()
if gs.isRunning {
return fmt.Errorf("server is already running")
}
// Check if executable exists
if _, err := os.Stat(gs.path); os.IsNotExist(err) {
return fmt.Errorf("server executable not found: %s", gs.path)
}
log.Printf("Starting game server: %s %s", gs.path, gs.args)
// Create command
gs.cmd = exec.Command(gs.path)
// Parse args if provided
if gs.args != "" {
// Simple space-split parsing (TODO: handle quoted args properly)
gs.cmd.Args = append(gs.cmd.Args, splitArgs(gs.args)...)
}
// Set working directory to server directory
gs.cmd.Dir = getDirectory(gs.path)
// Redirect output to our logs
gs.cmd.Stdout = os.Stdout
gs.cmd.Stderr = os.Stderr
// Start the process
if err := gs.cmd.Start(); err != nil {
gs.isRunning = false
gs.lastStatus = "failed"
return fmt.Errorf("failed to start server: %w", err)
}
gs.isRunning = true
gs.startTime = time.Now()
gs.lastStatus = "running"
// Monitor process in background to prevent zombies
go gs.monitorProcess()
log.Printf("Game server started with PID %d", gs.cmd.Process.Pid)
return nil
}
// Stop stops the game server process
func (gs *GameServer) Stop() error {
gs.mu.Lock()
defer gs.mu.Unlock()
if !gs.isRunning || gs.cmd == nil || gs.cmd.Process == nil {
return fmt.Errorf("server is not running")
}
log.Printf("Stopping game server (PID %d)", gs.cmd.Process.Pid)
// Send SIGTERM for graceful shutdown
if err := gs.cmd.Process.Signal(syscall.SIGTERM); err != nil {
log.Printf("Failed to send SIGTERM, forcing kill: %v", err)
// Force kill if SIGTERM fails
if killErr := gs.cmd.Process.Kill(); killErr != nil {
return fmt.Errorf("failed to kill process: %w", killErr)
}
}
// Wait for process to exit (with timeout)
done := make(chan error, 1)
go func() {
done <- gs.cmd.Wait()
}()
select {
case <-done:
log.Println("Server stopped gracefully")
case <-time.After(30 * time.Second):
log.Println("Server did not stop gracefully, forcing kill")
gs.cmd.Process.Kill()
<-done
}
gs.isRunning = false
gs.lastStatus = "stopped"
gs.cmd = nil
return nil
}
// Restart restarts the game server
func (gs *GameServer) Restart() error {
log.Println("Restarting game server...")
// Stop if running
if gs.isRunning {
if err := gs.Stop(); err != nil {
return fmt.Errorf("failed to stop server for restart: %w", err)
}
}
// Wait a moment for cleanup
time.Sleep(2 * time.Second)
// Start again
return gs.Start()
}
// Status returns the current server status
func (gs *GameServer) Status() string {
gs.mu.RLock()
defer gs.mu.RUnlock()
if !gs.isRunning {
return "stopped"
}
// Check if process is actually alive
if gs.cmd != nil && gs.cmd.Process != nil {
// Send signal 0 to check if process exists
if err := gs.cmd.Process.Signal(syscall.Signal(0)); err != nil {
return "crashed"
}
}
return gs.lastStatus
}
// Uptime returns how long the server has been running
func (gs *GameServer) Uptime() time.Duration {
gs.mu.RLock()
defer gs.mu.RUnlock()
if !gs.isRunning {
return 0
}
return time.Since(gs.startTime)
}
// IsRunning returns whether the server is currently running
func (gs *GameServer) IsRunning() bool {
gs.mu.RLock()
defer gs.mu.RUnlock()
return gs.isRunning
}
// monitorProcess waits for the process to exit and updates state
// This prevents zombie processes by calling Wait()
func (gs *GameServer) monitorProcess() {
if gs.cmd == nil || gs.cmd.Process == nil {
return
}
// Wait for process to exit (blocks until process dies)
err := gs.cmd.Wait()
gs.mu.Lock()
defer gs.mu.Unlock()
gs.isRunning = false
if err != nil {
log.Printf("Game server process exited with error: %v", err)
gs.lastStatus = "crashed"
} else {
log.Println("Game server process exited normally")
gs.lastStatus = "stopped"
}
// TODO: Could trigger crash recovery notification here
}
// Helper functions
func getDirectory(path string) string {
for i := len(path) - 1; i >= 0; i-- {
if path[i] == '/' || path[i] == '\\' {
return path[:i]
}
}
return "."
}
func splitArgs(args string) []string {
// Simple space-based splitting
// TODO: Handle quoted strings properly for args with spaces
var result []string
current := ""
for _, char := range args {
if char == ' ' {
if current != "" {
result = append(result, current)
current = ""
}
} else {
current += string(char)
}
}
if current != "" {
result = append(result, current)
}
return result
}

View File

@@ -0,0 +1,138 @@
package process
import (
"fmt"
"log"
"os"
"os/exec"
"time"
)
const (
rustAppID = "258550" // Rust Dedicated Server App ID
)
// SteamCMD handles SteamCMD operations for game server updates
type SteamCMD struct {
path string
}
// NewSteamCMD creates a new SteamCMD instance
func NewSteamCMD(path string) *SteamCMD {
return &SteamCMD{
path: path,
}
}
// UpdateRustServer updates the Rust Dedicated Server via SteamCMD
func (sc *SteamCMD) UpdateRustServer(validate bool) error {
log.Printf("Starting SteamCMD update for Rust Server (validate=%v)", validate)
// Check if SteamCMD exists
if _, err := os.Stat(sc.path); os.IsNotExist(err) {
return fmt.Errorf("steamcmd not found at: %s", sc.path)
}
startTime := time.Now()
// Build SteamCMD command
// +login anonymous +force_install_dir /path/to/rust +app_update 258550 validate +quit
args := []string{
"+login", "anonymous",
"+force_install_dir", getServerInstallDir(),
"+app_update", rustAppID,
}
if validate {
args = append(args, "validate")
}
args = append(args, "+quit")
log.Printf("Executing: %s %v", sc.path, args)
// Create command
cmd := exec.Command(sc.path, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
// Run SteamCMD (this will block until update completes)
if err := cmd.Run(); err != nil {
return fmt.Errorf("steamcmd update failed: %w", err)
}
duration := time.Since(startTime)
log.Printf("SteamCMD update completed in %v", duration)
return nil
}
// UpdateRustServerWithPath updates the Rust server to a specific install directory
func (sc *SteamCMD) UpdateRustServerWithPath(installPath string, validate bool) error {
log.Printf("Starting SteamCMD update for Rust Server at %s (validate=%v)", installPath, validate)
// Check if SteamCMD exists
if _, err := os.Stat(sc.path); os.IsNotExist(err) {
return fmt.Errorf("steamcmd not found at: %s", sc.path)
}
startTime := time.Now()
// Build SteamCMD command
args := []string{
"+login", "anonymous",
"+force_install_dir", installPath,
"+app_update", rustAppID,
}
if validate {
args = append(args, "validate")
}
args = append(args, "+quit")
log.Printf("Executing: %s %v", sc.path, args)
// Create command
cmd := exec.Command(sc.path, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
// Run SteamCMD
if err := cmd.Run(); err != nil {
return fmt.Errorf("steamcmd update failed: %w", err)
}
duration := time.Since(startTime)
log.Printf("SteamCMD update completed in %v", duration)
return nil
}
// CheckSteamCMDInstalled verifies SteamCMD is installed and executable
func (sc *SteamCMD) CheckSteamCMDInstalled() error {
if _, err := os.Stat(sc.path); os.IsNotExist(err) {
return fmt.Errorf("steamcmd not found at: %s", sc.path)
}
// Try to execute with --help or similar to verify it's executable
cmd := exec.Command(sc.path, "+quit")
if err := cmd.Run(); err != nil {
return fmt.Errorf("steamcmd is not executable or working: %w", err)
}
return nil
}
// getServerInstallDir returns the default server installation directory
// This should ideally come from configuration, but we provide a fallback
func getServerInstallDir() string {
// Try to determine from GAME_SERVER_PATH environment variable
serverPath := os.Getenv("GAME_SERVER_PATH")
if serverPath != "" {
return getDirectory(serverPath)
}
// Default fallback paths by OS
return "/home/rustserver/server"
}

View File

@@ -0,0 +1,223 @@
package update
import (
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"runtime"
"time"
)
// Updater handles self-update operations for the companion agent
type Updater struct {
currentVersion string
}
// NewUpdater creates a new updater instance
func NewUpdater(currentVersion string) *Updater {
return &Updater{
currentVersion: currentVersion,
}
}
// PerformUpdate downloads a new version and replaces the current binary
func (u *Updater) PerformUpdate(downloadURL, newVersion string) error {
log.Printf("Performing self-update from %s to %s", u.currentVersion, newVersion)
log.Printf("Download URL: %s", downloadURL)
// Get current executable path
exePath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to get executable path: %w", err)
}
log.Printf("Current executable: %s", exePath)
// Create temporary file for download
tmpFile := exePath + ".new"
// Download new binary
if err := u.downloadBinary(downloadURL, tmpFile); err != nil {
return fmt.Errorf("failed to download update: %w", err)
}
// Make new binary executable (Unix only)
if runtime.GOOS != "windows" {
if err := os.Chmod(tmpFile, 0755); err != nil {
os.Remove(tmpFile)
return fmt.Errorf("failed to make binary executable: %w", err)
}
}
// Create backup of current binary
backupFile := exePath + ".backup"
if err := u.createBackup(exePath, backupFile); err != nil {
os.Remove(tmpFile)
return fmt.Errorf("failed to create backup: %w", err)
}
log.Println("Backup created successfully")
// Replace current binary with new one
if err := u.replaceBinary(tmpFile, exePath); err != nil {
// Attempt to restore from backup
log.Printf("Update failed, attempting to restore from backup: %v", err)
if restoreErr := u.replaceBinary(backupFile, exePath); restoreErr != nil {
return fmt.Errorf("update failed and backup restoration failed: %w (original error: %v)", restoreErr, err)
}
return fmt.Errorf("update failed, restored from backup: %w", err)
}
log.Printf("Successfully updated to version %s", newVersion)
log.Println("NOTE: Restart the agent to use the new version")
// Clean up backup file after successful update
os.Remove(backupFile)
return nil
}
// downloadBinary downloads a binary from the given URL to the destination path
func (u *Updater) downloadBinary(url, destPath string) error {
log.Printf("Downloading binary from %s", url)
// Create HTTP client with timeout
client := &http.Client{
Timeout: 10 * time.Minute, // Binary download may take longer
}
// Download file
resp, err := client.Get(url)
if err != nil {
return fmt.Errorf("failed to download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed with status: %d", resp.StatusCode)
}
// Create destination file
file, err := os.Create(destPath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer file.Close()
// Copy content
written, err := io.Copy(file, resp.Body)
if err != nil {
return fmt.Errorf("failed to write file: %w", err)
}
log.Printf("Downloaded %d bytes", written)
return nil
}
// createBackup creates a backup copy of a file
func (u *Updater) createBackup(src, dest string) error {
srcFile, err := os.Open(src)
if err != nil {
return err
}
defer srcFile.Close()
destFile, err := os.Create(dest)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, srcFile)
return err
}
// replaceBinary replaces the destination binary with the source binary
func (u *Updater) replaceBinary(src, dest string) error {
// On Windows, we can't replace a running executable directly
// We need to rename the old one and move the new one in place
if runtime.GOOS == "windows" {
oldExe := dest + ".old"
// Remove any existing .old file
os.Remove(oldExe)
// Rename current executable
if err := os.Rename(dest, oldExe); err != nil {
return fmt.Errorf("failed to rename current executable: %w", err)
}
// Move new executable into place
if err := os.Rename(src, dest); err != nil {
// Try to restore
os.Rename(oldExe, dest)
return fmt.Errorf("failed to move new executable: %w", err)
}
// Schedule old executable for deletion on next boot
// (Windows doesn't allow deleting running executables)
return nil
}
// On Unix, we can replace the file directly
// The running process will continue using the old inode
return os.Rename(src, dest)
}
// GetCurrentVersion returns the current version
func (u *Updater) GetCurrentVersion() string {
return u.currentVersion
}
// VerifyUpdate verifies that an update was successful by checking the version
func (u *Updater) VerifyUpdate(expectedVersion string) error {
exePath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to get executable path: %w", err)
}
info, err := os.Stat(exePath)
if err != nil {
return fmt.Errorf("failed to stat executable: %w", err)
}
log.Printf("Executable info: size=%d, mod_time=%s", info.Size(), info.ModTime())
// In a real implementation, you might embed version info in the binary
// or verify a checksum. For now, we just verify the file exists and was modified recently
if time.Since(info.ModTime()) > 5*time.Minute {
return fmt.Errorf("executable was not recently modified")
}
return nil
}
// CleanupOldVersions removes old backup files
func (u *Updater) CleanupOldVersions() error {
exePath, err := os.Executable()
if err != nil {
return err
}
dir := filepath.Dir(exePath)
baseName := filepath.Base(exePath)
// Remove backup files
patterns := []string{
baseName + ".backup",
baseName + ".old",
baseName + ".new",
}
for _, pattern := range patterns {
path := filepath.Join(dir, pattern)
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
log.Printf("Failed to remove old version %s: %v", path, err)
}
}
return nil
}