From 70a83a1da9cc612608930098a7839d16863591c6 Mon Sep 17 00:00:00 2001 From: Tyler Koenig Date: Thu, 11 Jun 2026 14:40:30 -0400 Subject: [PATCH 1/2] refactor(store): propagate context.Context through all Store methods Every Store interface method (except Close) now takes context.Context as first parameter. All 54 db.Query/Exec/QueryRow calls in SQLStore replaced with their *Context variants. DB operations now respect cancellation and deadlines. Context sources by caller: - Engine dbWriter/poll/pruner: engine ctx from Start() - HTTP handlers: r.Context() - config.Apply/Export: caller-provided ctx - TUI/main.go init: context.Background() RunCheck and all sub-checks (HTTP/ping/port/DNS) accept parent ctx. HTTP checks now inherit shutdown cancellation instead of rooting in context.Background(). dbWrite.exec takes ctx so the writer goroutine can cancel stuck DB operations. DeleteSite/ImportData use BeginTx(ctx) instead of Begin(). --- cmd/uptop/keycache_test.go | 7 +- cmd/uptop/main.go | 52 +++---- internal/cluster/cluster_test.go | 120 +++++++++------ internal/cluster/probe.go | 2 +- internal/config/apply.go | 34 +++-- internal/config/apply_test.go | 56 +++---- internal/config/export.go | 7 +- internal/config/export_test.go | 40 ++--- internal/metrics/prometheus_test.go | 122 +++++++++------- internal/monitor/checker.go | 20 +-- internal/monitor/checker_test.go | 21 +-- internal/monitor/dbwriter.go | 28 ++-- internal/monitor/history.go | 7 +- internal/monitor/monitor.go | 60 ++++---- internal/monitor/monitor_test.go | 140 +++++++++--------- internal/server/server.go | 14 +- internal/server/server_test.go | 122 +++++++++------- internal/store/sqlstore.go | 219 ++++++++++++++-------------- internal/store/sqlstore_test.go | 131 ++++++++--------- internal/store/store.go | 97 ++++++------ internal/tui/data.go | 12 +- internal/tui/tab_alerts.go | 5 +- internal/tui/tab_maint.go | 3 +- internal/tui/tab_sites.go | 5 +- internal/tui/tab_users.go | 5 +- internal/tui/tui.go | 3 +- internal/tui/update.go | 17 ++- internal/tui/update_test.go | 141 +++++++++++------- 28 files changed, 813 insertions(+), 677 deletions(-) diff --git a/cmd/uptop/keycache_test.go b/cmd/uptop/keycache_test.go index 23df5ac..dc10640 100644 --- a/cmd/uptop/keycache_test.go +++ b/cmd/uptop/keycache_test.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/ed25519" "crypto/rand" "errors" @@ -22,8 +23,8 @@ type kcMockStore struct { err error } -func (m *kcMockStore) GetAllUsers() ([]models.User, error) { return m.users, m.err } -func (m *kcMockStore) DeleteUser(int) error { return nil } +func (m *kcMockStore) GetAllUsers(_ context.Context) ([]models.User, error) { return m.users, m.err } +func (m *kcMockStore) DeleteUser(_ context.Context, _ int) error { return nil } func testKey(t *testing.T) (string, ssh.PublicKey) { t.Helper() @@ -103,7 +104,7 @@ func TestUserInvalidatingStore_DeleteDropsKeyCache(t *testing.T) { // Revoke the user; DB unreachable immediately after. The cached key must // be gone the moment the delete returns. - if err := s.DeleteUser(1); err != nil { + if err := s.DeleteUser(context.Background(), 1); err != nil { t.Fatal(err) } ms.users = nil diff --git a/cmd/uptop/main.go b/cmd/uptop/main.go index 1d5410b..619ea14 100644 --- a/cmd/uptop/main.go +++ b/cmd/uptop/main.go @@ -141,7 +141,7 @@ func openStore(dbType, dsn string) store.Store { } else { fmt.Println("WARNING: No UPTOP_ENCRYPTION_KEY set. Alert credentials stored unencrypted.") } - if err := ss.Init(); err != nil { + if err := ss.Init(context.Background()); err != nil { fmt.Fprintf(os.Stderr, "database init error: %v\n", err) os.Exit(1) } @@ -171,7 +171,7 @@ func runApply(args []string) { os.Exit(1) } - changes, err := config.Apply(s, f, config.ApplyOpts{ + changes, err := config.Apply(context.Background(), s, f, config.ApplyOpts{ DryRun: *dryRun, Prune: *prune, }) @@ -192,7 +192,7 @@ func runExport(args []string) { s := openStore(*dbType, *dsn) - f, err := config.Export(s) + f, err := config.Export(context.Background(), s) if err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) @@ -231,12 +231,12 @@ func runMigrateSecrets(args []string) { fmt.Fprintf(os.Stderr, "database error: %v\n", err) os.Exit(1) } - if err := ss.Init(); err != nil { + if err := ss.Init(context.Background()); err != nil { fmt.Fprintf(os.Stderr, "database init error: %v\n", err) os.Exit(1) } - alerts, err := ss.GetAllAlerts() + alerts, err := ss.GetAllAlerts(context.Background()) if err != nil { fmt.Fprintf(os.Stderr, "error loading alerts: %v\n", err) os.Exit(1) @@ -245,7 +245,7 @@ func runMigrateSecrets(args []string) { ss.SetEncryptor(enc) migrated := 0 for _, a := range alerts { - if err := ss.UpdateAlert(a.ID, a.Name, a.Type, a.Settings); err != nil { + if err := ss.UpdateAlert(context.Background(), a.ID, a.Name, a.Type, a.Settings); err != nil { fmt.Fprintf(os.Stderr, "error migrating alert %q: %v\n", a.Name, err) os.Exit(1) } @@ -378,7 +378,7 @@ func runServe(args []string) { kc := newKeyCache(ss) var s store.Store = &userInvalidatingStore{Store: ss, kc: kc} - if err := s.Init(); err != nil { + if err := s.Init(context.Background()); err != nil { fmt.Fprintf(os.Stderr, "database init error: %v\n", err) os.Exit(1) } @@ -395,7 +395,7 @@ func runServe(args []string) { os.Exit(1) } backup := importer.ConvertKuma(kb) - if err := s.ImportData(backup); err != nil { + if err := s.ImportData(context.Background(), backup); err != nil { fmt.Fprintf(os.Stderr, "import failed: %v\n", err) os.Exit(1) } @@ -515,21 +515,22 @@ func startSSHServer(port int, db store.Store, eng *monitor.Engine, kc *keyCache) } func seedDemoData(s store.Store) { - existing, _ := s.GetSites() + ctx := context.Background() + existing, _ := s.GetSites(ctx) if len(existing) > 0 { return } fmt.Println("Seeding demo data...") - if err := s.AddAlert("Discord Ops", "discord", map[string]string{"url": "https://discord.com/api/webhooks/demo/token"}); err != nil { + if err := s.AddAlert(ctx, "Discord Ops", "discord", map[string]string{"url": "https://discord.com/api/webhooks/demo/token"}); err != nil { log.Printf("demo seed: add alert: %v", err) return } - if err := s.AddAlert("Slack Infra", "slack", map[string]string{"url": "https://hooks.slack.com/services/DEMO/WEBHOOK"}); err != nil { + if err := s.AddAlert(ctx, "Slack Infra", "slack", map[string]string{"url": "https://hooks.slack.com/services/DEMO/WEBHOOK"}); err != nil { log.Printf("demo seed: add alert: %v", err) return } - if err := s.AddAlert("Email Oncall", "email", map[string]string{ + if err := s.AddAlert(ctx, "Email Oncall", "email", map[string]string{ "host": "smtp.example.com", "port": "587", "user": "oncall@example.com", "pass": "replace-me", "from": "oncall@example.com", "to": "team@example.com", @@ -538,7 +539,7 @@ func seedDemoData(s store.Store) { return } - alerts, _ := s.GetAllAlerts() + alerts, _ := s.GetAllAlerts(ctx) alertID := 0 if len(alerts) > 0 { alertID = alerts[0].ID @@ -557,7 +558,7 @@ func seedDemoData(s store.Store) { {Name: "SSH Server", Type: "port", Interval: 60, AlertID: alertID, Hostname: "10.0.0.1", Port: 22, Timeout: 5, ExpiryThreshold: 7}, } for _, site := range demoSites { - if err := s.AddSite(site); err != nil { + if err := s.AddSite(ctx, site); err != nil { log.Printf("demo seed: add site %q: %v", site.Name, err) } } @@ -576,7 +577,7 @@ func newKeyCache(db store.Store) *keyCache { } func (c *keyCache) refresh() { - users, err := c.db.GetAllUsers() + users, err := c.db.GetAllUsers(context.Background()) if err != nil { // Keep the previous key set: a transient DB error must not lock every // admin out. Revocation still fails closed because Invalidate clears @@ -637,31 +638,32 @@ type userInvalidatingStore struct { kc *keyCache } -func (s *userInvalidatingStore) AddUser(username, publicKey, role string) error { - err := s.Store.AddUser(username, publicKey, role) +func (s *userInvalidatingStore) AddUser(ctx context.Context, username, publicKey, role string) error { + err := s.Store.AddUser(ctx, username, publicKey, role) s.kc.Invalidate() return err } -func (s *userInvalidatingStore) UpdateUser(id int, username, publicKey, role string) error { - err := s.Store.UpdateUser(id, username, publicKey, role) +func (s *userInvalidatingStore) UpdateUser(ctx context.Context, id int, username, publicKey, role string) error { + err := s.Store.UpdateUser(ctx, id, username, publicKey, role) s.kc.Invalidate() return err } -func (s *userInvalidatingStore) DeleteUser(id int) error { - err := s.Store.DeleteUser(id) +func (s *userInvalidatingStore) DeleteUser(ctx context.Context, id int) error { + err := s.Store.DeleteUser(ctx, id) s.kc.Invalidate() return err } -func (s *userInvalidatingStore) ImportData(data models.Backup) error { - err := s.Store.ImportData(data) +func (s *userInvalidatingStore) ImportData(ctx context.Context, data models.Backup) error { + err := s.Store.ImportData(ctx, data) s.kc.Invalidate() return err } func seedKeysFromEnv(s store.Store) { + ctx := context.Background() var keys []string if v := os.Getenv("UPTOP_ADMIN_KEY"); v != "" { @@ -687,7 +689,7 @@ func seedKeysFromEnv(s store.Store) { return } - existing, err := s.GetAllUsers() + existing, err := s.GetAllUsers(ctx) if err != nil { fmt.Fprintf(os.Stderr, "warning: could not check existing users: %v\n", err) return @@ -705,7 +707,7 @@ func seedKeysFromEnv(s store.Store) { } username := usernameFromKey(key, i, len(existing)+added) - if err := s.AddUser(username, key, "admin"); err != nil { + if err := s.AddUser(ctx, username, key, "admin"); err != nil { fmt.Fprintf(os.Stderr, "warning: failed to seed user %q: %v\n", username, err) continue } diff --git a/internal/cluster/cluster_test.go b/internal/cluster/cluster_test.go index 0a702e5..a8ea1dd 100644 --- a/internal/cluster/cluster_test.go +++ b/internal/cluster/cluster_test.go @@ -20,64 +20,88 @@ type mockStore struct { sites []models.Site } -func (m *mockStore) Init() error { return nil } -func (m *mockStore) GetSites() ([]models.Site, error) { return m.sites, nil } -func (m *mockStore) AddSite(models.Site) error { return nil } -func (m *mockStore) UpdateSite(models.Site) error { return nil } -func (m *mockStore) UpdateSitePaused(int, bool) error { return nil } -func (m *mockStore) DeleteSite(int) error { return nil } -func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { return nil, nil } -func (m *mockStore) GetAlert(int) (models.AlertConfig, error) { return models.AlertConfig{}, nil } -func (m *mockStore) AddAlert(string, string, map[string]string) error { return nil } -func (m *mockStore) UpdateAlert(int, string, string, map[string]string) error { return nil } -func (m *mockStore) DeleteAlert(int) error { return nil } -func (m *mockStore) GetAllUsers() ([]models.User, error) { return nil, nil } -func (m *mockStore) AddUser(string, string, string) error { return nil } -func (m *mockStore) UpdateUser(int, string, string, string) error { return nil } -func (m *mockStore) DeleteUser(int) error { return nil } -func (m *mockStore) SaveCheck(int, int64, bool) error { return nil } -func (m *mockStore) SaveCheckFromNode(int, string, int64, bool) error { return nil } -func (m *mockStore) LoadAllHistory(int) (map[int][]models.CheckRecord, error) { return nil, nil } -func (m *mockStore) ExportData() (models.Backup, error) { return models.Backup{}, nil } -func (m *mockStore) ImportData(models.Backup) error { return nil } -func (m *mockStore) GetSiteByName(string) (models.Site, error) { return models.Site{}, nil } -func (m *mockStore) GetAlertByName(string) (models.AlertConfig, error) { +func (m *mockStore) Init(_ context.Context) error { return nil } +func (m *mockStore) GetSites(_ context.Context) ([]models.Site, error) { return m.sites, nil } +func (m *mockStore) AddSite(_ context.Context, _ models.Site) error { return nil } +func (m *mockStore) UpdateSite(_ context.Context, _ models.Site) error { return nil } +func (m *mockStore) UpdateSitePaused(_ context.Context, _ int, _ bool) error { return nil } +func (m *mockStore) DeleteSite(_ context.Context, _ int) error { return nil } +func (m *mockStore) GetAllAlerts(_ context.Context) ([]models.AlertConfig, error) { return nil, nil } +func (m *mockStore) GetAlert(_ context.Context, _ int) (models.AlertConfig, error) { return models.AlertConfig{}, nil } -func (m *mockStore) AddSiteReturningID(models.Site) (int, error) { return 0, nil } -func (m *mockStore) AddAlertReturningID(string, string, map[string]string) (int, error) { +func (m *mockStore) AddAlert(_ context.Context, _ string, _ string, _ map[string]string) error { + return nil +} +func (m *mockStore) UpdateAlert(_ context.Context, _ int, _ string, _ string, _ map[string]string) error { + return nil +} +func (m *mockStore) DeleteAlert(_ context.Context, _ int) error { return nil } +func (m *mockStore) GetAllUsers(_ context.Context) ([]models.User, error) { return nil, nil } +func (m *mockStore) AddUser(_ context.Context, _ string, _ string, _ string) error { return nil } +func (m *mockStore) UpdateUser(_ context.Context, _ int, _ string, _ string, _ string) error { + return nil +} +func (m *mockStore) DeleteUser(_ context.Context, _ int) error { return nil } +func (m *mockStore) SaveCheck(_ context.Context, _ int, _ int64, _ bool) error { return nil } +func (m *mockStore) SaveCheckFromNode(_ context.Context, _ int, _ string, _ int64, _ bool) error { + return nil +} +func (m *mockStore) LoadAllHistory(_ context.Context, _ int) (map[int][]models.CheckRecord, error) { + return nil, nil +} +func (m *mockStore) ExportData(_ context.Context) (models.Backup, error) { return models.Backup{}, nil } +func (m *mockStore) ImportData(_ context.Context, _ models.Backup) error { return nil } +func (m *mockStore) GetSiteByName(_ context.Context, _ string) (models.Site, error) { + return models.Site{}, nil +} +func (m *mockStore) GetAlertByName(_ context.Context, _ string) (models.AlertConfig, error) { + return models.AlertConfig{}, nil +} +func (m *mockStore) AddSiteReturningID(_ context.Context, _ models.Site) (int, error) { return 0, nil } +func (m *mockStore) AddAlertReturningID(_ context.Context, _ string, _ string, _ map[string]string) (int, error) { return 0, nil } -func (m *mockStore) RegisterNode(models.ProbeNode) error { return nil } -func (m *mockStore) GetNode(string) (models.ProbeNode, error) { return models.ProbeNode{}, nil } -func (m *mockStore) GetAllNodes() ([]models.ProbeNode, error) { return nil, nil } -func (m *mockStore) UpdateNodeLastSeen(string) error { return nil } -func (m *mockStore) DeleteNode(string) error { return nil } -func (m *mockStore) LoadAlertHealth() (map[int]models.AlertHealthRecord, error) { +func (m *mockStore) RegisterNode(_ context.Context, _ models.ProbeNode) error { return nil } +func (m *mockStore) GetNode(_ context.Context, _ string) (models.ProbeNode, error) { + return models.ProbeNode{}, nil +} +func (m *mockStore) GetAllNodes(_ context.Context) ([]models.ProbeNode, error) { return nil, nil } +func (m *mockStore) UpdateNodeLastSeen(_ context.Context, _ string) error { return nil } +func (m *mockStore) DeleteNode(_ context.Context, _ string) error { return nil } +func (m *mockStore) LoadAlertHealth(_ context.Context) (map[int]models.AlertHealthRecord, error) { return nil, nil } -func (m *mockStore) SaveAlertHealth(models.AlertHealthRecord) error { return nil } -func (m *mockStore) SaveLog(string) error { return nil } -func (m *mockStore) PruneLogs() error { return nil } -func (m *mockStore) PruneCheckHistory() error { return nil } -func (m *mockStore) PruneStateChanges() error { return nil } -func (m *mockStore) LoadLogs(int) ([]string, error) { return nil, nil } -func (m *mockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) { +func (m *mockStore) SaveAlertHealth(_ context.Context, _ models.AlertHealthRecord) error { return nil } +func (m *mockStore) SaveLog(_ context.Context, _ string) error { return nil } +func (m *mockStore) PruneLogs(_ context.Context) error { return nil } +func (m *mockStore) PruneCheckHistory(_ context.Context) error { return nil } +func (m *mockStore) PruneStateChanges(_ context.Context) error { return nil } +func (m *mockStore) LoadLogs(_ context.Context, _ int) ([]string, error) { return nil, nil } +func (m *mockStore) GetActiveMaintenanceWindows(_ context.Context) ([]models.MaintenanceWindow, error) { return nil, nil } -func (m *mockStore) GetAllMaintenanceWindows(int) ([]models.MaintenanceWindow, error) { +func (m *mockStore) GetAllMaintenanceWindows(_ context.Context, _ int) ([]models.MaintenanceWindow, error) { return nil, nil } -func (m *mockStore) AddMaintenanceWindow(models.MaintenanceWindow) error { return nil } -func (m *mockStore) EndMaintenanceWindow(int) error { return nil } -func (m *mockStore) DeleteMaintenanceWindow(int) error { return nil } -func (m *mockStore) PruneExpiredMaintenanceWindows(time.Duration) (int64, error) { return 0, nil } -func (m *mockStore) IsMonitorInMaintenance(int) (bool, error) { return false, nil } -func (m *mockStore) GetPreference(string) (string, error) { return "", nil } -func (m *mockStore) SetPreference(string, string) error { return nil } -func (m *mockStore) SaveStateChange(int, string, string, string) error { return nil } -func (m *mockStore) GetStateChanges(int, int) ([]models.StateChange, error) { return nil, nil } -func (m *mockStore) GetStateChangesSince(int, time.Time) ([]models.StateChange, error) { +func (m *mockStore) AddMaintenanceWindow(_ context.Context, _ models.MaintenanceWindow) error { + return nil +} +func (m *mockStore) EndMaintenanceWindow(_ context.Context, _ int) error { return nil } +func (m *mockStore) DeleteMaintenanceWindow(_ context.Context, _ int) error { return nil } +func (m *mockStore) PruneExpiredMaintenanceWindows(_ context.Context, _ time.Duration) (int64, error) { + return 0, nil +} +func (m *mockStore) IsMonitorInMaintenance(_ context.Context, _ int) (bool, error) { return false, nil } +func (m *mockStore) GetPreference(_ context.Context, _ string) (string, error) { return "", nil } +func (m *mockStore) SetPreference(_ context.Context, _ string, _ string) error { return nil } +func (m *mockStore) SaveStateChange(_ context.Context, _ int, _ string, _ string, _ string) error { + return nil +} +func (m *mockStore) GetStateChanges(_ context.Context, _ int, _ int) ([]models.StateChange, error) { + return nil, nil +} +func (m *mockStore) GetStateChangesSince(_ context.Context, _ int, _ time.Time) ([]models.StateChange, error) { return nil, nil } func (m *mockStore) Close() error { return nil } diff --git a/internal/cluster/probe.go b/internal/cluster/probe.go index de6ea2a..41828ae 100644 --- a/internal/cluster/probe.go +++ b/internal/cluster/probe.go @@ -152,7 +152,7 @@ loop: defer wg.Done() defer func() { <-sem }() - cr := monitor.RunCheck(s, strict, insecure, false, allowPrivate) + cr := monitor.RunCheck(ctx, s, strict, insecure, false, allowPrivate) mu.Lock() results = append(results, probeResultItem{ SiteID: s.ID, diff --git a/internal/config/apply.go b/internal/config/apply.go index 616c4ba..e5a9bfa 100644 --- a/internal/config/apply.go +++ b/internal/config/apply.go @@ -1,11 +1,13 @@ package config import ( + "context" "fmt" - "gitea.lerkolabs.com/lerkolabs/uptop/internal/models" - "gitea.lerkolabs.com/lerkolabs/uptop/internal/store" "reflect" "strings" + + "gitea.lerkolabs.com/lerkolabs/uptop/internal/models" + "gitea.lerkolabs.com/lerkolabs/uptop/internal/store" ) type ApplyOpts struct { @@ -20,17 +22,17 @@ type Change struct { Details string } -func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) { +func Apply(ctx context.Context, s store.Store, f *File, opts ApplyOpts) ([]Change, error) { if err := Validate(f); err != nil { return nil, err } - existingAlerts, err := s.GetAllAlerts() + existingAlerts, err := s.GetAllAlerts(ctx) if err != nil { return nil, fmt.Errorf("load alerts: %w", err) } - existingSites, err := s.GetSites() + existingSites, err := s.GetSites(ctx) if err != nil { return nil, fmt.Errorf("load sites: %w", err) } @@ -59,7 +61,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) { if !exists { changes = append(changes, Change{Action: "create", Kind: "alert", Name: a.Name, Details: a.Type}) if !opts.DryRun { - id, err := s.AddAlertReturningID(a.Name, a.Type, a.Settings) + id, err := s.AddAlertReturningID(ctx, a.Name, a.Type, a.Settings) if err != nil { return changes, fmt.Errorf("create alert %q: %w", a.Name, err) } @@ -70,7 +72,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) { if diff := diffAlert(existing, a); diff != "" { changes = append(changes, Change{Action: "update", Kind: "alert", Name: a.Name, Details: diff}) if !opts.DryRun { - if err := s.UpdateAlert(existing.ID, a.Name, a.Type, a.Settings); err != nil { + if err := s.UpdateAlert(ctx, existing.ID, a.Name, a.Type, a.Settings); err != nil { return changes, fmt.Errorf("update alert %q: %w", a.Name, err) } } @@ -102,7 +104,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) { if !exists { changes = append(changes, Change{Action: "create", Kind: "monitor", Name: g.Name, Details: "group"}) if !opts.DryRun { - id, err := s.AddSiteReturningID(site) + id, err := s.AddSiteReturningID(ctx, site) if err != nil { return changes, fmt.Errorf("create group %q: %w", g.Name, err) } @@ -114,7 +116,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) { if diff := diffSite(normalizeSite(existing), site); diff != "" { changes = append(changes, Change{Action: "update", Kind: "monitor", Name: g.Name, Details: diff}) if !opts.DryRun { - if err := s.UpdateSite(site); err != nil { + if err := s.UpdateSite(ctx, site); err != nil { return changes, fmt.Errorf("update group %q: %w", g.Name, err) } } @@ -125,7 +127,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) { for _, g := range groups { parentID := groupMap[g.Name] for _, child := range g.Monitors { - c, err := applyMonitor(s, child, alertMap, existingSitesByName, parentID, opts.DryRun) + c, err := applyMonitor(ctx, s, child, alertMap, existingSitesByName, parentID, opts.DryRun) if err != nil { return changes, err } @@ -134,7 +136,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) { } for _, m := range topLevel { - c, err := applyMonitor(s, m, alertMap, existingSitesByName, 0, opts.DryRun) + c, err := applyMonitor(ctx, s, m, alertMap, existingSitesByName, 0, opts.DryRun) if err != nil { return changes, err } @@ -155,7 +157,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) { childDeletes = append(childDeletes, c) } if !opts.DryRun { - if err := s.DeleteSite(es.ID); err != nil { + if err := s.DeleteSite(ctx, es.ID); err != nil { return changes, fmt.Errorf("delete monitor %q: %w", es.Name, err) } } @@ -169,7 +171,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) { } changes = append(changes, Change{Action: "delete", Kind: "alert", Name: ea.Name, Details: ea.Type}) if !opts.DryRun { - if err := s.DeleteAlert(ea.ID); err != nil { + if err := s.DeleteAlert(ctx, ea.ID); err != nil { return changes, fmt.Errorf("delete alert %q: %w", ea.Name, err) } } @@ -179,7 +181,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) { return changes, nil } -func applyMonitor(s store.Store, m Monitor, alertMap map[string]int, existing map[string]models.Site, parentID int, dryRun bool) ([]Change, error) { +func applyMonitor(ctx context.Context, s store.Store, m Monitor, alertMap map[string]int, existing map[string]models.Site, parentID int, dryRun bool) ([]Change, error) { alertID, err := resolveAlertID(alertMap, m.Alert) if err != nil { return nil, fmt.Errorf("monitor %q: %w", m.Name, err) @@ -191,7 +193,7 @@ func applyMonitor(s store.Store, m Monitor, alertMap map[string]int, existing ma if !exists { changes = append(changes, Change{Action: "create", Kind: "monitor", Name: m.Name, Details: m.Type}) if !dryRun { - if _, err := s.AddSiteReturningID(site); err != nil { + if _, err := s.AddSiteReturningID(ctx, site); err != nil { return changes, fmt.Errorf("create monitor %q: %w", m.Name, err) } } @@ -200,7 +202,7 @@ func applyMonitor(s store.Store, m Monitor, alertMap map[string]int, existing ma if diff := diffSite(normalizeSite(ex), site); diff != "" { changes = append(changes, Change{Action: "update", Kind: "monitor", Name: m.Name, Details: diff}) if !dryRun { - if err := s.UpdateSite(site); err != nil { + if err := s.UpdateSite(ctx, site); err != nil { return changes, fmt.Errorf("update monitor %q: %w", m.Name, err) } } diff --git a/internal/config/apply_test.go b/internal/config/apply_test.go index 8e7b636..6a661b2 100644 --- a/internal/config/apply_test.go +++ b/internal/config/apply_test.go @@ -1,10 +1,12 @@ package config import ( - "gitea.lerkolabs.com/lerkolabs/uptop/internal/models" - "gitea.lerkolabs.com/lerkolabs/uptop/internal/store" + "context" "strings" "testing" + + "gitea.lerkolabs.com/lerkolabs/uptop/internal/models" + "gitea.lerkolabs.com/lerkolabs/uptop/internal/store" ) func newTestStore(t *testing.T) store.Store { @@ -13,7 +15,7 @@ func newTestStore(t *testing.T) store.Store { if err != nil { t.Fatalf("NewSQLiteStore: %v", err) } - if err := s.Init(); err != nil { + if err := s.Init(context.Background()); err != nil { t.Fatalf("Init: %v", err) } return s @@ -31,7 +33,7 @@ func TestApplyCreateFromScratch(t *testing.T) { }, } - changes, err := Apply(s, f, ApplyOpts{}) + changes, err := Apply(context.Background(), s, f, ApplyOpts{}) if err != nil { t.Fatalf("Apply: %v", err) } @@ -46,12 +48,12 @@ func TestApplyCreateFromScratch(t *testing.T) { t.Fatalf("expected 3 creates, got %d", creates) } - sites, _ := s.GetSites() + sites, _ := s.GetSites(context.Background()) if len(sites) != 2 { t.Fatalf("expected 2 sites, got %d", len(sites)) } - alerts, _ := s.GetAllAlerts() + alerts, _ := s.GetAllAlerts(context.Background()) if len(alerts) != 1 { t.Fatalf("expected 1 alert, got %d", len(alerts)) } @@ -68,11 +70,11 @@ func TestApplyIdempotent(t *testing.T) { }, } - if _, err := Apply(s, f, ApplyOpts{}); err != nil { + if _, err := Apply(context.Background(), s, f, ApplyOpts{}); err != nil { t.Fatalf("first Apply: %v", err) } - changes, err := Apply(s, f, ApplyOpts{}) + changes, err := Apply(context.Background(), s, f, ApplyOpts{}) if err != nil { t.Fatalf("second Apply: %v", err) } @@ -90,12 +92,12 @@ func TestApplyUpdate(t *testing.T) { }, } - if _, err := Apply(s, f, ApplyOpts{}); err != nil { + if _, err := Apply(context.Background(), s, f, ApplyOpts{}); err != nil { t.Fatalf("first Apply: %v", err) } f.Monitors[0].Interval = 60 - changes, err := Apply(s, f, ApplyOpts{}) + changes, err := Apply(context.Background(), s, f, ApplyOpts{}) if err != nil { t.Fatalf("second Apply: %v", err) } @@ -104,7 +106,7 @@ func TestApplyUpdate(t *testing.T) { t.Fatalf("expected 1 update, got %+v", changes) } - sites, _ := s.GetSites() + sites, _ := s.GetSites(context.Background()) if sites[0].Interval != 60 { t.Fatalf("expected interval 60, got %d", sites[0].Interval) } @@ -112,8 +114,8 @@ func TestApplyUpdate(t *testing.T) { func TestApplyPrune(t *testing.T) { s := newTestStore(t) - s.AddSite(models.Site{Name: "Keep", URL: "https://keep.com", Type: "http", Interval: 30, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) - s.AddSite(models.Site{Name: "Remove", URL: "https://remove.com", Type: "http", Interval: 30, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) + s.AddSite(context.Background(), models.Site{Name: "Keep", URL: "https://keep.com", Type: "http", Interval: 30, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) + s.AddSite(context.Background(), models.Site{Name: "Remove", URL: "https://remove.com", Type: "http", Interval: 30, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) f := &File{ Monitors: []Monitor{ @@ -121,7 +123,7 @@ func TestApplyPrune(t *testing.T) { }, } - changes, err := Apply(s, f, ApplyOpts{Prune: true}) + changes, err := Apply(context.Background(), s, f, ApplyOpts{Prune: true}) if err != nil { t.Fatalf("Apply: %v", err) } @@ -136,7 +138,7 @@ func TestApplyPrune(t *testing.T) { t.Fatalf("expected 1 delete, got %d", deleteCount) } - sites, _ := s.GetSites() + sites, _ := s.GetSites(context.Background()) if len(sites) != 1 || sites[0].Name != "Keep" { t.Fatalf("expected only 'Keep', got %+v", sites) } @@ -150,7 +152,7 @@ func TestApplyDryRun(t *testing.T) { }, } - changes, err := Apply(s, f, ApplyOpts{DryRun: true}) + changes, err := Apply(context.Background(), s, f, ApplyOpts{DryRun: true}) if err != nil { t.Fatalf("Apply: %v", err) } @@ -159,7 +161,7 @@ func TestApplyDryRun(t *testing.T) { t.Fatalf("expected 1 create in dry-run, got %+v", changes) } - sites, _ := s.GetSites() + sites, _ := s.GetSites(context.Background()) if len(sites) != 0 { t.Fatalf("expected 0 sites after dry-run, got %d", len(sites)) } @@ -179,7 +181,7 @@ func TestApplyGroupHierarchy(t *testing.T) { }, } - changes, err := Apply(s, f, ApplyOpts{}) + changes, err := Apply(context.Background(), s, f, ApplyOpts{}) if err != nil { t.Fatalf("Apply: %v", err) } @@ -188,7 +190,7 @@ func TestApplyGroupHierarchy(t *testing.T) { t.Fatalf("expected 3 creates, got %d", len(changes)) } - sites, _ := s.GetSites() + sites, _ := s.GetSites(context.Background()) var group models.Site for _, s := range sites { if s.Type == "group" { @@ -223,12 +225,12 @@ func TestApplyAlertReference(t *testing.T) { }, } - if _, err := Apply(s, f, ApplyOpts{}); err != nil { + if _, err := Apply(context.Background(), s, f, ApplyOpts{}); err != nil { t.Fatalf("Apply: %v", err) } - sites, _ := s.GetSites() - alerts, _ := s.GetAllAlerts() + sites, _ := s.GetSites(context.Background()) + alerts, _ := s.GetAllAlerts(context.Background()) if sites[0].AlertID != alerts[0].ID { t.Fatalf("expected alert_id %d, got %d", alerts[0].ID, sites[0].AlertID) @@ -243,7 +245,7 @@ func TestApplyInvalidAlertRef(t *testing.T) { }, } - _, err := Apply(s, f, ApplyOpts{}) + _, err := Apply(context.Background(), s, f, ApplyOpts{}) if err == nil || !strings.Contains(err.Error(), "not found") { t.Fatalf("expected alert not found error, got %v", err) } @@ -258,7 +260,7 @@ func TestApplyDuplicateNames(t *testing.T) { }, } - _, err := Apply(s, f, ApplyOpts{}) + _, err := Apply(context.Background(), s, f, ApplyOpts{}) if err == nil || !strings.Contains(err.Error(), "duplicate") { t.Fatalf("expected duplicate error, got %v", err) } @@ -266,7 +268,7 @@ func TestApplyDuplicateNames(t *testing.T) { func TestApplyExistingAlertReference(t *testing.T) { s := newTestStore(t) - s.AddAlert("Existing", "webhook", map[string]string{"url": "https://example.com"}) + s.AddAlert(context.Background(), "Existing", "webhook", map[string]string{"url": "https://example.com"}) f := &File{ Monitors: []Monitor{ @@ -274,7 +276,7 @@ func TestApplyExistingAlertReference(t *testing.T) { }, } - changes, err := Apply(s, f, ApplyOpts{}) + changes, err := Apply(context.Background(), s, f, ApplyOpts{}) if err != nil { t.Fatalf("Apply: %v", err) } @@ -283,7 +285,7 @@ func TestApplyExistingAlertReference(t *testing.T) { t.Fatalf("expected 1 create, got %+v", changes) } - sites, _ := s.GetSites() + sites, _ := s.GetSites(context.Background()) if sites[0].AlertID == 0 { t.Fatal("expected non-zero alert_id for existing alert reference") } diff --git a/internal/config/export.go b/internal/config/export.go index 83a2cc8..6902e84 100644 --- a/internal/config/export.go +++ b/internal/config/export.go @@ -1,6 +1,7 @@ package config import ( + "context" "fmt" "os" "sort" @@ -11,13 +12,13 @@ import ( "gopkg.in/yaml.v3" ) -func Export(s store.Store) (*File, error) { - dbAlerts, err := s.GetAllAlerts() +func Export(ctx context.Context, s store.Store) (*File, error) { + dbAlerts, err := s.GetAllAlerts(ctx) if err != nil { return nil, fmt.Errorf("load alerts: %w", err) } - dbSites, err := s.GetSites() + dbSites, err := s.GetSites(ctx) if err != nil { return nil, fmt.Errorf("load sites: %w", err) } diff --git a/internal/config/export_test.go b/internal/config/export_test.go index 66afce4..da2ce1c 100644 --- a/internal/config/export_test.go +++ b/internal/config/export_test.go @@ -1,13 +1,15 @@ package config import ( - "gitea.lerkolabs.com/lerkolabs/uptop/internal/models" + "context" "testing" + + "gitea.lerkolabs.com/lerkolabs/uptop/internal/models" ) func TestExportEmpty(t *testing.T) { s := newTestStore(t) - f, err := Export(s) + f, err := Export(context.Background(), s) if err != nil { t.Fatalf("Export: %v", err) } @@ -18,11 +20,11 @@ func TestExportEmpty(t *testing.T) { func TestExportAlertNames(t *testing.T) { s := newTestStore(t) - s.AddAlert("Discord", "discord", map[string]string{"url": "https://example.com"}) - alerts, _ := s.GetAllAlerts() - s.AddSite(models.Site{Name: "Web", URL: "https://example.com", Type: "http", Interval: 30, AlertID: alerts[0].ID, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) + s.AddAlert(context.Background(), "Discord", "discord", map[string]string{"url": "https://example.com"}) + alerts, _ := s.GetAllAlerts(context.Background()) + s.AddSite(context.Background(), models.Site{Name: "Web", URL: "https://example.com", Type: "http", Interval: 30, AlertID: alerts[0].ID, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) - f, err := Export(s) + f, err := Export(context.Background(), s) if err != nil { t.Fatalf("Export: %v", err) } @@ -37,11 +39,11 @@ func TestExportAlertNames(t *testing.T) { func TestExportGroupHierarchy(t *testing.T) { s := newTestStore(t) - groupID, _ := s.AddSiteReturningID(models.Site{Name: "Prod", Type: "group", ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) - s.AddSite(models.Site{Name: "Prod Web", URL: "https://prod.example.com", Type: "http", Interval: 15, ParentID: groupID, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) - s.AddSite(models.Site{Name: "Top Level", URL: "https://example.com", Type: "http", Interval: 30, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) + groupID, _ := s.AddSiteReturningID(context.Background(), models.Site{Name: "Prod", Type: "group", ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) + s.AddSite(context.Background(), models.Site{Name: "Prod Web", URL: "https://prod.example.com", Type: "http", Interval: 15, ParentID: groupID, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) + s.AddSite(context.Background(), models.Site{Name: "Top Level", URL: "https://example.com", Type: "http", Interval: 30, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) - f, err := Export(s) + f, err := Export(context.Background(), s) if err != nil { t.Fatalf("Export: %v", err) } @@ -70,12 +72,12 @@ func TestExportGroupHierarchy(t *testing.T) { func TestExportOmitsDefaults(t *testing.T) { s := newTestStore(t) - s.AddSite(models.Site{ + s.AddSite(context.Background(), models.Site{ Name: "Web", URL: "https://example.com", Type: "http", Interval: 30, Method: "GET", AcceptedCodes: "200-299", ExpiryThreshold: 7, }) - f, err := Export(s) + f, err := Export(context.Background(), s) if err != nil { t.Fatalf("Export: %v", err) } @@ -94,18 +96,18 @@ func TestExportOmitsDefaults(t *testing.T) { func TestExportRoundTrip(t *testing.T) { s1 := newTestStore(t) - s1.AddAlert("Discord", "discord", map[string]string{"url": "https://example.com"}) - alerts, _ := s1.GetAllAlerts() - s1.AddSite(models.Site{Name: "Web", URL: "https://example.com", Type: "http", Interval: 30, AlertID: alerts[0].ID, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) - s1.AddSite(models.Site{Name: "Ping", Type: "ping", Hostname: "10.0.0.1", Interval: 60, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) + s1.AddAlert(context.Background(), "Discord", "discord", map[string]string{"url": "https://example.com"}) + alerts, _ := s1.GetAllAlerts(context.Background()) + s1.AddSite(context.Background(), models.Site{Name: "Web", URL: "https://example.com", Type: "http", Interval: 30, AlertID: alerts[0].ID, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) + s1.AddSite(context.Background(), models.Site{Name: "Ping", Type: "ping", Hostname: "10.0.0.1", Interval: 60, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) - exported, err := Export(s1) + exported, err := Export(context.Background(), s1) if err != nil { t.Fatalf("Export: %v", err) } s2 := newTestStore(t) - changes, err := Apply(s2, exported, ApplyOpts{}) + changes, err := Apply(context.Background(), s2, exported, ApplyOpts{}) if err != nil { t.Fatalf("Apply: %v", err) } @@ -120,7 +122,7 @@ func TestExportRoundTrip(t *testing.T) { t.Fatalf("expected 3 creates, got %d", creates) } - reexported, err := Export(s2) + reexported, err := Export(context.Background(), s2) if err != nil { t.Fatalf("re-Export: %v", err) } diff --git a/internal/metrics/prometheus_test.go b/internal/metrics/prometheus_test.go index 77cbbd8..9352bc4 100644 --- a/internal/metrics/prometheus_test.go +++ b/internal/metrics/prometheus_test.go @@ -16,66 +16,88 @@ type mockStore struct { sites []models.Site } -func (m *mockStore) Init() error { return nil } -func (m *mockStore) GetSites() ([]models.Site, error) { return m.sites, nil } -func (m *mockStore) AddSite(models.Site) error { return nil } -func (m *mockStore) UpdateSite(models.Site) error { return nil } -func (m *mockStore) UpdateSitePaused(int, bool) error { return nil } -func (m *mockStore) DeleteSite(int) error { return nil } -func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { return nil, nil } -func (m *mockStore) GetAlert(int) (models.AlertConfig, error) { return models.AlertConfig{}, nil } -func (m *mockStore) AddAlert(string, string, map[string]string) error { return nil } -func (m *mockStore) UpdateAlert(int, string, string, map[string]string) error { return nil } -func (m *mockStore) DeleteAlert(int) error { return nil } -func (m *mockStore) GetAllUsers() ([]models.User, error) { return nil, nil } -func (m *mockStore) AddUser(string, string, string) error { return nil } -func (m *mockStore) UpdateUser(int, string, string, string) error { return nil } -func (m *mockStore) DeleteUser(int) error { return nil } -func (m *mockStore) SaveCheck(int, int64, bool) error { return nil } -func (m *mockStore) LoadAllHistory(int) (map[int][]models.CheckRecord, error) { - return nil, nil -} -func (m *mockStore) ExportData() (models.Backup, error) { return models.Backup{}, nil } -func (m *mockStore) ImportData(models.Backup) error { return nil } -func (m *mockStore) GetSiteByName(string) (models.Site, error) { return models.Site{}, nil } -func (m *mockStore) GetAlertByName(string) (models.AlertConfig, error) { +func (m *mockStore) Init(_ context.Context) error { return nil } +func (m *mockStore) GetSites(_ context.Context) ([]models.Site, error) { return m.sites, nil } +func (m *mockStore) AddSite(_ context.Context, _ models.Site) error { return nil } +func (m *mockStore) UpdateSite(_ context.Context, _ models.Site) error { return nil } +func (m *mockStore) UpdateSitePaused(_ context.Context, _ int, _ bool) error { return nil } +func (m *mockStore) DeleteSite(_ context.Context, _ int) error { return nil } +func (m *mockStore) GetAllAlerts(_ context.Context) ([]models.AlertConfig, error) { return nil, nil } +func (m *mockStore) GetAlert(_ context.Context, _ int) (models.AlertConfig, error) { return models.AlertConfig{}, nil } -func (m *mockStore) AddSiteReturningID(models.Site) (int, error) { return 0, nil } -func (m *mockStore) AddAlertReturningID(string, string, map[string]string) (int, error) { +func (m *mockStore) AddAlert(_ context.Context, _ string, _ string, _ map[string]string) error { + return nil +} +func (m *mockStore) UpdateAlert(_ context.Context, _ int, _ string, _ string, _ map[string]string) error { + return nil +} +func (m *mockStore) DeleteAlert(_ context.Context, _ int) error { return nil } +func (m *mockStore) GetAllUsers(_ context.Context) ([]models.User, error) { return nil, nil } +func (m *mockStore) AddUser(_ context.Context, _ string, _ string, _ string) error { return nil } +func (m *mockStore) UpdateUser(_ context.Context, _ int, _ string, _ string, _ string) error { + return nil +} +func (m *mockStore) DeleteUser(_ context.Context, _ int) error { return nil } +func (m *mockStore) SaveCheck(_ context.Context, _ int, _ int64, _ bool) error { return nil } +func (m *mockStore) LoadAllHistory(_ context.Context, _ int) (map[int][]models.CheckRecord, error) { + return nil, nil +} +func (m *mockStore) ExportData(_ context.Context) (models.Backup, error) { return models.Backup{}, nil } +func (m *mockStore) ImportData(_ context.Context, _ models.Backup) error { return nil } +func (m *mockStore) GetSiteByName(_ context.Context, _ string) (models.Site, error) { + return models.Site{}, nil +} +func (m *mockStore) GetAlertByName(_ context.Context, _ string) (models.AlertConfig, error) { + return models.AlertConfig{}, nil +} +func (m *mockStore) AddSiteReturningID(_ context.Context, _ models.Site) (int, error) { return 0, nil } +func (m *mockStore) AddAlertReturningID(_ context.Context, _ string, _ string, _ map[string]string) (int, error) { return 0, nil } -func (m *mockStore) SaveCheckFromNode(int, string, int64, bool) error { return nil } -func (m *mockStore) RegisterNode(models.ProbeNode) error { return nil } -func (m *mockStore) GetNode(string) (models.ProbeNode, error) { return models.ProbeNode{}, nil } -func (m *mockStore) GetAllNodes() ([]models.ProbeNode, error) { return nil, nil } -func (m *mockStore) UpdateNodeLastSeen(string) error { return nil } -func (m *mockStore) DeleteNode(string) error { return nil } -func (m *mockStore) LoadAlertHealth() (map[int]models.AlertHealthRecord, error) { +func (m *mockStore) SaveCheckFromNode(_ context.Context, _ int, _ string, _ int64, _ bool) error { + return nil +} +func (m *mockStore) RegisterNode(_ context.Context, _ models.ProbeNode) error { return nil } +func (m *mockStore) GetNode(_ context.Context, _ string) (models.ProbeNode, error) { + return models.ProbeNode{}, nil +} +func (m *mockStore) GetAllNodes(_ context.Context) ([]models.ProbeNode, error) { return nil, nil } +func (m *mockStore) UpdateNodeLastSeen(_ context.Context, _ string) error { return nil } +func (m *mockStore) DeleteNode(_ context.Context, _ string) error { return nil } +func (m *mockStore) LoadAlertHealth(_ context.Context) (map[int]models.AlertHealthRecord, error) { return nil, nil } -func (m *mockStore) SaveAlertHealth(models.AlertHealthRecord) error { return nil } -func (m *mockStore) SaveLog(string) error { return nil } -func (m *mockStore) PruneLogs() error { return nil } -func (m *mockStore) PruneCheckHistory() error { return nil } -func (m *mockStore) PruneStateChanges() error { return nil } -func (m *mockStore) LoadLogs(int) ([]string, error) { return nil, nil } -func (m *mockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) { +func (m *mockStore) SaveAlertHealth(_ context.Context, _ models.AlertHealthRecord) error { return nil } +func (m *mockStore) SaveLog(_ context.Context, _ string) error { return nil } +func (m *mockStore) PruneLogs(_ context.Context) error { return nil } +func (m *mockStore) PruneCheckHistory(_ context.Context) error { return nil } +func (m *mockStore) PruneStateChanges(_ context.Context) error { return nil } +func (m *mockStore) LoadLogs(_ context.Context, _ int) ([]string, error) { return nil, nil } +func (m *mockStore) GetActiveMaintenanceWindows(_ context.Context) ([]models.MaintenanceWindow, error) { return nil, nil } -func (m *mockStore) GetAllMaintenanceWindows(int) ([]models.MaintenanceWindow, error) { +func (m *mockStore) GetAllMaintenanceWindows(_ context.Context, _ int) ([]models.MaintenanceWindow, error) { return nil, nil } -func (m *mockStore) AddMaintenanceWindow(models.MaintenanceWindow) error { return nil } -func (m *mockStore) EndMaintenanceWindow(int) error { return nil } -func (m *mockStore) DeleteMaintenanceWindow(int) error { return nil } -func (m *mockStore) PruneExpiredMaintenanceWindows(time.Duration) (int64, error) { return 0, nil } -func (m *mockStore) IsMonitorInMaintenance(int) (bool, error) { return false, nil } -func (m *mockStore) GetPreference(string) (string, error) { return "", nil } -func (m *mockStore) SetPreference(string, string) error { return nil } -func (m *mockStore) SaveStateChange(int, string, string, string) error { return nil } -func (m *mockStore) GetStateChanges(int, int) ([]models.StateChange, error) { return nil, nil } -func (m *mockStore) GetStateChangesSince(int, time.Time) ([]models.StateChange, error) { +func (m *mockStore) AddMaintenanceWindow(_ context.Context, _ models.MaintenanceWindow) error { + return nil +} +func (m *mockStore) EndMaintenanceWindow(_ context.Context, _ int) error { return nil } +func (m *mockStore) DeleteMaintenanceWindow(_ context.Context, _ int) error { return nil } +func (m *mockStore) PruneExpiredMaintenanceWindows(_ context.Context, _ time.Duration) (int64, error) { + return 0, nil +} +func (m *mockStore) IsMonitorInMaintenance(_ context.Context, _ int) (bool, error) { return false, nil } +func (m *mockStore) GetPreference(_ context.Context, _ string) (string, error) { return "", nil } +func (m *mockStore) SetPreference(_ context.Context, _ string, _ string) error { return nil } +func (m *mockStore) SaveStateChange(_ context.Context, _ int, _ string, _ string, _ string) error { + return nil +} +func (m *mockStore) GetStateChanges(_ context.Context, _ int, _ int) ([]models.StateChange, error) { + return nil, nil +} +func (m *mockStore) GetStateChangesSince(_ context.Context, _ int, _ time.Time) ([]models.StateChange, error) { return nil, nil } func (m *mockStore) Close() error { return nil } diff --git a/internal/monitor/checker.go b/internal/monitor/checker.go index 5ed3f3a..48abc97 100644 --- a/internal/monitor/checker.go +++ b/internal/monitor/checker.go @@ -35,7 +35,7 @@ type CheckResult struct { ErrorReason string } -func RunCheck(site models.Site, strict, insecure *http.Client, globalInsecure bool, allowPrivate ...bool) CheckResult { +func RunCheck(ctx context.Context, site models.Site, strict, insecure *http.Client, globalInsecure bool, allowPrivate ...bool) CheckResult { private := len(allowPrivate) > 0 && allowPrivate[0] if site.Type != "http" && site.Type != "dns" && !private { @@ -56,26 +56,26 @@ func RunCheck(site models.Site, strict, insecure *http.Client, globalInsecure bo switch site.Type { case "http": - return runHTTPCheck(site, strict, insecure, globalInsecure) + return runHTTPCheck(ctx, site, strict, insecure, globalInsecure) case "ping": - return runPingCheck(site) + return runPingCheck(ctx, site) case "port": - return runPortCheck(site) + return runPortCheck(ctx, site) case "dns": - return runDNSCheck(site) + return runDNSCheck(ctx, site) default: return CheckResult{SiteID: site.ID, Status: "DOWN", ErrorReason: "unsupported monitor type: " + site.Type} } } -func runHTTPCheck(site models.Site, strict, insecure *http.Client, globalInsecure bool) CheckResult { +func runHTTPCheck(ctx context.Context, site models.Site, strict, insecure *http.Client, globalInsecure bool) CheckResult { method := site.Method if method == "" { method = "GET" } timeout := siteTimeout(site) - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() req, err := http.NewRequestWithContext(ctx, method, site.URL, nil) @@ -128,7 +128,7 @@ func runHTTPCheck(site models.Site, strict, insecure *http.Client, globalInsecur return result } -func runPingCheck(site models.Site) CheckResult { +func runPingCheck(_ context.Context, site models.Site) CheckResult { host := site.Hostname if host == "" { host = site.URL @@ -157,7 +157,7 @@ func runPingCheck(site models.Site) CheckResult { return CheckResult{SiteID: site.ID, Status: "UP", LatencyNs: stats.AvgRtt.Nanoseconds()} } -func runPortCheck(site models.Site) CheckResult { +func runPortCheck(_ context.Context, site models.Site) CheckResult { host := site.Hostname if host == "" { host = site.URL @@ -176,7 +176,7 @@ func runPortCheck(site models.Site) CheckResult { return CheckResult{SiteID: site.ID, Status: "UP", LatencyNs: latency.Nanoseconds()} } -func runDNSCheck(site models.Site) CheckResult { +func runDNSCheck(_ context.Context, site models.Site) CheckResult { host := site.Hostname if host == "" { host = site.URL diff --git a/internal/monitor/checker_test.go b/internal/monitor/checker_test.go index 0e71c28..c2cf15d 100644 --- a/internal/monitor/checker_test.go +++ b/internal/monitor/checker_test.go @@ -1,6 +1,7 @@ package monitor import ( + "context" "crypto/tls" "net" "net/http" @@ -19,7 +20,7 @@ func TestRunCheck_HTTP_Success(t *testing.T) { defer srv.Close() site := models.Site{ID: 1, Type: "http", URL: srv.URL} - result := RunCheck(site, http.DefaultClient, http.DefaultClient, false) + result := RunCheck(context.Background(), site, http.DefaultClient, http.DefaultClient, false) if result.Status != "UP" { t.Errorf("expected UP, got %s", result.Status) @@ -39,7 +40,7 @@ func TestRunCheck_HTTP_ServerError(t *testing.T) { defer srv.Close() site := models.Site{ID: 1, Type: "http", URL: srv.URL} - result := RunCheck(site, http.DefaultClient, http.DefaultClient, false) + result := RunCheck(context.Background(), site, http.DefaultClient, http.DefaultClient, false) if result.Status != "DOWN" { t.Errorf("expected DOWN, got %s", result.Status) @@ -60,7 +61,7 @@ func TestRunCheck_HTTP_CustomAcceptedCodes(t *testing.T) { }} site := models.Site{ID: 1, Type: "http", URL: srv.URL, AcceptedCodes: "200-399"} - result := RunCheck(site, client, client, false) + result := RunCheck(context.Background(), site, client, client, false) if result.Status != "UP" { t.Errorf("expected UP with accepted 200-399, got %s", result.Status) @@ -76,7 +77,7 @@ func TestRunCheck_HTTP_MethodRespected(t *testing.T) { defer srv.Close() site := models.Site{ID: 1, Type: "http", URL: srv.URL, Method: "HEAD"} - RunCheck(site, http.DefaultClient, http.DefaultClient, false) + RunCheck(context.Background(), site, http.DefaultClient, http.DefaultClient, false) if receivedMethod != "HEAD" { t.Errorf("expected HEAD, got %s", receivedMethod) @@ -91,7 +92,7 @@ func TestRunCheck_HTTP_Timeout(t *testing.T) { defer srv.Close() site := models.Site{ID: 1, Type: "http", URL: srv.URL, Timeout: 1} - result := RunCheck(site, http.DefaultClient, http.DefaultClient, false) + result := RunCheck(context.Background(), site, http.DefaultClient, http.DefaultClient, false) if result.Status != "DOWN" { t.Errorf("expected DOWN on timeout, got %s", result.Status) @@ -109,7 +110,7 @@ func TestRunCheck_HTTP_SSLFields(t *testing.T) { } site := models.Site{ID: 1, Type: "http", URL: srv.URL, CheckSSL: true, IgnoreTLS: true} - result := RunCheck(site, http.DefaultClient, insecureClient, false) + result := RunCheck(context.Background(), site, http.DefaultClient, insecureClient, false) if result.Status != "UP" { t.Errorf("expected UP, got %s", result.Status) @@ -133,7 +134,7 @@ func TestRunCheck_Port_Open(t *testing.T) { port, _ := strconv.Atoi(portStr) site := models.Site{ID: 1, Type: "port", Hostname: "127.0.0.1", Port: port, Timeout: 2} - result := RunCheck(site, nil, nil, false, true) + result := RunCheck(context.Background(), site, nil, nil, false, true) if result.Status != "UP" { t.Errorf("expected UP, got %s", result.Status) @@ -153,7 +154,7 @@ func TestRunCheck_Port_Closed(t *testing.T) { ln.Close() site := models.Site{ID: 1, Type: "port", Hostname: "127.0.0.1", Port: port, Timeout: 1} - result := RunCheck(site, nil, nil, false, true) + result := RunCheck(context.Background(), site, nil, nil, false, true) if result.Status != "DOWN" { t.Errorf("expected DOWN, got %s", result.Status) @@ -171,7 +172,7 @@ func TestRunCheck_Port_BlocksPrivateByDefault(t *testing.T) { port, _ := strconv.Atoi(portStr) site := models.Site{ID: 1, Type: "port", Hostname: "127.0.0.1", Port: port, Timeout: 2} - result := RunCheck(site, nil, nil, false) + result := RunCheck(context.Background(), site, nil, nil, false) if result.Status != "DOWN" { t.Errorf("expected DOWN when private targets blocked, got %s", result.Status) @@ -180,7 +181,7 @@ func TestRunCheck_Port_BlocksPrivateByDefault(t *testing.T) { func TestRunCheck_UnknownType(t *testing.T) { site := models.Site{ID: 1, Type: "invalid"} - result := RunCheck(site, nil, nil, false) + result := RunCheck(context.Background(), site, nil, nil, false) if result.Status != "DOWN" { t.Errorf("expected DOWN for unknown type, got %s", result.Status) diff --git a/internal/monitor/dbwriter.go b/internal/monitor/dbwriter.go index ebdbb89..e97f622 100644 --- a/internal/monitor/dbwriter.go +++ b/internal/monitor/dbwriter.go @@ -1,6 +1,8 @@ package monitor import ( + "context" + "gitea.lerkolabs.com/lerkolabs/uptop/internal/models" "gitea.lerkolabs.com/lerkolabs/uptop/internal/store" ) @@ -10,14 +12,14 @@ import ( // serializing all writes through one connection and surfacing errors instead of // discarding them. desc names the write for diagnostics on drop/failure. type dbWrite interface { - exec(s store.Store) error + exec(ctx context.Context, s store.Store) error desc() string } type writeLog struct{ message string } -func (w writeLog) exec(s store.Store) error { return s.SaveLog(w.message) } -func (w writeLog) desc() string { return "log" } +func (w writeLog) exec(ctx context.Context, s store.Store) error { return s.SaveLog(ctx, w.message) } +func (w writeLog) desc() string { return "log" } type writeCheck struct { siteID int @@ -25,8 +27,10 @@ type writeCheck struct { isUp bool } -func (w writeCheck) exec(s store.Store) error { return s.SaveCheck(w.siteID, w.latencyNs, w.isUp) } -func (w writeCheck) desc() string { return "check" } +func (w writeCheck) exec(ctx context.Context, s store.Store) error { + return s.SaveCheck(ctx, w.siteID, w.latencyNs, w.isUp) +} +func (w writeCheck) desc() string { return "check" } type writeStateChange struct { siteID int @@ -35,15 +39,17 @@ type writeStateChange struct { reason string } -func (w writeStateChange) exec(s store.Store) error { - return s.SaveStateChange(w.siteID, w.fromStatus, w.toStatus, w.reason) +func (w writeStateChange) exec(ctx context.Context, s store.Store) error { + return s.SaveStateChange(ctx, w.siteID, w.fromStatus, w.toStatus, w.reason) } func (w writeStateChange) desc() string { return "state-change" } type writeAlertHealth struct{ rec models.AlertHealthRecord } -func (w writeAlertHealth) exec(s store.Store) error { return s.SaveAlertHealth(w.rec) } -func (w writeAlertHealth) desc() string { return "alert-health" } +func (w writeAlertHealth) exec(ctx context.Context, s store.Store) error { + return s.SaveAlertHealth(ctx, w.rec) +} +func (w writeAlertHealth) desc() string { return "alert-health" } type writeProbeCheck struct { siteID int @@ -52,7 +58,7 @@ type writeProbeCheck struct { isUp bool } -func (w writeProbeCheck) exec(s store.Store) error { - return s.SaveCheckFromNode(w.siteID, w.nodeID, w.latencyNs, w.isUp) +func (w writeProbeCheck) exec(ctx context.Context, s store.Store) error { + return s.SaveCheckFromNode(ctx, w.siteID, w.nodeID, w.latencyNs, w.isUp) } func (w writeProbeCheck) desc() string { return "probe-check" } diff --git a/internal/monitor/history.go b/internal/monitor/history.go index ec0e294..2a88154 100644 --- a/internal/monitor/history.go +++ b/internal/monitor/history.go @@ -1,6 +1,9 @@ package monitor -import "time" +import ( + "context" + "time" +) const maxHistoryLen = 60 @@ -12,7 +15,7 @@ type SiteHistory struct { } func (e *Engine) InitHistory() { - all, err := e.db.LoadAllHistory(maxHistoryLen) + all, err := e.db.LoadAllHistory(context.Background(), maxHistoryLen) if err != nil { e.AddLog("Failed to load check history: " + err.Error()) return diff --git a/internal/monitor/monitor.go b/internal/monitor/monitor.go index f48ccdf..c9ae786 100644 --- a/internal/monitor/monitor.go +++ b/internal/monitor/monitor.go @@ -185,16 +185,16 @@ func (e *Engine) dbWriter(ctx context.Context) { pruneTicker := time.NewTicker(dbPruneInterval) defer pruneTicker.Stop() - e.prune() + e.prune(ctx) for { select { case w := <-e.dbWrites: - if err := w.exec(e.db); err != nil { + if err := w.exec(ctx, e.db); err != nil { e.appendLog(fmt.Sprintf("db %s write failed: %v", w.desc(), err)) } case <-pruneTicker.C: - e.prune() + e.prune(ctx) case <-ctx.Done(): e.drainWrites() return @@ -207,7 +207,7 @@ func (e *Engine) drainWrites() { for { select { case w := <-e.dbWrites: - if err := w.exec(e.db); err != nil { + if err := w.exec(context.Background(), e.db); err != nil { e.appendLog(fmt.Sprintf("db %s write failed (drain): %v", w.desc(), err)) } default: @@ -216,14 +216,14 @@ func (e *Engine) drainWrites() { } } -func (e *Engine) prune() { - if err := e.db.PruneLogs(); err != nil { +func (e *Engine) prune(ctx context.Context) { + if err := e.db.PruneLogs(ctx); err != nil { e.appendLog(fmt.Sprintf("log prune failed: %v", err)) } - if err := e.db.PruneCheckHistory(); err != nil { + if err := e.db.PruneCheckHistory(ctx); err != nil { e.appendLog(fmt.Sprintf("check-history prune failed: %v", err)) } - if err := e.db.PruneStateChanges(); err != nil { + if err := e.db.PruneStateChanges(ctx); err != nil { e.appendLog(fmt.Sprintf("state-change prune failed: %v", err)) } } @@ -242,7 +242,7 @@ func (e *Engine) Stop() { } func (e *Engine) InitLogs() { - logs, err := e.db.LoadLogs(maxLogEntries) + logs, err := e.db.LoadLogs(context.Background(), maxLogEntries) if err != nil { return } @@ -257,7 +257,7 @@ func (e *Engine) InitLogs() { // InitAlertHealth restores persisted alert send health so the dashboard shows real // "last sent" / health state on startup instead of resetting every channel to "never". func (e *Engine) InitAlertHealth() { - records, err := e.db.LoadAlertHealth() + records, err := e.db.LoadAlertHealth(context.Background()) if err != nil { return } @@ -416,9 +416,9 @@ func (e *Engine) Start(ctx context.Context) { default: } - e.refreshMaintenanceCache() + e.refreshMaintenanceCache(ctx) - sites, err := e.db.GetSites() + sites, err := e.db.GetSites(ctx) if err != nil { e.AddLog(fmt.Sprintf("Failed to load sites: %v", err)) select { @@ -475,20 +475,20 @@ func (e *Engine) maintenancePruner(ctx context.Context) { ticker := time.NewTicker(maintPruneInterval) defer ticker.Stop() - e.pruneMaintenanceWindows() + e.pruneMaintenanceWindows(ctx) for { select { case <-ticker.C: - e.pruneMaintenanceWindows() + e.pruneMaintenanceWindows(ctx) case <-ctx.Done(): return } } } -func (e *Engine) pruneMaintenanceWindows() { - pruned, err := e.db.PruneExpiredMaintenanceWindows(e.maintRetention) +func (e *Engine) pruneMaintenanceWindows(ctx context.Context) { + pruned, err := e.db.PruneExpiredMaintenanceWindows(ctx, e.maintRetention) if err != nil { e.AddLog(fmt.Sprintf("Maintenance prune error: %v", err)) return @@ -588,7 +588,7 @@ func (e *Engine) monitorRoutine(ctx context.Context, id int) { return } - e.checkByID(id) + e.checkByID(ctx, id) for { select { case <-ctx.Done(): @@ -634,7 +634,7 @@ func (e *Engine) monitorRoutine(ctx context.Context, id int) { return case <-recheckCh: } - e.checkByID(id) + e.checkByID(ctx, id) } } @@ -657,7 +657,7 @@ func (e *Engine) applyState(id int, mutate func(s *models.Site)) (models.Site, b return cur, true } -func (e *Engine) checkByID(id int) { +func (e *Engine) checkByID(ctx context.Context, id int) { if !e.IsActive() { return } @@ -671,11 +671,11 @@ func (e *Engine) checkByID(id int) { switch site.Type { case "push": - e.checkPush(site) + e.checkPush(ctx, site) case "group": - e.checkGroup(site) + e.checkGroup(ctx, site) default: - result := RunCheck(site, e.strictClient, e.insecureClient, e.insecureSkipVerify, e.allowPrivateTargets) + result := RunCheck(ctx, site, e.strictClient, e.insecureClient, e.insecureSkipVerify, e.allowPrivateTargets) updatedSite := site updatedSite.HasSSL = result.HasSSL updatedSite.CertExpiry = result.CertExpiry @@ -685,7 +685,7 @@ func (e *Engine) checkByID(id int) { } } -func (e *Engine) checkPush(site models.Site) { +func (e *Engine) checkPush(_ context.Context, site models.Site) { if site.Status == "PENDING" { return } @@ -875,7 +875,7 @@ func (e *Engine) handleStatusChange(snap models.Site, rawStatus string, code int } func (e *Engine) triggerAlert(alertID int, title, message string) { - cfg, err := e.db.GetAlert(alertID) + cfg, err := e.db.GetAlert(context.Background(), alertID) if err != nil { e.AddLog(fmt.Sprintf("Failed to load alert config %d: %v", alertID, err)) return @@ -928,7 +928,7 @@ func (e *Engine) GetAlertHealth(alertID int) AlertHealth { } func (e *Engine) TestAlert(alertID int) error { - cfg, err := e.db.GetAlert(alertID) + cfg, err := e.db.GetAlert(context.Background(), alertID) if err != nil { return fmt.Errorf("failed to load alert: %w", err) } @@ -954,8 +954,8 @@ func (e *Engine) isInMaintenance(monitorID int) bool { return e.maintCache[monitorID] } -func (e *Engine) refreshMaintenanceCache() { - windows, err := e.db.GetActiveMaintenanceWindows() +func (e *Engine) refreshMaintenanceCache(ctx context.Context) { + windows, err := e.db.GetActiveMaintenanceWindows(ctx) if err != nil { return } @@ -994,7 +994,7 @@ func (e *Engine) GetDisplayStatus(site models.Site) string { return site.Status } -func (e *Engine) checkGroup(site models.Site) { +func (e *Engine) checkGroup(_ context.Context, site models.Site) { e.mu.RLock() status := "UP" hasChildren := false @@ -1095,7 +1095,7 @@ func (e *Engine) GetProbeResults(siteID int) map[string]NodeResult { } func (e *Engine) GetStateChanges(siteID int, limit int) []models.StateChange { - changes, err := e.db.GetStateChanges(siteID, limit) + changes, err := e.db.GetStateChanges(context.Background(), siteID, limit) if err != nil { return nil } @@ -1103,7 +1103,7 @@ func (e *Engine) GetStateChanges(siteID int, limit int) []models.StateChange { } func (e *Engine) GetStateChangesSince(siteID int, since time.Time) []models.StateChange { - changes, err := e.db.GetStateChangesSince(siteID, since) + changes, err := e.db.GetStateChangesSince(context.Background(), siteID, since) if err != nil { return nil } diff --git a/internal/monitor/monitor_test.go b/internal/monitor/monitor_test.go index 1a8d35d..80c5936 100644 --- a/internal/monitor/monitor_test.go +++ b/internal/monitor/monitor_test.go @@ -38,37 +38,43 @@ func newMockStore() *mockStore { } } -func (m *mockStore) Init() error { return nil } -func (m *mockStore) GetSites() ([]models.Site, error) { return m.sites, nil } -func (m *mockStore) AddSite(models.Site) error { return nil } -func (m *mockStore) UpdateSite(models.Site) error { return nil } -func (m *mockStore) UpdateSitePaused(int, bool) error { return nil } -func (m *mockStore) DeleteSite(int) error { return nil } -func (m *mockStore) AddAlert(string, string, map[string]string) error { return nil } -func (m *mockStore) UpdateAlert(int, string, string, map[string]string) error { return nil } -func (m *mockStore) DeleteAlert(int) error { return nil } -func (m *mockStore) GetAllUsers() ([]models.User, error) { return nil, nil } -func (m *mockStore) AddUser(string, string, string) error { return nil } -func (m *mockStore) UpdateUser(int, string, string, string) error { return nil } -func (m *mockStore) DeleteUser(int) error { return nil } -func (m *mockStore) ExportData() (models.Backup, error) { return models.Backup{}, nil } -func (m *mockStore) ImportData(models.Backup) error { return nil } -func (m *mockStore) GetSiteByName(string) (models.Site, error) { return models.Site{}, nil } -func (m *mockStore) AddSiteReturningID(models.Site) (int, error) { return 0, nil } -func (m *mockStore) AddAlertReturningID(string, string, map[string]string) (int, error) { +func (m *mockStore) Init(context.Context) error { return nil } +func (m *mockStore) GetSites(context.Context) ([]models.Site, error) { return m.sites, nil } +func (m *mockStore) AddSite(context.Context, models.Site) error { return nil } +func (m *mockStore) UpdateSite(context.Context, models.Site) error { return nil } +func (m *mockStore) UpdateSitePaused(context.Context, int, bool) error { return nil } +func (m *mockStore) DeleteSite(context.Context, int) error { return nil } +func (m *mockStore) AddAlert(context.Context, string, string, map[string]string) error { return nil } +func (m *mockStore) UpdateAlert(context.Context, int, string, string, map[string]string) error { + return nil +} +func (m *mockStore) DeleteAlert(context.Context, int) error { return nil } +func (m *mockStore) GetAllUsers(context.Context) ([]models.User, error) { return nil, nil } +func (m *mockStore) AddUser(context.Context, string, string, string) error { return nil } +func (m *mockStore) UpdateUser(context.Context, int, string, string, string) error { return nil } +func (m *mockStore) DeleteUser(context.Context, int) error { return nil } +func (m *mockStore) ExportData(context.Context) (models.Backup, error) { return models.Backup{}, nil } +func (m *mockStore) ImportData(context.Context, models.Backup) error { return nil } +func (m *mockStore) GetSiteByName(context.Context, string) (models.Site, error) { + return models.Site{}, nil +} +func (m *mockStore) AddSiteReturningID(context.Context, models.Site) (int, error) { return 0, nil } +func (m *mockStore) AddAlertReturningID(context.Context, string, string, map[string]string) (int, error) { return 0, nil } -func (m *mockStore) SaveCheckFromNode(int, string, int64, bool) error { return nil } -func (m *mockStore) RegisterNode(models.ProbeNode) error { return nil } -func (m *mockStore) GetNode(string) (models.ProbeNode, error) { return models.ProbeNode{}, nil } -func (m *mockStore) GetAllNodes() ([]models.ProbeNode, error) { return nil, nil } -func (m *mockStore) UpdateNodeLastSeen(string) error { return nil } -func (m *mockStore) DeleteNode(string) error { return nil } -func (m *mockStore) LoadAlertHealth() (map[int]models.AlertHealthRecord, error) { +func (m *mockStore) SaveCheckFromNode(context.Context, int, string, int64, bool) error { return nil } +func (m *mockStore) RegisterNode(context.Context, models.ProbeNode) error { return nil } +func (m *mockStore) GetNode(context.Context, string) (models.ProbeNode, error) { + return models.ProbeNode{}, nil +} +func (m *mockStore) GetAllNodes(context.Context) ([]models.ProbeNode, error) { return nil, nil } +func (m *mockStore) UpdateNodeLastSeen(context.Context, string) error { return nil } +func (m *mockStore) DeleteNode(context.Context, string) error { return nil } +func (m *mockStore) LoadAlertHealth(context.Context) (map[int]models.AlertHealthRecord, error) { return nil, nil } -func (m *mockStore) SaveAlertHealth(models.AlertHealthRecord) error { return nil } -func (m *mockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) { +func (m *mockStore) SaveAlertHealth(context.Context, models.AlertHealthRecord) error { return nil } +func (m *mockStore) GetActiveMaintenanceWindows(context.Context) ([]models.MaintenanceWindow, error) { m.mu.Lock() defer m.mu.Unlock() var windows []models.MaintenanceWindow @@ -77,23 +83,27 @@ func (m *mockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, e } return windows, nil } -func (m *mockStore) GetAllMaintenanceWindows(int) ([]models.MaintenanceWindow, error) { +func (m *mockStore) GetAllMaintenanceWindows(context.Context, int) ([]models.MaintenanceWindow, error) { return nil, nil } -func (m *mockStore) AddMaintenanceWindow(models.MaintenanceWindow) error { return nil } -func (m *mockStore) EndMaintenanceWindow(int) error { return nil } -func (m *mockStore) DeleteMaintenanceWindow(int) error { return nil } -func (m *mockStore) PruneExpiredMaintenanceWindows(time.Duration) (int64, error) { return 0, nil } -func (m *mockStore) GetPreference(string) (string, error) { return "", nil } -func (m *mockStore) SetPreference(string, string) error { return nil } -func (m *mockStore) SaveStateChange(int, string, string, string) error { return nil } -func (m *mockStore) GetStateChanges(int, int) ([]models.StateChange, error) { return nil, nil } -func (m *mockStore) GetStateChangesSince(int, time.Time) ([]models.StateChange, error) { +func (m *mockStore) AddMaintenanceWindow(context.Context, models.MaintenanceWindow) error { return nil } +func (m *mockStore) EndMaintenanceWindow(context.Context, int) error { return nil } +func (m *mockStore) DeleteMaintenanceWindow(context.Context, int) error { return nil } +func (m *mockStore) PruneExpiredMaintenanceWindows(context.Context, time.Duration) (int64, error) { + return 0, nil +} +func (m *mockStore) GetPreference(context.Context, string) (string, error) { return "", nil } +func (m *mockStore) SetPreference(context.Context, string, string) error { return nil } +func (m *mockStore) SaveStateChange(context.Context, int, string, string, string) error { return nil } +func (m *mockStore) GetStateChanges(context.Context, int, int) ([]models.StateChange, error) { + return nil, nil +} +func (m *mockStore) GetStateChangesSince(context.Context, int, time.Time) ([]models.StateChange, error) { return nil, nil } func (m *mockStore) Close() error { return nil } -func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { +func (m *mockStore) GetAllAlerts(context.Context) ([]models.AlertConfig, error) { m.mu.Lock() defer m.mu.Unlock() var result []models.AlertConfig @@ -103,7 +113,7 @@ func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { return result, nil } -func (m *mockStore) GetAlert(id int) (models.AlertConfig, error) { +func (m *mockStore) GetAlert(_ context.Context, id int) (models.AlertConfig, error) { m.mu.Lock() defer m.mu.Unlock() m.getAlertCalls = append(m.getAlertCalls, id) @@ -113,7 +123,7 @@ func (m *mockStore) GetAlert(id int) (models.AlertConfig, error) { return models.AlertConfig{}, fmt.Errorf("alert %d not found", id) } -func (m *mockStore) GetAlertByName(name string) (models.AlertConfig, error) { +func (m *mockStore) GetAlertByName(_ context.Context, name string) (models.AlertConfig, error) { m.mu.Lock() defer m.mu.Unlock() for _, a := range m.alerts { @@ -124,37 +134,37 @@ func (m *mockStore) GetAlertByName(name string) (models.AlertConfig, error) { return models.AlertConfig{}, fmt.Errorf("alert %q not found", name) } -func (m *mockStore) IsMonitorInMaintenance(id int) (bool, error) { +func (m *mockStore) IsMonitorInMaintenance(_ context.Context, id int) (bool, error) { m.mu.Lock() defer m.mu.Unlock() return m.maintenance[id], nil } -func (m *mockStore) SaveCheck(siteID int, latencyNs int64, isUp bool) error { +func (m *mockStore) SaveCheck(_ context.Context, siteID int, latencyNs int64, isUp bool) error { m.mu.Lock() defer m.mu.Unlock() m.savedChecks = append(m.savedChecks, savedCheck{siteID, latencyNs, isUp}) return nil } -func (m *mockStore) SaveLog(msg string) error { +func (m *mockStore) SaveLog(_ context.Context, msg string) error { m.mu.Lock() defer m.mu.Unlock() m.savedLogs = append(m.savedLogs, msg) return nil } -func (m *mockStore) LoadLogs(limit int) ([]string, error) { +func (m *mockStore) LoadLogs(_ context.Context, limit int) ([]string, error) { return m.logs, nil } -func (m *mockStore) LoadAllHistory(limit int) (map[int][]models.CheckRecord, error) { +func (m *mockStore) LoadAllHistory(_ context.Context, limit int) (map[int][]models.CheckRecord, error) { return m.history, nil } -func (m *mockStore) PruneLogs() error { return nil } -func (m *mockStore) PruneCheckHistory() error { return nil } -func (m *mockStore) PruneStateChanges() error { return nil } +func (m *mockStore) PruneLogs(context.Context) error { return nil } +func (m *mockStore) PruneCheckHistory(context.Context) error { return nil } +func (m *mockStore) PruneStateChanges(context.Context) error { return nil } // --- Helpers --- @@ -336,7 +346,7 @@ func TestHandleStatusChange_AlertSuppressedMaintenance(t *testing.T) { e := newTestEngine(ms) site := models.Site{ID: 1, Name: "test", Status: "UP", MaxRetries: 0, AlertID: 1} injectSite(e, site) - e.refreshMaintenanceCache() + e.refreshMaintenanceCache(context.Background()) e.handleStatusChange(site, "DOWN", 0, 0, "test error") @@ -368,7 +378,7 @@ func TestHandleStatusChange_RecoverySuppressedMaintenance(t *testing.T) { e := newTestEngine(ms) site := models.Site{ID: 1, Name: "test", Status: "DOWN", AlertID: 1} injectSite(e, site) - e.refreshMaintenanceCache() + e.refreshMaintenanceCache(context.Background()) e.handleStatusChange(site, "UP", 200, 0, "") @@ -456,7 +466,7 @@ func TestHandleStatusChange_SSLWarningSuppressedMaint(t *testing.T) { CertExpiry: time.Now().Add(15 * 24 * time.Hour), } injectSite(e, site) - e.refreshMaintenanceCache() + e.refreshMaintenanceCache(context.Background()) e.handleStatusChange(site, "UP", 200, 0, "") @@ -563,7 +573,7 @@ func TestCheckPush_DeadlineMissed(t *testing.T) { } injectSite(e, site) - e.checkPush(site) + e.checkPush(context.Background(), site) s, _ := getSite(e, 1) if s.Status != "DOWN" { @@ -581,7 +591,7 @@ func TestCheckPush_OverdueBecomesLate(t *testing.T) { } injectSite(e, site) - e.checkPush(site) + e.checkPush(context.Background(), site) s, _ := getSite(e, 1) if s.Status != "LATE" { @@ -601,7 +611,7 @@ func TestCheckPush_OverdueBecomesStale(t *testing.T) { } injectSite(e, site) - e.checkPush(site) + e.checkPush(context.Background(), site) s, _ := getSite(e, 1) if s.Status != "STALE" { @@ -618,7 +628,7 @@ func TestCheckPush_WithinDeadline(t *testing.T) { } injectSite(e, site) - e.checkPush(site) + e.checkPush(context.Background(), site) s, _ := getSite(e, 1) if s.Status != "UP" { @@ -635,7 +645,7 @@ func TestCheckPush_PendingStaysPending(t *testing.T) { } injectSite(e, site) - e.checkPush(site) + e.checkPush(context.Background(), site) s, _ := getSite(e, 1) if s.Status != "PENDING" { @@ -655,7 +665,7 @@ func TestCheckGroup_AllChildrenUp(t *testing.T) { injectSite(e, child1) injectSite(e, child2) - e.checkGroup(group) + e.checkGroup(context.Background(), group) s, _ := getSite(e, 1) if s.Status != "UP" { @@ -673,7 +683,7 @@ func TestCheckGroup_OneChildDown(t *testing.T) { injectSite(e, child1) injectSite(e, child2) - e.checkGroup(group) + e.checkGroup(context.Background(), group) s, _ := getSite(e, 1) if s.Status != "DOWN" { @@ -691,7 +701,7 @@ func TestCheckGroup_PausedChildIgnored(t *testing.T) { injectSite(e, child1) injectSite(e, child2) - e.checkGroup(group) + e.checkGroup(context.Background(), group) s, _ := getSite(e, 1) if s.Status != "UP" { @@ -709,9 +719,9 @@ func TestCheckGroup_MaintenanceChildIgnored(t *testing.T) { injectSite(e, group) injectSite(e, child1) injectSite(e, child2) - e.refreshMaintenanceCache() + e.refreshMaintenanceCache(context.Background()) - e.checkGroup(group) + e.checkGroup(context.Background(), group) s, _ := getSite(e, 1) if s.Status != "UP" { @@ -725,7 +735,7 @@ func TestCheckGroup_NoChildren(t *testing.T) { group := models.Site{ID: 1, Name: "group", Type: "group", Status: "UP"} injectSite(e, group) - e.checkGroup(group) + e.checkGroup(context.Background(), group) s, _ := getSite(e, 1) if s.Status != "PENDING" { @@ -1241,7 +1251,7 @@ func TestCheckGroup_AllPausedNoAutoFreeze(t *testing.T) { injectSite(e, child1) injectSite(e, child2) - e.checkGroup(group) + e.checkGroup(context.Background(), group) s, _ := getSite(e, 1) if s.Paused { @@ -1361,7 +1371,7 @@ func TestIsInMaintenance_UsesCache(t *testing.T) { child := models.Site{ID: 20, Name: "child", Type: "http", ParentID: 10, Status: "UP"} injectSite(e, group) injectSite(e, child) - e.refreshMaintenanceCache() + e.refreshMaintenanceCache(context.Background()) if !e.isInMaintenance(10) { t.Error("group should be in maintenance (direct)") @@ -1381,7 +1391,7 @@ func TestIsInMaintenance_GlobalMaintenance(t *testing.T) { e := newTestEngine(ms) site := models.Site{ID: 1, Name: "test", Type: "http", Status: "UP"} injectSite(e, site) - e.refreshMaintenanceCache() + e.refreshMaintenanceCache(context.Background()) if !e.isInMaintenance(1) { t.Error("all monitors should be in maintenance during global window") diff --git a/internal/server/server.go b/internal/server/server.go index f791892..6538e41 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -255,7 +255,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { http.Error(w, "Unauthorized: UPTOP_CLUSTER_SECRET required", http.StatusUnauthorized) return } - data, err := s.ExportData() + data, err := s.ExportData(r.Context()) if err != nil { log.Printf("Export failed: %v", err) http.Error(w, "Export failed", http.StatusInternalServerError) @@ -285,7 +285,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { http.Error(w, "Invalid JSON", http.StatusBadRequest) return } - if err := s.ImportData(data); err != nil { + if err := s.ImportData(r.Context(), data); err != nil { log.Printf("Import failed: %v", err) http.Error(w, "Import failed", http.StatusInternalServerError) return @@ -311,7 +311,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { return } backup := importer.ConvertKuma(&kb) - if err := s.ImportData(backup); err != nil { + if err := s.ImportData(r.Context(), backup); err != nil { log.Printf("Kuma import failed: %v", err) http.Error(w, "Import failed", http.StatusInternalServerError) return @@ -344,7 +344,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { http.Error(w, "id is required", http.StatusBadRequest) return } - if err := s.RegisterNode(models.ProbeNode{ + if err := s.RegisterNode(r.Context(), models.ProbeNode{ ID: req.ID, Name: req.Name, Region: req.Region, Version: req.Version, }); err != nil { log.Printf("Probe register failed: %v", err) @@ -367,7 +367,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { nodeID := r.URL.Query().Get("node_id") var nodeRegion string if nodeID != "" { - if node, err := s.GetNode(nodeID); err == nil { + if node, err := s.GetNode(r.Context(), nodeID); err == nil { nodeRegion = node.Region } } @@ -427,7 +427,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { eng.EnqueueProbeCheck(result.SiteID, req.NodeID, result.LatencyNs, result.IsUp) eng.IngestProbeResult(req.NodeID, result.SiteID, result.LatencyNs, result.IsUp, result.ErrorReason) } - if err := s.UpdateNodeLastSeen(req.NodeID); err != nil { + if err := s.UpdateNodeLastSeen(r.Context(), req.NodeID); err != nil { log.Printf("Failed to update node last seen: %v", err) } _ = json.NewEncoder(w).Encode(map[string]bool{"ok": true}) //nolint:errcheck @@ -453,7 +453,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { mux.HandleFunc("/status", RateLimit(statusRL, func(w http.ResponseWriter, r *http.Request) { renderStatusPage(w, cfg.Title, eng) })) mux.HandleFunc("/status/json", RateLimit(statusRL, func(w http.ResponseWriter, r *http.Request) { state := eng.GetLiveState() - activeWindows, _ := s.GetActiveMaintenanceWindows() + activeWindows, _ := s.GetActiveMaintenanceWindows(r.Context()) maintSet := make(map[int]bool) allInMaint := false for _, mw := range activeWindows { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 4609096..ad14ef9 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -33,80 +33,100 @@ func newMockStore() *mockStore { } } -func (m *mockStore) Init() error { return nil } -func (m *mockStore) GetSites() ([]models.Site, error) { return m.sites, nil } -func (m *mockStore) AddSite(models.Site) error { return nil } -func (m *mockStore) UpdateSite(models.Site) error { return nil } -func (m *mockStore) UpdateSitePaused(int, bool) error { return nil } -func (m *mockStore) DeleteSite(int) error { return nil } -func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { return m.alerts, nil } -func (m *mockStore) GetAlert(int) (models.AlertConfig, error) { return models.AlertConfig{}, nil } -func (m *mockStore) AddAlert(string, string, map[string]string) error { return nil } -func (m *mockStore) UpdateAlert(int, string, string, map[string]string) error { return nil } -func (m *mockStore) DeleteAlert(int) error { return nil } -func (m *mockStore) GetAllUsers() ([]models.User, error) { return nil, nil } -func (m *mockStore) AddUser(string, string, string) error { return nil } -func (m *mockStore) UpdateUser(int, string, string, string) error { return nil } -func (m *mockStore) DeleteUser(int) error { return nil } -func (m *mockStore) SaveCheck(int, int64, bool) error { return nil } -func (m *mockStore) SaveCheckFromNode(siteID int, nodeID string, latencyNs int64, isUp bool) error { - return nil +func (m *mockStore) Init(_ context.Context) error { return nil } +func (m *mockStore) GetSites(_ context.Context) ([]models.Site, error) { return m.sites, nil } +func (m *mockStore) AddSite(_ context.Context, _ models.Site) error { return nil } +func (m *mockStore) UpdateSite(_ context.Context, _ models.Site) error { return nil } +func (m *mockStore) UpdateSitePaused(_ context.Context, _ int, _ bool) error { return nil } +func (m *mockStore) DeleteSite(_ context.Context, _ int) error { return nil } +func (m *mockStore) GetAllAlerts(_ context.Context) ([]models.AlertConfig, error) { + return m.alerts, nil } -func (m *mockStore) LoadAllHistory(int) (map[int][]models.CheckRecord, error) { - return nil, nil -} -func (m *mockStore) GetSiteByName(string) (models.Site, error) { return models.Site{}, nil } -func (m *mockStore) GetAlertByName(string) (models.AlertConfig, error) { +func (m *mockStore) GetAlert(_ context.Context, _ int) (models.AlertConfig, error) { return models.AlertConfig{}, nil } -func (m *mockStore) AddSiteReturningID(models.Site) (int, error) { return 0, nil } -func (m *mockStore) AddAlertReturningID(string, string, map[string]string) (int, error) { +func (m *mockStore) AddAlert(_ context.Context, _ string, _ string, _ map[string]string) error { + return nil +} +func (m *mockStore) UpdateAlert(_ context.Context, _ int, _ string, _ string, _ map[string]string) error { + return nil +} +func (m *mockStore) DeleteAlert(_ context.Context, _ int) error { return nil } +func (m *mockStore) GetAllUsers(_ context.Context) ([]models.User, error) { return nil, nil } +func (m *mockStore) AddUser(_ context.Context, _ string, _ string, _ string) error { return nil } +func (m *mockStore) UpdateUser(_ context.Context, _ int, _ string, _ string, _ string) error { + return nil +} +func (m *mockStore) DeleteUser(_ context.Context, _ int) error { return nil } +func (m *mockStore) SaveCheck(_ context.Context, _ int, _ int64, _ bool) error { return nil } +func (m *mockStore) SaveCheckFromNode(_ context.Context, siteID int, nodeID string, latencyNs int64, isUp bool) error { + return nil +} +func (m *mockStore) LoadAllHistory(_ context.Context, _ int) (map[int][]models.CheckRecord, error) { + return nil, nil +} +func (m *mockStore) GetSiteByName(_ context.Context, _ string) (models.Site, error) { + return models.Site{}, nil +} +func (m *mockStore) GetAlertByName(_ context.Context, _ string) (models.AlertConfig, error) { + return models.AlertConfig{}, nil +} +func (m *mockStore) AddSiteReturningID(_ context.Context, _ models.Site) (int, error) { return 0, nil } +func (m *mockStore) AddAlertReturningID(_ context.Context, _ string, _ string, _ map[string]string) (int, error) { return 0, nil } -func (m *mockStore) GetAllNodes() ([]models.ProbeNode, error) { return nil, nil } -func (m *mockStore) UpdateNodeLastSeen(string) error { return nil } -func (m *mockStore) DeleteNode(string) error { return nil } -func (m *mockStore) LoadAlertHealth() (map[int]models.AlertHealthRecord, error) { +func (m *mockStore) GetAllNodes(_ context.Context) ([]models.ProbeNode, error) { return nil, nil } +func (m *mockStore) UpdateNodeLastSeen(_ context.Context, _ string) error { return nil } +func (m *mockStore) DeleteNode(_ context.Context, _ string) error { return nil } +func (m *mockStore) LoadAlertHealth(_ context.Context) (map[int]models.AlertHealthRecord, error) { return nil, nil } -func (m *mockStore) SaveAlertHealth(models.AlertHealthRecord) error { return nil } -func (m *mockStore) SaveLog(string) error { return nil } -func (m *mockStore) PruneLogs() error { return nil } -func (m *mockStore) PruneCheckHistory() error { return nil } -func (m *mockStore) PruneStateChanges() error { return nil } -func (m *mockStore) LoadLogs(int) ([]string, error) { return nil, nil } -func (m *mockStore) GetAllMaintenanceWindows(int) ([]models.MaintenanceWindow, error) { +func (m *mockStore) SaveAlertHealth(_ context.Context, _ models.AlertHealthRecord) error { return nil } +func (m *mockStore) SaveLog(_ context.Context, _ string) error { return nil } +func (m *mockStore) PruneLogs(_ context.Context) error { return nil } +func (m *mockStore) PruneCheckHistory(_ context.Context) error { return nil } +func (m *mockStore) PruneStateChanges(_ context.Context) error { return nil } +func (m *mockStore) LoadLogs(_ context.Context, _ int) ([]string, error) { return nil, nil } +func (m *mockStore) GetAllMaintenanceWindows(_ context.Context, _ int) ([]models.MaintenanceWindow, error) { return nil, nil } -func (m *mockStore) AddMaintenanceWindow(models.MaintenanceWindow) error { return nil } -func (m *mockStore) EndMaintenanceWindow(int) error { return nil } -func (m *mockStore) DeleteMaintenanceWindow(int) error { return nil } -func (m *mockStore) PruneExpiredMaintenanceWindows(time.Duration) (int64, error) { return 0, nil } -func (m *mockStore) IsMonitorInMaintenance(int) (bool, error) { return false, nil } -func (m *mockStore) GetPreference(string) (string, error) { return "", nil } -func (m *mockStore) SetPreference(string, string) error { return nil } -func (m *mockStore) SaveStateChange(int, string, string, string) error { return nil } -func (m *mockStore) GetStateChanges(int, int) ([]models.StateChange, error) { return nil, nil } -func (m *mockStore) GetStateChangesSince(int, time.Time) ([]models.StateChange, error) { +func (m *mockStore) AddMaintenanceWindow(_ context.Context, _ models.MaintenanceWindow) error { + return nil +} +func (m *mockStore) EndMaintenanceWindow(_ context.Context, _ int) error { return nil } +func (m *mockStore) DeleteMaintenanceWindow(_ context.Context, _ int) error { return nil } +func (m *mockStore) PruneExpiredMaintenanceWindows(_ context.Context, _ time.Duration) (int64, error) { + return 0, nil +} +func (m *mockStore) IsMonitorInMaintenance(_ context.Context, _ int) (bool, error) { return false, nil } +func (m *mockStore) GetPreference(_ context.Context, _ string) (string, error) { return "", nil } +func (m *mockStore) SetPreference(_ context.Context, _ string, _ string) error { return nil } +func (m *mockStore) SaveStateChange(_ context.Context, _ int, _ string, _ string, _ string) error { + return nil +} +func (m *mockStore) GetStateChanges(_ context.Context, _ int, _ int) ([]models.StateChange, error) { + return nil, nil +} +func (m *mockStore) GetStateChangesSince(_ context.Context, _ int, _ time.Time) ([]models.StateChange, error) { return nil, nil } func (m *mockStore) Close() error { return nil } -func (m *mockStore) ExportData() (models.Backup, error) { +func (m *mockStore) ExportData(_ context.Context) (models.Backup, error) { return models.Backup{ Sites: m.sites, Alerts: m.alerts, }, nil } -func (m *mockStore) ImportData(data models.Backup) error { +func (m *mockStore) ImportData(_ context.Context, data models.Backup) error { m.mu.Lock() defer m.mu.Unlock() m.importedData = &data return nil } -func (m *mockStore) RegisterNode(node models.ProbeNode) error { +func (m *mockStore) RegisterNode(_ context.Context, node models.ProbeNode) error { m.mu.Lock() defer m.mu.Unlock() m.registeredNodes = append(m.registeredNodes, node) @@ -114,7 +134,7 @@ func (m *mockStore) RegisterNode(node models.ProbeNode) error { return nil } -func (m *mockStore) GetNode(id string) (models.ProbeNode, error) { +func (m *mockStore) GetNode(_ context.Context, id string) (models.ProbeNode, error) { m.mu.Lock() defer m.mu.Unlock() if n, ok := m.nodes[id]; ok { @@ -123,7 +143,7 @@ func (m *mockStore) GetNode(id string) (models.ProbeNode, error) { return models.ProbeNode{}, fmt.Errorf("not found") } -func (m *mockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) { +func (m *mockStore) GetActiveMaintenanceWindows(_ context.Context) ([]models.MaintenanceWindow, error) { return m.maintWindows, nil } diff --git a/internal/store/sqlstore.go b/internal/store/sqlstore.go index 1f0198f..05eff31 100644 --- a/internal/store/sqlstore.go +++ b/internal/store/sqlstore.go @@ -1,6 +1,7 @@ package store import ( + "context" "crypto/rand" "database/sql" "encoding/hex" @@ -73,14 +74,14 @@ func (s *SQLStore) Close() error { return s.db.Close() } -func (s *SQLStore) Init() error { +func (s *SQLStore) Init(ctx context.Context) error { for _, stmt := range s.dialect.CreateTablesSQL() { - if _, err := s.db.Exec(stmt); err != nil { + if _, err := s.db.ExecContext(ctx, stmt); err != nil { return err } } for _, m := range s.dialect.MigrationsSQL() { - if _, err := s.db.Exec(m); err != nil { + if _, err := s.db.ExecContext(ctx, m); err != nil { errMsg := err.Error() if strings.Contains(errMsg, "already exists") || strings.Contains(errMsg, "duplicate column") { continue @@ -91,13 +92,13 @@ func (s *SQLStore) Init() error { return nil } -func (s *SQLStore) GetSites() ([]models.Site, error) { +func (s *SQLStore) GetSites(ctx context.Context) ([]models.Site, error) { bf := s.dialect.BoolFalse() query := fmt.Sprintf( //nolint:gosec // bf is a dialect boolean literal, not user input "SELECT id, COALESCE(name, url), url, COALESCE(type, 'http'), COALESCE(token, ''), interval, alert_id, check_ssl, threshold, max_retries, COALESCE(hostname, ''), COALESCE(port, 0), COALESCE(timeout, 0), COALESCE(method, 'GET'), COALESCE(description, ''), COALESCE(parent_id, 0), COALESCE(accepted_codes, '200-299'), COALESCE(dns_resolve_type, ''), COALESCE(dns_server, ''), COALESCE(ignore_tls, %s), COALESCE(paused, %s), COALESCE(regions, '') FROM sites", bf, bf, ) - rows, err := s.db.Query(query) + rows, err := s.db.QueryContext(ctx, query) if err != nil { return nil, err } @@ -116,7 +117,7 @@ func (s *SQLStore) GetSites() ([]models.Site, error) { return sites, rows.Err() } -func (s *SQLStore) AddSite(site models.Site) error { +func (s *SQLStore) AddSite(ctx context.Context, site models.Site) error { token := "" if site.Type == "push" { var err error @@ -125,15 +126,15 @@ func (s *SQLStore) AddSite(site models.Site) error { return fmt.Errorf("generate push token: %w", err) } } - _, err := s.db.Exec(s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), + _, err := s.db.ExecContext(ctx, s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.Regions) return err } -func (s *SQLStore) UpdateSite(site models.Site) error { +func (s *SQLStore) UpdateSite(ctx context.Context, site models.Site) error { var existingToken string - _ = s.db.QueryRow(s.q("SELECT token FROM sites WHERE id=?"), site.ID).Scan(&existingToken) //nolint:errcheck + _ = s.db.QueryRowContext(ctx, s.q("SELECT token FROM sites WHERE id=?"), site.ID).Scan(&existingToken) //nolint:errcheck if site.Type == "push" && existingToken == "" { var err error existingToken, err = generateToken() @@ -141,19 +142,19 @@ func (s *SQLStore) UpdateSite(site models.Site) error { return fmt.Errorf("generate push token: %w", err) } } - _, err := s.db.Exec(s.q("UPDATE sites SET name=?, url=?, type=?, token=?, interval=?, alert_id=?, check_ssl=?, threshold=?, max_retries=?, hostname=?, port=?, timeout=?, method=?, description=?, parent_id=?, accepted_codes=?, dns_resolve_type=?, dns_server=?, ignore_tls=?, paused=?, regions=? WHERE id=?"), + _, err := s.db.ExecContext(ctx, s.q("UPDATE sites SET name=?, url=?, type=?, token=?, interval=?, alert_id=?, check_ssl=?, threshold=?, max_retries=?, hostname=?, port=?, timeout=?, method=?, description=?, parent_id=?, accepted_codes=?, dns_resolve_type=?, dns_server=?, ignore_tls=?, paused=?, regions=? WHERE id=?"), site.Name, site.URL, site.Type, existingToken, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.Regions, site.ID) return err } -func (s *SQLStore) UpdateSitePaused(id int, paused bool) error { - _, err := s.db.Exec(s.q("UPDATE sites SET paused=? WHERE id=?"), paused, id) +func (s *SQLStore) UpdateSitePaused(ctx context.Context, id int, paused bool) error { + _, err := s.db.ExecContext(ctx, s.q("UPDATE sites SET paused=? WHERE id=?"), paused, id) return err } -func (s *SQLStore) DeleteSite(id int) error { - tx, err := s.db.Begin() +func (s *SQLStore) DeleteSite(ctx context.Context, id int) error { + tx, err := s.db.BeginTx(ctx, nil) if err != nil { return err } @@ -165,7 +166,7 @@ func (s *SQLStore) DeleteSite(id int) error { "DELETE FROM state_changes WHERE site_id = ?", "DELETE FROM sites WHERE id = ?", } { - if _, err := tx.Exec(s.q(q), id); err != nil { + if _, err := tx.ExecContext(ctx, s.q(q), id); err != nil { return err } } @@ -177,14 +178,14 @@ func (s *SQLStore) DeleteSite(id int) error { return nil } -func (s *SQLStore) GetSiteByName(name string) (models.Site, error) { +func (s *SQLStore) GetSiteByName(ctx context.Context, name string) (models.Site, error) { bf := s.dialect.BoolFalse() query := fmt.Sprintf( //nolint:gosec // bf is a dialect boolean literal, not user input "SELECT id, COALESCE(name, url), url, COALESCE(type, 'http'), COALESCE(token, ''), interval, alert_id, check_ssl, threshold, max_retries, COALESCE(hostname, ''), COALESCE(port, 0), COALESCE(timeout, 0), COALESCE(method, 'GET'), COALESCE(description, ''), COALESCE(parent_id, 0), COALESCE(accepted_codes, '200-299'), COALESCE(dns_resolve_type, ''), COALESCE(dns_server, ''), COALESCE(ignore_tls, %s), COALESCE(paused, %s), COALESCE(regions, '') FROM sites WHERE name = %s", bf, bf, s.q("?"), ) var st models.Site - err := s.db.QueryRow(query, name).Scan(&st.ID, &st.Name, &st.URL, &st.Type, &st.Token, &st.Interval, &st.AlertID, + err := s.db.QueryRowContext(ctx, query, name).Scan(&st.ID, &st.Name, &st.URL, &st.Type, &st.Token, &st.Interval, &st.AlertID, &st.CheckSSL, &st.ExpiryThreshold, &st.MaxRetries, &st.Hostname, &st.Port, &st.Timeout, &st.Method, &st.Description, &st.ParentID, &st.AcceptedCodes, &st.DNSResolveType, &st.DNSServer, &st.IgnoreTLS, &st.Paused, &st.Regions) @@ -211,10 +212,10 @@ func (s *SQLStore) marshalSettings(settings map[string]string) (string, error) { return s.encryptSettings(string(jsonBytes)) } -func (s *SQLStore) GetAlertByName(name string) (models.AlertConfig, error) { +func (s *SQLStore) GetAlertByName(ctx context.Context, name string) (models.AlertConfig, error) { var a models.AlertConfig var settingsRaw string - err := s.db.QueryRow(s.q("SELECT id, name, type, settings FROM alerts WHERE name = ?"), name).Scan(&a.ID, &a.Name, &a.Type, &settingsRaw) + err := s.db.QueryRowContext(ctx, s.q("SELECT id, name, type, settings FROM alerts WHERE name = ?"), name).Scan(&a.ID, &a.Name, &a.Type, &settingsRaw) if err != nil { return a, err } @@ -225,7 +226,7 @@ func (s *SQLStore) GetAlertByName(name string) (models.AlertConfig, error) { return a, nil } -func (s *SQLStore) AddSiteReturningID(site models.Site) (int, error) { +func (s *SQLStore) AddSiteReturningID(ctx context.Context, site models.Site) (int, error) { token := "" if site.Type == "push" { var err error @@ -236,12 +237,12 @@ func (s *SQLStore) AddSiteReturningID(site models.Site) (int, error) { } if s.dollar { var id int - err := s.db.QueryRow(s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) RETURNING id"), + err := s.db.QueryRowContext(ctx, s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) RETURNING id"), site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.Regions).Scan(&id) return id, err } - result, err := s.db.Exec(s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), + result, err := s.db.ExecContext(ctx, s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.Regions) if err != nil { @@ -251,17 +252,17 @@ func (s *SQLStore) AddSiteReturningID(site models.Site) (int, error) { return int(id), err } -func (s *SQLStore) AddAlertReturningID(name, aType string, settings map[string]string) (int, error) { +func (s *SQLStore) AddAlertReturningID(ctx context.Context, name, aType string, settings map[string]string) (int, error) { stored, err := s.marshalSettings(settings) if err != nil { return 0, err } if s.dollar { var id int - err := s.db.QueryRow(s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?) RETURNING id"), name, aType, stored).Scan(&id) + err := s.db.QueryRowContext(ctx, s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?) RETURNING id"), name, aType, stored).Scan(&id) return id, err } - result, err := s.db.Exec(s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?)"), name, aType, stored) + result, err := s.db.ExecContext(ctx, s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?)"), name, aType, stored) if err != nil { return 0, err } @@ -269,8 +270,8 @@ func (s *SQLStore) AddAlertReturningID(name, aType string, settings map[string]s return int(id), err } -func (s *SQLStore) GetAllAlerts() ([]models.AlertConfig, error) { - rows, err := s.db.Query("SELECT id, name, type, settings FROM alerts") +func (s *SQLStore) GetAllAlerts(ctx context.Context) ([]models.AlertConfig, error) { + rows, err := s.db.QueryContext(ctx, "SELECT id, name, type, settings FROM alerts") if err != nil { return nil, err } @@ -291,10 +292,10 @@ func (s *SQLStore) GetAllAlerts() ([]models.AlertConfig, error) { return alerts, rows.Err() } -func (s *SQLStore) GetAlert(id int) (models.AlertConfig, error) { +func (s *SQLStore) GetAlert(ctx context.Context, id int) (models.AlertConfig, error) { var a models.AlertConfig var settingsRaw string - err := s.db.QueryRow(s.q("SELECT id, name, type, settings FROM alerts WHERE id = ?"), id).Scan(&a.ID, &a.Name, &a.Type, &settingsRaw) + err := s.db.QueryRowContext(ctx, s.q("SELECT id, name, type, settings FROM alerts WHERE id = ?"), id).Scan(&a.ID, &a.Name, &a.Type, &settingsRaw) if err != nil { return a, err } @@ -305,26 +306,26 @@ func (s *SQLStore) GetAlert(id int) (models.AlertConfig, error) { return a, nil } -func (s *SQLStore) AddAlert(name, aType string, settings map[string]string) error { +func (s *SQLStore) AddAlert(ctx context.Context, name, aType string, settings map[string]string) error { stored, err := s.marshalSettings(settings) if err != nil { return err } - _, err = s.db.Exec(s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?)"), name, aType, stored) + _, err = s.db.ExecContext(ctx, s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?)"), name, aType, stored) return err } -func (s *SQLStore) UpdateAlert(id int, name, aType string, settings map[string]string) error { +func (s *SQLStore) UpdateAlert(ctx context.Context, id int, name, aType string, settings map[string]string) error { stored, err := s.marshalSettings(settings) if err != nil { return err } - _, err = s.db.Exec(s.q("UPDATE alerts SET name=?, type=?, settings=? WHERE id=?"), name, aType, stored, id) + _, err = s.db.ExecContext(ctx, s.q("UPDATE alerts SET name=?, type=?, settings=? WHERE id=?"), name, aType, stored, id) return err } -func (s *SQLStore) DeleteAlert(id int) error { - _, err := s.db.Exec(s.q("DELETE FROM alerts WHERE id=?"), id) +func (s *SQLStore) DeleteAlert(ctx context.Context, id int) error { + _, err := s.db.ExecContext(ctx, s.q("DELETE FROM alerts WHERE id=?"), id) if err != nil { return err } @@ -332,8 +333,8 @@ func (s *SQLStore) DeleteAlert(id int) error { return nil } -func (s *SQLStore) GetAllUsers() ([]models.User, error) { - rows, err := s.db.Query("SELECT id, username, public_key, role FROM users") +func (s *SQLStore) GetAllUsers(ctx context.Context) ([]models.User, error) { + rows, err := s.db.QueryContext(ctx, "SELECT id, username, public_key, role FROM users") if err != nil { return nil, err } @@ -349,29 +350,29 @@ func (s *SQLStore) GetAllUsers() ([]models.User, error) { return users, rows.Err() } -func (s *SQLStore) AddUser(username, publicKey, role string) error { - _, err := s.db.Exec(s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), username, publicKey, role) +func (s *SQLStore) AddUser(ctx context.Context, username, publicKey, role string) error { + _, err := s.db.ExecContext(ctx, s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), username, publicKey, role) return err } -func (s *SQLStore) UpdateUser(id int, username, publicKey, role string) error { - _, err := s.db.Exec(s.q("UPDATE users SET username=?, public_key=?, role=? WHERE id=?"), username, publicKey, role, id) +func (s *SQLStore) UpdateUser(ctx context.Context, id int, username, publicKey, role string) error { + _, err := s.db.ExecContext(ctx, s.q("UPDATE users SET username=?, public_key=?, role=? WHERE id=?"), username, publicKey, role, id) return err } -func (s *SQLStore) DeleteUser(id int) error { - _, err := s.db.Exec(s.q("DELETE FROM users WHERE id=?"), id) +func (s *SQLStore) DeleteUser(ctx context.Context, id int) error { + _, err := s.db.ExecContext(ctx, s.q("DELETE FROM users WHERE id=?"), id) return err } -func (s *SQLStore) SaveStateChange(siteID int, fromStatus, toStatus, errorReason string) error { - _, err := s.db.Exec(s.q("INSERT INTO state_changes (site_id, from_status, to_status, error_reason) VALUES (?, ?, ?, ?)"), +func (s *SQLStore) SaveStateChange(ctx context.Context, siteID int, fromStatus, toStatus, errorReason string) error { + _, err := s.db.ExecContext(ctx, s.q("INSERT INTO state_changes (site_id, from_status, to_status, error_reason) VALUES (?, ?, ?, ?)"), siteID, fromStatus, toStatus, errorReason) return err } -func (s *SQLStore) GetStateChanges(siteID int, limit int) ([]models.StateChange, error) { - rows, err := s.db.Query(s.q("SELECT id, site_id, from_status, to_status, error_reason, changed_at FROM state_changes WHERE site_id = ? ORDER BY changed_at DESC LIMIT ?"), siteID, limit) +func (s *SQLStore) GetStateChanges(ctx context.Context, siteID int, limit int) ([]models.StateChange, error) { + rows, err := s.db.QueryContext(ctx, s.q("SELECT id, site_id, from_status, to_status, error_reason, changed_at FROM state_changes WHERE site_id = ? ORDER BY changed_at DESC LIMIT ?"), siteID, limit) if err != nil { return nil, err } @@ -387,8 +388,8 @@ func (s *SQLStore) GetStateChanges(siteID int, limit int) ([]models.StateChange, return changes, rows.Err() } -func (s *SQLStore) GetStateChangesSince(siteID int, since time.Time) ([]models.StateChange, error) { - rows, err := s.db.Query(s.q("SELECT id, site_id, from_status, to_status, error_reason, changed_at FROM state_changes WHERE site_id = ? AND changed_at >= ? ORDER BY changed_at DESC"), siteID, since) +func (s *SQLStore) GetStateChangesSince(ctx context.Context, siteID int, since time.Time) ([]models.StateChange, error) { + rows, err := s.db.QueryContext(ctx, s.q("SELECT id, site_id, from_status, to_status, error_reason, changed_at FROM state_changes WHERE site_id = ? AND changed_at >= ? ORDER BY changed_at DESC"), siteID, since) if err != nil { return nil, err } @@ -404,59 +405,59 @@ func (s *SQLStore) GetStateChangesSince(siteID int, since time.Time) ([]models.S return changes, rows.Err() } -func (s *SQLStore) SaveCheck(siteID int, latencyNs int64, isUp bool) error { - return s.SaveCheckFromNode(siteID, "", latencyNs, isUp) +func (s *SQLStore) SaveCheck(ctx context.Context, siteID int, latencyNs int64, isUp bool) error { + return s.SaveCheckFromNode(ctx, siteID, "", latencyNs, isUp) } // SaveCheckFromNode inserts a single check row. Retention is handled out of // band by PruneCheckHistory on a timer, not per-insert, to keep the write hot // path a plain INSERT. -func (s *SQLStore) SaveCheckFromNode(siteID int, nodeID string, latencyNs int64, isUp bool) error { - _, err := s.db.Exec(s.q("INSERT INTO check_history (site_id, node_id, latency_ns, is_up) VALUES (?, ?, ?, ?)"), siteID, nodeID, latencyNs, isUp) +func (s *SQLStore) SaveCheckFromNode(ctx context.Context, siteID int, nodeID string, latencyNs int64, isUp bool) error { + _, err := s.db.ExecContext(ctx, s.q("INSERT INTO check_history (site_id, node_id, latency_ns, is_up) VALUES (?, ?, ?, ?)"), siteID, nodeID, latencyNs, isUp) return err } // PruneCheckHistory trims check_history to the newest maxCheckHistory rows per // site, across all sites, in one pass. Intended to run periodically. -func (s *SQLStore) PruneCheckHistory() error { +func (s *SQLStore) PruneCheckHistory(ctx context.Context) error { q := fmt.Sprintf(`DELETE FROM check_history WHERE id IN ( SELECT id FROM ( SELECT id, ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY checked_at DESC, id DESC) AS rn FROM check_history ) ranked WHERE rn > %d )`, maxCheckHistory) - _, err := s.db.Exec(s.q(q)) + _, err := s.db.ExecContext(ctx, s.q(q)) return err } // PruneStateChanges trims state_changes to the newest maxStateChangesPerSite // rows per site. Generous so realistic SLA windows are unaffected; bounds the // otherwise unbounded growth of a flapping monitor's history. -func (s *SQLStore) PruneStateChanges() error { +func (s *SQLStore) PruneStateChanges(ctx context.Context) error { q := fmt.Sprintf(`DELETE FROM state_changes WHERE id IN ( SELECT id FROM ( SELECT id, ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY changed_at DESC, id DESC) AS rn FROM state_changes ) ranked WHERE rn > %d )`, maxStateChangesPerSite) - _, err := s.db.Exec(s.q(q)) + _, err := s.db.ExecContext(ctx, s.q(q)) return err } -func (s *SQLStore) RegisterNode(node models.ProbeNode) error { - _, err := s.db.Exec(s.dialect.UpsertNodeSQL(), node.ID, node.Name, node.Region, node.Version) +func (s *SQLStore) RegisterNode(ctx context.Context, node models.ProbeNode) error { + _, err := s.db.ExecContext(ctx, s.dialect.UpsertNodeSQL(), node.ID, node.Name, node.Region, node.Version) return err } -func (s *SQLStore) GetNode(id string) (models.ProbeNode, error) { +func (s *SQLStore) GetNode(ctx context.Context, id string) (models.ProbeNode, error) { var n models.ProbeNode - err := s.db.QueryRow(s.q("SELECT id, name, region, last_seen, version FROM nodes WHERE id = ?"), id). + err := s.db.QueryRowContext(ctx, s.q("SELECT id, name, region, last_seen, version FROM nodes WHERE id = ?"), id). Scan(&n.ID, &n.Name, &n.Region, &n.LastSeen, &n.Version) return n, err } -func (s *SQLStore) GetAllNodes() ([]models.ProbeNode, error) { - rows, err := s.db.Query("SELECT id, name, region, last_seen, version FROM nodes ORDER BY region, name") +func (s *SQLStore) GetAllNodes(ctx context.Context) ([]models.ProbeNode, error) { + rows, err := s.db.QueryContext(ctx, "SELECT id, name, region, last_seen, version FROM nodes ORDER BY region, name") if err != nil { return nil, err } @@ -472,18 +473,18 @@ func (s *SQLStore) GetAllNodes() ([]models.ProbeNode, error) { return nodes, rows.Err() } -func (s *SQLStore) UpdateNodeLastSeen(id string) error { - _, err := s.db.Exec(s.q("UPDATE nodes SET last_seen = CURRENT_TIMESTAMP WHERE id = ?"), id) +func (s *SQLStore) UpdateNodeLastSeen(ctx context.Context, id string) error { + _, err := s.db.ExecContext(ctx, s.q("UPDATE nodes SET last_seen = CURRENT_TIMESTAMP WHERE id = ?"), id) return err } -func (s *SQLStore) DeleteNode(id string) error { - _, err := s.db.Exec(s.q("DELETE FROM nodes WHERE id = ?"), id) +func (s *SQLStore) DeleteNode(ctx context.Context, id string) error { + _, err := s.db.ExecContext(ctx, s.q("DELETE FROM nodes WHERE id = ?"), id) return err } -func (s *SQLStore) LoadAlertHealth() (map[int]models.AlertHealthRecord, error) { - rows, err := s.db.Query("SELECT alert_id, last_send_at, last_send_ok, last_error, send_count, fail_count FROM alert_health") +func (s *SQLStore) LoadAlertHealth(ctx context.Context) (map[int]models.AlertHealthRecord, error) { + rows, err := s.db.QueryContext(ctx, "SELECT alert_id, last_send_at, last_send_ok, last_error, send_count, fail_count FROM alert_health") if err != nil { return nil, err } @@ -503,35 +504,35 @@ func (s *SQLStore) LoadAlertHealth() (map[int]models.AlertHealthRecord, error) { return out, rows.Err() } -func (s *SQLStore) SaveAlertHealth(h models.AlertHealthRecord) error { +func (s *SQLStore) SaveAlertHealth(ctx context.Context, h models.AlertHealthRecord) error { var lastSend interface{} if !h.LastSendAt.IsZero() { lastSend = h.LastSendAt } - _, err := s.db.Exec(s.dialect.UpsertAlertHealthSQL(), + _, err := s.db.ExecContext(ctx, s.dialect.UpsertAlertHealthSQL(), h.AlertID, lastSend, h.LastSendOK, h.LastError, h.SendCount, h.FailCount) return err } // SaveLog inserts a single log row. Retention is handled by PruneLogs on a // timer, not per-insert. -func (s *SQLStore) SaveLog(message string) error { - _, err := s.db.Exec(s.q("INSERT INTO logs (message) VALUES (?)"), message) +func (s *SQLStore) SaveLog(ctx context.Context, message string) error { + _, err := s.db.ExecContext(ctx, s.q("INSERT INTO logs (message) VALUES (?)"), message) return err } // PruneLogs trims the logs table to the newest maxLogRows rows. The id DESC // tiebreak keeps ordering deterministic when rows share a created_at second. -func (s *SQLStore) PruneLogs() error { +func (s *SQLStore) PruneLogs(ctx context.Context) error { q := fmt.Sprintf(`DELETE FROM logs WHERE id NOT IN ( SELECT id FROM logs ORDER BY created_at DESC, id DESC LIMIT %d )`, maxLogRows) - _, err := s.db.Exec(s.q(q)) + _, err := s.db.ExecContext(ctx, s.q(q)) return err } -func (s *SQLStore) LoadLogs(limit int) ([]string, error) { - rows, err := s.db.Query(s.q("SELECT message FROM logs ORDER BY created_at DESC LIMIT ?"), limit) +func (s *SQLStore) LoadLogs(ctx context.Context, limit int) ([]string, error) { + rows, err := s.db.QueryContext(ctx, s.q("SELECT message FROM logs ORDER BY created_at DESC LIMIT ?"), limit) if err != nil { return nil, err } @@ -547,9 +548,9 @@ func (s *SQLStore) LoadLogs(limit int) ([]string, error) { return logs, rows.Err() } -func (s *SQLStore) LoadAllHistory(limit int) (map[int][]models.CheckRecord, error) { +func (s *SQLStore) LoadAllHistory(ctx context.Context, limit int) (map[int][]models.CheckRecord, error) { result := make(map[int][]models.CheckRecord) - rows, err := s.db.Query(s.q(` + rows, err := s.db.QueryContext(ctx, s.q(` SELECT site_id, latency_ns, is_up FROM ( SELECT site_id, latency_ns, is_up, ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY checked_at DESC) AS rn @@ -587,8 +588,8 @@ func (s *SQLStore) scanMaintenanceWindow(rows *sql.Rows) (models.MaintenanceWind return mw, nil } -func (s *SQLStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) { - rows, err := s.db.Query(s.q("SELECT id, monitor_id, title, description, type, start_time, end_time, created_by, created_at FROM maintenance_windows WHERE start_time <= CURRENT_TIMESTAMP AND (end_time IS NULL OR end_time > CURRENT_TIMESTAMP) ORDER BY start_time DESC")) +func (s *SQLStore) GetActiveMaintenanceWindows(ctx context.Context) ([]models.MaintenanceWindow, error) { + rows, err := s.db.QueryContext(ctx, s.q("SELECT id, monitor_id, title, description, type, start_time, end_time, created_by, created_at FROM maintenance_windows WHERE start_time <= CURRENT_TIMESTAMP AND (end_time IS NULL OR end_time > CURRENT_TIMESTAMP) ORDER BY start_time DESC")) if err != nil { return nil, err } @@ -604,8 +605,8 @@ func (s *SQLStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, er return windows, rows.Err() } -func (s *SQLStore) GetAllMaintenanceWindows(limit int) ([]models.MaintenanceWindow, error) { - rows, err := s.db.Query(s.q("SELECT id, monitor_id, title, description, type, start_time, end_time, created_by, created_at FROM maintenance_windows ORDER BY created_at DESC LIMIT ?"), limit) +func (s *SQLStore) GetAllMaintenanceWindows(ctx context.Context, limit int) ([]models.MaintenanceWindow, error) { + rows, err := s.db.QueryContext(ctx, s.q("SELECT id, monitor_id, title, description, type, start_time, end_time, created_by, created_at FROM maintenance_windows ORDER BY created_at DESC LIMIT ?"), limit) if err != nil { return nil, err } @@ -621,22 +622,22 @@ func (s *SQLStore) GetAllMaintenanceWindows(limit int) ([]models.MaintenanceWind return windows, rows.Err() } -func (s *SQLStore) AddMaintenanceWindow(mw models.MaintenanceWindow) error { +func (s *SQLStore) AddMaintenanceWindow(ctx context.Context, mw models.MaintenanceWindow) error { if mw.StartTime.IsZero() { mw.StartTime = time.Now() } - _, err := s.db.Exec(s.q("INSERT INTO maintenance_windows (monitor_id, title, description, type, start_time, end_time, created_by) VALUES (?, ?, ?, ?, ?, ?, ?)"), + _, err := s.db.ExecContext(ctx, s.q("INSERT INTO maintenance_windows (monitor_id, title, description, type, start_time, end_time, created_by) VALUES (?, ?, ?, ?, ?, ?, ?)"), mw.MonitorID, mw.Title, mw.Description, mw.Type, mw.StartTime, sql.NullTime{Time: mw.EndTime, Valid: !mw.EndTime.IsZero()}, mw.CreatedBy) return err } -func (s *SQLStore) EndMaintenanceWindow(id int) error { - _, err := s.db.Exec(s.q("UPDATE maintenance_windows SET end_time = CURRENT_TIMESTAMP WHERE id = ?"), id) +func (s *SQLStore) EndMaintenanceWindow(ctx context.Context, id int) error { + _, err := s.db.ExecContext(ctx, s.q("UPDATE maintenance_windows SET end_time = CURRENT_TIMESTAMP WHERE id = ?"), id) return err } -func (s *SQLStore) DeleteMaintenanceWindow(id int) error { - _, err := s.db.Exec(s.q("DELETE FROM maintenance_windows WHERE id = ?"), id) +func (s *SQLStore) DeleteMaintenanceWindow(ctx context.Context, id int) error { + _, err := s.db.ExecContext(ctx, s.q("DELETE FROM maintenance_windows WHERE id = ?"), id) if err != nil { return err } @@ -644,9 +645,9 @@ func (s *SQLStore) DeleteMaintenanceWindow(id int) error { return nil } -func (s *SQLStore) PruneExpiredMaintenanceWindows(retention time.Duration) (int64, error) { +func (s *SQLStore) PruneExpiredMaintenanceWindows(ctx context.Context, retention time.Duration) (int64, error) { cutoff := time.Now().Add(-retention) - result, err := s.db.Exec( + result, err := s.db.ExecContext(ctx, s.q("DELETE FROM maintenance_windows WHERE end_time IS NOT NULL AND end_time < ?"), cutoff, ) @@ -656,9 +657,9 @@ func (s *SQLStore) PruneExpiredMaintenanceWindows(retention time.Duration) (int6 return result.RowsAffected() } -func (s *SQLStore) IsMonitorInMaintenance(monitorID int) (bool, error) { +func (s *SQLStore) IsMonitorInMaintenance(ctx context.Context, monitorID int) (bool, error) { var count int - err := s.db.QueryRow(s.q(`SELECT COUNT(*) FROM maintenance_windows + err := s.db.QueryRowContext(ctx, s.q(`SELECT COUNT(*) FROM maintenance_windows WHERE type = 'maintenance' AND start_time <= CURRENT_TIMESTAMP AND (end_time IS NULL OR end_time > CURRENT_TIMESTAMP) @@ -671,46 +672,46 @@ func (s *SQLStore) IsMonitorInMaintenance(monitorID int) (bool, error) { return count > 0, nil } -func (s *SQLStore) GetPreference(key string) (string, error) { +func (s *SQLStore) GetPreference(ctx context.Context, key string) (string, error) { var value string - err := s.db.QueryRow(s.q("SELECT value FROM preferences WHERE key = ?"), key).Scan(&value) + err := s.db.QueryRowContext(ctx, s.q("SELECT value FROM preferences WHERE key = ?"), key).Scan(&value) if err != nil { return "", err } return value, nil } -func (s *SQLStore) SetPreference(key, value string) error { +func (s *SQLStore) SetPreference(ctx context.Context, key, value string) error { if s.dollar { - _, err := s.db.Exec(s.q("INSERT INTO preferences (key, value) VALUES (?, ?) ON CONFLICT (key) DO UPDATE SET value = ?"), key, value, value) + _, err := s.db.ExecContext(ctx, s.q("INSERT INTO preferences (key, value) VALUES (?, ?) ON CONFLICT (key) DO UPDATE SET value = ?"), key, value, value) return err } - _, err := s.db.Exec("INSERT OR REPLACE INTO preferences (key, value) VALUES (?, ?)", key, value) + _, err := s.db.ExecContext(ctx, "INSERT OR REPLACE INTO preferences (key, value) VALUES (?, ?)", key, value) return err } -func (s *SQLStore) ExportData() (models.Backup, error) { - sites, err := s.GetSites() +func (s *SQLStore) ExportData(ctx context.Context) (models.Backup, error) { + sites, err := s.GetSites(ctx) if err != nil { return models.Backup{}, err } - alerts, err := s.GetAllAlerts() + alerts, err := s.GetAllAlerts(ctx) if err != nil { return models.Backup{}, err } - users, err := s.GetAllUsers() + users, err := s.GetAllUsers(ctx) if err != nil { return models.Backup{}, err } - windows, err := s.GetAllMaintenanceWindows(maxMaintenanceExport) + windows, err := s.GetAllMaintenanceWindows(ctx, maxMaintenanceExport) if err != nil { return models.Backup{}, err } return models.Backup{Sites: sites, Alerts: alerts, Users: users, MaintenanceWindows: windows}, nil } -func (s *SQLStore) ImportData(data models.Backup) error { - tx, err := s.db.Begin() +func (s *SQLStore) ImportData(ctx context.Context, data models.Backup) error { + tx, err := s.db.BeginTx(ctx, nil) if err != nil { return err } @@ -719,7 +720,7 @@ func (s *SQLStore) ImportData(data models.Backup) error { s.dialect.ImportWipe(tx) for _, u := range data.Users { - if _, err := tx.Exec(s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), u.Username, u.PublicKey, u.Role); err != nil { + if _, err := tx.ExecContext(ctx, s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), u.Username, u.PublicKey, u.Role); err != nil { return err } } @@ -730,12 +731,12 @@ func (s *SQLStore) ImportData(data models.Backup) error { if err != nil { return err } - if _, err := tx.Exec(s.q("INSERT INTO alerts (id, name, type, settings) VALUES (?, ?, ?, ?)"), a.ID, a.Name, a.Type, settingsStr); err != nil { + if _, err := tx.ExecContext(ctx, s.q("INSERT INTO alerts (id, name, type, settings) VALUES (?, ?, ?, ?)"), a.ID, a.Name, a.Type, settingsStr); err != nil { return err } } for _, st := range data.Sites { - if _, err := tx.Exec(s.q("INSERT INTO sites (id, name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), + if _, err := tx.ExecContext(ctx, s.q("INSERT INTO sites (id, name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), st.ID, st.Name, st.URL, st.Type, st.Token, st.Interval, st.AlertID, st.CheckSSL, st.ExpiryThreshold, st.MaxRetries, st.Hostname, st.Port, st.Timeout, st.Method, st.Description, st.ParentID, st.AcceptedCodes, st.DNSResolveType, st.DNSServer, st.IgnoreTLS, st.Paused, st.Regions); err != nil { return err @@ -743,7 +744,7 @@ func (s *SQLStore) ImportData(data models.Backup) error { } for _, mw := range data.MaintenanceWindows { - if _, err := tx.Exec(s.q("INSERT INTO maintenance_windows (id, monitor_id, title, description, type, start_time, end_time, created_by) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"), + if _, err := tx.ExecContext(ctx, s.q("INSERT INTO maintenance_windows (id, monitor_id, title, description, type, start_time, end_time, created_by) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"), mw.ID, mw.MonitorID, mw.Title, mw.Description, mw.Type, mw.StartTime, sql.NullTime{Time: mw.EndTime, Valid: !mw.EndTime.IsZero()}, mw.CreatedBy); err != nil { return err } diff --git a/internal/store/sqlstore_test.go b/internal/store/sqlstore_test.go index 5876de3..39fc08c 100644 --- a/internal/store/sqlstore_test.go +++ b/internal/store/sqlstore_test.go @@ -1,6 +1,7 @@ package store import ( + "context" "fmt" "strings" "testing" @@ -15,7 +16,7 @@ func newTestStore(t *testing.T) *SQLStore { if err != nil { t.Fatalf("NewSQLiteStore: %v", err) } - if err := s.Init(); err != nil { + if err := s.Init(context.Background()); err != nil { t.Fatalf("Init: %v", err) } return s @@ -24,7 +25,7 @@ func newTestStore(t *testing.T) *SQLStore { func TestSiteCRUD(t *testing.T) { s := newTestStore(t) - sites, err := s.GetSites() + sites, err := s.GetSites(context.Background()) if err != nil { t.Fatalf("GetSites: %v", err) } @@ -32,11 +33,11 @@ func TestSiteCRUD(t *testing.T) { t.Fatalf("expected 0 sites, got %d", len(sites)) } - if err := s.AddSite(models.Site{Name: "Test", URL: "https://example.com", Type: "http", Interval: 30}); err != nil { + if err := s.AddSite(context.Background(), models.Site{Name: "Test", URL: "https://example.com", Type: "http", Interval: 30}); err != nil { t.Fatalf("AddSite: %v", err) } - sites, err = s.GetSites() + sites, err = s.GetSites(context.Background()) if err != nil { t.Fatalf("GetSites: %v", err) } @@ -48,11 +49,11 @@ func TestSiteCRUD(t *testing.T) { } sites[0].Name = "Updated" - if err := s.UpdateSite(sites[0]); err != nil { + if err := s.UpdateSite(context.Background(), sites[0]); err != nil { t.Fatalf("UpdateSite: %v", err) } - sites, err = s.GetSites() + sites, err = s.GetSites(context.Background()) if err != nil { t.Fatalf("GetSites: %v", err) } @@ -60,11 +61,11 @@ func TestSiteCRUD(t *testing.T) { t.Errorf("expected name 'Updated', got '%s'", sites[0].Name) } - if err := s.DeleteSite(sites[0].ID); err != nil { + if err := s.DeleteSite(context.Background(), sites[0].ID); err != nil { t.Fatalf("DeleteSite: %v", err) } - sites, err = s.GetSites() + sites, err = s.GetSites(context.Background()) if err != nil { t.Fatalf("GetSites: %v", err) } @@ -76,11 +77,11 @@ func TestSiteCRUD(t *testing.T) { func TestAlertCRUD(t *testing.T) { s := newTestStore(t) - if err := s.AddAlert("Discord", "discord", map[string]string{"url": "https://example.com/hook"}); err != nil { + if err := s.AddAlert(context.Background(), "Discord", "discord", map[string]string{"url": "https://example.com/hook"}); err != nil { t.Fatalf("AddAlert: %v", err) } - alerts, err := s.GetAllAlerts() + alerts, err := s.GetAllAlerts(context.Background()) if err != nil { t.Fatalf("GetAllAlerts: %v", err) } @@ -94,7 +95,7 @@ func TestAlertCRUD(t *testing.T) { t.Errorf("settings url mismatch") } - a, err := s.GetAlert(alerts[0].ID) + a, err := s.GetAlert(context.Background(), alerts[0].ID) if err != nil { t.Fatalf("GetAlert: %v", err) } @@ -102,11 +103,11 @@ func TestAlertCRUD(t *testing.T) { t.Errorf("expected name 'Discord', got '%s'", a.Name) } - if err := s.UpdateAlert(a.ID, "Slack", "slack", map[string]string{"url": "https://slack.com/hook"}); err != nil { + if err := s.UpdateAlert(context.Background(), a.ID, "Slack", "slack", map[string]string{"url": "https://slack.com/hook"}); err != nil { t.Fatalf("UpdateAlert: %v", err) } - a, err = s.GetAlert(a.ID) + a, err = s.GetAlert(context.Background(), a.ID) if err != nil { t.Fatalf("GetAlert: %v", err) } @@ -114,11 +115,11 @@ func TestAlertCRUD(t *testing.T) { t.Errorf("expected type 'slack', got '%s'", a.Type) } - if err := s.DeleteAlert(a.ID); err != nil { + if err := s.DeleteAlert(context.Background(), a.ID); err != nil { t.Fatalf("DeleteAlert: %v", err) } - alerts, err = s.GetAllAlerts() + alerts, err = s.GetAllAlerts(context.Background()) if err != nil { t.Fatalf("GetAllAlerts: %v", err) } @@ -130,11 +131,11 @@ func TestAlertCRUD(t *testing.T) { func TestUserCRUD(t *testing.T) { s := newTestStore(t) - if err := s.AddUser("admin", "ssh-ed25519 AAAA...", "admin"); err != nil { + if err := s.AddUser(context.Background(), "admin", "ssh-ed25519 AAAA...", "admin"); err != nil { t.Fatalf("AddUser: %v", err) } - users, err := s.GetAllUsers() + users, err := s.GetAllUsers(context.Background()) if err != nil { t.Fatalf("GetAllUsers: %v", err) } @@ -145,11 +146,11 @@ func TestUserCRUD(t *testing.T) { t.Errorf("expected username 'admin', got '%s'", users[0].Username) } - if err := s.UpdateUser(users[0].ID, "root", "ssh-ed25519 BBBB...", "admin"); err != nil { + if err := s.UpdateUser(context.Background(), users[0].ID, "root", "ssh-ed25519 BBBB...", "admin"); err != nil { t.Fatalf("UpdateUser: %v", err) } - users, err = s.GetAllUsers() + users, err = s.GetAllUsers(context.Background()) if err != nil { t.Fatalf("GetAllUsers: %v", err) } @@ -157,11 +158,11 @@ func TestUserCRUD(t *testing.T) { t.Errorf("expected username 'root', got '%s'", users[0].Username) } - if err := s.DeleteUser(users[0].ID); err != nil { + if err := s.DeleteUser(context.Background(), users[0].ID); err != nil { t.Fatalf("DeleteUser: %v", err) } - users, err = s.GetAllUsers() + users, err = s.GetAllUsers(context.Background()) if err != nil { t.Fatalf("GetAllUsers: %v", err) } @@ -173,11 +174,11 @@ func TestUserCRUD(t *testing.T) { func TestPushTokenGeneration(t *testing.T) { s := newTestStore(t) - if err := s.AddSite(models.Site{Name: "Push Monitor", Type: "push", Interval: 60}); err != nil { + if err := s.AddSite(context.Background(), models.Site{Name: "Push Monitor", Type: "push", Interval: 60}); err != nil { t.Fatalf("AddSite: %v", err) } - sites, err := s.GetSites() + sites, err := s.GetSites(context.Background()) if err != nil { t.Fatalf("GetSites: %v", err) } @@ -195,17 +196,17 @@ func TestPushTokenGeneration(t *testing.T) { func TestImportExport(t *testing.T) { s := newTestStore(t) - if err := s.AddAlert("Test Alert", "webhook", map[string]string{"url": "https://example.com"}); err != nil { + if err := s.AddAlert(context.Background(), "Test Alert", "webhook", map[string]string{"url": "https://example.com"}); err != nil { t.Fatalf("AddAlert: %v", err) } - if err := s.AddSite(models.Site{Name: "Site1", URL: "https://example.com", Type: "http", Interval: 30}); err != nil { + if err := s.AddSite(context.Background(), models.Site{Name: "Site1", URL: "https://example.com", Type: "http", Interval: 30}); err != nil { t.Fatalf("AddSite: %v", err) } - if err := s.AddUser("user1", "ssh-ed25519 KEY", "user"); err != nil { + if err := s.AddUser(context.Background(), "user1", "ssh-ed25519 KEY", "user"); err != nil { t.Fatalf("AddUser: %v", err) } - backup, err := s.ExportData() + backup, err := s.ExportData(context.Background()) if err != nil { t.Fatalf("ExportData: %v", err) } @@ -214,19 +215,19 @@ func TestImportExport(t *testing.T) { } s2 := newTestStore(t) - if err := s2.ImportData(backup); err != nil { + if err := s2.ImportData(context.Background(), backup); err != nil { t.Fatalf("ImportData: %v", err) } - sites, err := s2.GetSites() + sites, err := s2.GetSites(context.Background()) if err != nil { t.Fatalf("GetSites: %v", err) } - alerts, err := s2.GetAllAlerts() + alerts, err := s2.GetAllAlerts(context.Background()) if err != nil { t.Fatalf("GetAllAlerts: %v", err) } - users, err := s2.GetAllUsers() + users, err := s2.GetAllUsers(context.Background()) if err != nil { t.Fatalf("GetAllUsers: %v", err) } @@ -238,27 +239,27 @@ func TestImportExport(t *testing.T) { func TestImportData_WipesHistory(t *testing.T) { s := newTestStore(t) - if err := s.AddSite(models.Site{Name: "OldSite", URL: "https://old.com", Type: "http", Interval: 30}); err != nil { + if err := s.AddSite(context.Background(), models.Site{Name: "OldSite", URL: "https://old.com", Type: "http", Interval: 30}); err != nil { t.Fatalf("AddSite: %v", err) } - if err := s.SaveCheck(1, 5000, true); err != nil { + if err := s.SaveCheck(context.Background(), 1, 5000, true); err != nil { t.Fatalf("SaveCheck: %v", err) } - if err := s.SaveStateChange(1, "UP", "DOWN", "timeout"); err != nil { + if err := s.SaveStateChange(context.Background(), 1, "UP", "DOWN", "timeout"); err != nil { t.Fatalf("SaveStateChange: %v", err) } - if err := s.SaveAlertHealth(models.AlertHealthRecord{AlertID: 1, LastSendOK: true, SendCount: 1}); err != nil { + if err := s.SaveAlertHealth(context.Background(), models.AlertHealthRecord{AlertID: 1, LastSendOK: true, SendCount: 1}); err != nil { t.Fatalf("SaveAlertHealth: %v", err) } backup := models.Backup{ Sites: []models.Site{{ID: 1, Name: "NewSite", URL: "https://new.com", Type: "http", Interval: 60}}, } - if err := s.ImportData(backup); err != nil { + if err := s.ImportData(context.Background(), backup); err != nil { t.Fatalf("ImportData: %v", err) } - history, err := s.LoadAllHistory(100) + history, err := s.LoadAllHistory(context.Background(), 100) if err != nil { t.Fatalf("LoadAllHistory: %v", err) } @@ -266,7 +267,7 @@ func TestImportData_WipesHistory(t *testing.T) { t.Errorf("expected empty check_history after import, got %d sites with history", len(history)) } - changes, err := s.GetStateChanges(1, 100) + changes, err := s.GetStateChanges(context.Background(), 1, 100) if err != nil { t.Fatalf("GetStateChanges: %v", err) } @@ -278,17 +279,17 @@ func TestImportData_WipesHistory(t *testing.T) { func TestCheckHistory(t *testing.T) { s := newTestStore(t) - if err := s.SaveCheck(1, 5000000, true); err != nil { + if err := s.SaveCheck(context.Background(), 1, 5000000, true); err != nil { t.Fatalf("SaveCheck: %v", err) } - if err := s.SaveCheck(1, 10000000, false); err != nil { + if err := s.SaveCheck(context.Background(), 1, 10000000, false); err != nil { t.Fatalf("SaveCheck: %v", err) } - if err := s.SaveCheck(2, 3000000, true); err != nil { + if err := s.SaveCheck(context.Background(), 2, 3000000, true); err != nil { t.Fatalf("SaveCheck site 2: %v", err) } - history, err := s.LoadAllHistory(10) + history, err := s.LoadAllHistory(context.Background(), 10) if err != nil { t.Fatalf("LoadAllHistory: %v", err) } @@ -314,16 +315,16 @@ func TestDeleteSiteCascade(t *testing.T) { s := newTestStore(t) site := models.Site{Name: "Cascade Test", URL: "https://example.com", Interval: 30} - if err := s.AddSite(site); err != nil { + if err := s.AddSite(context.Background(), site); err != nil { t.Fatalf("AddSite: %v", err) } - sites, _ := s.GetSites() + sites, _ := s.GetSites(context.Background()) siteID := sites[0].ID - if err := s.SaveCheck(siteID, 1000, true); err != nil { + if err := s.SaveCheck(context.Background(), siteID, 1000, true); err != nil { t.Fatalf("SaveCheck: %v", err) } - if err := s.SaveStateChange(siteID, "UP", "DOWN", "timeout"); err != nil { + if err := s.SaveStateChange(context.Background(), siteID, "UP", "DOWN", "timeout"); err != nil { t.Fatalf("SaveStateChange: %v", err) } mw := models.MaintenanceWindow{ @@ -332,25 +333,25 @@ func TestDeleteSiteCascade(t *testing.T) { Type: "maintenance", StartTime: time.Now(), } - if err := s.AddMaintenanceWindow(mw); err != nil { + if err := s.AddMaintenanceWindow(context.Background(), mw); err != nil { t.Fatalf("AddMaintenanceWindow: %v", err) } - if err := s.DeleteSite(siteID); err != nil { + if err := s.DeleteSite(context.Background(), siteID); err != nil { t.Fatalf("DeleteSite: %v", err) } - history, _ := s.LoadAllHistory(100) + history, _ := s.LoadAllHistory(context.Background(), 100) if len(history[siteID]) != 0 { t.Errorf("expected 0 check_history rows, got %d", len(history[siteID])) } - changes, _ := s.GetStateChanges(siteID, 100) + changes, _ := s.GetStateChanges(context.Background(), siteID, 100) if len(changes) != 0 { t.Errorf("expected 0 state_changes rows, got %d", len(changes)) } - windows, _ := s.GetActiveMaintenanceWindows() + windows, _ := s.GetActiveMaintenanceWindows(context.Background()) for _, w := range windows { if w.MonitorID == siteID { t.Errorf("orphaned maintenance window found: id=%d", w.ID) @@ -362,15 +363,15 @@ func TestPruneLogs(t *testing.T) { s := newTestStore(t) for i := 0; i < maxLogRows+50; i++ { - if err := s.SaveLog(fmt.Sprintf("log %d", i)); err != nil { + if err := s.SaveLog(context.Background(), fmt.Sprintf("log %d", i)); err != nil { t.Fatalf("SaveLog: %v", err) } } - if err := s.PruneLogs(); err != nil { + if err := s.PruneLogs(context.Background()); err != nil { t.Fatalf("PruneLogs: %v", err) } - logs, err := s.LoadLogs(maxLogRows * 2) + logs, err := s.LoadLogs(context.Background(), maxLogRows*2) if err != nil { t.Fatalf("LoadLogs: %v", err) } @@ -395,21 +396,21 @@ func TestPruneCheckHistory(t *testing.T) { s := newTestStore(t) for i := 0; i < maxCheckHistory+5; i++ { - if err := s.SaveCheck(1, int64(i), true); err != nil { + if err := s.SaveCheck(context.Background(), 1, int64(i), true); err != nil { t.Fatalf("SaveCheck site 1: %v", err) } } for i := 0; i < 3; i++ { - if err := s.SaveCheck(2, int64(i), true); err != nil { + if err := s.SaveCheck(context.Background(), 2, int64(i), true); err != nil { t.Fatalf("SaveCheck site 2: %v", err) } } - if err := s.PruneCheckHistory(); err != nil { + if err := s.PruneCheckHistory(context.Background()); err != nil { t.Fatalf("PruneCheckHistory: %v", err) } - history, err := s.LoadAllHistory(maxCheckHistory * 2) + history, err := s.LoadAllHistory(context.Background(), maxCheckHistory*2) if err != nil { t.Fatalf("LoadAllHistory: %v", err) } @@ -434,7 +435,7 @@ func TestPruneExpiredMaintenanceWindows(t *testing.T) { StartTime: now.Add(-11 * 24 * time.Hour), EndTime: now.Add(-10 * 24 * time.Hour), } - if err := s.AddMaintenanceWindow(old); err != nil { + if err := s.AddMaintenanceWindow(context.Background(), old); err != nil { t.Fatalf("AddMaintenanceWindow (old): %v", err) } @@ -446,7 +447,7 @@ func TestPruneExpiredMaintenanceWindows(t *testing.T) { StartTime: now.Add(-2 * 24 * time.Hour), EndTime: now.Add(-1 * 24 * time.Hour), } - if err := s.AddMaintenanceWindow(recent); err != nil { + if err := s.AddMaintenanceWindow(context.Background(), recent); err != nil { t.Fatalf("AddMaintenanceWindow (recent): %v", err) } @@ -457,11 +458,11 @@ func TestPruneExpiredMaintenanceWindows(t *testing.T) { Type: "maintenance", StartTime: now.Add(-1 * time.Hour), } - if err := s.AddMaintenanceWindow(ongoing); err != nil { + if err := s.AddMaintenanceWindow(context.Background(), ongoing); err != nil { t.Fatalf("AddMaintenanceWindow (ongoing): %v", err) } - pruned, err := s.PruneExpiredMaintenanceWindows(7 * 24 * time.Hour) + pruned, err := s.PruneExpiredMaintenanceWindows(context.Background(), 7*24*time.Hour) if err != nil { t.Fatalf("PruneExpiredMaintenanceWindows: %v", err) } @@ -469,7 +470,7 @@ func TestPruneExpiredMaintenanceWindows(t *testing.T) { t.Errorf("expected 1 pruned, got %d", pruned) } - all, err := s.GetAllMaintenanceWindows(100) + all, err := s.GetAllMaintenanceWindows(context.Background(), 100) if err != nil { t.Fatalf("GetAllMaintenanceWindows: %v", err) } @@ -498,7 +499,7 @@ func TestImportData_EncryptsAlertSettings(t *testing.T) { {ID: 1, Name: "tg", Type: "telegram", Settings: map[string]string{"token": "123:SECRET", "chat_id": "42"}}, }, } - if err := s.ImportData(backup); err != nil { + if err := s.ImportData(context.Background(), backup); err != nil { t.Fatalf("ImportData: %v", err) } @@ -513,7 +514,7 @@ func TestImportData_EncryptsAlertSettings(t *testing.T) { t.Errorf("plaintext secret found in stored column: %q", raw) } - alerts, err := s.GetAllAlerts() + alerts, err := s.GetAllAlerts(context.Background()) if err != nil { t.Fatalf("GetAllAlerts: %v", err) } diff --git a/internal/store/store.go b/internal/store/store.go index 6d1a29d..ee70a6b 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -1,84 +1,85 @@ package store import ( + "context" "time" "gitea.lerkolabs.com/lerkolabs/uptop/internal/models" ) type Store interface { - Init() error + Init(ctx context.Context) error // Sites - GetSites() ([]models.Site, error) - AddSite(site models.Site) error - UpdateSite(site models.Site) error - UpdateSitePaused(id int, paused bool) error - DeleteSite(id int) error + GetSites(ctx context.Context) ([]models.Site, error) + AddSite(ctx context.Context, site models.Site) error + UpdateSite(ctx context.Context, site models.Site) error + UpdateSitePaused(ctx context.Context, id int, paused bool) error + DeleteSite(ctx context.Context, id int) error // Alerts - GetAllAlerts() ([]models.AlertConfig, error) - GetAlert(id int) (models.AlertConfig, error) - AddAlert(name, aType string, settings map[string]string) error - UpdateAlert(id int, name, aType string, settings map[string]string) error - DeleteAlert(id int) error + GetAllAlerts(ctx context.Context) ([]models.AlertConfig, error) + GetAlert(ctx context.Context, id int) (models.AlertConfig, error) + AddAlert(ctx context.Context, name, aType string, settings map[string]string) error + UpdateAlert(ctx context.Context, id int, name, aType string, settings map[string]string) error + DeleteAlert(ctx context.Context, id int) error // Declarative config support - GetSiteByName(name string) (models.Site, error) - GetAlertByName(name string) (models.AlertConfig, error) - AddSiteReturningID(site models.Site) (int, error) - AddAlertReturningID(name, aType string, settings map[string]string) (int, error) + GetSiteByName(ctx context.Context, name string) (models.Site, error) + GetAlertByName(ctx context.Context, name string) (models.AlertConfig, error) + AddSiteReturningID(ctx context.Context, site models.Site) (int, error) + AddAlertReturningID(ctx context.Context, name, aType string, settings map[string]string) (int, error) // Users - GetAllUsers() ([]models.User, error) - AddUser(username, publicKey, role string) error - UpdateUser(id int, username, publicKey, role string) error - DeleteUser(id int) error + GetAllUsers(ctx context.Context) ([]models.User, error) + AddUser(ctx context.Context, username, publicKey, role string) error + UpdateUser(ctx context.Context, id int, username, publicKey, role string) error + DeleteUser(ctx context.Context, id int) error // History - SaveCheck(siteID int, latencyNs int64, isUp bool) error - SaveCheckFromNode(siteID int, nodeID string, latencyNs int64, isUp bool) error - LoadAllHistory(limit int) (map[int][]models.CheckRecord, error) - PruneCheckHistory() error + SaveCheck(ctx context.Context, siteID int, latencyNs int64, isUp bool) error + SaveCheckFromNode(ctx context.Context, siteID int, nodeID string, latencyNs int64, isUp bool) error + LoadAllHistory(ctx context.Context, limit int) (map[int][]models.CheckRecord, error) + PruneCheckHistory(ctx context.Context) error // State Changes - SaveStateChange(siteID int, fromStatus, toStatus, errorReason string) error - GetStateChanges(siteID int, limit int) ([]models.StateChange, error) - GetStateChangesSince(siteID int, since time.Time) ([]models.StateChange, error) - PruneStateChanges() error + SaveStateChange(ctx context.Context, siteID int, fromStatus, toStatus, errorReason string) error + GetStateChanges(ctx context.Context, siteID int, limit int) ([]models.StateChange, error) + GetStateChangesSince(ctx context.Context, siteID int, since time.Time) ([]models.StateChange, error) + PruneStateChanges(ctx context.Context) error // Nodes - RegisterNode(node models.ProbeNode) error - GetNode(id string) (models.ProbeNode, error) - GetAllNodes() ([]models.ProbeNode, error) - UpdateNodeLastSeen(id string) error - DeleteNode(id string) error + RegisterNode(ctx context.Context, node models.ProbeNode) error + GetNode(ctx context.Context, id string) (models.ProbeNode, error) + GetAllNodes(ctx context.Context) ([]models.ProbeNode, error) + UpdateNodeLastSeen(ctx context.Context, id string) error + DeleteNode(ctx context.Context, id string) error // Alert Health - LoadAlertHealth() (map[int]models.AlertHealthRecord, error) - SaveAlertHealth(h models.AlertHealthRecord) error + LoadAlertHealth(ctx context.Context) (map[int]models.AlertHealthRecord, error) + SaveAlertHealth(ctx context.Context, h models.AlertHealthRecord) error // Logs - SaveLog(message string) error - LoadLogs(limit int) ([]string, error) - PruneLogs() error + SaveLog(ctx context.Context, message string) error + LoadLogs(ctx context.Context, limit int) ([]string, error) + PruneLogs(ctx context.Context) error // Maintenance Windows - GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) - GetAllMaintenanceWindows(limit int) ([]models.MaintenanceWindow, error) - AddMaintenanceWindow(mw models.MaintenanceWindow) error - EndMaintenanceWindow(id int) error - DeleteMaintenanceWindow(id int) error - PruneExpiredMaintenanceWindows(retention time.Duration) (int64, error) - IsMonitorInMaintenance(monitorID int) (bool, error) + GetActiveMaintenanceWindows(ctx context.Context) ([]models.MaintenanceWindow, error) + GetAllMaintenanceWindows(ctx context.Context, limit int) ([]models.MaintenanceWindow, error) + AddMaintenanceWindow(ctx context.Context, mw models.MaintenanceWindow) error + EndMaintenanceWindow(ctx context.Context, id int) error + DeleteMaintenanceWindow(ctx context.Context, id int) error + PruneExpiredMaintenanceWindows(ctx context.Context, retention time.Duration) (int64, error) + IsMonitorInMaintenance(ctx context.Context, monitorID int) (bool, error) // Preferences - GetPreference(key string) (string, error) - SetPreference(key, value string) error + GetPreference(ctx context.Context, key string) (string, error) + SetPreference(ctx context.Context, key, value string) error // Backup & Restore - ExportData() (models.Backup, error) - ImportData(data models.Backup) error + ExportData(ctx context.Context) (models.Backup, error) + ImportData(ctx context.Context, data models.Backup) error // Lifecycle Close() error diff --git a/internal/tui/data.go b/internal/tui/data.go index 46ba093..857fb07 100644 --- a/internal/tui/data.go +++ b/internal/tui/data.go @@ -1,6 +1,7 @@ package tui import ( + "context" "encoding/json" "sort" "strings" @@ -13,7 +14,7 @@ import ( func loadCollapsed(s store.Store) map[int]bool { m := make(map[int]bool) - raw, err := s.GetPreference("collapsed_groups") + raw, err := s.GetPreference(context.Background(), "collapsed_groups") if err != nil || raw == "" { return m } @@ -130,21 +131,22 @@ func (m *Model) loadTabDataCmd() tea.Cmd { st := m.store isAdmin := m.isAdmin return func() tea.Msg { - alerts, err := st.GetAllAlerts() + ctx := context.Background() + alerts, err := st.GetAllAlerts(ctx) if err != nil { return tabDataMsg{seq: seq, err: err} } var users []models.User if isAdmin { - if users, err = st.GetAllUsers(); err != nil { + if users, err = st.GetAllUsers(ctx); err != nil { return tabDataMsg{seq: seq, err: err} } } - nodes, err := st.GetAllNodes() + nodes, err := st.GetAllNodes(ctx) if err != nil { return tabDataMsg{seq: seq, err: err} } - maint, err := st.GetAllMaintenanceWindows(100) + maint, err := st.GetAllMaintenanceWindows(ctx, 100) if err != nil { return tabDataMsg{seq: seq, err: err} } diff --git a/internal/tui/tab_alerts.go b/internal/tui/tab_alerts.go index f7c7e1a..3d96c63 100644 --- a/internal/tui/tab_alerts.go +++ b/internal/tui/tab_alerts.go @@ -1,6 +1,7 @@ package tui import ( + "context" "fmt" neturl "net/url" "sort" @@ -528,10 +529,10 @@ func (m *Model) submitAlertForm() tea.Cmd { m.state = stateDashboard if id > 0 { return writeCmd("Update alert", func() error { - return st.UpdateAlert(id, name, aType, settings) + return st.UpdateAlert(context.Background(), id, name, aType, settings) }) } return writeCmd("Add alert", func() error { - return st.AddAlert(name, aType, settings) + return st.AddAlert(context.Background(), name, aType, settings) }) } diff --git a/internal/tui/tab_maint.go b/internal/tui/tab_maint.go index 132cad1..43dd67c 100644 --- a/internal/tui/tab_maint.go +++ b/internal/tui/tab_maint.go @@ -1,6 +1,7 @@ package tui import ( + "context" "fmt" "strconv" "time" @@ -240,6 +241,6 @@ func (m *Model) submitMaintForm() tea.Cmd { st := m.store m.state = stateDashboard return writeCmd("Add maintenance window", func() error { - return st.AddMaintenanceWindow(mw) + return st.AddMaintenanceWindow(context.Background(), mw) }) } diff --git a/internal/tui/tab_sites.go b/internal/tui/tab_sites.go index 4086512..6c970ee 100644 --- a/internal/tui/tab_sites.go +++ b/internal/tui/tab_sites.go @@ -1,6 +1,7 @@ package tui import ( + "context" "fmt" "net/url" "strconv" @@ -562,7 +563,7 @@ func (m *Model) submitSiteForm() tea.Cmd { // follows in the Cmd. New sites enter the engine via its poll loop // once the insert lands. m.engine.UpdateSiteConfig(site) - return writeCmd("Update site", func() error { return st.UpdateSite(site) }) + return writeCmd("Update site", func() error { return st.UpdateSite(context.Background(), site) }) } - return writeCmd("Add site", func() error { return st.AddSite(site) }) + return writeCmd("Add site", func() error { return st.AddSite(context.Background(), site) }) } diff --git a/internal/tui/tab_users.go b/internal/tui/tab_users.go index 16b45e3..3666357 100644 --- a/internal/tui/tab_users.go +++ b/internal/tui/tab_users.go @@ -1,6 +1,7 @@ package tui import ( + "context" "fmt" tea "github.com/charmbracelet/bubbletea" @@ -118,10 +119,10 @@ func (m *Model) submitUserForm() tea.Cmd { m.state = stateUsers if id > 0 { return writeCmd("Update user", func() error { - return st.UpdateUser(id, username, key, role) + return st.UpdateUser(context.Background(), id, username, key, role) }) } return writeCmd("Add user", func() error { - return st.AddUser(username, key, role) + return st.AddUser(context.Background(), username, key, role) }) } diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 4644bc0..7fab229 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -1,6 +1,7 @@ package tui import ( + "context" "os" "time" @@ -180,7 +181,7 @@ func InitialModel(isAdmin bool, s store.Store, eng *monitor.Engine, version stri spring := harmonica.NewSpring(harmonica.FPS(10), 6.0, 0.4) collapsed := loadCollapsed(s) - themeName, _ := s.GetPreference("theme") + themeName, _ := s.GetPreference(context.Background(), "theme") theme := themeByName(themeName) themeIdx := 0 for i, t := range themes { diff --git a/internal/tui/update.go b/internal/tui/update.go index 605658c..b855a9f 100644 --- a/internal/tui/update.go +++ b/internal/tui/update.go @@ -1,6 +1,7 @@ package tui import ( + "context" "fmt" "time" @@ -78,17 +79,17 @@ func (m *Model) handleConfirmDelete(msg tea.Msg) (tea.Model, tea.Cmd) { var cmd tea.Cmd switch m.deleteTab { case 0: - cmd = writeCmd("Delete site", func() error { return st.DeleteSite(id) }) + cmd = writeCmd("Delete site", func() error { return st.DeleteSite(context.Background(), id) }) m.engine.RemoveSite(id) m.adjustCursor(len(m.sites) - 1) case 1: - cmd = writeCmd("Delete alert", func() error { return st.DeleteAlert(id) }) + cmd = writeCmd("Delete alert", func() error { return st.DeleteAlert(context.Background(), id) }) m.adjustCursor(len(m.alerts) - 1) case 4: - cmd = writeCmd("Delete maintenance window", func() error { return st.DeleteMaintenanceWindow(id) }) + cmd = writeCmd("Delete maintenance window", func() error { return st.DeleteMaintenanceWindow(context.Background(), id) }) m.adjustCursor(len(m.maintenanceWindows) - 1) case 5: - cmd = writeCmd("Delete user", func() error { return st.DeleteUser(id) }) + cmd = writeCmd("Delete user", func() error { return st.DeleteUser(context.Background(), id) }) m.adjustCursor(len(m.users) - 1) } m.refreshLive() @@ -566,7 +567,7 @@ func (m *Model) handleDashboardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { st := m.store m.refreshLive() return m, writeCmd("Save collapsed groups", func() error { - return st.SetPreference("collapsed_groups", payload) + return st.SetPreference(context.Background(), "collapsed_groups", payload) }) } case "p": @@ -576,7 +577,7 @@ func (m *Model) handleDashboardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { st := m.store m.refreshLive() return m, writeCmd("Update pause state", func() error { - return st.UpdateSitePaused(id, paused) + return st.UpdateSitePaused(context.Background(), id, paused) }) } case "i": @@ -596,7 +597,7 @@ func (m *Model) handleDashboardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { id := mw.ID m.refreshLive() return m, writeCmd("End maintenance", func() error { - return st.EndMaintenanceWindow(id) + return st.EndMaintenanceWindow(context.Background(), id) }) } } @@ -607,7 +608,7 @@ func (m *Model) handleDashboardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { st := m.store name := m.theme.Name return m, writeCmd("Save theme", func() error { - return st.SetPreference("theme", name) + return st.SetPreference(context.Background(), "theme", name) }) case "d", "backspace": return m.handleDeleteItem() diff --git a/internal/tui/update_test.go b/internal/tui/update_test.go index ea15ba4..cc82a49 100644 --- a/internal/tui/update_test.go +++ b/internal/tui/update_test.go @@ -1,6 +1,7 @@ package tui import ( + "context" "strings" "testing" "time" @@ -23,80 +24,108 @@ type tuiMockStore struct { deleteSiteCalls int // counts DeleteSite hits (to prove writes run in Cmds) } -func (m *tuiMockStore) GetAllAlerts() ([]models.AlertConfig, error) { return m.alerts, nil } -func (m *tuiMockStore) GetAllUsers() ([]models.User, error) { return m.users, nil } -func (m *tuiMockStore) GetAllNodes() ([]models.ProbeNode, error) { return m.nodes, nil } -func (m *tuiMockStore) GetStateChanges(int, int) ([]models.StateChange, error) { +func (m *tuiMockStore) GetAllAlerts(_ context.Context) ([]models.AlertConfig, error) { + return m.alerts, nil +} +func (m *tuiMockStore) GetAllUsers(_ context.Context) ([]models.User, error) { return m.users, nil } +func (m *tuiMockStore) GetAllNodes(_ context.Context) ([]models.ProbeNode, error) { + return m.nodes, nil +} +func (m *tuiMockStore) GetStateChanges(_ context.Context, _ int, _ int) ([]models.StateChange, error) { m.stateChangeCalls++ return m.stateChanges, nil } -func (m *tuiMockStore) GetAllMaintenanceWindows(int) ([]models.MaintenanceWindow, error) { +func (m *tuiMockStore) GetAllMaintenanceWindows(_ context.Context, _ int) ([]models.MaintenanceWindow, error) { return m.maint, nil } -func (m *tuiMockStore) Init() error { return nil } -func (m *tuiMockStore) GetSites() ([]models.Site, error) { return nil, nil } -func (m *tuiMockStore) AddSite(models.Site) error { return nil } -func (m *tuiMockStore) UpdateSite(models.Site) error { return nil } -func (m *tuiMockStore) UpdateSitePaused(int, bool) error { return nil } -func (m *tuiMockStore) DeleteSite(int) error { +func (m *tuiMockStore) Init(_ context.Context) error { return nil } +func (m *tuiMockStore) GetSites(_ context.Context) ([]models.Site, error) { return nil, nil } +func (m *tuiMockStore) AddSite(_ context.Context, _ models.Site) error { return nil } +func (m *tuiMockStore) UpdateSite(_ context.Context, _ models.Site) error { return nil } +func (m *tuiMockStore) UpdateSitePaused(_ context.Context, _ int, _ bool) error { return nil } +func (m *tuiMockStore) DeleteSite(_ context.Context, _ int) error { m.deleteSiteCalls++ return nil } -func (m *tuiMockStore) GetAlert(int) (models.AlertConfig, error) { return models.AlertConfig{}, nil } -func (m *tuiMockStore) AddAlert(string, string, map[string]string) error { return nil } -func (m *tuiMockStore) UpdateAlert(int, string, string, map[string]string) error { return nil } -func (m *tuiMockStore) DeleteAlert(int) error { return nil } -func (m *tuiMockStore) GetSiteByName(string) (models.Site, error) { return models.Site{}, nil } -func (m *tuiMockStore) GetAlertByName(string) (models.AlertConfig, error) { +func (m *tuiMockStore) GetAlert(_ context.Context, _ int) (models.AlertConfig, error) { return models.AlertConfig{}, nil } -func (m *tuiMockStore) AddSiteReturningID(models.Site) (int, error) { return 0, nil } -func (m *tuiMockStore) AddAlertReturningID(string, string, map[string]string) (int, error) { - return 0, nil -} -func (m *tuiMockStore) AddUser(string, string, string) error { return nil } -func (m *tuiMockStore) UpdateUser(int, string, string, string) error { return nil } -func (m *tuiMockStore) DeleteUser(int) error { return nil } -func (m *tuiMockStore) SaveCheck(int, int64, bool) error { return nil } -func (m *tuiMockStore) SaveCheckFromNode(int, string, int64, bool) error { +func (m *tuiMockStore) AddAlert(_ context.Context, _ string, _ string, _ map[string]string) error { return nil } -func (m *tuiMockStore) LoadAllHistory(int) (map[int][]models.CheckRecord, error) { - return nil, nil +func (m *tuiMockStore) UpdateAlert(_ context.Context, _ int, _ string, _ string, _ map[string]string) error { + return nil } -func (m *tuiMockStore) PruneCheckHistory() error { return nil } -func (m *tuiMockStore) SaveStateChange(int, string, string, string) error { return nil } -func (m *tuiMockStore) GetStateChangesSince(int, time.Time) ([]models.StateChange, error) { - return nil, nil +func (m *tuiMockStore) DeleteAlert(_ context.Context, _ int) error { return nil } +func (m *tuiMockStore) GetSiteByName(_ context.Context, _ string) (models.Site, error) { + return models.Site{}, nil } -func (m *tuiMockStore) PruneStateChanges() error { return nil } -func (m *tuiMockStore) RegisterNode(models.ProbeNode) error { return nil } -func (m *tuiMockStore) GetNode(string) (models.ProbeNode, error) { return models.ProbeNode{}, nil } -func (m *tuiMockStore) UpdateNodeLastSeen(string) error { return nil } -func (m *tuiMockStore) DeleteNode(string) error { return nil } -func (m *tuiMockStore) LoadAlertHealth() (map[int]models.AlertHealthRecord, error) { - return nil, nil +func (m *tuiMockStore) GetAlertByName(_ context.Context, _ string) (models.AlertConfig, error) { + return models.AlertConfig{}, nil } -func (m *tuiMockStore) SaveAlertHealth(models.AlertHealthRecord) error { return nil } -func (m *tuiMockStore) SaveLog(string) error { return nil } -func (m *tuiMockStore) LoadLogs(int) ([]string, error) { return nil, nil } -func (m *tuiMockStore) PruneLogs() error { return nil } -func (m *tuiMockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) { - return nil, nil -} -func (m *tuiMockStore) AddMaintenanceWindow(models.MaintenanceWindow) error { return nil } -func (m *tuiMockStore) EndMaintenanceWindow(int) error { return nil } -func (m *tuiMockStore) DeleteMaintenanceWindow(int) error { return nil } -func (m *tuiMockStore) PruneExpiredMaintenanceWindows(time.Duration) (int64, error) { +func (m *tuiMockStore) AddSiteReturningID(_ context.Context, _ models.Site) (int, error) { return 0, nil } -func (m *tuiMockStore) IsMonitorInMaintenance(int) (bool, error) { return false, nil } -func (m *tuiMockStore) GetPreference(string) (string, error) { return "", nil } -func (m *tuiMockStore) SetPreference(string, string) error { return nil } -func (m *tuiMockStore) ExportData() (models.Backup, error) { return models.Backup{}, nil } -func (m *tuiMockStore) ImportData(models.Backup) error { return nil } -func (m *tuiMockStore) Close() error { return nil } +func (m *tuiMockStore) AddAlertReturningID(_ context.Context, _ string, _ string, _ map[string]string) (int, error) { + return 0, nil +} +func (m *tuiMockStore) AddUser(_ context.Context, _ string, _ string, _ string) error { return nil } +func (m *tuiMockStore) UpdateUser(_ context.Context, _ int, _ string, _ string, _ string) error { + return nil +} +func (m *tuiMockStore) DeleteUser(_ context.Context, _ int) error { return nil } +func (m *tuiMockStore) SaveCheck(_ context.Context, _ int, _ int64, _ bool) error { return nil } +func (m *tuiMockStore) SaveCheckFromNode(_ context.Context, _ int, _ string, _ int64, _ bool) error { + return nil +} +func (m *tuiMockStore) LoadAllHistory(_ context.Context, _ int) (map[int][]models.CheckRecord, error) { + return nil, nil +} +func (m *tuiMockStore) PruneCheckHistory(_ context.Context) error { return nil } +func (m *tuiMockStore) SaveStateChange(_ context.Context, _ int, _ string, _ string, _ string) error { + return nil +} +func (m *tuiMockStore) GetStateChangesSince(_ context.Context, _ int, _ time.Time) ([]models.StateChange, error) { + return nil, nil +} +func (m *tuiMockStore) PruneStateChanges(_ context.Context) error { return nil } +func (m *tuiMockStore) RegisterNode(_ context.Context, _ models.ProbeNode) error { return nil } +func (m *tuiMockStore) GetNode(_ context.Context, _ string) (models.ProbeNode, error) { + return models.ProbeNode{}, nil +} +func (m *tuiMockStore) UpdateNodeLastSeen(_ context.Context, _ string) error { return nil } +func (m *tuiMockStore) DeleteNode(_ context.Context, _ string) error { return nil } +func (m *tuiMockStore) LoadAlertHealth(_ context.Context) (map[int]models.AlertHealthRecord, error) { + return nil, nil +} +func (m *tuiMockStore) SaveAlertHealth(_ context.Context, _ models.AlertHealthRecord) error { + return nil +} +func (m *tuiMockStore) SaveLog(_ context.Context, _ string) error { return nil } +func (m *tuiMockStore) LoadLogs(_ context.Context, _ int) ([]string, error) { return nil, nil } +func (m *tuiMockStore) PruneLogs(_ context.Context) error { return nil } +func (m *tuiMockStore) GetActiveMaintenanceWindows(_ context.Context) ([]models.MaintenanceWindow, error) { + return nil, nil +} +func (m *tuiMockStore) AddMaintenanceWindow(_ context.Context, _ models.MaintenanceWindow) error { + return nil +} +func (m *tuiMockStore) EndMaintenanceWindow(_ context.Context, _ int) error { return nil } +func (m *tuiMockStore) DeleteMaintenanceWindow(_ context.Context, _ int) error { return nil } +func (m *tuiMockStore) PruneExpiredMaintenanceWindows(_ context.Context, _ time.Duration) (int64, error) { + return 0, nil +} +func (m *tuiMockStore) IsMonitorInMaintenance(_ context.Context, _ int) (bool, error) { + return false, nil +} +func (m *tuiMockStore) GetPreference(_ context.Context, _ string) (string, error) { return "", nil } +func (m *tuiMockStore) SetPreference(_ context.Context, _ string, _ string) error { return nil } +func (m *tuiMockStore) ExportData(_ context.Context) (models.Backup, error) { + return models.Backup{}, nil +} +func (m *tuiMockStore) ImportData(_ context.Context, _ models.Backup) error { return nil } +func (m *tuiMockStore) Close() error { return nil } func newTestModel(ms *tuiMockStore) Model { return Model{ -- 2.52.0 From c3ae0bd80a2f4ed90541e7c5d081da1ed1424c9c Mon Sep 17 00:00:00 2001 From: Tyler Koenig Date: Thu, 11 Jun 2026 14:41:03 -0400 Subject: [PATCH 2/2] fix(store): migrate Postgres timestamps to TIMESTAMPTZ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All 8 TIMESTAMP columns in Postgres CREATE TABLE statements changed to TIMESTAMPTZ. Migration ALTER TYPE statements added for existing databases (converts assuming stored values are UTC). Prevents timezone-shifted instants on non-UTC Postgres servers, which would skew SLA math and maintenance-window checks. SQLite unaffected — DATETIME is typeless. --- internal/store/postgres.go | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/internal/store/postgres.go b/internal/store/postgres.go index 4b4a5b0..1be31ca 100644 --- a/internal/store/postgres.go +++ b/internal/store/postgres.go @@ -42,20 +42,20 @@ func (d *PostgresDialect) CreateTablesSQL() []string { `CREATE TABLE IF NOT EXISTS check_history ( id SERIAL PRIMARY KEY, site_id INTEGER NOT NULL, latency_ns BIGINT, - is_up BOOLEAN, checked_at TIMESTAMP DEFAULT NOW() + is_up BOOLEAN, checked_at TIMESTAMPTZ DEFAULT NOW() )`, `CREATE INDEX IF NOT EXISTS idx_check_history_site ON check_history(site_id, checked_at DESC)`, `CREATE TABLE IF NOT EXISTS nodes ( id TEXT PRIMARY KEY, name TEXT NOT NULL, region TEXT DEFAULT '', - last_seen TIMESTAMP DEFAULT NOW(), + last_seen TIMESTAMPTZ DEFAULT NOW(), version TEXT DEFAULT '' )`, `CREATE TABLE IF NOT EXISTS logs ( id SERIAL PRIMARY KEY, message TEXT NOT NULL, - created_at TIMESTAMP DEFAULT NOW() + created_at TIMESTAMPTZ DEFAULT NOW() )`, `CREATE TABLE IF NOT EXISTS maintenance_windows ( id SERIAL PRIMARY KEY, @@ -63,10 +63,10 @@ func (d *PostgresDialect) CreateTablesSQL() []string { title TEXT NOT NULL, description TEXT DEFAULT '', type TEXT DEFAULT 'maintenance', - start_time TIMESTAMP NOT NULL, - end_time TIMESTAMP, + start_time TIMESTAMPTZ NOT NULL, + end_time TIMESTAMPTZ, created_by TEXT DEFAULT '', - created_at TIMESTAMP DEFAULT NOW() + created_at TIMESTAMPTZ DEFAULT NOW() )`, `CREATE TABLE IF NOT EXISTS preferences ( key TEXT PRIMARY KEY, @@ -78,12 +78,12 @@ func (d *PostgresDialect) CreateTablesSQL() []string { from_status TEXT NOT NULL, to_status TEXT NOT NULL, error_reason TEXT DEFAULT '', - changed_at TIMESTAMP DEFAULT NOW() + changed_at TIMESTAMPTZ DEFAULT NOW() )`, `CREATE INDEX IF NOT EXISTS idx_state_changes_site ON state_changes(site_id, changed_at DESC)`, `CREATE TABLE IF NOT EXISTS alert_health ( alert_id INTEGER PRIMARY KEY, - last_send_at TIMESTAMP, + last_send_at TIMESTAMPTZ, last_send_ok BOOLEAN DEFAULT FALSE, last_error TEXT DEFAULT '', send_count INTEGER DEFAULT 0, @@ -107,6 +107,14 @@ func (d *PostgresDialect) MigrationsSQL() []string { "ALTER TABLE sites ADD COLUMN IF NOT EXISTS paused BOOLEAN DEFAULT FALSE", "ALTER TABLE check_history ADD COLUMN IF NOT EXISTS node_id TEXT DEFAULT ''", "ALTER TABLE sites ADD COLUMN IF NOT EXISTS regions TEXT DEFAULT ''", + "ALTER TABLE check_history ALTER COLUMN checked_at TYPE TIMESTAMPTZ USING checked_at AT TIME ZONE 'UTC'", + "ALTER TABLE nodes ALTER COLUMN last_seen TYPE TIMESTAMPTZ USING last_seen AT TIME ZONE 'UTC'", + "ALTER TABLE logs ALTER COLUMN created_at TYPE TIMESTAMPTZ USING created_at AT TIME ZONE 'UTC'", + "ALTER TABLE maintenance_windows ALTER COLUMN start_time TYPE TIMESTAMPTZ USING start_time AT TIME ZONE 'UTC'", + "ALTER TABLE maintenance_windows ALTER COLUMN end_time TYPE TIMESTAMPTZ USING end_time AT TIME ZONE 'UTC'", + "ALTER TABLE maintenance_windows ALTER COLUMN created_at TYPE TIMESTAMPTZ USING created_at AT TIME ZONE 'UTC'", + "ALTER TABLE state_changes ALTER COLUMN changed_at TYPE TIMESTAMPTZ USING changed_at AT TIME ZONE 'UTC'", + "ALTER TABLE alert_health ALTER COLUMN last_send_at TYPE TIMESTAMPTZ USING last_send_at AT TIME ZONE 'UTC'", } } -- 2.52.0