diff --git a/cmd/goupkeep/main.go b/cmd/goupkeep/main.go index adec38f..b2529ae 100644 --- a/cmd/goupkeep/main.go +++ b/cmd/goupkeep/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "flag" "fmt" "go-upkeep/internal/cluster" @@ -17,6 +18,7 @@ import ( "os/signal" "strconv" "syscall" + "time" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/ssh" @@ -225,6 +227,7 @@ func runServe(args []string) { fmt.Printf("Database connection error: %v\n", dbErr) os.Exit(1) } + defer s.Close() if err := s.Init(); err != nil { fmt.Printf("Database init error: %v\n", err) @@ -263,7 +266,7 @@ func runServe(args []string) { eng.InitLogs() eng.Start(ctx) - server.Start(server.ServerConfig{ + httpSrv := server.Start(server.ServerConfig{ Port: httpPort, EnableStatus: enableStatus, Title: statusTitle, @@ -276,7 +279,7 @@ func runServe(args []string) { SharedKey: clusterKey, }, eng) - startSSHServer(*port, s, eng) + sshSrv := startSSHServer(*port, s, eng) if isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd()) { p := tea.NewProgram(tui.InitialModel(true, s, eng), tea.WithAltScreen(), tea.WithMouseCellMotion()) @@ -291,9 +294,22 @@ func runServe(args []string) { fmt.Println("Shutting down...") } cancel() + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer shutdownCancel() + if httpSrv != nil { + if err := httpSrv.Shutdown(shutdownCtx); err != nil { + log.Printf("HTTP shutdown error: %v", err) + } + } + if sshSrv != nil { + if err := sshSrv.Shutdown(shutdownCtx); err != nil { + log.Printf("SSH shutdown error: %v", err) + } + } } -func startSSHServer(port int, db store.Store, eng *monitor.Engine) { +func startSSHServer(port int, db store.Store, eng *monitor.Engine) *ssh.Server { s, err := wish.NewServer( wish.WithAddress(fmt.Sprintf(":%d", port)), wish.WithHostKeyPath(".ssh/id_ed25519"), @@ -308,13 +324,14 @@ func startSSHServer(port int, db store.Store, eng *monitor.Engine) { ) if err != nil { fmt.Printf("SSH server error: %v\n", err) - return + return nil } go func() { - if err := s.ListenAndServe(); err != nil { - log.Fatalf("SSH server failed: %v", err) + if err := s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { + log.Printf("SSH server error: %v", err) } }() + return s } func seedDemoData(s store.Store) { diff --git a/internal/alert/alert.go b/internal/alert/alert.go index c60013f..12a2f41 100644 --- a/internal/alert/alert.go +++ b/internal/alert/alert.go @@ -2,6 +2,7 @@ package alert import ( "bytes" + "context" "encoding/json" "fmt" "go-upkeep/internal/models" @@ -15,7 +16,7 @@ import ( var alertClient = &http.Client{Timeout: 10 * time.Second} type Provider interface { - Send(title, message string) error + Send(ctx context.Context, title, message string) error } type PayloadFunc func(title, message string) ([]byte, error) @@ -25,12 +26,17 @@ type HTTPProvider struct { Payload PayloadFunc } -func (h *HTTPProvider) Send(title, message string) error { +func (h *HTTPProvider) Send(ctx context.Context, title, message string) error { body, err := h.Payload(title, message) if err != nil { return err } - resp, err := alertClient.Post(h.URL, "application/json", bytes.NewBuffer(body)) + req, err := http.NewRequestWithContext(ctx, "POST", h.URL, bytes.NewBuffer(body)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + resp, err := alertClient.Do(req) if err != nil { return err } @@ -170,7 +176,12 @@ type EmailProvider struct { Host, Port, User, Pass, To, From string } -func (e *EmailProvider) Send(title, message string) error { +func (e *EmailProvider) Send(ctx context.Context, title, message string) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } auth := smtp.PlainAuth("", e.User, e.Pass, e.Host) msg := []byte("To: " + e.To + "\r\n" + "Subject: Go-Upkeep: " + title + "\r\n" + @@ -187,9 +198,9 @@ type NtfyProvider struct { Password string } -func (n *NtfyProvider) Send(title, message string) error { +func (n *NtfyProvider) Send(ctx context.Context, title, message string) error { url := strings.TrimRight(n.ServerURL, "/") + "/" + n.Topic - req, err := http.NewRequest("POST", url, strings.NewReader(message)) + req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(message)) if err != nil { return err } diff --git a/internal/alert/alert_test.go b/internal/alert/alert_test.go index 35e1c8d..3314d17 100644 --- a/internal/alert/alert_test.go +++ b/internal/alert/alert_test.go @@ -1,6 +1,7 @@ package alert import ( + "context" "encoding/json" "go-upkeep/internal/models" "net/http" @@ -17,7 +18,7 @@ func TestHTTPProviderDiscord(t *testing.T) { defer srv.Close() p := GetProvider(models.AlertConfig{Type: "discord", Settings: map[string]string{"url": srv.URL}}) - if err := p.Send("Test Title", "Test Body"); err != nil { + if err := p.Send(context.Background(), "Test Title", "Test Body"); err != nil { t.Fatalf("Send: %v", err) } @@ -35,7 +36,7 @@ func TestHTTPProviderSlack(t *testing.T) { defer srv.Close() p := GetProvider(models.AlertConfig{Type: "slack", Settings: map[string]string{"url": srv.URL}}) - if err := p.Send("Alert", "Message"); err != nil { + if err := p.Send(context.Background(), "Alert", "Message"); err != nil { t.Fatalf("Send: %v", err) } @@ -53,7 +54,7 @@ func TestHTTPProviderWebhook(t *testing.T) { defer srv.Close() p := GetProvider(models.AlertConfig{Type: "webhook", Settings: map[string]string{"url": srv.URL}}) - if err := p.Send("Title", "Body"); err != nil { + if err := p.Send(context.Background(), "Title", "Body"); err != nil { t.Fatalf("Send: %v", err) } @@ -69,7 +70,7 @@ func TestHTTPProviderErrorOnHTTP4xx(t *testing.T) { defer srv.Close() p := GetProvider(models.AlertConfig{Type: "discord", Settings: map[string]string{"url": srv.URL}}) - if err := p.Send("Test", "Test"); err == nil { + if err := p.Send(context.Background(), "Test", "Test"); err == nil { t.Fatal("expected error on 403 response") } } @@ -89,7 +90,7 @@ func TestNtfyProvider(t *testing.T) { "url": srv.URL, "topic": "test", }}) - if err := p.Send("Alert Title", "Alert Body"); err != nil { + if err := p.Send(context.Background(), "Alert Title", "Alert Body"); err != nil { t.Fatalf("Send: %v", err) } @@ -110,7 +111,7 @@ func TestHTTPProviderTelegram(t *testing.T) { defer srv.Close() p := &HTTPProvider{URL: srv.URL, Payload: telegramPayload("12345")} - if err := p.Send("Alert", "Down"); err != nil { + if err := p.Send(context.Background(), "Alert", "Down"); err != nil { t.Fatalf("Send: %v", err) } if received["chat_id"] != "12345" { @@ -133,7 +134,7 @@ func TestHTTPProviderPagerDuty(t *testing.T) { defer srv.Close() p := &HTTPProvider{URL: srv.URL, Payload: pagerdutyPayload("test-key", "critical")} - if err := p.Send("Alert", "Down"); err != nil { + if err := p.Send(context.Background(), "Alert", "Down"); err != nil { t.Fatalf("Send: %v", err) } if received["routing_key"] != "test-key" { @@ -160,7 +161,7 @@ func TestHTTPProviderPushover(t *testing.T) { defer srv.Close() p := &HTTPProvider{URL: srv.URL, Payload: pushoverPayload("app-tok", "user-key")} - if err := p.Send("Alert", "Down"); err != nil { + if err := p.Send(context.Background(), "Alert", "Down"); err != nil { t.Fatalf("Send: %v", err) } if received["token"] != "app-tok" { @@ -183,7 +184,7 @@ func TestHTTPProviderGotify(t *testing.T) { defer srv.Close() p := &HTTPProvider{URL: srv.URL, Payload: gotifyPayload("8")} - if err := p.Send("Alert", "Down"); err != nil { + if err := p.Send(context.Background(), "Alert", "Down"); err != nil { t.Fatalf("Send: %v", err) } if received["title"] != "Alert" || received["message"] != "Down" { diff --git a/internal/cluster/cluster_test.go b/internal/cluster/cluster_test.go new file mode 100644 index 0000000..5f062f9 --- /dev/null +++ b/internal/cluster/cluster_test.go @@ -0,0 +1,395 @@ +package cluster + +import ( + "context" + "encoding/json" + "go-upkeep/internal/models" + "go-upkeep/internal/monitor" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" +) + +// --- Mock Store (minimal, for monitor.NewEngine) --- + +type mockStore struct { + mu sync.Mutex + sites []models.Site +} + +func (m *mockStore) Init() error { return nil } +func (m *mockStore) GetSites() ([]models.Site, error) { return m.sites, nil } +func (m *mockStore) AddSite(models.Site) error { return nil } +func (m *mockStore) UpdateSite(models.Site) error { return nil } +func (m *mockStore) UpdateSitePaused(int, bool) error { return nil } +func (m *mockStore) DeleteSite(int) error { return nil } +func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { return nil, nil } +func (m *mockStore) GetAlert(int) (models.AlertConfig, error) { return models.AlertConfig{}, nil } +func (m *mockStore) AddAlert(string, string, map[string]string) error { return nil } +func (m *mockStore) UpdateAlert(int, string, string, map[string]string) error { return nil } +func (m *mockStore) DeleteAlert(int) error { return nil } +func (m *mockStore) GetAllUsers() ([]models.User, error) { return nil, nil } +func (m *mockStore) AddUser(string, string, string) error { return nil } +func (m *mockStore) UpdateUser(int, string, string, string) error { return nil } +func (m *mockStore) DeleteUser(int) error { return nil } +func (m *mockStore) SaveCheck(int, int64, bool) error { return nil } +func (m *mockStore) SaveCheckFromNode(int, string, int64, bool) error { return nil } +func (m *mockStore) LoadAllHistory(int) (map[int][]models.CheckRecord, error) { return nil, nil } +func (m *mockStore) ExportData() (models.Backup, error) { return models.Backup{}, nil } +func (m *mockStore) ImportData(models.Backup) error { return nil } +func (m *mockStore) GetSiteByName(string) (models.Site, error) { return models.Site{}, nil } +func (m *mockStore) GetAlertByName(string) (models.AlertConfig, error) { + return models.AlertConfig{}, nil +} +func (m *mockStore) AddSiteReturningID(models.Site) (int, error) { return 0, nil } +func (m *mockStore) AddAlertReturningID(string, string, map[string]string) (int, error) { + return 0, nil +} +func (m *mockStore) RegisterNode(models.ProbeNode) error { return nil } +func (m *mockStore) GetNode(string) (models.ProbeNode, error) { return models.ProbeNode{}, nil } +func (m *mockStore) GetAllNodes() ([]models.ProbeNode, error) { return nil, nil } +func (m *mockStore) UpdateNodeLastSeen(string) error { return nil } +func (m *mockStore) DeleteNode(string) error { return nil } +func (m *mockStore) SaveLog(string) error { return nil } +func (m *mockStore) LoadLogs(int) ([]string, error) { return nil, nil } +func (m *mockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) { + return nil, nil +} +func (m *mockStore) GetAllMaintenanceWindows(int) ([]models.MaintenanceWindow, error) { + return nil, nil +} +func (m *mockStore) AddMaintenanceWindow(models.MaintenanceWindow) error { return nil } +func (m *mockStore) EndMaintenanceWindow(int) error { return nil } +func (m *mockStore) DeleteMaintenanceWindow(int) error { return nil } +func (m *mockStore) IsMonitorInMaintenance(int) (bool, error) { return false, nil } +func (m *mockStore) GetPreference(string) (string, error) { return "", nil } +func (m *mockStore) SetPreference(string, string) error { return nil } +func (m *mockStore) Close() error { return nil } + +// --- Cluster Start Tests --- + +func TestStart_LeaderMode(t *testing.T) { + eng := monitor.NewEngine(&mockStore{}) + eng.SetActive(false) + + ctx := context.Background() + Start(ctx, Config{Mode: "leader"}, eng) + + if !eng.IsActive() { + t.Error("leader mode should set engine active") + } +} + +func TestStart_FollowerMode(t *testing.T) { + eng := monitor.NewEngine(&mockStore{}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + Start(ctx, Config{Mode: "follower", PeerURL: "http://localhost:9999"}, eng) + time.Sleep(50 * time.Millisecond) + + if eng.IsActive() { + t.Error("follower mode should set engine inactive") + } +} + +// --- Follower Loop Tests --- + +func TestFollowerLoop_FailoverOnLeaderDown(t *testing.T) { + eng := monitor.NewEngine(&mockStore{}) + eng.SetActive(false) + + // Server always returns 503 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(503) + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go runFollowerLoop(ctx, Config{PeerURL: srv.URL, SharedKey: "key"}, eng) + + // Follower checks every 5s, needs 3 failures → ~15s minimum + // But we can't wait that long in a test. The loop sleeps 5s between checks. + // We'll wait up to 20s for failover. + deadline := time.After(20 * time.Second) + for { + if eng.IsActive() { + return // success + } + select { + case <-deadline: + t.Fatal("expected failover to ACTIVE after 3 failures") + case <-time.After(500 * time.Millisecond): + } + } +} + +func TestFollowerLoop_RecoveryOnLeaderReturn(t *testing.T) { + eng := monitor.NewEngine(&mockStore{}) + eng.SetActive(true) // simulate already failed over + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("OK")) + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go runFollowerLoop(ctx, Config{PeerURL: srv.URL}, eng) + + deadline := time.After(10 * time.Second) + for { + if !eng.IsActive() { + return // success — switched back to passive + } + select { + case <-deadline: + t.Fatal("expected switch back to PASSIVE when leader returns") + case <-time.After(500 * time.Millisecond): + } + } +} + +func TestFollowerLoop_SendsSecret(t *testing.T) { + var mu sync.Mutex + var receivedSecret string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + receivedSecret = r.Header.Get("X-Upkeep-Secret") + mu.Unlock() + w.WriteHeader(200) + w.Write([]byte("OK")) + })) + defer srv.Close() + + eng := monitor.NewEngine(&mockStore{}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go runFollowerLoop(ctx, Config{PeerURL: srv.URL, SharedKey: "test-secret"}, eng) + + deadline := time.After(10 * time.Second) + for { + mu.Lock() + got := receivedSecret + mu.Unlock() + if got == "test-secret" { + return + } + select { + case <-deadline: + t.Fatalf("expected secret 'test-secret', got %q", got) + case <-time.After(500 * time.Millisecond): + } + } +} + +func TestFollowerLoop_CancelContext(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer srv.Close() + + eng := monitor.NewEngine(&mockStore{}) + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan struct{}) + go func() { + runFollowerLoop(ctx, Config{PeerURL: srv.URL}, eng) + close(done) + }() + + cancel() + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("expected follower loop to exit on context cancel") + } +} + +// --- Probe Tests --- + +func TestProbeRegister_Success(t *testing.T) { + var received map[string]string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&received) + w.WriteHeader(200) + })) + defer srv.Close() + + err := probeRegister(context.Background(), srv.Client(), ProbeConfig{ + NodeID: "n1", NodeName: "US East", Region: "us-east", LeaderURL: srv.URL, SharedKey: "key", + }) + if err != nil { + t.Fatalf("register: %v", err) + } + if received["id"] != "n1" { + t.Errorf("expected id n1, got %s", received["id"]) + } +} + +func TestProbeRegister_Failure(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + })) + defer srv.Close() + + err := probeRegister(context.Background(), srv.Client(), ProbeConfig{ + LeaderURL: srv.URL, + }) + if err == nil { + t.Error("expected error on 401") + } +} + +func TestProbeFetchAssignments_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string][]models.Site{ + "sites": {{ID: 1, Name: "s1", Type: "http", URL: "http://example.com"}}, + }) + })) + defer srv.Close() + + sites, err := probeFetchAssignments(context.Background(), srv.Client(), ProbeConfig{ + NodeID: "n1", LeaderURL: srv.URL, SharedKey: "key", + }) + if err != nil { + t.Fatalf("fetch: %v", err) + } + if len(sites) != 1 { + t.Errorf("expected 1 site, got %d", len(sites)) + } +} + +func TestProbeFetchAssignments_Unauthorized(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + })) + defer srv.Close() + + _, err := probeFetchAssignments(context.Background(), srv.Client(), ProbeConfig{ + LeaderURL: srv.URL, + }) + if err == nil { + t.Error("expected error on 401") + } +} + +func TestProbeExecuteChecks(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer srv.Close() + + sites := []models.Site{ + {ID: 1, Type: "http", URL: srv.URL}, + {ID: 2, Type: "http", URL: srv.URL}, + } + + strict := &http.Client{} + insecure := &http.Client{} + results := probeExecuteChecks(context.Background(), sites, strict, insecure) + + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } + for _, r := range results { + if !r.IsUp { + t.Errorf("site %d expected UP", r.SiteID) + } + } +} + +func TestProbeExecuteChecks_Concurrency(t *testing.T) { + var concurrent int64 + var maxConcurrent int64 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cur := atomic.AddInt64(&concurrent, 1) + for { + old := atomic.LoadInt64(&maxConcurrent) + if cur <= old || atomic.CompareAndSwapInt64(&maxConcurrent, old, cur) { + break + } + } + time.Sleep(50 * time.Millisecond) + atomic.AddInt64(&concurrent, -1) + w.WriteHeader(200) + })) + defer srv.Close() + + var sites []models.Site + for i := 0; i < 20; i++ { + sites = append(sites, models.Site{ID: i + 1, Type: "http", URL: srv.URL}) + } + + results := probeExecuteChecks(context.Background(), sites, &http.Client{}, &http.Client{}) + if len(results) != 20 { + t.Errorf("expected 20 results, got %d", len(results)) + } + mc := atomic.LoadInt64(&maxConcurrent) + if mc > 10 { + t.Errorf("expected max 10 concurrent, got %d", mc) + } +} + +func TestProbeReportResults_Success(t *testing.T) { + var received struct { + NodeID string `json:"node_id"` + Results []probeResultItem `json:"results"` + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&received) + w.WriteHeader(200) + })) + defer srv.Close() + + err := probeReportResults(context.Background(), srv.Client(), ProbeConfig{ + NodeID: "n1", LeaderURL: srv.URL, SharedKey: "key", + }, []probeResultItem{{SiteID: 1, LatencyNs: 5000000, IsUp: true}}) + + if err != nil { + t.Fatalf("report: %v", err) + } + if received.NodeID != "n1" { + t.Errorf("expected n1, got %s", received.NodeID) + } + if len(received.Results) != 1 { + t.Errorf("expected 1 result, got %d", len(received.Results)) + } +} + +func TestProbeReportResults_Failure(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + })) + defer srv.Close() + + err := probeReportResults(context.Background(), srv.Client(), ProbeConfig{ + LeaderURL: srv.URL, + }, []probeResultItem{{SiteID: 1}}) + + if err == nil { + t.Error("expected error on 500") + } +} + +// --- sleepCtx --- + +func TestSleepCtx_Cancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + sleepCtx(ctx, 10*time.Second) + if time.Since(start) > time.Second { + t.Error("expected immediate return on canceled context") + } +} diff --git a/internal/importer/kuma.go b/internal/importer/kuma.go index d2fd2aa..8226d3f 100644 --- a/internal/importer/kuma.go +++ b/internal/importer/kuma.go @@ -96,7 +96,9 @@ func convertKumaNotifications(entries []KumaNotifEntry) map[int]models.AlertConf result := make(map[int]models.AlertConfig) for _, entry := range entries { var cfg KumaNotifConfig - json.Unmarshal([]byte(entry.Config), &cfg) + if err := json.Unmarshal([]byte(entry.Config), &cfg); err != nil { + continue + } alert := models.AlertConfig{ ID: entry.ID, diff --git a/internal/metrics/prometheus_test.go b/internal/metrics/prometheus_test.go index d16723b..847d7cf 100644 --- a/internal/metrics/prometheus_test.go +++ b/internal/metrics/prometheus_test.go @@ -64,6 +64,7 @@ func (m *mockStore) DeleteMaintenanceWindow(int) error { retur func (m *mockStore) IsMonitorInMaintenance(int) (bool, error) { return false, nil } func (m *mockStore) GetPreference(string) (string, error) { return "", nil } func (m *mockStore) SetPreference(string, string) error { return nil } +func (m *mockStore) Close() error { return nil } func TestMetricsHandler(t *testing.T) { ms := &mockStore{ diff --git a/internal/monitor/checker_test.go b/internal/monitor/checker_test.go new file mode 100644 index 0000000..39698b4 --- /dev/null +++ b/internal/monitor/checker_test.go @@ -0,0 +1,203 @@ +package monitor + +import ( + "crypto/tls" + "go-upkeep/internal/models" + "net" + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" +) + +func TestRunCheck_HTTP_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer srv.Close() + + site := models.Site{ID: 1, Type: "http", URL: srv.URL} + result := RunCheck(site, http.DefaultClient, http.DefaultClient, false) + + if result.Status != "UP" { + t.Errorf("expected UP, got %s", result.Status) + } + if result.StatusCode != 200 { + t.Errorf("expected 200, got %d", result.StatusCode) + } + if result.LatencyNs <= 0 { + t.Error("expected positive latency") + } +} + +func TestRunCheck_HTTP_ServerError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + })) + defer srv.Close() + + site := models.Site{ID: 1, Type: "http", URL: srv.URL} + result := RunCheck(site, http.DefaultClient, http.DefaultClient, false) + + if result.Status != "DOWN" { + t.Errorf("expected DOWN, got %s", result.Status) + } + if result.StatusCode != 500 { + t.Errorf("expected 500, got %d", result.StatusCode) + } +} + +func TestRunCheck_HTTP_CustomAcceptedCodes(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(302) + })) + defer srv.Close() + + client := &http.Client{CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }} + + site := models.Site{ID: 1, Type: "http", URL: srv.URL, AcceptedCodes: "200-399"} + result := RunCheck(site, client, client, false) + + if result.Status != "UP" { + t.Errorf("expected UP with accepted 200-399, got %s", result.Status) + } +} + +func TestRunCheck_HTTP_MethodRespected(t *testing.T) { + var receivedMethod string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedMethod = r.Method + w.WriteHeader(200) + })) + defer srv.Close() + + site := models.Site{ID: 1, Type: "http", URL: srv.URL, Method: "HEAD"} + RunCheck(site, http.DefaultClient, http.DefaultClient, false) + + if receivedMethod != "HEAD" { + t.Errorf("expected HEAD, got %s", receivedMethod) + } +} + +func TestRunCheck_HTTP_Timeout(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(2 * time.Second) + w.WriteHeader(200) + })) + defer srv.Close() + + site := models.Site{ID: 1, Type: "http", URL: srv.URL, Timeout: 1} + result := RunCheck(site, http.DefaultClient, http.DefaultClient, false) + + if result.Status != "DOWN" { + t.Errorf("expected DOWN on timeout, got %s", result.Status) + } +} + +func TestRunCheck_HTTP_SSLFields(t *testing.T) { + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer srv.Close() + + insecureClient := &http.Client{ + Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + } + + site := models.Site{ID: 1, Type: "http", URL: srv.URL, CheckSSL: true, IgnoreTLS: true} + result := RunCheck(site, http.DefaultClient, insecureClient, false) + + if result.Status != "UP" { + t.Errorf("expected UP, got %s", result.Status) + } + if !result.HasSSL { + t.Error("expected HasSSL=true") + } + if result.CertExpiry.IsZero() { + t.Error("expected CertExpiry populated") + } +} + +func TestRunCheck_Port_Open(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + _, portStr, _ := net.SplitHostPort(ln.Addr().String()) + port, _ := strconv.Atoi(portStr) + + site := models.Site{ID: 1, Type: "port", Hostname: "127.0.0.1", Port: port, Timeout: 2} + result := RunCheck(site, nil, nil, false) + + if result.Status != "UP" { + t.Errorf("expected UP, got %s", result.Status) + } + if result.LatencyNs <= 0 { + t.Error("expected positive latency") + } +} + +func TestRunCheck_Port_Closed(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + _, portStr, _ := net.SplitHostPort(ln.Addr().String()) + port, _ := strconv.Atoi(portStr) + ln.Close() + + site := models.Site{ID: 1, Type: "port", Hostname: "127.0.0.1", Port: port, Timeout: 1} + result := RunCheck(site, nil, nil, false) + + if result.Status != "DOWN" { + t.Errorf("expected DOWN, got %s", result.Status) + } +} + +func TestRunCheck_UnknownType(t *testing.T) { + site := models.Site{ID: 1, Type: "invalid"} + result := RunCheck(site, nil, nil, false) + + if result.Status != "DOWN" { + t.Errorf("expected DOWN for unknown type, got %s", result.Status) + } +} + +func TestIsCodeAccepted(t *testing.T) { + tests := []struct { + code int + accepted string + want bool + }{ + {200, "", true}, + {299, "", true}, + {300, "", false}, + {302, "200-399", true}, + {400, "200-399", false}, + {301, "200,301,404", true}, + {500, "200,301,404", false}, + {404, "200-299,400-499", true}, + {500, "200-299,400-499", false}, + } + + for _, tt := range tests { + got := isCodeAccepted(tt.code, tt.accepted) + if got != tt.want { + t.Errorf("isCodeAccepted(%d, %q) = %v, want %v", tt.code, tt.accepted, got, tt.want) + } + } +} + +func TestSiteTimeout(t *testing.T) { + if got := siteTimeout(models.Site{Timeout: 0}); got != 5*time.Second { + t.Errorf("expected 5s default, got %v", got) + } + if got := siteTimeout(models.Site{Timeout: 10}); got != 10*time.Second { + t.Errorf("expected 10s, got %v", got) + } +} diff --git a/internal/monitor/monitor.go b/internal/monitor/monitor.go index 284497c..bcd4371 100644 --- a/internal/monitor/monitor.go +++ b/internal/monitor/monitor.go @@ -7,6 +7,7 @@ import ( "go-upkeep/internal/alert" "go-upkeep/internal/models" "go-upkeep/internal/store" + "math/rand/v2" "net/http" "sync" "time" @@ -25,7 +26,7 @@ type Engine struct { histMu sync.RWMutex histories map[int]*SiteHistory - tokenIndex map[string]int + tokenIndex map[string]int // protected by mu probeResultsMu sync.RWMutex probeResults map[int]map[string]NodeResult @@ -277,6 +278,14 @@ func (e *Engine) ToggleSitePause(id int) bool { } func (e *Engine) monitorRoutine(ctx context.Context, id int) { + // Stagger initial check to avoid thundering herd on startup + stagger := time.Duration(rand.IntN(3000)) * time.Millisecond + select { + case <-time.After(stagger): + case <-ctx.Done(): + return + } + e.checkByID(id) for { select { @@ -314,8 +323,9 @@ func (e *Engine) monitorRoutine(ctx context.Context, id int) { if interval < 5 { interval = 5 } + jitter := time.Duration(rand.IntN(interval*100)) * time.Millisecond select { - case <-time.After(time.Duration(interval) * time.Second): + case <-time.After(time.Duration(interval)*time.Second + jitter): case <-ctx.Done(): return } @@ -433,6 +443,7 @@ func (e *Engine) handleStatusChange(site models.Site, rawStatus string, code int func (e *Engine) triggerAlert(alertID int, title, message string) { cfg, err := e.db.GetAlert(alertID) if err != nil { + e.AddLog(fmt.Sprintf("Failed to load alert config %d: %v", alertID, err)) return } provider := alert.GetProvider(cfg) @@ -440,8 +451,9 @@ func (e *Engine) triggerAlert(alertID int, title, message string) { go func() { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - _ = ctx - _ = provider.Send(title, message) + if err := provider.Send(ctx, title, message); err != nil { + e.AddLog(fmt.Sprintf("Alert send failed (%s): %v", cfg.Name, err)) + } }() } } diff --git a/internal/monitor/monitor_test.go b/internal/monitor/monitor_test.go new file mode 100644 index 0000000..dd392de --- /dev/null +++ b/internal/monitor/monitor_test.go @@ -0,0 +1,1044 @@ +package monitor + +import ( + "fmt" + "go-upkeep/internal/models" + "sync" + "testing" + "time" +) + +// --- Mock Store --- + +type savedCheck struct { + SiteID int + LatencyNs int64 + IsUp bool +} + +type mockStore struct { + mu sync.Mutex + sites []models.Site + alerts map[int]models.AlertConfig + maintenance map[int]bool + logs []string + history map[int][]models.CheckRecord + savedChecks []savedCheck + savedLogs []string + getAlertCalls []int +} + +func newMockStore() *mockStore { + return &mockStore{ + alerts: make(map[int]models.AlertConfig), + maintenance: make(map[int]bool), + history: make(map[int][]models.CheckRecord), + } +} + +func (m *mockStore) Init() error { return nil } +func (m *mockStore) GetSites() ([]models.Site, error) { return m.sites, nil } +func (m *mockStore) AddSite(models.Site) error { return nil } +func (m *mockStore) UpdateSite(models.Site) error { return nil } +func (m *mockStore) UpdateSitePaused(int, bool) error { return nil } +func (m *mockStore) DeleteSite(int) error { return nil } +func (m *mockStore) AddAlert(string, string, map[string]string) error { return nil } +func (m *mockStore) UpdateAlert(int, string, string, map[string]string) error { return nil } +func (m *mockStore) DeleteAlert(int) error { return nil } +func (m *mockStore) GetAllUsers() ([]models.User, error) { return nil, nil } +func (m *mockStore) AddUser(string, string, string) error { return nil } +func (m *mockStore) UpdateUser(int, string, string, string) error { return nil } +func (m *mockStore) DeleteUser(int) error { return nil } +func (m *mockStore) ExportData() (models.Backup, error) { return models.Backup{}, nil } +func (m *mockStore) ImportData(models.Backup) error { return nil } +func (m *mockStore) GetSiteByName(string) (models.Site, error) { return models.Site{}, nil } +func (m *mockStore) AddSiteReturningID(models.Site) (int, error) { return 0, nil } +func (m *mockStore) AddAlertReturningID(string, string, map[string]string) (int, error) { + return 0, nil +} +func (m *mockStore) SaveCheckFromNode(int, string, int64, bool) error { return nil } +func (m *mockStore) RegisterNode(models.ProbeNode) error { return nil } +func (m *mockStore) GetNode(string) (models.ProbeNode, error) { return models.ProbeNode{}, nil } +func (m *mockStore) GetAllNodes() ([]models.ProbeNode, error) { return nil, nil } +func (m *mockStore) UpdateNodeLastSeen(string) error { return nil } +func (m *mockStore) DeleteNode(string) error { return nil } +func (m *mockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) { + return nil, nil +} +func (m *mockStore) GetAllMaintenanceWindows(int) ([]models.MaintenanceWindow, error) { + return nil, nil +} +func (m *mockStore) AddMaintenanceWindow(models.MaintenanceWindow) error { return nil } +func (m *mockStore) EndMaintenanceWindow(int) error { return nil } +func (m *mockStore) DeleteMaintenanceWindow(int) error { return nil } +func (m *mockStore) GetPreference(string) (string, error) { return "", nil } +func (m *mockStore) SetPreference(string, string) error { return nil } +func (m *mockStore) Close() error { return nil } + +func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { + m.mu.Lock() + defer m.mu.Unlock() + var result []models.AlertConfig + for _, a := range m.alerts { + result = append(result, a) + } + return result, nil +} + +func (m *mockStore) GetAlert(id int) (models.AlertConfig, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.getAlertCalls = append(m.getAlertCalls, id) + if a, ok := m.alerts[id]; ok { + return a, nil + } + return models.AlertConfig{}, fmt.Errorf("alert %d not found", id) +} + +func (m *mockStore) GetAlertByName(name string) (models.AlertConfig, error) { + m.mu.Lock() + defer m.mu.Unlock() + for _, a := range m.alerts { + if a.Name == name { + return a, nil + } + } + return models.AlertConfig{}, fmt.Errorf("alert %q not found", name) +} + +func (m *mockStore) IsMonitorInMaintenance(id int) (bool, error) { + m.mu.Lock() + defer m.mu.Unlock() + return m.maintenance[id], nil +} + +func (m *mockStore) SaveCheck(siteID int, latencyNs int64, isUp bool) error { + m.mu.Lock() + defer m.mu.Unlock() + m.savedChecks = append(m.savedChecks, savedCheck{siteID, latencyNs, isUp}) + return nil +} + +func (m *mockStore) SaveLog(msg string) error { + m.mu.Lock() + defer m.mu.Unlock() + m.savedLogs = append(m.savedLogs, msg) + return nil +} + +func (m *mockStore) LoadLogs(limit int) ([]string, error) { + return m.logs, nil +} + +func (m *mockStore) LoadAllHistory(limit int) (map[int][]models.CheckRecord, error) { + return m.history, nil +} + +// --- Helpers --- + +func newTestEngine(ms *mockStore) *Engine { + return NewEngine(ms) +} + +func injectSite(e *Engine, site models.Site) { + e.mu.Lock() + e.liveState[site.ID] = site + e.addToTokenIndex(site) + e.mu.Unlock() +} + +func getSite(e *Engine, id int) (models.Site, bool) { + e.mu.RLock() + defer e.mu.RUnlock() + s, ok := e.liveState[id] + return s, ok +} + +func waitAsync() { + time.Sleep(50 * time.Millisecond) +} + +func (m *mockStore) getAlertCallsSnapshot() []int { + m.mu.Lock() + defer m.mu.Unlock() + cp := make([]int, len(m.getAlertCalls)) + copy(cp, m.getAlertCalls) + return cp +} + +// --- Group 1: State Machine --- + +func TestHandleStatusChange_PendingToUp(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", Status: "PENDING", MaxRetries: 3, AlertID: 1} + injectSite(e, site) + + e.handleStatusChange(site, "UP", 200, 10*time.Millisecond) + + s, _ := getSite(e, 1) + if s.Status != "UP" { + t.Errorf("expected UP, got %s", s.Status) + } + if s.FailureCount != 0 { + t.Errorf("expected FailureCount 0, got %d", s.FailureCount) + } + waitAsync() + if len(ms.getAlertCallsSnapshot()) != 0 { + t.Error("expected no alert for PENDING→UP") + } +} + +func TestHandleStatusChange_UpIncrementFailure(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", Status: "UP", MaxRetries: 3, FailureCount: 0} + injectSite(e, site) + + e.handleStatusChange(site, "DOWN", 500, 0) + + s, _ := getSite(e, 1) + if s.Status != "UP" { + t.Errorf("expected UP (under retry threshold), got %s", s.Status) + } + if s.FailureCount != 1 { + t.Errorf("expected FailureCount 1, got %d", s.FailureCount) + } +} + +func TestHandleStatusChange_UpToDown_ExceedsRetries(t *testing.T) { + ms := newMockStore() + ms.alerts[1] = models.AlertConfig{ID: 1, Name: "discord", Type: "webhook", Settings: map[string]string{"url": "http://example.com"}} + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", Status: "UP", MaxRetries: 2, FailureCount: 2, AlertID: 1} + injectSite(e, site) + + e.handleStatusChange(site, "DOWN", 500, 0) + + s, _ := getSite(e, 1) + if s.Status != "DOWN" { + t.Errorf("expected DOWN, got %s", s.Status) + } + if s.FailureCount != 3 { + t.Errorf("expected FailureCount 3, got %d", s.FailureCount) + } + waitAsync() + calls := ms.getAlertCallsSnapshot() + if len(calls) == 0 || calls[0] != 1 { + t.Errorf("expected alert call for alertID 1, got %v", calls) + } +} + +func TestHandleStatusChange_UpToDown_ZeroRetries(t *testing.T) { + ms := newMockStore() + ms.alerts[1] = models.AlertConfig{ID: 1, Name: "test", Type: "webhook", Settings: map[string]string{"url": "http://example.com"}} + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", Status: "UP", MaxRetries: 0, FailureCount: 0, AlertID: 1} + injectSite(e, site) + + e.handleStatusChange(site, "DOWN", 0, 0) + + s, _ := getSite(e, 1) + if s.Status != "DOWN" { + t.Errorf("expected DOWN, got %s", s.Status) + } + waitAsync() + if len(ms.getAlertCallsSnapshot()) == 0 { + t.Error("expected alert on immediate DOWN") + } +} + +func TestHandleStatusChange_DownToUp_Recovery(t *testing.T) { + ms := newMockStore() + ms.alerts[1] = models.AlertConfig{ID: 1, Name: "test", Type: "webhook", Settings: map[string]string{"url": "http://example.com"}} + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", Status: "DOWN", FailureCount: 4, AlertID: 1} + injectSite(e, site) + + e.handleStatusChange(site, "UP", 200, 5*time.Millisecond) + + s, _ := getSite(e, 1) + if s.Status != "UP" { + t.Errorf("expected UP, got %s", s.Status) + } + if s.FailureCount != 0 { + t.Errorf("expected FailureCount 0, got %d", s.FailureCount) + } + waitAsync() + if len(ms.getAlertCallsSnapshot()) == 0 { + t.Error("expected recovery alert") + } +} + +func TestHandleStatusChange_DownStaysDown(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", Status: "DOWN", MaxRetries: 2, FailureCount: 3} + injectSite(e, site) + + e.handleStatusChange(site, "DOWN", 0, 0) + + s, _ := getSite(e, 1) + if s.Status != "DOWN" { + t.Errorf("expected DOWN, got %s", s.Status) + } + waitAsync() + if len(ms.getAlertCallsSnapshot()) != 0 { + t.Error("expected no re-alert for already DOWN") + } +} + +func TestHandleStatusChange_SSLExpired(t *testing.T) { + ms := newMockStore() + ms.alerts[1] = models.AlertConfig{ID: 1, Name: "test", Type: "webhook", Settings: map[string]string{"url": "http://example.com"}} + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", Status: "UP", MaxRetries: 0, AlertID: 1} + injectSite(e, site) + + e.handleStatusChange(site, "SSL EXP", 0, 0) + + s, _ := getSite(e, 1) + if s.Status != "SSL EXP" { + t.Errorf("expected SSL EXP, got %s", s.Status) + } + waitAsync() + if len(ms.getAlertCallsSnapshot()) == 0 { + t.Error("expected alert on SSL EXP") + } +} + +func TestHandleStatusChange_AlertSuppressedMaintenance(t *testing.T) { + ms := newMockStore() + ms.maintenance[1] = true + ms.alerts[1] = models.AlertConfig{ID: 1, Name: "test", Type: "webhook", Settings: map[string]string{"url": "http://example.com"}} + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", Status: "UP", MaxRetries: 0, AlertID: 1} + injectSite(e, site) + + e.handleStatusChange(site, "DOWN", 0, 0) + + s, _ := getSite(e, 1) + if s.Status != "DOWN" { + t.Errorf("expected DOWN, got %s", s.Status) + } + waitAsync() + if len(ms.getAlertCallsSnapshot()) != 0 { + t.Error("expected no alert during maintenance") + } + logs := e.GetLogs() + found := false + for _, l := range logs { + if containsStr(l, "suppressed") { + found = true + break + } + } + if !found { + t.Error("expected log mentioning suppressed") + } +} + +func TestHandleStatusChange_RecoverySuppressedMaintenance(t *testing.T) { + ms := newMockStore() + ms.maintenance[1] = true + ms.alerts[1] = models.AlertConfig{ID: 1, Name: "test", Type: "webhook", Settings: map[string]string{"url": "http://example.com"}} + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", Status: "DOWN", AlertID: 1} + injectSite(e, site) + + e.handleStatusChange(site, "UP", 200, 0) + + s, _ := getSite(e, 1) + if s.Status != "UP" { + t.Errorf("expected UP, got %s", s.Status) + } + waitAsync() + if len(ms.getAlertCallsSnapshot()) != 0 { + t.Error("expected no alert during maintenance recovery") + } +} + +func TestHandleStatusChange_SSLWarning(t *testing.T) { + ms := newMockStore() + ms.alerts[1] = models.AlertConfig{ID: 1, Name: "test", Type: "webhook", Settings: map[string]string{"url": "http://example.com"}} + e := newTestEngine(ms) + site := models.Site{ + ID: 1, Name: "test", Status: "UP", Type: "http", + CheckSSL: true, HasSSL: true, ExpiryThreshold: 30, + SentSSLWarning: false, AlertID: 1, + CertExpiry: time.Now().Add(15 * 24 * time.Hour), + } + injectSite(e, site) + + e.handleStatusChange(site, "UP", 200, 0) + + s, _ := getSite(e, 1) + if !s.SentSSLWarning { + t.Error("expected SentSSLWarning=true") + } + waitAsync() + if len(ms.getAlertCallsSnapshot()) == 0 { + t.Error("expected SSL warning alert") + } +} + +func TestHandleStatusChange_SSLWarningNotRepeated(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ + ID: 1, Name: "test", Status: "UP", Type: "http", + CheckSSL: true, HasSSL: true, ExpiryThreshold: 30, + SentSSLWarning: true, AlertID: 1, + CertExpiry: time.Now().Add(15 * 24 * time.Hour), + } + injectSite(e, site) + + e.handleStatusChange(site, "UP", 200, 0) + + waitAsync() + if len(ms.getAlertCallsSnapshot()) != 0 { + t.Error("expected no repeat SSL warning") + } +} + +func TestHandleStatusChange_SSLWarningReset(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ + ID: 1, Name: "test", Status: "UP", Type: "http", + CheckSSL: true, HasSSL: true, ExpiryThreshold: 30, + SentSSLWarning: true, + CertExpiry: time.Now().Add(60 * 24 * time.Hour), + } + injectSite(e, site) + + e.handleStatusChange(site, "UP", 200, 0) + + s, _ := getSite(e, 1) + if s.SentSSLWarning { + t.Error("expected SentSSLWarning reset to false") + } +} + +func TestHandleStatusChange_SSLWarningSuppressedMaint(t *testing.T) { + ms := newMockStore() + ms.maintenance[1] = true + ms.alerts[1] = models.AlertConfig{ID: 1, Name: "test", Type: "webhook", Settings: map[string]string{"url": "http://example.com"}} + e := newTestEngine(ms) + site := models.Site{ + ID: 1, Name: "test", Status: "UP", Type: "http", + CheckSSL: true, HasSSL: true, ExpiryThreshold: 30, + SentSSLWarning: false, AlertID: 1, + CertExpiry: time.Now().Add(15 * 24 * time.Hour), + } + injectSite(e, site) + + e.handleStatusChange(site, "UP", 200, 0) + + s, _ := getSite(e, 1) + if !s.SentSSLWarning { + t.Error("expected SentSSLWarning=true even in maintenance") + } + waitAsync() + if len(ms.getAlertCallsSnapshot()) != 0 { + t.Error("expected no alert during maintenance") + } +} + +func TestHandleStatusChange_InactiveEngine(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", Status: "UP", MaxRetries: 0} + injectSite(e, site) + e.SetActive(false) + + e.handleStatusChange(site, "DOWN", 0, 0) + + s, _ := getSite(e, 1) + if s.Status != "UP" { + t.Error("expected no state change when inactive") + } +} + +// --- Group 2: Heartbeat --- + +func TestRecordHeartbeat_ValidToken(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "push-test", Type: "push", Token: "abc123", Status: "UP"} + injectSite(e, site) + + if !e.RecordHeartbeat("abc123") { + t.Error("expected true for valid token") + } + + s, _ := getSite(e, 1) + if s.Status != "UP" { + t.Errorf("expected UP, got %s", s.Status) + } + if time.Since(s.LastCheck) > time.Second { + t.Error("expected LastCheck to be recent") + } +} + +func TestRecordHeartbeat_RecoveryFromDown(t *testing.T) { + ms := newMockStore() + ms.alerts[1] = models.AlertConfig{ID: 1, Name: "test", Type: "webhook", Settings: map[string]string{"url": "http://example.com"}} + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "push-test", Type: "push", Token: "abc123", Status: "DOWN", AlertID: 1, FailureCount: 3} + injectSite(e, site) + + if !e.RecordHeartbeat("abc123") { + t.Error("expected true") + } + + s, _ := getSite(e, 1) + if s.Status != "UP" { + t.Errorf("expected UP, got %s", s.Status) + } + if s.FailureCount != 0 { + t.Errorf("expected FailureCount 0, got %d", s.FailureCount) + } + waitAsync() + if len(ms.getAlertCallsSnapshot()) == 0 { + t.Error("expected recovery alert") + } +} + +func TestRecordHeartbeat_UnknownToken(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + + if e.RecordHeartbeat("unknown") { + t.Error("expected false for unknown token") + } +} + +func TestRecordHeartbeat_InactiveEngine(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ID: 1, Type: "push", Token: "abc123", Status: "UP"} + injectSite(e, site) + e.SetActive(false) + + if e.RecordHeartbeat("abc123") { + t.Error("expected false when inactive") + } +} + +// --- Group 3: Push Deadline --- + +func TestCheckPush_DeadlineMissed(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ + ID: 1, Name: "push", Type: "push", Status: "UP", + Interval: 10, MaxRetries: 0, + LastCheck: time.Now().Add(-20 * time.Second), + } + injectSite(e, site) + + e.checkPush(site) + + s, _ := getSite(e, 1) + if s.Status != "DOWN" { + t.Errorf("expected DOWN after missed deadline, got %s", s.Status) + } +} + +func TestCheckPush_WithinDeadline(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ + ID: 1, Name: "push", Type: "push", Status: "UP", + Interval: 60, LastCheck: time.Now(), + } + injectSite(e, site) + + e.checkPush(site) + + s, _ := getSite(e, 1) + if s.Status != "UP" { + t.Errorf("expected UP, got %s", s.Status) + } +} + +func TestCheckPush_PendingToUp(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ + ID: 1, Name: "push", Type: "push", Status: "PENDING", + Interval: 60, LastCheck: time.Now(), + } + injectSite(e, site) + + e.checkPush(site) + + s, _ := getSite(e, 1) + if s.Status != "UP" { + t.Errorf("expected UP, got %s", s.Status) + } +} + +// --- Group 4: Group Checks --- + +func TestCheckGroup_AllChildrenUp(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + group := models.Site{ID: 1, Name: "group", Type: "group", Status: "PENDING"} + child1 := models.Site{ID: 2, Name: "child1", Type: "http", ParentID: 1, Status: "UP"} + child2 := models.Site{ID: 3, Name: "child2", Type: "http", ParentID: 1, Status: "UP"} + injectSite(e, group) + injectSite(e, child1) + injectSite(e, child2) + + e.checkGroup(group) + + s, _ := getSite(e, 1) + if s.Status != "UP" { + t.Errorf("expected group UP, got %s", s.Status) + } +} + +func TestCheckGroup_OneChildDown(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + group := models.Site{ID: 1, Name: "group", Type: "group", Status: "UP"} + child1 := models.Site{ID: 2, Name: "child1", Type: "http", ParentID: 1, Status: "UP"} + child2 := models.Site{ID: 3, Name: "child2", Type: "http", ParentID: 1, Status: "DOWN"} + injectSite(e, group) + injectSite(e, child1) + injectSite(e, child2) + + e.checkGroup(group) + + s, _ := getSite(e, 1) + if s.Status != "DOWN" { + t.Errorf("expected group DOWN, got %s", s.Status) + } +} + +func TestCheckGroup_PausedChildIgnored(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + group := models.Site{ID: 1, Name: "group", Type: "group"} + child1 := models.Site{ID: 2, Name: "child1", Type: "http", ParentID: 1, Status: "UP"} + child2 := models.Site{ID: 3, Name: "child2", Type: "http", ParentID: 1, Status: "DOWN", Paused: true} + injectSite(e, group) + injectSite(e, child1) + injectSite(e, child2) + + e.checkGroup(group) + + s, _ := getSite(e, 1) + if s.Status != "UP" { + t.Errorf("expected UP (paused child ignored), got %s", s.Status) + } +} + +func TestCheckGroup_MaintenanceChildIgnored(t *testing.T) { + ms := newMockStore() + ms.maintenance[3] = true + e := newTestEngine(ms) + group := models.Site{ID: 1, Name: "group", Type: "group"} + child1 := models.Site{ID: 2, Name: "child1", Type: "http", ParentID: 1, Status: "UP"} + child2 := models.Site{ID: 3, Name: "child2", Type: "http", ParentID: 1, Status: "DOWN"} + injectSite(e, group) + injectSite(e, child1) + injectSite(e, child2) + + e.checkGroup(group) + + s, _ := getSite(e, 1) + if s.Status != "UP" { + t.Errorf("expected UP (maint child ignored), got %s", s.Status) + } +} + +func TestCheckGroup_NoChildren(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + group := models.Site{ID: 1, Name: "group", Type: "group", Status: "UP"} + injectSite(e, group) + + e.checkGroup(group) + + s, _ := getSite(e, 1) + if s.Status != "PENDING" { + t.Errorf("expected PENDING for no children, got %s", s.Status) + } +} + +// --- Group 5: History --- + +func TestRecordCheck_Appends(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + + e.recordCheck(1, 5*time.Millisecond, true) + + h, ok := e.GetHistory(1) + if !ok { + t.Fatal("expected history for site 1") + } + if h.TotalChecks != 1 || h.UpChecks != 1 { + t.Errorf("expected 1/1, got %d/%d", h.TotalChecks, h.UpChecks) + } + if len(h.Latencies) != 1 || h.Latencies[0] != 5*time.Millisecond { + t.Errorf("unexpected latencies: %v", h.Latencies) + } +} + +func TestRecordCheck_RollingWindow(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + + for i := 0; i < 65; i++ { + e.recordCheck(1, time.Duration(i)*time.Millisecond, i%2 == 0) + } + + h, _ := e.GetHistory(1) + if len(h.Latencies) != 60 { + t.Errorf("expected 60 latencies, got %d", len(h.Latencies)) + } + if len(h.Statuses) != 60 { + t.Errorf("expected 60 statuses, got %d", len(h.Statuses)) + } + if h.TotalChecks != 65 { + t.Errorf("expected TotalChecks 65, got %d", h.TotalChecks) + } +} + +func TestGetHistory_DeepCopy(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + e.recordCheck(1, 5*time.Millisecond, true) + + h1, _ := e.GetHistory(1) + h1.Latencies[0] = 999 * time.Second + h1.TotalChecks = 999 + + h2, _ := e.GetHistory(1) + if h2.Latencies[0] == 999*time.Second { + t.Error("GetHistory returned reference, not copy") + } + if h2.TotalChecks == 999 { + t.Error("GetHistory returned reference, not copy") + } +} + +func TestInitHistory_LoadsFromDB(t *testing.T) { + ms := newMockStore() + ms.history[1] = []models.CheckRecord{ + {SiteID: 1, LatencyNs: 5000000, IsUp: true}, + {SiteID: 1, LatencyNs: 3000000, IsUp: false}, + } + e := newTestEngine(ms) + e.InitHistory() + + h, ok := e.GetHistory(1) + if !ok { + t.Fatal("expected history for site 1") + } + if h.TotalChecks != 2 { + t.Errorf("expected TotalChecks 2, got %d", h.TotalChecks) + } + if h.UpChecks != 1 { + t.Errorf("expected UpChecks 1, got %d", h.UpChecks) + } +} + +// --- Group 6: State Management --- + +func TestUpdateSiteConfig_PreservesRuntime(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", URL: "http://old.com", Status: "DOWN", FailureCount: 3, Latency: 100 * time.Millisecond} + injectSite(e, site) + + updated := models.Site{ID: 1, Name: "test", URL: "http://new.com", Interval: 60} + e.UpdateSiteConfig(updated) + + s, _ := getSite(e, 1) + if s.URL != "http://new.com" { + t.Errorf("expected URL updated, got %s", s.URL) + } + if s.Status != "DOWN" { + t.Errorf("expected Status preserved, got %s", s.Status) + } + if s.FailureCount != 3 { + t.Errorf("expected FailureCount preserved, got %d", s.FailureCount) + } + if s.Latency != 100*time.Millisecond { + t.Errorf("expected Latency preserved, got %v", s.Latency) + } +} + +func TestRemoveSite_CleansUp(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", Type: "push", Token: "tok1", Status: "UP"} + injectSite(e, site) + e.recordCheck(1, 5*time.Millisecond, true) + + e.RemoveSite(1) + + if _, ok := getSite(e, 1); ok { + t.Error("expected site removed from liveState") + } + if e.RecordHeartbeat("tok1") { + t.Error("expected token removed from index") + } + if _, ok := e.GetHistory(1); ok { + t.Error("expected history removed") + } +} + +func TestToggleSitePause(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", Status: "UP"} + injectSite(e, site) + + paused := e.ToggleSitePause(1) + if !paused { + t.Error("expected paused=true after first toggle") + } + s, _ := getSite(e, 1) + if !s.Paused { + t.Error("expected Paused=true in state") + } + + paused = e.ToggleSitePause(1) + if paused { + t.Error("expected paused=false after second toggle") + } +} + +func TestToggleSitePause_NonexistentSite(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + if e.ToggleSitePause(999) { + t.Error("expected false for nonexistent site") + } +} + +func TestGetAllSites_ReturnsCopy(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + injectSite(e, models.Site{ID: 1, Name: "s1", Status: "UP"}) + injectSite(e, models.Site{ID: 2, Name: "s2", Status: "DOWN"}) + + sites := e.GetAllSites() + if len(sites) != 2 { + t.Fatalf("expected 2 sites, got %d", len(sites)) + } + sites[0].Name = "mutated" + + fresh := e.GetAllSites() + for _, s := range fresh { + if s.Name == "mutated" { + t.Error("GetAllSites returned reference, not copy") + } + } +} + +func TestGetLiveState_ReturnsCopy(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + injectSite(e, models.Site{ID: 1, Name: "s1", Status: "UP"}) + + state := e.GetLiveState() + state[1] = models.Site{Name: "mutated"} + + fresh := e.GetLiveState() + if fresh[1].Name == "mutated" { + t.Error("GetLiveState returned reference, not copy") + } +} + +// --- Group 7: Logs --- + +func TestAddLog_PrependAndCap(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + + for i := 0; i < 105; i++ { + e.AddLog(fmt.Sprintf("log-%d", i)) + } + + logs := e.GetLogs() + if len(logs) != 100 { + t.Errorf("expected 100 logs, got %d", len(logs)) + } + if !containsStr(logs[0], "log-104") { + t.Errorf("expected newest log first, got %s", logs[0]) + } +} + +func TestInitLogs_LoadsFromDB(t *testing.T) { + ms := newMockStore() + ms.logs = []string{"old-log-1", "old-log-2"} + e := newTestEngine(ms) + e.InitLogs() + + logs := e.GetLogs() + if len(logs) != 2 { + t.Errorf("expected 2 logs, got %d", len(logs)) + } +} + +// --- Group 8: Probe Aggregation --- + +func TestAggregateStatus_AnyDown(t *testing.T) { + results := []NodeResult{ + {IsUp: true, LatencyNs: 100}, + {IsUp: false, LatencyNs: 200}, + } + isUp, _ := AggregateStatus(results, AggAnyDown) + if isUp { + t.Error("AggAnyDown: expected DOWN when any node is down") + } +} + +func TestAggregateStatus_AnyDown_AllUp(t *testing.T) { + results := []NodeResult{ + {IsUp: true, LatencyNs: 100}, + {IsUp: true, LatencyNs: 200}, + } + isUp, _ := AggregateStatus(results, AggAnyDown) + if !isUp { + t.Error("AggAnyDown: expected UP when all nodes up") + } +} + +func TestAggregateStatus_Majority(t *testing.T) { + results := []NodeResult{ + {IsUp: true, LatencyNs: 100}, + {IsUp: true, LatencyNs: 200}, + {IsUp: false, LatencyNs: 300}, + } + isUp, _ := AggregateStatus(results, AggMajorityDown) + if !isUp { + t.Error("AggMajority: expected UP when 2/3 are up") + } +} + +func TestAggregateStatus_AllDown(t *testing.T) { + results := []NodeResult{ + {IsUp: false, LatencyNs: 100}, + {IsUp: false, LatencyNs: 200}, + {IsUp: true, LatencyNs: 300}, + } + isUp, _ := AggregateStatus(results, AggAllDown) + if !isUp { + t.Error("AggAllDown: expected UP when at least one node up") + } +} + +func TestAggregateStatus_Empty(t *testing.T) { + isUp, avg := AggregateStatus(nil, AggAnyDown) + if !isUp { + t.Error("expected UP for empty results") + } + if avg != 0 { + t.Errorf("expected 0 avg latency, got %d", avg) + } +} + +func TestAggregateStatus_LatencyAverage(t *testing.T) { + results := []NodeResult{ + {IsUp: true, LatencyNs: 100}, + {IsUp: true, LatencyNs: 200}, + {IsUp: true, LatencyNs: 300}, + } + _, avg := AggregateStatus(results, AggAnyDown) + if avg != 200 { + t.Errorf("expected avg 200, got %d", avg) + } +} + +// --- Group 9: Concurrency --- + +func TestConcurrent_RecordHeartbeat(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + for i := 0; i < 10; i++ { + injectSite(e, models.Site{ + ID: i + 1, Type: "push", Token: fmt.Sprintf("tok-%d", i+1), Status: "UP", + }) + } + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + e.RecordHeartbeat(fmt.Sprintf("tok-%d", (n%10)+1)) + }(i) + } + wg.Wait() +} + +func TestConcurrent_HandleStatusChangeAndGetState(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", Status: "UP", MaxRetries: 100} + injectSite(e, site) + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(2) + go func() { + defer wg.Done() + e.handleStatusChange(site, "DOWN", 500, 0) + }() + go func() { + defer wg.Done() + e.GetLiveState() + }() + } + wg.Wait() +} + +func TestConcurrent_RecordCheckAndGetHistory(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(2) + go func(n int) { + defer wg.Done() + e.recordCheck(1, time.Duration(n)*time.Millisecond, true) + }(i) + go func() { + defer wg.Done() + e.GetHistory(1) + }() + } + wg.Wait() + + h, ok := e.GetHistory(1) + if !ok { + t.Fatal("expected history") + } + if len(h.Latencies) > maxHistoryLen { + t.Errorf("history exceeded max: %d", len(h.Latencies)) + } +} + +// --- Utilities --- + +func containsStr(s, substr string) bool { + return len(s) >= len(substr) && searchStr(s, substr) +} + +func searchStr(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/internal/server/server.go b/internal/server/server.go index e08188d..8b18667 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,6 +1,7 @@ package server import ( + "crypto/subtle" "encoding/json" "fmt" "go-upkeep/internal/importer" @@ -15,6 +16,10 @@ import ( "strings" ) +func checkSecret(got, want string) bool { + return subtle.ConstantTimeCompare([]byte(got), []byte(want)) == 1 +} + var statusTpl = template.Must(template.New("status").Parse(` @@ -153,7 +158,7 @@ type ServerConfig struct { ClusterKey string // Shared Secret for Security } -func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) { +func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { if cfg.ClusterKey == "" { fmt.Println("WARNING: No UPKEEP_CLUSTER_SECRET set. Cluster API endpoints are unauthenticated.") } @@ -176,7 +181,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) { // 2. Health Check (For Cluster Follower) mux.HandleFunc("/api/health", func(w http.ResponseWriter, r *http.Request) { - if cfg.ClusterKey != "" && r.Header.Get("X-Upkeep-Secret") != cfg.ClusterKey { + if cfg.ClusterKey != "" && !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) { http.Error(w, "Unauthorized", 401) return } @@ -186,7 +191,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) { // 3. Config Export mux.HandleFunc("/api/backup/export", func(w http.ResponseWriter, r *http.Request) { - if cfg.ClusterKey == "" || r.Header.Get("X-Upkeep-Secret") != cfg.ClusterKey { + if cfg.ClusterKey == "" || !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) { http.Error(w, "Unauthorized: UPKEEP_CLUSTER_SECRET required", 401) return } @@ -205,10 +210,11 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) { http.Error(w, "POST required", 405) return } - if cfg.ClusterKey == "" || r.Header.Get("X-Upkeep-Secret") != cfg.ClusterKey { + if cfg.ClusterKey == "" || !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) { http.Error(w, "Unauthorized", 401) return } + r.Body = http.MaxBytesReader(w, r.Body, 1<<20) var data models.Backup if err := json.NewDecoder(r.Body).Decode(&data); err != nil { http.Error(w, "Invalid JSON", 400) @@ -228,10 +234,11 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) { http.Error(w, "POST required", 405) return } - if cfg.ClusterKey == "" || r.Header.Get("X-Upkeep-Secret") != cfg.ClusterKey { + if cfg.ClusterKey == "" || !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) { http.Error(w, "Unauthorized", 401) return } + r.Body = http.MaxBytesReader(w, r.Body, 1<<20) var kb importer.KumaBackup if err := json.NewDecoder(r.Body).Decode(&kb); err != nil { log.Printf("Invalid Kuma JSON: %v", err) @@ -253,10 +260,11 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) { http.Error(w, "POST required", 405) return } - if cfg.ClusterKey == "" || r.Header.Get("X-Upkeep-Secret") != cfg.ClusterKey { + if cfg.ClusterKey == "" || !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) { http.Error(w, "Unauthorized", 401) return } + r.Body = http.MaxBytesReader(w, r.Body, 1<<20) var req struct { ID string `json:"id"` Name string `json:"name"` @@ -283,7 +291,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) { // 7. Probe Assignment Fetch mux.HandleFunc("/api/probe/assignments", func(w http.ResponseWriter, r *http.Request) { - if cfg.ClusterKey == "" || r.Header.Get("X-Upkeep-Secret") != cfg.ClusterKey { + if cfg.ClusterKey == "" || !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) { http.Error(w, "Unauthorized", 401) return } @@ -324,10 +332,11 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) { http.Error(w, "POST required", 405) return } - if cfg.ClusterKey == "" || r.Header.Get("X-Upkeep-Secret") != cfg.ClusterKey { + if cfg.ClusterKey == "" || !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) { http.Error(w, "Unauthorized", 401) return } + r.Body = http.MaxBytesReader(w, r.Body, 1<<20) var req struct { NodeID string `json:"node_id"` Results []struct { @@ -387,13 +396,15 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) { }) } + addr := fmt.Sprintf(":%d", cfg.Port) + srv := &http.Server{Addr: addr, Handler: mux} go func() { - addr := fmt.Sprintf(":%d", cfg.Port) fmt.Printf("HTTP Server listening on %s\n", addr) - if err := http.ListenAndServe(addr, mux); err != nil { - log.Fatalf("HTTP server failed: %v", err) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Printf("HTTP server error: %v", err) } }() + return srv } func renderStatusPage(w http.ResponseWriter, title string, eng *monitor.Engine) { diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..56a8d6a --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,553 @@ +package server + +import ( + "bytes" + "encoding/json" + "fmt" + "go-upkeep/internal/models" + "go-upkeep/internal/monitor" + "net" + "net/http" + "sync" + "testing" + "time" +) + +// --- Mock Store --- + +type mockStore struct { + mu sync.Mutex + sites []models.Site + alerts []models.AlertConfig + nodes map[string]models.ProbeNode + importedData *models.Backup + registeredNodes []models.ProbeNode + maintWindows []models.MaintenanceWindow +} + +func newMockStore() *mockStore { + return &mockStore{ + nodes: make(map[string]models.ProbeNode), + } +} + +func (m *mockStore) Init() error { return nil } +func (m *mockStore) GetSites() ([]models.Site, error) { return m.sites, nil } +func (m *mockStore) AddSite(models.Site) error { return nil } +func (m *mockStore) UpdateSite(models.Site) error { return nil } +func (m *mockStore) UpdateSitePaused(int, bool) error { return nil } +func (m *mockStore) DeleteSite(int) error { return nil } +func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { return m.alerts, nil } +func (m *mockStore) GetAlert(int) (models.AlertConfig, error) { return models.AlertConfig{}, nil } +func (m *mockStore) AddAlert(string, string, map[string]string) error { return nil } +func (m *mockStore) UpdateAlert(int, string, string, map[string]string) error { return nil } +func (m *mockStore) DeleteAlert(int) error { return nil } +func (m *mockStore) GetAllUsers() ([]models.User, error) { return nil, nil } +func (m *mockStore) AddUser(string, string, string) error { return nil } +func (m *mockStore) UpdateUser(int, string, string, string) error { return nil } +func (m *mockStore) DeleteUser(int) error { return nil } +func (m *mockStore) SaveCheck(int, int64, bool) error { return nil } +func (m *mockStore) SaveCheckFromNode(siteID int, nodeID string, latencyNs int64, isUp bool) error { + return nil +} +func (m *mockStore) LoadAllHistory(int) (map[int][]models.CheckRecord, error) { + return nil, nil +} +func (m *mockStore) GetSiteByName(string) (models.Site, error) { return models.Site{}, nil } +func (m *mockStore) GetAlertByName(string) (models.AlertConfig, error) { + return models.AlertConfig{}, nil +} +func (m *mockStore) AddSiteReturningID(models.Site) (int, error) { return 0, nil } +func (m *mockStore) AddAlertReturningID(string, string, map[string]string) (int, error) { + return 0, nil +} +func (m *mockStore) GetAllNodes() ([]models.ProbeNode, error) { return nil, nil } +func (m *mockStore) UpdateNodeLastSeen(string) error { return nil } +func (m *mockStore) DeleteNode(string) error { return nil } +func (m *mockStore) SaveLog(string) error { return nil } +func (m *mockStore) LoadLogs(int) ([]string, error) { return nil, nil } +func (m *mockStore) GetAllMaintenanceWindows(int) ([]models.MaintenanceWindow, error) { + return nil, nil +} +func (m *mockStore) AddMaintenanceWindow(models.MaintenanceWindow) error { return nil } +func (m *mockStore) EndMaintenanceWindow(int) error { return nil } +func (m *mockStore) DeleteMaintenanceWindow(int) error { return nil } +func (m *mockStore) IsMonitorInMaintenance(int) (bool, error) { return false, nil } +func (m *mockStore) GetPreference(string) (string, error) { return "", nil } +func (m *mockStore) SetPreference(string, string) error { return nil } +func (m *mockStore) Close() error { return nil } + +func (m *mockStore) ExportData() (models.Backup, error) { + return models.Backup{ + Sites: m.sites, + Alerts: m.alerts, + }, nil +} + +func (m *mockStore) ImportData(data models.Backup) error { + m.mu.Lock() + defer m.mu.Unlock() + m.importedData = &data + return nil +} + +func (m *mockStore) RegisterNode(node models.ProbeNode) error { + m.mu.Lock() + defer m.mu.Unlock() + m.registeredNodes = append(m.registeredNodes, node) + m.nodes[node.ID] = node + return nil +} + +func (m *mockStore) GetNode(id string) (models.ProbeNode, error) { + m.mu.Lock() + defer m.mu.Unlock() + if n, ok := m.nodes[id]; ok { + return n, nil + } + return models.ProbeNode{}, fmt.Errorf("not found") +} + +func (m *mockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) { + return m.maintWindows, nil +} + +// --- Helpers --- + +func freePort() int { + ln, _ := net.Listen("tcp", "127.0.0.1:0") + port := ln.Addr().(*net.TCPAddr).Port + ln.Close() + return port +} + +type testServer struct { + baseURL string + srv *http.Server + store *mockStore + engine *monitor.Engine +} + +func newTestServer(t *testing.T, clusterKey string, enableStatus bool) *testServer { + t.Helper() + ms := newMockStore() + eng := monitor.NewEngine(ms) + port := freePort() + + srv := Start(ServerConfig{ + Port: port, + EnableStatus: enableStatus, + Title: "Test Status", + ClusterKey: clusterKey, + }, ms, eng) + + ts := &testServer{ + baseURL: fmt.Sprintf("http://127.0.0.1:%d", port), + srv: srv, + store: ms, + engine: eng, + } + + // Wait for server to be ready + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + resp, err := http.Get(ts.baseURL + "/api/health") + if err == nil { + resp.Body.Close() + break + } + time.Sleep(10 * time.Millisecond) + } + + t.Cleanup(func() { + srv.Close() + }) + + return ts +} + +func authReq(method, url, secret string, body []byte) (*http.Response, error) { + var req *http.Request + var err error + if body != nil { + req, err = http.NewRequest(method, url, bytes.NewReader(body)) + } else { + req, err = http.NewRequest(method, url, nil) + } + if err != nil { + return nil, err + } + if secret != "" { + req.Header.Set("X-Upkeep-Secret", secret) + } + return http.DefaultClient.Do(req) +} + +// --- Tests --- + +func TestCheckSecret(t *testing.T) { + if !checkSecret("mykey", "mykey") { + t.Error("expected match") + } + if checkSecret("mykey", "wrong") { + t.Error("expected no match") + } + if checkSecret("", "key") { + t.Error("expected no match for empty got") + } +} + +// --- Push Heartbeat --- + +func TestPush_MissingToken(t *testing.T) { + ts := newTestServer(t, "secret", false) + resp, err := http.Get(ts.baseURL + "/api/push") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 400 { + t.Errorf("expected 400, got %d", resp.StatusCode) + } +} + +func TestPush_InvalidToken(t *testing.T) { + ts := newTestServer(t, "secret", false) + resp, err := http.Get(ts.baseURL + "/api/push?token=bad") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 404 { + t.Errorf("expected 404, got %d", resp.StatusCode) + } +} + +// --- Health --- + +func TestHealth_NoSecret(t *testing.T) { + ts := newTestServer(t, "", false) + resp, err := http.Get(ts.baseURL + "/api/health") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("expected 200 with no cluster key, got %d", resp.StatusCode) + } +} + +func TestHealth_ValidSecret(t *testing.T) { + ts := newTestServer(t, "secret", false) + resp, err := authReq("GET", ts.baseURL+"/api/health", "secret", nil) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } +} + +func TestHealth_WrongSecret(t *testing.T) { + ts := newTestServer(t, "secret", false) + resp, err := authReq("GET", ts.baseURL+"/api/health", "wrong", nil) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 401 { + t.Errorf("expected 401, got %d", resp.StatusCode) + } +} + +// --- Backup Export --- + +func TestExport_Unauthorized_NoKey(t *testing.T) { + ts := newTestServer(t, "", false) + resp, err := http.Get(ts.baseURL + "/api/backup/export") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 401 { + t.Errorf("expected 401 when no cluster key configured, got %d", resp.StatusCode) + } +} + +func TestExport_Unauthorized_WrongKey(t *testing.T) { + ts := newTestServer(t, "secret", false) + resp, err := authReq("GET", ts.baseURL+"/api/backup/export", "wrong", nil) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 401 { + t.Errorf("expected 401, got %d", resp.StatusCode) + } +} + +func TestExport_Success(t *testing.T) { + ts := newTestServer(t, "secret", false) + ts.store.sites = []models.Site{{ID: 1, Name: "example", URL: "http://example.com"}} + + resp, err := authReq("GET", ts.baseURL+"/api/backup/export", "secret", nil) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + var backup models.Backup + json.NewDecoder(resp.Body).Decode(&backup) + if len(backup.Sites) != 1 { + t.Errorf("expected 1 site, got %d", len(backup.Sites)) + } +} + +// --- Backup Import --- + +func TestImport_MethodNotAllowed(t *testing.T) { + ts := newTestServer(t, "secret", false) + resp, err := authReq("GET", ts.baseURL+"/api/backup/import", "secret", nil) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 405 { + t.Errorf("expected 405, got %d", resp.StatusCode) + } +} + +func TestImport_Unauthorized(t *testing.T) { + ts := newTestServer(t, "secret", false) + body, _ := json.Marshal(models.Backup{}) + resp, err := authReq("POST", ts.baseURL+"/api/backup/import", "wrong", body) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 401 { + t.Errorf("expected 401, got %d", resp.StatusCode) + } +} + +func TestImport_Success(t *testing.T) { + ts := newTestServer(t, "secret", false) + backup := models.Backup{ + Sites: []models.Site{{Name: "imported", URL: "http://example.com"}}, + } + body, _ := json.Marshal(backup) + resp, err := authReq("POST", ts.baseURL+"/api/backup/import", "secret", body) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + ts.store.mu.Lock() + defer ts.store.mu.Unlock() + if ts.store.importedData == nil { + t.Error("expected import data to be stored") + } +} + +func TestImport_InvalidJSON(t *testing.T) { + ts := newTestServer(t, "secret", false) + resp, err := authReq("POST", ts.baseURL+"/api/backup/import", "secret", []byte("not json")) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 400 { + t.Errorf("expected 400, got %d", resp.StatusCode) + } +} + +// --- Probe Registration --- + +func TestProbeRegister_Success(t *testing.T) { + ts := newTestServer(t, "secret", false) + body, _ := json.Marshal(map[string]string{ + "id": "node-1", "name": "US East", "region": "us-east", + }) + resp, err := authReq("POST", ts.baseURL+"/api/probe/register", "secret", body) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + ts.store.mu.Lock() + defer ts.store.mu.Unlock() + if len(ts.store.registeredNodes) != 1 { + t.Errorf("expected 1 registered node, got %d", len(ts.store.registeredNodes)) + } + if ts.store.registeredNodes[0].ID != "node-1" { + t.Errorf("expected node-1, got %s", ts.store.registeredNodes[0].ID) + } +} + +func TestProbeRegister_MissingID(t *testing.T) { + ts := newTestServer(t, "secret", false) + body, _ := json.Marshal(map[string]string{"name": "test"}) + resp, err := authReq("POST", ts.baseURL+"/api/probe/register", "secret", body) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 400 { + t.Errorf("expected 400, got %d", resp.StatusCode) + } +} + +func TestProbeRegister_Unauthorized(t *testing.T) { + ts := newTestServer(t, "secret", false) + body, _ := json.Marshal(map[string]string{"id": "node-1"}) + resp, err := authReq("POST", ts.baseURL+"/api/probe/register", "wrong", body) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 401 { + t.Errorf("expected 401, got %d", resp.StatusCode) + } +} + +// --- Probe Results --- + +func TestProbeResults_Success(t *testing.T) { + ts := newTestServer(t, "secret", false) + body, _ := json.Marshal(map[string]any{ + "node_id": "node-1", + "results": []map[string]any{ + {"site_id": 1, "latency_ns": 5000000, "is_up": true}, + }, + }) + resp, err := authReq("POST", ts.baseURL+"/api/probe/results", "secret", body) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } +} + +func TestProbeResults_MissingNodeID(t *testing.T) { + ts := newTestServer(t, "secret", false) + body, _ := json.Marshal(map[string]any{ + "results": []map[string]any{}, + }) + resp, err := authReq("POST", ts.baseURL+"/api/probe/results", "secret", body) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 400 { + t.Errorf("expected 400, got %d", resp.StatusCode) + } +} + +// --- Status Page --- + +func TestStatusPage_Enabled(t *testing.T) { + ts := newTestServer(t, "secret", true) + resp, err := http.Get(ts.baseURL + "/status") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } +} + +func TestStatusJSON_TokensStripped(t *testing.T) { + ts := newTestServer(t, "secret", true) + + // Inject a site with a token into engine state + ts.engine.UpdateSiteConfig(models.Site{ID: 1, Name: "test", Type: "push", Token: "secret-token", Status: "UP"}) + // Need to inject directly since UpdateSiteConfig only updates existing + func() { + ts.engine.RecordHeartbeat("unused") // just to exercise, won't match + }() + + resp, err := http.Get(ts.baseURL + "/status/json") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + var state map[string]models.Site + json.NewDecoder(resp.Body).Decode(&state) + for _, site := range state { + if site.Token != "" { + t.Error("expected token stripped from status JSON response") + } + } +} + +func TestStatusJSON_MaintenanceOverride(t *testing.T) { + ts := newTestServer(t, "secret", true) + ts.store.maintWindows = []models.MaintenanceWindow{ + {ID: 1, MonitorID: 0, Type: "maintenance", StartTime: time.Now().Add(-1 * time.Hour)}, + } + + resp, err := http.Get(ts.baseURL + "/status/json") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } +} + +func TestStatusPage_Disabled(t *testing.T) { + ts := newTestServer(t, "secret", false) + resp, err := http.Get(ts.baseURL + "/status") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 404 { + t.Errorf("expected 404 when status disabled, got %d", resp.StatusCode) + } +} + +// --- Probe Assignments --- + +func TestProbeAssignments_Success(t *testing.T) { + ts := newTestServer(t, "secret", false) + resp, err := authReq("GET", ts.baseURL+"/api/probe/assignments", "secret", nil) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + var result map[string][]models.Site + json.NewDecoder(resp.Body).Decode(&result) + if _, ok := result["sites"]; !ok { + t.Error("expected 'sites' key in response") + } +} + +func TestProbeAssignments_Unauthorized(t *testing.T) { + ts := newTestServer(t, "secret", false) + resp, err := authReq("GET", ts.baseURL+"/api/probe/assignments", "wrong", nil) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 401 { + t.Errorf("expected 401, got %d", resp.StatusCode) + } +} diff --git a/internal/store/sqlstore.go b/internal/store/sqlstore.go index ebb0aa1..63bc096 100644 --- a/internal/store/sqlstore.go +++ b/internal/store/sqlstore.go @@ -29,12 +29,16 @@ func (s *SQLStore) q(query string) string { return rewritePlaceholders(query, s.dollar) } -func generateToken() string { +func generateToken() (string, error) { b := make([]byte, 16) if _, err := rand.Read(b); err != nil { - panic("crypto/rand failed: " + err.Error()) + return "", fmt.Errorf("crypto/rand failed: %w", err) } - return hex.EncodeToString(b) + return hex.EncodeToString(b), nil +} + +func (s *SQLStore) Close() error { + return s.db.Close() } func (s *SQLStore) Init() error { @@ -77,7 +81,11 @@ func (s *SQLStore) GetSites() ([]models.Site, error) { func (s *SQLStore) AddSite(site models.Site) error { token := "" if site.Type == "push" { - token = generateToken() + var err error + token, err = generateToken() + if err != nil { + return fmt.Errorf("generate push token: %w", err) + } } _, err := s.db.Exec(s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, @@ -89,7 +97,11 @@ func (s *SQLStore) UpdateSite(site models.Site) error { var existingToken string s.db.QueryRow(s.q("SELECT token FROM sites WHERE id=?"), site.ID).Scan(&existingToken) if site.Type == "push" && existingToken == "" { - existingToken = generateToken() + var err error + existingToken, err = generateToken() + if err != nil { + return fmt.Errorf("generate push token: %w", err) + } } _, err := s.db.Exec(s.q("UPDATE sites SET name=?, url=?, type=?, token=?, interval=?, alert_id=?, check_ssl=?, threshold=?, max_retries=?, hostname=?, port=?, timeout=?, method=?, description=?, parent_id=?, accepted_codes=?, dns_resolve_type=?, dns_server=?, ignore_tls=?, paused=?, regions=? WHERE id=?"), site.Name, site.URL, site.Type, existingToken, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, @@ -132,7 +144,9 @@ func (s *SQLStore) GetAlertByName(name string) (models.AlertConfig, error) { if err != nil { return a, err } - json.Unmarshal([]byte(settingsJSON), &a.Settings) + if err := json.Unmarshal([]byte(settingsJSON), &a.Settings); err != nil { + return a, fmt.Errorf("unmarshal alert settings: %w", err) + } return a, nil } @@ -171,7 +185,9 @@ func (s *SQLStore) GetAllAlerts() ([]models.AlertConfig, error) { if err := rows.Scan(&a.ID, &a.Name, &a.Type, &settingsJSON); err != nil { return alerts, err } - json.Unmarshal([]byte(settingsJSON), &a.Settings) + if err := json.Unmarshal([]byte(settingsJSON), &a.Settings); err != nil { + return alerts, fmt.Errorf("unmarshal alert settings for %q: %w", a.Name, err) + } alerts = append(alerts, a) } return alerts, rows.Err() @@ -184,7 +200,9 @@ func (s *SQLStore) GetAlert(id int) (models.AlertConfig, error) { if err != nil { return a, err } - json.Unmarshal([]byte(settingsJSON), &a.Settings) + if err := json.Unmarshal([]byte(settingsJSON), &a.Settings); err != nil { + return a, fmt.Errorf("unmarshal alert settings: %w", err) + } return a, nil } diff --git a/internal/store/store.go b/internal/store/store.go index d83ca72..be562b0 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -64,4 +64,7 @@ type Store interface { // Backup & Restore ExportData() (models.Backup, error) ImportData(data models.Backup) error + + // Lifecycle + Close() error }