diff --git a/cmd/goupkeep/main.go b/cmd/goupkeep/main.go index adec38f..b2529ae 100644 --- a/cmd/goupkeep/main.go +++ b/cmd/goupkeep/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "flag" "fmt" "go-upkeep/internal/cluster" @@ -17,6 +18,7 @@ import ( "os/signal" "strconv" "syscall" + "time" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/ssh" @@ -225,6 +227,7 @@ func runServe(args []string) { fmt.Printf("Database connection error: %v\n", dbErr) os.Exit(1) } + defer s.Close() if err := s.Init(); err != nil { fmt.Printf("Database init error: %v\n", err) @@ -263,7 +266,7 @@ func runServe(args []string) { eng.InitLogs() eng.Start(ctx) - server.Start(server.ServerConfig{ + httpSrv := server.Start(server.ServerConfig{ Port: httpPort, EnableStatus: enableStatus, Title: statusTitle, @@ -276,7 +279,7 @@ func runServe(args []string) { SharedKey: clusterKey, }, eng) - startSSHServer(*port, s, eng) + sshSrv := startSSHServer(*port, s, eng) if isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd()) { p := tea.NewProgram(tui.InitialModel(true, s, eng), tea.WithAltScreen(), tea.WithMouseCellMotion()) @@ -291,9 +294,22 @@ func runServe(args []string) { fmt.Println("Shutting down...") } cancel() + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer shutdownCancel() + if httpSrv != nil { + if err := httpSrv.Shutdown(shutdownCtx); err != nil { + log.Printf("HTTP shutdown error: %v", err) + } + } + if sshSrv != nil { + if err := sshSrv.Shutdown(shutdownCtx); err != nil { + log.Printf("SSH shutdown error: %v", err) + } + } } -func startSSHServer(port int, db store.Store, eng *monitor.Engine) { +func startSSHServer(port int, db store.Store, eng *monitor.Engine) *ssh.Server { s, err := wish.NewServer( wish.WithAddress(fmt.Sprintf(":%d", port)), wish.WithHostKeyPath(".ssh/id_ed25519"), @@ -308,13 +324,14 @@ func startSSHServer(port int, db store.Store, eng *monitor.Engine) { ) if err != nil { fmt.Printf("SSH server error: %v\n", err) - return + return nil } go func() { - if err := s.ListenAndServe(); err != nil { - log.Fatalf("SSH server failed: %v", err) + if err := s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { + log.Printf("SSH server error: %v", err) } }() + return s } func seedDemoData(s store.Store) {