refactor(store): propagate context.Context through all Store methods

Every Store interface method (except Close) now takes context.Context
as first parameter. All 54 db.Query/Exec/QueryRow calls in SQLStore
replaced with their *Context variants. DB operations now respect
cancellation and deadlines.

Context sources by caller:
- Engine dbWriter/poll/pruner: engine ctx from Start()
- HTTP handlers: r.Context()
- config.Apply/Export: caller-provided ctx
- TUI/main.go init: context.Background()

RunCheck and all sub-checks (HTTP/ping/port/DNS) accept parent ctx.
HTTP checks now inherit shutdown cancellation instead of rooting in
context.Background(). dbWrite.exec takes ctx so the writer goroutine
can cancel stuck DB operations.

DeleteSite/ImportData use BeginTx(ctx) instead of Begin().
This commit is contained in:
2026-06-11 14:40:30 -04:00
parent 5d5153351e
commit 70a83a1da9
28 changed files with 813 additions and 677 deletions
+4 -3
View File
@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"errors" "errors"
@@ -22,8 +23,8 @@ type kcMockStore struct {
err error err error
} }
func (m *kcMockStore) GetAllUsers() ([]models.User, error) { return m.users, m.err } func (m *kcMockStore) GetAllUsers(_ context.Context) ([]models.User, error) { return m.users, m.err }
func (m *kcMockStore) DeleteUser(int) error { return nil } func (m *kcMockStore) DeleteUser(_ context.Context, _ int) error { return nil }
func testKey(t *testing.T) (string, ssh.PublicKey) { func testKey(t *testing.T) (string, ssh.PublicKey) {
t.Helper() t.Helper()
@@ -103,7 +104,7 @@ func TestUserInvalidatingStore_DeleteDropsKeyCache(t *testing.T) {
// Revoke the user; DB unreachable immediately after. The cached key must // Revoke the user; DB unreachable immediately after. The cached key must
// be gone the moment the delete returns. // be gone the moment the delete returns.
if err := s.DeleteUser(1); err != nil { if err := s.DeleteUser(context.Background(), 1); err != nil {
t.Fatal(err) t.Fatal(err)
} }
ms.users = nil ms.users = nil
+27 -25
View File
@@ -141,7 +141,7 @@ func openStore(dbType, dsn string) store.Store {
} else { } else {
fmt.Println("WARNING: No UPTOP_ENCRYPTION_KEY set. Alert credentials stored unencrypted.") fmt.Println("WARNING: No UPTOP_ENCRYPTION_KEY set. Alert credentials stored unencrypted.")
} }
if err := ss.Init(); err != nil { if err := ss.Init(context.Background()); err != nil {
fmt.Fprintf(os.Stderr, "database init error: %v\n", err) fmt.Fprintf(os.Stderr, "database init error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
@@ -171,7 +171,7 @@ func runApply(args []string) {
os.Exit(1) os.Exit(1)
} }
changes, err := config.Apply(s, f, config.ApplyOpts{ changes, err := config.Apply(context.Background(), s, f, config.ApplyOpts{
DryRun: *dryRun, DryRun: *dryRun,
Prune: *prune, Prune: *prune,
}) })
@@ -192,7 +192,7 @@ func runExport(args []string) {
s := openStore(*dbType, *dsn) s := openStore(*dbType, *dsn)
f, err := config.Export(s) f, err := config.Export(context.Background(), s)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err) fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1) os.Exit(1)
@@ -231,12 +231,12 @@ func runMigrateSecrets(args []string) {
fmt.Fprintf(os.Stderr, "database error: %v\n", err) fmt.Fprintf(os.Stderr, "database error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
if err := ss.Init(); err != nil { if err := ss.Init(context.Background()); err != nil {
fmt.Fprintf(os.Stderr, "database init error: %v\n", err) fmt.Fprintf(os.Stderr, "database init error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
alerts, err := ss.GetAllAlerts() alerts, err := ss.GetAllAlerts(context.Background())
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "error loading alerts: %v\n", err) fmt.Fprintf(os.Stderr, "error loading alerts: %v\n", err)
os.Exit(1) os.Exit(1)
@@ -245,7 +245,7 @@ func runMigrateSecrets(args []string) {
ss.SetEncryptor(enc) ss.SetEncryptor(enc)
migrated := 0 migrated := 0
for _, a := range alerts { for _, a := range alerts {
if err := ss.UpdateAlert(a.ID, a.Name, a.Type, a.Settings); err != nil { if err := ss.UpdateAlert(context.Background(), a.ID, a.Name, a.Type, a.Settings); err != nil {
fmt.Fprintf(os.Stderr, "error migrating alert %q: %v\n", a.Name, err) fmt.Fprintf(os.Stderr, "error migrating alert %q: %v\n", a.Name, err)
os.Exit(1) os.Exit(1)
} }
@@ -378,7 +378,7 @@ func runServe(args []string) {
kc := newKeyCache(ss) kc := newKeyCache(ss)
var s store.Store = &userInvalidatingStore{Store: ss, kc: kc} var s store.Store = &userInvalidatingStore{Store: ss, kc: kc}
if err := s.Init(); err != nil { if err := s.Init(context.Background()); err != nil {
fmt.Fprintf(os.Stderr, "database init error: %v\n", err) fmt.Fprintf(os.Stderr, "database init error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
@@ -395,7 +395,7 @@ func runServe(args []string) {
os.Exit(1) os.Exit(1)
} }
backup := importer.ConvertKuma(kb) backup := importer.ConvertKuma(kb)
if err := s.ImportData(backup); err != nil { if err := s.ImportData(context.Background(), backup); err != nil {
fmt.Fprintf(os.Stderr, "import failed: %v\n", err) fmt.Fprintf(os.Stderr, "import failed: %v\n", err)
os.Exit(1) os.Exit(1)
} }
@@ -515,21 +515,22 @@ func startSSHServer(port int, db store.Store, eng *monitor.Engine, kc *keyCache)
} }
func seedDemoData(s store.Store) { func seedDemoData(s store.Store) {
existing, _ := s.GetSites() ctx := context.Background()
existing, _ := s.GetSites(ctx)
if len(existing) > 0 { if len(existing) > 0 {
return return
} }
fmt.Println("Seeding demo data...") fmt.Println("Seeding demo data...")
if err := s.AddAlert("Discord Ops", "discord", map[string]string{"url": "https://discord.com/api/webhooks/demo/token"}); err != nil { if err := s.AddAlert(ctx, "Discord Ops", "discord", map[string]string{"url": "https://discord.com/api/webhooks/demo/token"}); err != nil {
log.Printf("demo seed: add alert: %v", err) log.Printf("demo seed: add alert: %v", err)
return return
} }
if err := s.AddAlert("Slack Infra", "slack", map[string]string{"url": "https://hooks.slack.com/services/DEMO/WEBHOOK"}); err != nil { if err := s.AddAlert(ctx, "Slack Infra", "slack", map[string]string{"url": "https://hooks.slack.com/services/DEMO/WEBHOOK"}); err != nil {
log.Printf("demo seed: add alert: %v", err) log.Printf("demo seed: add alert: %v", err)
return return
} }
if err := s.AddAlert("Email Oncall", "email", map[string]string{ if err := s.AddAlert(ctx, "Email Oncall", "email", map[string]string{
"host": "smtp.example.com", "port": "587", "host": "smtp.example.com", "port": "587",
"user": "oncall@example.com", "pass": "replace-me", "user": "oncall@example.com", "pass": "replace-me",
"from": "oncall@example.com", "to": "team@example.com", "from": "oncall@example.com", "to": "team@example.com",
@@ -538,7 +539,7 @@ func seedDemoData(s store.Store) {
return return
} }
alerts, _ := s.GetAllAlerts() alerts, _ := s.GetAllAlerts(ctx)
alertID := 0 alertID := 0
if len(alerts) > 0 { if len(alerts) > 0 {
alertID = alerts[0].ID alertID = alerts[0].ID
@@ -557,7 +558,7 @@ func seedDemoData(s store.Store) {
{Name: "SSH Server", Type: "port", Interval: 60, AlertID: alertID, Hostname: "10.0.0.1", Port: 22, Timeout: 5, ExpiryThreshold: 7}, {Name: "SSH Server", Type: "port", Interval: 60, AlertID: alertID, Hostname: "10.0.0.1", Port: 22, Timeout: 5, ExpiryThreshold: 7},
} }
for _, site := range demoSites { for _, site := range demoSites {
if err := s.AddSite(site); err != nil { if err := s.AddSite(ctx, site); err != nil {
log.Printf("demo seed: add site %q: %v", site.Name, err) log.Printf("demo seed: add site %q: %v", site.Name, err)
} }
} }
@@ -576,7 +577,7 @@ func newKeyCache(db store.Store) *keyCache {
} }
func (c *keyCache) refresh() { func (c *keyCache) refresh() {
users, err := c.db.GetAllUsers() users, err := c.db.GetAllUsers(context.Background())
if err != nil { if err != nil {
// Keep the previous key set: a transient DB error must not lock every // Keep the previous key set: a transient DB error must not lock every
// admin out. Revocation still fails closed because Invalidate clears // admin out. Revocation still fails closed because Invalidate clears
@@ -637,31 +638,32 @@ type userInvalidatingStore struct {
kc *keyCache kc *keyCache
} }
func (s *userInvalidatingStore) AddUser(username, publicKey, role string) error { func (s *userInvalidatingStore) AddUser(ctx context.Context, username, publicKey, role string) error {
err := s.Store.AddUser(username, publicKey, role) err := s.Store.AddUser(ctx, username, publicKey, role)
s.kc.Invalidate() s.kc.Invalidate()
return err return err
} }
func (s *userInvalidatingStore) UpdateUser(id int, username, publicKey, role string) error { func (s *userInvalidatingStore) UpdateUser(ctx context.Context, id int, username, publicKey, role string) error {
err := s.Store.UpdateUser(id, username, publicKey, role) err := s.Store.UpdateUser(ctx, id, username, publicKey, role)
s.kc.Invalidate() s.kc.Invalidate()
return err return err
} }
func (s *userInvalidatingStore) DeleteUser(id int) error { func (s *userInvalidatingStore) DeleteUser(ctx context.Context, id int) error {
err := s.Store.DeleteUser(id) err := s.Store.DeleteUser(ctx, id)
s.kc.Invalidate() s.kc.Invalidate()
return err return err
} }
func (s *userInvalidatingStore) ImportData(data models.Backup) error { func (s *userInvalidatingStore) ImportData(ctx context.Context, data models.Backup) error {
err := s.Store.ImportData(data) err := s.Store.ImportData(ctx, data)
s.kc.Invalidate() s.kc.Invalidate()
return err return err
} }
func seedKeysFromEnv(s store.Store) { func seedKeysFromEnv(s store.Store) {
ctx := context.Background()
var keys []string var keys []string
if v := os.Getenv("UPTOP_ADMIN_KEY"); v != "" { if v := os.Getenv("UPTOP_ADMIN_KEY"); v != "" {
@@ -687,7 +689,7 @@ func seedKeysFromEnv(s store.Store) {
return return
} }
existing, err := s.GetAllUsers() existing, err := s.GetAllUsers(ctx)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "warning: could not check existing users: %v\n", err) fmt.Fprintf(os.Stderr, "warning: could not check existing users: %v\n", err)
return return
@@ -705,7 +707,7 @@ func seedKeysFromEnv(s store.Store) {
} }
username := usernameFromKey(key, i, len(existing)+added) username := usernameFromKey(key, i, len(existing)+added)
if err := s.AddUser(username, key, "admin"); err != nil { if err := s.AddUser(ctx, username, key, "admin"); err != nil {
fmt.Fprintf(os.Stderr, "warning: failed to seed user %q: %v\n", username, err) fmt.Fprintf(os.Stderr, "warning: failed to seed user %q: %v\n", username, err)
continue continue
} }
+72 -48
View File
@@ -20,64 +20,88 @@ type mockStore struct {
sites []models.Site sites []models.Site
} }
func (m *mockStore) Init() error { return nil } func (m *mockStore) Init(_ context.Context) error { return nil }
func (m *mockStore) GetSites() ([]models.Site, error) { return m.sites, nil } func (m *mockStore) GetSites(_ context.Context) ([]models.Site, error) { return m.sites, nil }
func (m *mockStore) AddSite(models.Site) error { return nil } func (m *mockStore) AddSite(_ context.Context, _ models.Site) error { return nil }
func (m *mockStore) UpdateSite(models.Site) error { return nil } func (m *mockStore) UpdateSite(_ context.Context, _ models.Site) error { return nil }
func (m *mockStore) UpdateSitePaused(int, bool) error { return nil } func (m *mockStore) UpdateSitePaused(_ context.Context, _ int, _ bool) error { return nil }
func (m *mockStore) DeleteSite(int) error { return nil } func (m *mockStore) DeleteSite(_ context.Context, _ int) error { return nil }
func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { return nil, nil } func (m *mockStore) GetAllAlerts(_ context.Context) ([]models.AlertConfig, error) { return nil, nil }
func (m *mockStore) GetAlert(int) (models.AlertConfig, error) { return models.AlertConfig{}, nil } func (m *mockStore) GetAlert(_ context.Context, _ int) (models.AlertConfig, error) {
func (m *mockStore) AddAlert(string, string, map[string]string) error { return nil }
func (m *mockStore) UpdateAlert(int, string, string, map[string]string) error { return nil }
func (m *mockStore) DeleteAlert(int) error { return nil }
func (m *mockStore) GetAllUsers() ([]models.User, error) { return nil, nil }
func (m *mockStore) AddUser(string, string, string) error { return nil }
func (m *mockStore) UpdateUser(int, string, string, string) error { return nil }
func (m *mockStore) DeleteUser(int) error { return nil }
func (m *mockStore) SaveCheck(int, int64, bool) error { return nil }
func (m *mockStore) SaveCheckFromNode(int, string, int64, bool) error { return nil }
func (m *mockStore) LoadAllHistory(int) (map[int][]models.CheckRecord, error) { return nil, nil }
func (m *mockStore) ExportData() (models.Backup, error) { return models.Backup{}, nil }
func (m *mockStore) ImportData(models.Backup) error { return nil }
func (m *mockStore) GetSiteByName(string) (models.Site, error) { return models.Site{}, nil }
func (m *mockStore) GetAlertByName(string) (models.AlertConfig, error) {
return models.AlertConfig{}, nil return models.AlertConfig{}, nil
} }
func (m *mockStore) AddSiteReturningID(models.Site) (int, error) { return 0, nil } func (m *mockStore) AddAlert(_ context.Context, _ string, _ string, _ map[string]string) error {
func (m *mockStore) AddAlertReturningID(string, string, map[string]string) (int, error) { return nil
}
func (m *mockStore) UpdateAlert(_ context.Context, _ int, _ string, _ string, _ map[string]string) error {
return nil
}
func (m *mockStore) DeleteAlert(_ context.Context, _ int) error { return nil }
func (m *mockStore) GetAllUsers(_ context.Context) ([]models.User, error) { return nil, nil }
func (m *mockStore) AddUser(_ context.Context, _ string, _ string, _ string) error { return nil }
func (m *mockStore) UpdateUser(_ context.Context, _ int, _ string, _ string, _ string) error {
return nil
}
func (m *mockStore) DeleteUser(_ context.Context, _ int) error { return nil }
func (m *mockStore) SaveCheck(_ context.Context, _ int, _ int64, _ bool) error { return nil }
func (m *mockStore) SaveCheckFromNode(_ context.Context, _ int, _ string, _ int64, _ bool) error {
return nil
}
func (m *mockStore) LoadAllHistory(_ context.Context, _ int) (map[int][]models.CheckRecord, error) {
return nil, nil
}
func (m *mockStore) ExportData(_ context.Context) (models.Backup, error) { return models.Backup{}, nil }
func (m *mockStore) ImportData(_ context.Context, _ models.Backup) error { return nil }
func (m *mockStore) GetSiteByName(_ context.Context, _ string) (models.Site, error) {
return models.Site{}, nil
}
func (m *mockStore) GetAlertByName(_ context.Context, _ string) (models.AlertConfig, error) {
return models.AlertConfig{}, nil
}
func (m *mockStore) AddSiteReturningID(_ context.Context, _ models.Site) (int, error) { return 0, nil }
func (m *mockStore) AddAlertReturningID(_ context.Context, _ string, _ string, _ map[string]string) (int, error) {
return 0, nil return 0, nil
} }
func (m *mockStore) RegisterNode(models.ProbeNode) error { return nil } func (m *mockStore) RegisterNode(_ context.Context, _ models.ProbeNode) error { return nil }
func (m *mockStore) GetNode(string) (models.ProbeNode, error) { return models.ProbeNode{}, nil } func (m *mockStore) GetNode(_ context.Context, _ string) (models.ProbeNode, error) {
func (m *mockStore) GetAllNodes() ([]models.ProbeNode, error) { return nil, nil } return models.ProbeNode{}, nil
func (m *mockStore) UpdateNodeLastSeen(string) error { return nil } }
func (m *mockStore) DeleteNode(string) error { return nil } func (m *mockStore) GetAllNodes(_ context.Context) ([]models.ProbeNode, error) { return nil, nil }
func (m *mockStore) LoadAlertHealth() (map[int]models.AlertHealthRecord, error) { func (m *mockStore) UpdateNodeLastSeen(_ context.Context, _ string) error { return nil }
func (m *mockStore) DeleteNode(_ context.Context, _ string) error { return nil }
func (m *mockStore) LoadAlertHealth(_ context.Context) (map[int]models.AlertHealthRecord, error) {
return nil, nil return nil, nil
} }
func (m *mockStore) SaveAlertHealth(models.AlertHealthRecord) error { return nil } func (m *mockStore) SaveAlertHealth(_ context.Context, _ models.AlertHealthRecord) error { return nil }
func (m *mockStore) SaveLog(string) error { return nil } func (m *mockStore) SaveLog(_ context.Context, _ string) error { return nil }
func (m *mockStore) PruneLogs() error { return nil } func (m *mockStore) PruneLogs(_ context.Context) error { return nil }
func (m *mockStore) PruneCheckHistory() error { return nil } func (m *mockStore) PruneCheckHistory(_ context.Context) error { return nil }
func (m *mockStore) PruneStateChanges() error { return nil } func (m *mockStore) PruneStateChanges(_ context.Context) error { return nil }
func (m *mockStore) LoadLogs(int) ([]string, error) { return nil, nil } func (m *mockStore) LoadLogs(_ context.Context, _ int) ([]string, error) { return nil, nil }
func (m *mockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) { func (m *mockStore) GetActiveMaintenanceWindows(_ context.Context) ([]models.MaintenanceWindow, error) {
return nil, nil return nil, nil
} }
func (m *mockStore) GetAllMaintenanceWindows(int) ([]models.MaintenanceWindow, error) { func (m *mockStore) GetAllMaintenanceWindows(_ context.Context, _ int) ([]models.MaintenanceWindow, error) {
return nil, nil return nil, nil
} }
func (m *mockStore) AddMaintenanceWindow(models.MaintenanceWindow) error { return nil } func (m *mockStore) AddMaintenanceWindow(_ context.Context, _ models.MaintenanceWindow) error {
func (m *mockStore) EndMaintenanceWindow(int) error { return nil } return nil
func (m *mockStore) DeleteMaintenanceWindow(int) error { return nil } }
func (m *mockStore) PruneExpiredMaintenanceWindows(time.Duration) (int64, error) { return 0, nil } func (m *mockStore) EndMaintenanceWindow(_ context.Context, _ int) error { return nil }
func (m *mockStore) IsMonitorInMaintenance(int) (bool, error) { return false, nil } func (m *mockStore) DeleteMaintenanceWindow(_ context.Context, _ int) error { return nil }
func (m *mockStore) GetPreference(string) (string, error) { return "", nil } func (m *mockStore) PruneExpiredMaintenanceWindows(_ context.Context, _ time.Duration) (int64, error) {
func (m *mockStore) SetPreference(string, string) error { return nil } return 0, nil
func (m *mockStore) SaveStateChange(int, string, string, string) error { return nil } }
func (m *mockStore) GetStateChanges(int, int) ([]models.StateChange, error) { return nil, nil } func (m *mockStore) IsMonitorInMaintenance(_ context.Context, _ int) (bool, error) { return false, nil }
func (m *mockStore) GetStateChangesSince(int, time.Time) ([]models.StateChange, error) { func (m *mockStore) GetPreference(_ context.Context, _ string) (string, error) { return "", nil }
func (m *mockStore) SetPreference(_ context.Context, _ string, _ string) error { return nil }
func (m *mockStore) SaveStateChange(_ context.Context, _ int, _ string, _ string, _ string) error {
return nil
}
func (m *mockStore) GetStateChanges(_ context.Context, _ int, _ int) ([]models.StateChange, error) {
return nil, nil
}
func (m *mockStore) GetStateChangesSince(_ context.Context, _ int, _ time.Time) ([]models.StateChange, error) {
return nil, nil return nil, nil
} }
func (m *mockStore) Close() error { return nil } func (m *mockStore) Close() error { return nil }
+1 -1
View File
@@ -152,7 +152,7 @@ loop:
defer wg.Done() defer wg.Done()
defer func() { <-sem }() defer func() { <-sem }()
cr := monitor.RunCheck(s, strict, insecure, false, allowPrivate) cr := monitor.RunCheck(ctx, s, strict, insecure, false, allowPrivate)
mu.Lock() mu.Lock()
results = append(results, probeResultItem{ results = append(results, probeResultItem{
SiteID: s.ID, SiteID: s.ID,
+18 -16
View File
@@ -1,11 +1,13 @@
package config package config
import ( import (
"context"
"fmt" "fmt"
"gitea.lerkolabs.com/lerkolabs/uptop/internal/models"
"gitea.lerkolabs.com/lerkolabs/uptop/internal/store"
"reflect" "reflect"
"strings" "strings"
"gitea.lerkolabs.com/lerkolabs/uptop/internal/models"
"gitea.lerkolabs.com/lerkolabs/uptop/internal/store"
) )
type ApplyOpts struct { type ApplyOpts struct {
@@ -20,17 +22,17 @@ type Change struct {
Details string Details string
} }
func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) { func Apply(ctx context.Context, s store.Store, f *File, opts ApplyOpts) ([]Change, error) {
if err := Validate(f); err != nil { if err := Validate(f); err != nil {
return nil, err return nil, err
} }
existingAlerts, err := s.GetAllAlerts() existingAlerts, err := s.GetAllAlerts(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("load alerts: %w", err) return nil, fmt.Errorf("load alerts: %w", err)
} }
existingSites, err := s.GetSites() existingSites, err := s.GetSites(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("load sites: %w", err) return nil, fmt.Errorf("load sites: %w", err)
} }
@@ -59,7 +61,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) {
if !exists { if !exists {
changes = append(changes, Change{Action: "create", Kind: "alert", Name: a.Name, Details: a.Type}) changes = append(changes, Change{Action: "create", Kind: "alert", Name: a.Name, Details: a.Type})
if !opts.DryRun { if !opts.DryRun {
id, err := s.AddAlertReturningID(a.Name, a.Type, a.Settings) id, err := s.AddAlertReturningID(ctx, a.Name, a.Type, a.Settings)
if err != nil { if err != nil {
return changes, fmt.Errorf("create alert %q: %w", a.Name, err) return changes, fmt.Errorf("create alert %q: %w", a.Name, err)
} }
@@ -70,7 +72,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) {
if diff := diffAlert(existing, a); diff != "" { if diff := diffAlert(existing, a); diff != "" {
changes = append(changes, Change{Action: "update", Kind: "alert", Name: a.Name, Details: diff}) changes = append(changes, Change{Action: "update", Kind: "alert", Name: a.Name, Details: diff})
if !opts.DryRun { if !opts.DryRun {
if err := s.UpdateAlert(existing.ID, a.Name, a.Type, a.Settings); err != nil { if err := s.UpdateAlert(ctx, existing.ID, a.Name, a.Type, a.Settings); err != nil {
return changes, fmt.Errorf("update alert %q: %w", a.Name, err) return changes, fmt.Errorf("update alert %q: %w", a.Name, err)
} }
} }
@@ -102,7 +104,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) {
if !exists { if !exists {
changes = append(changes, Change{Action: "create", Kind: "monitor", Name: g.Name, Details: "group"}) changes = append(changes, Change{Action: "create", Kind: "monitor", Name: g.Name, Details: "group"})
if !opts.DryRun { if !opts.DryRun {
id, err := s.AddSiteReturningID(site) id, err := s.AddSiteReturningID(ctx, site)
if err != nil { if err != nil {
return changes, fmt.Errorf("create group %q: %w", g.Name, err) return changes, fmt.Errorf("create group %q: %w", g.Name, err)
} }
@@ -114,7 +116,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) {
if diff := diffSite(normalizeSite(existing), site); diff != "" { if diff := diffSite(normalizeSite(existing), site); diff != "" {
changes = append(changes, Change{Action: "update", Kind: "monitor", Name: g.Name, Details: diff}) changes = append(changes, Change{Action: "update", Kind: "monitor", Name: g.Name, Details: diff})
if !opts.DryRun { if !opts.DryRun {
if err := s.UpdateSite(site); err != nil { if err := s.UpdateSite(ctx, site); err != nil {
return changes, fmt.Errorf("update group %q: %w", g.Name, err) return changes, fmt.Errorf("update group %q: %w", g.Name, err)
} }
} }
@@ -125,7 +127,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) {
for _, g := range groups { for _, g := range groups {
parentID := groupMap[g.Name] parentID := groupMap[g.Name]
for _, child := range g.Monitors { for _, child := range g.Monitors {
c, err := applyMonitor(s, child, alertMap, existingSitesByName, parentID, opts.DryRun) c, err := applyMonitor(ctx, s, child, alertMap, existingSitesByName, parentID, opts.DryRun)
if err != nil { if err != nil {
return changes, err return changes, err
} }
@@ -134,7 +136,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) {
} }
for _, m := range topLevel { for _, m := range topLevel {
c, err := applyMonitor(s, m, alertMap, existingSitesByName, 0, opts.DryRun) c, err := applyMonitor(ctx, s, m, alertMap, existingSitesByName, 0, opts.DryRun)
if err != nil { if err != nil {
return changes, err return changes, err
} }
@@ -155,7 +157,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) {
childDeletes = append(childDeletes, c) childDeletes = append(childDeletes, c)
} }
if !opts.DryRun { if !opts.DryRun {
if err := s.DeleteSite(es.ID); err != nil { if err := s.DeleteSite(ctx, es.ID); err != nil {
return changes, fmt.Errorf("delete monitor %q: %w", es.Name, err) return changes, fmt.Errorf("delete monitor %q: %w", es.Name, err)
} }
} }
@@ -169,7 +171,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) {
} }
changes = append(changes, Change{Action: "delete", Kind: "alert", Name: ea.Name, Details: ea.Type}) changes = append(changes, Change{Action: "delete", Kind: "alert", Name: ea.Name, Details: ea.Type})
if !opts.DryRun { if !opts.DryRun {
if err := s.DeleteAlert(ea.ID); err != nil { if err := s.DeleteAlert(ctx, ea.ID); err != nil {
return changes, fmt.Errorf("delete alert %q: %w", ea.Name, err) return changes, fmt.Errorf("delete alert %q: %w", ea.Name, err)
} }
} }
@@ -179,7 +181,7 @@ func Apply(s store.Store, f *File, opts ApplyOpts) ([]Change, error) {
return changes, nil return changes, nil
} }
func applyMonitor(s store.Store, m Monitor, alertMap map[string]int, existing map[string]models.Site, parentID int, dryRun bool) ([]Change, error) { func applyMonitor(ctx context.Context, s store.Store, m Monitor, alertMap map[string]int, existing map[string]models.Site, parentID int, dryRun bool) ([]Change, error) {
alertID, err := resolveAlertID(alertMap, m.Alert) alertID, err := resolveAlertID(alertMap, m.Alert)
if err != nil { if err != nil {
return nil, fmt.Errorf("monitor %q: %w", m.Name, err) return nil, fmt.Errorf("monitor %q: %w", m.Name, err)
@@ -191,7 +193,7 @@ func applyMonitor(s store.Store, m Monitor, alertMap map[string]int, existing ma
if !exists { if !exists {
changes = append(changes, Change{Action: "create", Kind: "monitor", Name: m.Name, Details: m.Type}) changes = append(changes, Change{Action: "create", Kind: "monitor", Name: m.Name, Details: m.Type})
if !dryRun { if !dryRun {
if _, err := s.AddSiteReturningID(site); err != nil { if _, err := s.AddSiteReturningID(ctx, site); err != nil {
return changes, fmt.Errorf("create monitor %q: %w", m.Name, err) return changes, fmt.Errorf("create monitor %q: %w", m.Name, err)
} }
} }
@@ -200,7 +202,7 @@ func applyMonitor(s store.Store, m Monitor, alertMap map[string]int, existing ma
if diff := diffSite(normalizeSite(ex), site); diff != "" { if diff := diffSite(normalizeSite(ex), site); diff != "" {
changes = append(changes, Change{Action: "update", Kind: "monitor", Name: m.Name, Details: diff}) changes = append(changes, Change{Action: "update", Kind: "monitor", Name: m.Name, Details: diff})
if !dryRun { if !dryRun {
if err := s.UpdateSite(site); err != nil { if err := s.UpdateSite(ctx, site); err != nil {
return changes, fmt.Errorf("update monitor %q: %w", m.Name, err) return changes, fmt.Errorf("update monitor %q: %w", m.Name, err)
} }
} }
+29 -27
View File
@@ -1,10 +1,12 @@
package config package config
import ( import (
"gitea.lerkolabs.com/lerkolabs/uptop/internal/models" "context"
"gitea.lerkolabs.com/lerkolabs/uptop/internal/store"
"strings" "strings"
"testing" "testing"
"gitea.lerkolabs.com/lerkolabs/uptop/internal/models"
"gitea.lerkolabs.com/lerkolabs/uptop/internal/store"
) )
func newTestStore(t *testing.T) store.Store { func newTestStore(t *testing.T) store.Store {
@@ -13,7 +15,7 @@ func newTestStore(t *testing.T) store.Store {
if err != nil { if err != nil {
t.Fatalf("NewSQLiteStore: %v", err) t.Fatalf("NewSQLiteStore: %v", err)
} }
if err := s.Init(); err != nil { if err := s.Init(context.Background()); err != nil {
t.Fatalf("Init: %v", err) t.Fatalf("Init: %v", err)
} }
return s return s
@@ -31,7 +33,7 @@ func TestApplyCreateFromScratch(t *testing.T) {
}, },
} }
changes, err := Apply(s, f, ApplyOpts{}) changes, err := Apply(context.Background(), s, f, ApplyOpts{})
if err != nil { if err != nil {
t.Fatalf("Apply: %v", err) t.Fatalf("Apply: %v", err)
} }
@@ -46,12 +48,12 @@ func TestApplyCreateFromScratch(t *testing.T) {
t.Fatalf("expected 3 creates, got %d", creates) t.Fatalf("expected 3 creates, got %d", creates)
} }
sites, _ := s.GetSites() sites, _ := s.GetSites(context.Background())
if len(sites) != 2 { if len(sites) != 2 {
t.Fatalf("expected 2 sites, got %d", len(sites)) t.Fatalf("expected 2 sites, got %d", len(sites))
} }
alerts, _ := s.GetAllAlerts() alerts, _ := s.GetAllAlerts(context.Background())
if len(alerts) != 1 { if len(alerts) != 1 {
t.Fatalf("expected 1 alert, got %d", len(alerts)) t.Fatalf("expected 1 alert, got %d", len(alerts))
} }
@@ -68,11 +70,11 @@ func TestApplyIdempotent(t *testing.T) {
}, },
} }
if _, err := Apply(s, f, ApplyOpts{}); err != nil { if _, err := Apply(context.Background(), s, f, ApplyOpts{}); err != nil {
t.Fatalf("first Apply: %v", err) t.Fatalf("first Apply: %v", err)
} }
changes, err := Apply(s, f, ApplyOpts{}) changes, err := Apply(context.Background(), s, f, ApplyOpts{})
if err != nil { if err != nil {
t.Fatalf("second Apply: %v", err) t.Fatalf("second Apply: %v", err)
} }
@@ -90,12 +92,12 @@ func TestApplyUpdate(t *testing.T) {
}, },
} }
if _, err := Apply(s, f, ApplyOpts{}); err != nil { if _, err := Apply(context.Background(), s, f, ApplyOpts{}); err != nil {
t.Fatalf("first Apply: %v", err) t.Fatalf("first Apply: %v", err)
} }
f.Monitors[0].Interval = 60 f.Monitors[0].Interval = 60
changes, err := Apply(s, f, ApplyOpts{}) changes, err := Apply(context.Background(), s, f, ApplyOpts{})
if err != nil { if err != nil {
t.Fatalf("second Apply: %v", err) t.Fatalf("second Apply: %v", err)
} }
@@ -104,7 +106,7 @@ func TestApplyUpdate(t *testing.T) {
t.Fatalf("expected 1 update, got %+v", changes) t.Fatalf("expected 1 update, got %+v", changes)
} }
sites, _ := s.GetSites() sites, _ := s.GetSites(context.Background())
if sites[0].Interval != 60 { if sites[0].Interval != 60 {
t.Fatalf("expected interval 60, got %d", sites[0].Interval) t.Fatalf("expected interval 60, got %d", sites[0].Interval)
} }
@@ -112,8 +114,8 @@ func TestApplyUpdate(t *testing.T) {
func TestApplyPrune(t *testing.T) { func TestApplyPrune(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
s.AddSite(models.Site{Name: "Keep", URL: "https://keep.com", Type: "http", Interval: 30, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) s.AddSite(context.Background(), models.Site{Name: "Keep", URL: "https://keep.com", Type: "http", Interval: 30, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"})
s.AddSite(models.Site{Name: "Remove", URL: "https://remove.com", Type: "http", Interval: 30, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) s.AddSite(context.Background(), models.Site{Name: "Remove", URL: "https://remove.com", Type: "http", Interval: 30, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"})
f := &File{ f := &File{
Monitors: []Monitor{ Monitors: []Monitor{
@@ -121,7 +123,7 @@ func TestApplyPrune(t *testing.T) {
}, },
} }
changes, err := Apply(s, f, ApplyOpts{Prune: true}) changes, err := Apply(context.Background(), s, f, ApplyOpts{Prune: true})
if err != nil { if err != nil {
t.Fatalf("Apply: %v", err) t.Fatalf("Apply: %v", err)
} }
@@ -136,7 +138,7 @@ func TestApplyPrune(t *testing.T) {
t.Fatalf("expected 1 delete, got %d", deleteCount) t.Fatalf("expected 1 delete, got %d", deleteCount)
} }
sites, _ := s.GetSites() sites, _ := s.GetSites(context.Background())
if len(sites) != 1 || sites[0].Name != "Keep" { if len(sites) != 1 || sites[0].Name != "Keep" {
t.Fatalf("expected only 'Keep', got %+v", sites) t.Fatalf("expected only 'Keep', got %+v", sites)
} }
@@ -150,7 +152,7 @@ func TestApplyDryRun(t *testing.T) {
}, },
} }
changes, err := Apply(s, f, ApplyOpts{DryRun: true}) changes, err := Apply(context.Background(), s, f, ApplyOpts{DryRun: true})
if err != nil { if err != nil {
t.Fatalf("Apply: %v", err) t.Fatalf("Apply: %v", err)
} }
@@ -159,7 +161,7 @@ func TestApplyDryRun(t *testing.T) {
t.Fatalf("expected 1 create in dry-run, got %+v", changes) t.Fatalf("expected 1 create in dry-run, got %+v", changes)
} }
sites, _ := s.GetSites() sites, _ := s.GetSites(context.Background())
if len(sites) != 0 { if len(sites) != 0 {
t.Fatalf("expected 0 sites after dry-run, got %d", len(sites)) t.Fatalf("expected 0 sites after dry-run, got %d", len(sites))
} }
@@ -179,7 +181,7 @@ func TestApplyGroupHierarchy(t *testing.T) {
}, },
} }
changes, err := Apply(s, f, ApplyOpts{}) changes, err := Apply(context.Background(), s, f, ApplyOpts{})
if err != nil { if err != nil {
t.Fatalf("Apply: %v", err) t.Fatalf("Apply: %v", err)
} }
@@ -188,7 +190,7 @@ func TestApplyGroupHierarchy(t *testing.T) {
t.Fatalf("expected 3 creates, got %d", len(changes)) t.Fatalf("expected 3 creates, got %d", len(changes))
} }
sites, _ := s.GetSites() sites, _ := s.GetSites(context.Background())
var group models.Site var group models.Site
for _, s := range sites { for _, s := range sites {
if s.Type == "group" { if s.Type == "group" {
@@ -223,12 +225,12 @@ func TestApplyAlertReference(t *testing.T) {
}, },
} }
if _, err := Apply(s, f, ApplyOpts{}); err != nil { if _, err := Apply(context.Background(), s, f, ApplyOpts{}); err != nil {
t.Fatalf("Apply: %v", err) t.Fatalf("Apply: %v", err)
} }
sites, _ := s.GetSites() sites, _ := s.GetSites(context.Background())
alerts, _ := s.GetAllAlerts() alerts, _ := s.GetAllAlerts(context.Background())
if sites[0].AlertID != alerts[0].ID { if sites[0].AlertID != alerts[0].ID {
t.Fatalf("expected alert_id %d, got %d", alerts[0].ID, sites[0].AlertID) t.Fatalf("expected alert_id %d, got %d", alerts[0].ID, sites[0].AlertID)
@@ -243,7 +245,7 @@ func TestApplyInvalidAlertRef(t *testing.T) {
}, },
} }
_, err := Apply(s, f, ApplyOpts{}) _, err := Apply(context.Background(), s, f, ApplyOpts{})
if err == nil || !strings.Contains(err.Error(), "not found") { if err == nil || !strings.Contains(err.Error(), "not found") {
t.Fatalf("expected alert not found error, got %v", err) t.Fatalf("expected alert not found error, got %v", err)
} }
@@ -258,7 +260,7 @@ func TestApplyDuplicateNames(t *testing.T) {
}, },
} }
_, err := Apply(s, f, ApplyOpts{}) _, err := Apply(context.Background(), s, f, ApplyOpts{})
if err == nil || !strings.Contains(err.Error(), "duplicate") { if err == nil || !strings.Contains(err.Error(), "duplicate") {
t.Fatalf("expected duplicate error, got %v", err) t.Fatalf("expected duplicate error, got %v", err)
} }
@@ -266,7 +268,7 @@ func TestApplyDuplicateNames(t *testing.T) {
func TestApplyExistingAlertReference(t *testing.T) { func TestApplyExistingAlertReference(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
s.AddAlert("Existing", "webhook", map[string]string{"url": "https://example.com"}) s.AddAlert(context.Background(), "Existing", "webhook", map[string]string{"url": "https://example.com"})
f := &File{ f := &File{
Monitors: []Monitor{ Monitors: []Monitor{
@@ -274,7 +276,7 @@ func TestApplyExistingAlertReference(t *testing.T) {
}, },
} }
changes, err := Apply(s, f, ApplyOpts{}) changes, err := Apply(context.Background(), s, f, ApplyOpts{})
if err != nil { if err != nil {
t.Fatalf("Apply: %v", err) t.Fatalf("Apply: %v", err)
} }
@@ -283,7 +285,7 @@ func TestApplyExistingAlertReference(t *testing.T) {
t.Fatalf("expected 1 create, got %+v", changes) t.Fatalf("expected 1 create, got %+v", changes)
} }
sites, _ := s.GetSites() sites, _ := s.GetSites(context.Background())
if sites[0].AlertID == 0 { if sites[0].AlertID == 0 {
t.Fatal("expected non-zero alert_id for existing alert reference") t.Fatal("expected non-zero alert_id for existing alert reference")
} }
+4 -3
View File
@@ -1,6 +1,7 @@
package config package config
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"sort" "sort"
@@ -11,13 +12,13 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
func Export(s store.Store) (*File, error) { func Export(ctx context.Context, s store.Store) (*File, error) {
dbAlerts, err := s.GetAllAlerts() dbAlerts, err := s.GetAllAlerts(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("load alerts: %w", err) return nil, fmt.Errorf("load alerts: %w", err)
} }
dbSites, err := s.GetSites() dbSites, err := s.GetSites(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("load sites: %w", err) return nil, fmt.Errorf("load sites: %w", err)
} }
+21 -19
View File
@@ -1,13 +1,15 @@
package config package config
import ( import (
"gitea.lerkolabs.com/lerkolabs/uptop/internal/models" "context"
"testing" "testing"
"gitea.lerkolabs.com/lerkolabs/uptop/internal/models"
) )
func TestExportEmpty(t *testing.T) { func TestExportEmpty(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
f, err := Export(s) f, err := Export(context.Background(), s)
if err != nil { if err != nil {
t.Fatalf("Export: %v", err) t.Fatalf("Export: %v", err)
} }
@@ -18,11 +20,11 @@ func TestExportEmpty(t *testing.T) {
func TestExportAlertNames(t *testing.T) { func TestExportAlertNames(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
s.AddAlert("Discord", "discord", map[string]string{"url": "https://example.com"}) s.AddAlert(context.Background(), "Discord", "discord", map[string]string{"url": "https://example.com"})
alerts, _ := s.GetAllAlerts() alerts, _ := s.GetAllAlerts(context.Background())
s.AddSite(models.Site{Name: "Web", URL: "https://example.com", Type: "http", Interval: 30, AlertID: alerts[0].ID, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) s.AddSite(context.Background(), models.Site{Name: "Web", URL: "https://example.com", Type: "http", Interval: 30, AlertID: alerts[0].ID, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"})
f, err := Export(s) f, err := Export(context.Background(), s)
if err != nil { if err != nil {
t.Fatalf("Export: %v", err) t.Fatalf("Export: %v", err)
} }
@@ -37,11 +39,11 @@ func TestExportAlertNames(t *testing.T) {
func TestExportGroupHierarchy(t *testing.T) { func TestExportGroupHierarchy(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
groupID, _ := s.AddSiteReturningID(models.Site{Name: "Prod", Type: "group", ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) groupID, _ := s.AddSiteReturningID(context.Background(), models.Site{Name: "Prod", Type: "group", ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"})
s.AddSite(models.Site{Name: "Prod Web", URL: "https://prod.example.com", Type: "http", Interval: 15, ParentID: groupID, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) s.AddSite(context.Background(), models.Site{Name: "Prod Web", URL: "https://prod.example.com", Type: "http", Interval: 15, ParentID: groupID, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"})
s.AddSite(models.Site{Name: "Top Level", URL: "https://example.com", Type: "http", Interval: 30, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) s.AddSite(context.Background(), models.Site{Name: "Top Level", URL: "https://example.com", Type: "http", Interval: 30, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"})
f, err := Export(s) f, err := Export(context.Background(), s)
if err != nil { if err != nil {
t.Fatalf("Export: %v", err) t.Fatalf("Export: %v", err)
} }
@@ -70,12 +72,12 @@ func TestExportGroupHierarchy(t *testing.T) {
func TestExportOmitsDefaults(t *testing.T) { func TestExportOmitsDefaults(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
s.AddSite(models.Site{ s.AddSite(context.Background(), models.Site{
Name: "Web", URL: "https://example.com", Type: "http", Interval: 30, Name: "Web", URL: "https://example.com", Type: "http", Interval: 30,
Method: "GET", AcceptedCodes: "200-299", ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299", ExpiryThreshold: 7,
}) })
f, err := Export(s) f, err := Export(context.Background(), s)
if err != nil { if err != nil {
t.Fatalf("Export: %v", err) t.Fatalf("Export: %v", err)
} }
@@ -94,18 +96,18 @@ func TestExportOmitsDefaults(t *testing.T) {
func TestExportRoundTrip(t *testing.T) { func TestExportRoundTrip(t *testing.T) {
s1 := newTestStore(t) s1 := newTestStore(t)
s1.AddAlert("Discord", "discord", map[string]string{"url": "https://example.com"}) s1.AddAlert(context.Background(), "Discord", "discord", map[string]string{"url": "https://example.com"})
alerts, _ := s1.GetAllAlerts() alerts, _ := s1.GetAllAlerts(context.Background())
s1.AddSite(models.Site{Name: "Web", URL: "https://example.com", Type: "http", Interval: 30, AlertID: alerts[0].ID, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) s1.AddSite(context.Background(), models.Site{Name: "Web", URL: "https://example.com", Type: "http", Interval: 30, AlertID: alerts[0].ID, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"})
s1.AddSite(models.Site{Name: "Ping", Type: "ping", Hostname: "10.0.0.1", Interval: 60, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"}) s1.AddSite(context.Background(), models.Site{Name: "Ping", Type: "ping", Hostname: "10.0.0.1", Interval: 60, ExpiryThreshold: 7, Method: "GET", AcceptedCodes: "200-299"})
exported, err := Export(s1) exported, err := Export(context.Background(), s1)
if err != nil { if err != nil {
t.Fatalf("Export: %v", err) t.Fatalf("Export: %v", err)
} }
s2 := newTestStore(t) s2 := newTestStore(t)
changes, err := Apply(s2, exported, ApplyOpts{}) changes, err := Apply(context.Background(), s2, exported, ApplyOpts{})
if err != nil { if err != nil {
t.Fatalf("Apply: %v", err) t.Fatalf("Apply: %v", err)
} }
@@ -120,7 +122,7 @@ func TestExportRoundTrip(t *testing.T) {
t.Fatalf("expected 3 creates, got %d", creates) t.Fatalf("expected 3 creates, got %d", creates)
} }
reexported, err := Export(s2) reexported, err := Export(context.Background(), s2)
if err != nil { if err != nil {
t.Fatalf("re-Export: %v", err) t.Fatalf("re-Export: %v", err)
} }
+72 -50
View File
@@ -16,66 +16,88 @@ type mockStore struct {
sites []models.Site sites []models.Site
} }
func (m *mockStore) Init() error { return nil } func (m *mockStore) Init(_ context.Context) error { return nil }
func (m *mockStore) GetSites() ([]models.Site, error) { return m.sites, nil } func (m *mockStore) GetSites(_ context.Context) ([]models.Site, error) { return m.sites, nil }
func (m *mockStore) AddSite(models.Site) error { return nil } func (m *mockStore) AddSite(_ context.Context, _ models.Site) error { return nil }
func (m *mockStore) UpdateSite(models.Site) error { return nil } func (m *mockStore) UpdateSite(_ context.Context, _ models.Site) error { return nil }
func (m *mockStore) UpdateSitePaused(int, bool) error { return nil } func (m *mockStore) UpdateSitePaused(_ context.Context, _ int, _ bool) error { return nil }
func (m *mockStore) DeleteSite(int) error { return nil } func (m *mockStore) DeleteSite(_ context.Context, _ int) error { return nil }
func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { return nil, nil } func (m *mockStore) GetAllAlerts(_ context.Context) ([]models.AlertConfig, error) { return nil, nil }
func (m *mockStore) GetAlert(int) (models.AlertConfig, error) { return models.AlertConfig{}, nil } func (m *mockStore) GetAlert(_ context.Context, _ int) (models.AlertConfig, error) {
func (m *mockStore) AddAlert(string, string, map[string]string) error { return nil }
func (m *mockStore) UpdateAlert(int, string, string, map[string]string) error { return nil }
func (m *mockStore) DeleteAlert(int) error { return nil }
func (m *mockStore) GetAllUsers() ([]models.User, error) { return nil, nil }
func (m *mockStore) AddUser(string, string, string) error { return nil }
func (m *mockStore) UpdateUser(int, string, string, string) error { return nil }
func (m *mockStore) DeleteUser(int) error { return nil }
func (m *mockStore) SaveCheck(int, int64, bool) error { return nil }
func (m *mockStore) LoadAllHistory(int) (map[int][]models.CheckRecord, error) {
return nil, nil
}
func (m *mockStore) ExportData() (models.Backup, error) { return models.Backup{}, nil }
func (m *mockStore) ImportData(models.Backup) error { return nil }
func (m *mockStore) GetSiteByName(string) (models.Site, error) { return models.Site{}, nil }
func (m *mockStore) GetAlertByName(string) (models.AlertConfig, error) {
return models.AlertConfig{}, nil return models.AlertConfig{}, nil
} }
func (m *mockStore) AddSiteReturningID(models.Site) (int, error) { return 0, nil } func (m *mockStore) AddAlert(_ context.Context, _ string, _ string, _ map[string]string) error {
func (m *mockStore) AddAlertReturningID(string, string, map[string]string) (int, error) { return nil
}
func (m *mockStore) UpdateAlert(_ context.Context, _ int, _ string, _ string, _ map[string]string) error {
return nil
}
func (m *mockStore) DeleteAlert(_ context.Context, _ int) error { return nil }
func (m *mockStore) GetAllUsers(_ context.Context) ([]models.User, error) { return nil, nil }
func (m *mockStore) AddUser(_ context.Context, _ string, _ string, _ string) error { return nil }
func (m *mockStore) UpdateUser(_ context.Context, _ int, _ string, _ string, _ string) error {
return nil
}
func (m *mockStore) DeleteUser(_ context.Context, _ int) error { return nil }
func (m *mockStore) SaveCheck(_ context.Context, _ int, _ int64, _ bool) error { return nil }
func (m *mockStore) LoadAllHistory(_ context.Context, _ int) (map[int][]models.CheckRecord, error) {
return nil, nil
}
func (m *mockStore) ExportData(_ context.Context) (models.Backup, error) { return models.Backup{}, nil }
func (m *mockStore) ImportData(_ context.Context, _ models.Backup) error { return nil }
func (m *mockStore) GetSiteByName(_ context.Context, _ string) (models.Site, error) {
return models.Site{}, nil
}
func (m *mockStore) GetAlertByName(_ context.Context, _ string) (models.AlertConfig, error) {
return models.AlertConfig{}, nil
}
func (m *mockStore) AddSiteReturningID(_ context.Context, _ models.Site) (int, error) { return 0, nil }
func (m *mockStore) AddAlertReturningID(_ context.Context, _ string, _ string, _ map[string]string) (int, error) {
return 0, nil return 0, nil
} }
func (m *mockStore) SaveCheckFromNode(int, string, int64, bool) error { return nil } func (m *mockStore) SaveCheckFromNode(_ context.Context, _ int, _ string, _ int64, _ bool) error {
func (m *mockStore) RegisterNode(models.ProbeNode) error { return nil } return nil
func (m *mockStore) GetNode(string) (models.ProbeNode, error) { return models.ProbeNode{}, nil } }
func (m *mockStore) GetAllNodes() ([]models.ProbeNode, error) { return nil, nil } func (m *mockStore) RegisterNode(_ context.Context, _ models.ProbeNode) error { return nil }
func (m *mockStore) UpdateNodeLastSeen(string) error { return nil } func (m *mockStore) GetNode(_ context.Context, _ string) (models.ProbeNode, error) {
func (m *mockStore) DeleteNode(string) error { return nil } return models.ProbeNode{}, nil
func (m *mockStore) LoadAlertHealth() (map[int]models.AlertHealthRecord, error) { }
func (m *mockStore) GetAllNodes(_ context.Context) ([]models.ProbeNode, error) { return nil, nil }
func (m *mockStore) UpdateNodeLastSeen(_ context.Context, _ string) error { return nil }
func (m *mockStore) DeleteNode(_ context.Context, _ string) error { return nil }
func (m *mockStore) LoadAlertHealth(_ context.Context) (map[int]models.AlertHealthRecord, error) {
return nil, nil return nil, nil
} }
func (m *mockStore) SaveAlertHealth(models.AlertHealthRecord) error { return nil } func (m *mockStore) SaveAlertHealth(_ context.Context, _ models.AlertHealthRecord) error { return nil }
func (m *mockStore) SaveLog(string) error { return nil } func (m *mockStore) SaveLog(_ context.Context, _ string) error { return nil }
func (m *mockStore) PruneLogs() error { return nil } func (m *mockStore) PruneLogs(_ context.Context) error { return nil }
func (m *mockStore) PruneCheckHistory() error { return nil } func (m *mockStore) PruneCheckHistory(_ context.Context) error { return nil }
func (m *mockStore) PruneStateChanges() error { return nil } func (m *mockStore) PruneStateChanges(_ context.Context) error { return nil }
func (m *mockStore) LoadLogs(int) ([]string, error) { return nil, nil } func (m *mockStore) LoadLogs(_ context.Context, _ int) ([]string, error) { return nil, nil }
func (m *mockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) { func (m *mockStore) GetActiveMaintenanceWindows(_ context.Context) ([]models.MaintenanceWindow, error) {
return nil, nil return nil, nil
} }
func (m *mockStore) GetAllMaintenanceWindows(int) ([]models.MaintenanceWindow, error) { func (m *mockStore) GetAllMaintenanceWindows(_ context.Context, _ int) ([]models.MaintenanceWindow, error) {
return nil, nil return nil, nil
} }
func (m *mockStore) AddMaintenanceWindow(models.MaintenanceWindow) error { return nil } func (m *mockStore) AddMaintenanceWindow(_ context.Context, _ models.MaintenanceWindow) error {
func (m *mockStore) EndMaintenanceWindow(int) error { return nil } return nil
func (m *mockStore) DeleteMaintenanceWindow(int) error { return nil } }
func (m *mockStore) PruneExpiredMaintenanceWindows(time.Duration) (int64, error) { return 0, nil } func (m *mockStore) EndMaintenanceWindow(_ context.Context, _ int) error { return nil }
func (m *mockStore) IsMonitorInMaintenance(int) (bool, error) { return false, nil } func (m *mockStore) DeleteMaintenanceWindow(_ context.Context, _ int) error { return nil }
func (m *mockStore) GetPreference(string) (string, error) { return "", nil } func (m *mockStore) PruneExpiredMaintenanceWindows(_ context.Context, _ time.Duration) (int64, error) {
func (m *mockStore) SetPreference(string, string) error { return nil } return 0, nil
func (m *mockStore) SaveStateChange(int, string, string, string) error { return nil } }
func (m *mockStore) GetStateChanges(int, int) ([]models.StateChange, error) { return nil, nil } func (m *mockStore) IsMonitorInMaintenance(_ context.Context, _ int) (bool, error) { return false, nil }
func (m *mockStore) GetStateChangesSince(int, time.Time) ([]models.StateChange, error) { func (m *mockStore) GetPreference(_ context.Context, _ string) (string, error) { return "", nil }
func (m *mockStore) SetPreference(_ context.Context, _ string, _ string) error { return nil }
func (m *mockStore) SaveStateChange(_ context.Context, _ int, _ string, _ string, _ string) error {
return nil
}
func (m *mockStore) GetStateChanges(_ context.Context, _ int, _ int) ([]models.StateChange, error) {
return nil, nil
}
func (m *mockStore) GetStateChangesSince(_ context.Context, _ int, _ time.Time) ([]models.StateChange, error) {
return nil, nil return nil, nil
} }
func (m *mockStore) Close() error { return nil } func (m *mockStore) Close() error { return nil }
+10 -10
View File
@@ -35,7 +35,7 @@ type CheckResult struct {
ErrorReason string ErrorReason string
} }
func RunCheck(site models.Site, strict, insecure *http.Client, globalInsecure bool, allowPrivate ...bool) CheckResult { func RunCheck(ctx context.Context, site models.Site, strict, insecure *http.Client, globalInsecure bool, allowPrivate ...bool) CheckResult {
private := len(allowPrivate) > 0 && allowPrivate[0] private := len(allowPrivate) > 0 && allowPrivate[0]
if site.Type != "http" && site.Type != "dns" && !private { if site.Type != "http" && site.Type != "dns" && !private {
@@ -56,26 +56,26 @@ func RunCheck(site models.Site, strict, insecure *http.Client, globalInsecure bo
switch site.Type { switch site.Type {
case "http": case "http":
return runHTTPCheck(site, strict, insecure, globalInsecure) return runHTTPCheck(ctx, site, strict, insecure, globalInsecure)
case "ping": case "ping":
return runPingCheck(site) return runPingCheck(ctx, site)
case "port": case "port":
return runPortCheck(site) return runPortCheck(ctx, site)
case "dns": case "dns":
return runDNSCheck(site) return runDNSCheck(ctx, site)
default: default:
return CheckResult{SiteID: site.ID, Status: "DOWN", ErrorReason: "unsupported monitor type: " + site.Type} return CheckResult{SiteID: site.ID, Status: "DOWN", ErrorReason: "unsupported monitor type: " + site.Type}
} }
} }
func runHTTPCheck(site models.Site, strict, insecure *http.Client, globalInsecure bool) CheckResult { func runHTTPCheck(ctx context.Context, site models.Site, strict, insecure *http.Client, globalInsecure bool) CheckResult {
method := site.Method method := site.Method
if method == "" { if method == "" {
method = "GET" method = "GET"
} }
timeout := siteTimeout(site) timeout := siteTimeout(site)
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel() defer cancel()
req, err := http.NewRequestWithContext(ctx, method, site.URL, nil) req, err := http.NewRequestWithContext(ctx, method, site.URL, nil)
@@ -128,7 +128,7 @@ func runHTTPCheck(site models.Site, strict, insecure *http.Client, globalInsecur
return result return result
} }
func runPingCheck(site models.Site) CheckResult { func runPingCheck(_ context.Context, site models.Site) CheckResult {
host := site.Hostname host := site.Hostname
if host == "" { if host == "" {
host = site.URL host = site.URL
@@ -157,7 +157,7 @@ func runPingCheck(site models.Site) CheckResult {
return CheckResult{SiteID: site.ID, Status: "UP", LatencyNs: stats.AvgRtt.Nanoseconds()} return CheckResult{SiteID: site.ID, Status: "UP", LatencyNs: stats.AvgRtt.Nanoseconds()}
} }
func runPortCheck(site models.Site) CheckResult { func runPortCheck(_ context.Context, site models.Site) CheckResult {
host := site.Hostname host := site.Hostname
if host == "" { if host == "" {
host = site.URL host = site.URL
@@ -176,7 +176,7 @@ func runPortCheck(site models.Site) CheckResult {
return CheckResult{SiteID: site.ID, Status: "UP", LatencyNs: latency.Nanoseconds()} return CheckResult{SiteID: site.ID, Status: "UP", LatencyNs: latency.Nanoseconds()}
} }
func runDNSCheck(site models.Site) CheckResult { func runDNSCheck(_ context.Context, site models.Site) CheckResult {
host := site.Hostname host := site.Hostname
if host == "" { if host == "" {
host = site.URL host = site.URL
+11 -10
View File
@@ -1,6 +1,7 @@
package monitor package monitor
import ( import (
"context"
"crypto/tls" "crypto/tls"
"net" "net"
"net/http" "net/http"
@@ -19,7 +20,7 @@ func TestRunCheck_HTTP_Success(t *testing.T) {
defer srv.Close() defer srv.Close()
site := models.Site{ID: 1, Type: "http", URL: srv.URL} site := models.Site{ID: 1, Type: "http", URL: srv.URL}
result := RunCheck(site, http.DefaultClient, http.DefaultClient, false) result := RunCheck(context.Background(), site, http.DefaultClient, http.DefaultClient, false)
if result.Status != "UP" { if result.Status != "UP" {
t.Errorf("expected UP, got %s", result.Status) t.Errorf("expected UP, got %s", result.Status)
@@ -39,7 +40,7 @@ func TestRunCheck_HTTP_ServerError(t *testing.T) {
defer srv.Close() defer srv.Close()
site := models.Site{ID: 1, Type: "http", URL: srv.URL} site := models.Site{ID: 1, Type: "http", URL: srv.URL}
result := RunCheck(site, http.DefaultClient, http.DefaultClient, false) result := RunCheck(context.Background(), site, http.DefaultClient, http.DefaultClient, false)
if result.Status != "DOWN" { if result.Status != "DOWN" {
t.Errorf("expected DOWN, got %s", result.Status) t.Errorf("expected DOWN, got %s", result.Status)
@@ -60,7 +61,7 @@ func TestRunCheck_HTTP_CustomAcceptedCodes(t *testing.T) {
}} }}
site := models.Site{ID: 1, Type: "http", URL: srv.URL, AcceptedCodes: "200-399"} site := models.Site{ID: 1, Type: "http", URL: srv.URL, AcceptedCodes: "200-399"}
result := RunCheck(site, client, client, false) result := RunCheck(context.Background(), site, client, client, false)
if result.Status != "UP" { if result.Status != "UP" {
t.Errorf("expected UP with accepted 200-399, got %s", result.Status) t.Errorf("expected UP with accepted 200-399, got %s", result.Status)
@@ -76,7 +77,7 @@ func TestRunCheck_HTTP_MethodRespected(t *testing.T) {
defer srv.Close() defer srv.Close()
site := models.Site{ID: 1, Type: "http", URL: srv.URL, Method: "HEAD"} site := models.Site{ID: 1, Type: "http", URL: srv.URL, Method: "HEAD"}
RunCheck(site, http.DefaultClient, http.DefaultClient, false) RunCheck(context.Background(), site, http.DefaultClient, http.DefaultClient, false)
if receivedMethod != "HEAD" { if receivedMethod != "HEAD" {
t.Errorf("expected HEAD, got %s", receivedMethod) t.Errorf("expected HEAD, got %s", receivedMethod)
@@ -91,7 +92,7 @@ func TestRunCheck_HTTP_Timeout(t *testing.T) {
defer srv.Close() defer srv.Close()
site := models.Site{ID: 1, Type: "http", URL: srv.URL, Timeout: 1} site := models.Site{ID: 1, Type: "http", URL: srv.URL, Timeout: 1}
result := RunCheck(site, http.DefaultClient, http.DefaultClient, false) result := RunCheck(context.Background(), site, http.DefaultClient, http.DefaultClient, false)
if result.Status != "DOWN" { if result.Status != "DOWN" {
t.Errorf("expected DOWN on timeout, got %s", result.Status) t.Errorf("expected DOWN on timeout, got %s", result.Status)
@@ -109,7 +110,7 @@ func TestRunCheck_HTTP_SSLFields(t *testing.T) {
} }
site := models.Site{ID: 1, Type: "http", URL: srv.URL, CheckSSL: true, IgnoreTLS: true} site := models.Site{ID: 1, Type: "http", URL: srv.URL, CheckSSL: true, IgnoreTLS: true}
result := RunCheck(site, http.DefaultClient, insecureClient, false) result := RunCheck(context.Background(), site, http.DefaultClient, insecureClient, false)
if result.Status != "UP" { if result.Status != "UP" {
t.Errorf("expected UP, got %s", result.Status) t.Errorf("expected UP, got %s", result.Status)
@@ -133,7 +134,7 @@ func TestRunCheck_Port_Open(t *testing.T) {
port, _ := strconv.Atoi(portStr) port, _ := strconv.Atoi(portStr)
site := models.Site{ID: 1, Type: "port", Hostname: "127.0.0.1", Port: port, Timeout: 2} site := models.Site{ID: 1, Type: "port", Hostname: "127.0.0.1", Port: port, Timeout: 2}
result := RunCheck(site, nil, nil, false, true) result := RunCheck(context.Background(), site, nil, nil, false, true)
if result.Status != "UP" { if result.Status != "UP" {
t.Errorf("expected UP, got %s", result.Status) t.Errorf("expected UP, got %s", result.Status)
@@ -153,7 +154,7 @@ func TestRunCheck_Port_Closed(t *testing.T) {
ln.Close() ln.Close()
site := models.Site{ID: 1, Type: "port", Hostname: "127.0.0.1", Port: port, Timeout: 1} site := models.Site{ID: 1, Type: "port", Hostname: "127.0.0.1", Port: port, Timeout: 1}
result := RunCheck(site, nil, nil, false, true) result := RunCheck(context.Background(), site, nil, nil, false, true)
if result.Status != "DOWN" { if result.Status != "DOWN" {
t.Errorf("expected DOWN, got %s", result.Status) t.Errorf("expected DOWN, got %s", result.Status)
@@ -171,7 +172,7 @@ func TestRunCheck_Port_BlocksPrivateByDefault(t *testing.T) {
port, _ := strconv.Atoi(portStr) port, _ := strconv.Atoi(portStr)
site := models.Site{ID: 1, Type: "port", Hostname: "127.0.0.1", Port: port, Timeout: 2} site := models.Site{ID: 1, Type: "port", Hostname: "127.0.0.1", Port: port, Timeout: 2}
result := RunCheck(site, nil, nil, false) result := RunCheck(context.Background(), site, nil, nil, false)
if result.Status != "DOWN" { if result.Status != "DOWN" {
t.Errorf("expected DOWN when private targets blocked, got %s", result.Status) t.Errorf("expected DOWN when private targets blocked, got %s", result.Status)
@@ -180,7 +181,7 @@ func TestRunCheck_Port_BlocksPrivateByDefault(t *testing.T) {
func TestRunCheck_UnknownType(t *testing.T) { func TestRunCheck_UnknownType(t *testing.T) {
site := models.Site{ID: 1, Type: "invalid"} site := models.Site{ID: 1, Type: "invalid"}
result := RunCheck(site, nil, nil, false) result := RunCheck(context.Background(), site, nil, nil, false)
if result.Status != "DOWN" { if result.Status != "DOWN" {
t.Errorf("expected DOWN for unknown type, got %s", result.Status) t.Errorf("expected DOWN for unknown type, got %s", result.Status)
+17 -11
View File
@@ -1,6 +1,8 @@
package monitor package monitor
import ( import (
"context"
"gitea.lerkolabs.com/lerkolabs/uptop/internal/models" "gitea.lerkolabs.com/lerkolabs/uptop/internal/models"
"gitea.lerkolabs.com/lerkolabs/uptop/internal/store" "gitea.lerkolabs.com/lerkolabs/uptop/internal/store"
) )
@@ -10,14 +12,14 @@ import (
// serializing all writes through one connection and surfacing errors instead of // serializing all writes through one connection and surfacing errors instead of
// discarding them. desc names the write for diagnostics on drop/failure. // discarding them. desc names the write for diagnostics on drop/failure.
type dbWrite interface { type dbWrite interface {
exec(s store.Store) error exec(ctx context.Context, s store.Store) error
desc() string desc() string
} }
type writeLog struct{ message string } type writeLog struct{ message string }
func (w writeLog) exec(s store.Store) error { return s.SaveLog(w.message) } func (w writeLog) exec(ctx context.Context, s store.Store) error { return s.SaveLog(ctx, w.message) }
func (w writeLog) desc() string { return "log" } func (w writeLog) desc() string { return "log" }
type writeCheck struct { type writeCheck struct {
siteID int siteID int
@@ -25,8 +27,10 @@ type writeCheck struct {
isUp bool isUp bool
} }
func (w writeCheck) exec(s store.Store) error { return s.SaveCheck(w.siteID, w.latencyNs, w.isUp) } func (w writeCheck) exec(ctx context.Context, s store.Store) error {
func (w writeCheck) desc() string { return "check" } return s.SaveCheck(ctx, w.siteID, w.latencyNs, w.isUp)
}
func (w writeCheck) desc() string { return "check" }
type writeStateChange struct { type writeStateChange struct {
siteID int siteID int
@@ -35,15 +39,17 @@ type writeStateChange struct {
reason string reason string
} }
func (w writeStateChange) exec(s store.Store) error { func (w writeStateChange) exec(ctx context.Context, s store.Store) error {
return s.SaveStateChange(w.siteID, w.fromStatus, w.toStatus, w.reason) return s.SaveStateChange(ctx, w.siteID, w.fromStatus, w.toStatus, w.reason)
} }
func (w writeStateChange) desc() string { return "state-change" } func (w writeStateChange) desc() string { return "state-change" }
type writeAlertHealth struct{ rec models.AlertHealthRecord } type writeAlertHealth struct{ rec models.AlertHealthRecord }
func (w writeAlertHealth) exec(s store.Store) error { return s.SaveAlertHealth(w.rec) } func (w writeAlertHealth) exec(ctx context.Context, s store.Store) error {
func (w writeAlertHealth) desc() string { return "alert-health" } return s.SaveAlertHealth(ctx, w.rec)
}
func (w writeAlertHealth) desc() string { return "alert-health" }
type writeProbeCheck struct { type writeProbeCheck struct {
siteID int siteID int
@@ -52,7 +58,7 @@ type writeProbeCheck struct {
isUp bool isUp bool
} }
func (w writeProbeCheck) exec(s store.Store) error { func (w writeProbeCheck) exec(ctx context.Context, s store.Store) error {
return s.SaveCheckFromNode(w.siteID, w.nodeID, w.latencyNs, w.isUp) return s.SaveCheckFromNode(ctx, w.siteID, w.nodeID, w.latencyNs, w.isUp)
} }
func (w writeProbeCheck) desc() string { return "probe-check" } func (w writeProbeCheck) desc() string { return "probe-check" }
+5 -2
View File
@@ -1,6 +1,9 @@
package monitor package monitor
import "time" import (
"context"
"time"
)
const maxHistoryLen = 60 const maxHistoryLen = 60
@@ -12,7 +15,7 @@ type SiteHistory struct {
} }
func (e *Engine) InitHistory() { func (e *Engine) InitHistory() {
all, err := e.db.LoadAllHistory(maxHistoryLen) all, err := e.db.LoadAllHistory(context.Background(), maxHistoryLen)
if err != nil { if err != nil {
e.AddLog("Failed to load check history: " + err.Error()) e.AddLog("Failed to load check history: " + err.Error())
return return
+30 -30
View File
@@ -185,16 +185,16 @@ func (e *Engine) dbWriter(ctx context.Context) {
pruneTicker := time.NewTicker(dbPruneInterval) pruneTicker := time.NewTicker(dbPruneInterval)
defer pruneTicker.Stop() defer pruneTicker.Stop()
e.prune() e.prune(ctx)
for { for {
select { select {
case w := <-e.dbWrites: case w := <-e.dbWrites:
if err := w.exec(e.db); err != nil { if err := w.exec(ctx, e.db); err != nil {
e.appendLog(fmt.Sprintf("db %s write failed: %v", w.desc(), err)) e.appendLog(fmt.Sprintf("db %s write failed: %v", w.desc(), err))
} }
case <-pruneTicker.C: case <-pruneTicker.C:
e.prune() e.prune(ctx)
case <-ctx.Done(): case <-ctx.Done():
e.drainWrites() e.drainWrites()
return return
@@ -207,7 +207,7 @@ func (e *Engine) drainWrites() {
for { for {
select { select {
case w := <-e.dbWrites: case w := <-e.dbWrites:
if err := w.exec(e.db); err != nil { if err := w.exec(context.Background(), e.db); err != nil {
e.appendLog(fmt.Sprintf("db %s write failed (drain): %v", w.desc(), err)) e.appendLog(fmt.Sprintf("db %s write failed (drain): %v", w.desc(), err))
} }
default: default:
@@ -216,14 +216,14 @@ func (e *Engine) drainWrites() {
} }
} }
func (e *Engine) prune() { func (e *Engine) prune(ctx context.Context) {
if err := e.db.PruneLogs(); err != nil { if err := e.db.PruneLogs(ctx); err != nil {
e.appendLog(fmt.Sprintf("log prune failed: %v", err)) e.appendLog(fmt.Sprintf("log prune failed: %v", err))
} }
if err := e.db.PruneCheckHistory(); err != nil { if err := e.db.PruneCheckHistory(ctx); err != nil {
e.appendLog(fmt.Sprintf("check-history prune failed: %v", err)) e.appendLog(fmt.Sprintf("check-history prune failed: %v", err))
} }
if err := e.db.PruneStateChanges(); err != nil { if err := e.db.PruneStateChanges(ctx); err != nil {
e.appendLog(fmt.Sprintf("state-change prune failed: %v", err)) e.appendLog(fmt.Sprintf("state-change prune failed: %v", err))
} }
} }
@@ -242,7 +242,7 @@ func (e *Engine) Stop() {
} }
func (e *Engine) InitLogs() { func (e *Engine) InitLogs() {
logs, err := e.db.LoadLogs(maxLogEntries) logs, err := e.db.LoadLogs(context.Background(), maxLogEntries)
if err != nil { if err != nil {
return return
} }
@@ -257,7 +257,7 @@ func (e *Engine) InitLogs() {
// InitAlertHealth restores persisted alert send health so the dashboard shows real // InitAlertHealth restores persisted alert send health so the dashboard shows real
// "last sent" / health state on startup instead of resetting every channel to "never". // "last sent" / health state on startup instead of resetting every channel to "never".
func (e *Engine) InitAlertHealth() { func (e *Engine) InitAlertHealth() {
records, err := e.db.LoadAlertHealth() records, err := e.db.LoadAlertHealth(context.Background())
if err != nil { if err != nil {
return return
} }
@@ -416,9 +416,9 @@ func (e *Engine) Start(ctx context.Context) {
default: default:
} }
e.refreshMaintenanceCache() e.refreshMaintenanceCache(ctx)
sites, err := e.db.GetSites() sites, err := e.db.GetSites(ctx)
if err != nil { if err != nil {
e.AddLog(fmt.Sprintf("Failed to load sites: %v", err)) e.AddLog(fmt.Sprintf("Failed to load sites: %v", err))
select { select {
@@ -475,20 +475,20 @@ func (e *Engine) maintenancePruner(ctx context.Context) {
ticker := time.NewTicker(maintPruneInterval) ticker := time.NewTicker(maintPruneInterval)
defer ticker.Stop() defer ticker.Stop()
e.pruneMaintenanceWindows() e.pruneMaintenanceWindows(ctx)
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
e.pruneMaintenanceWindows() e.pruneMaintenanceWindows(ctx)
case <-ctx.Done(): case <-ctx.Done():
return return
} }
} }
} }
func (e *Engine) pruneMaintenanceWindows() { func (e *Engine) pruneMaintenanceWindows(ctx context.Context) {
pruned, err := e.db.PruneExpiredMaintenanceWindows(e.maintRetention) pruned, err := e.db.PruneExpiredMaintenanceWindows(ctx, e.maintRetention)
if err != nil { if err != nil {
e.AddLog(fmt.Sprintf("Maintenance prune error: %v", err)) e.AddLog(fmt.Sprintf("Maintenance prune error: %v", err))
return return
@@ -588,7 +588,7 @@ func (e *Engine) monitorRoutine(ctx context.Context, id int) {
return return
} }
e.checkByID(id) e.checkByID(ctx, id)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@@ -634,7 +634,7 @@ func (e *Engine) monitorRoutine(ctx context.Context, id int) {
return return
case <-recheckCh: case <-recheckCh:
} }
e.checkByID(id) e.checkByID(ctx, id)
} }
} }
@@ -657,7 +657,7 @@ func (e *Engine) applyState(id int, mutate func(s *models.Site)) (models.Site, b
return cur, true return cur, true
} }
func (e *Engine) checkByID(id int) { func (e *Engine) checkByID(ctx context.Context, id int) {
if !e.IsActive() { if !e.IsActive() {
return return
} }
@@ -671,11 +671,11 @@ func (e *Engine) checkByID(id int) {
switch site.Type { switch site.Type {
case "push": case "push":
e.checkPush(site) e.checkPush(ctx, site)
case "group": case "group":
e.checkGroup(site) e.checkGroup(ctx, site)
default: default:
result := RunCheck(site, e.strictClient, e.insecureClient, e.insecureSkipVerify, e.allowPrivateTargets) result := RunCheck(ctx, site, e.strictClient, e.insecureClient, e.insecureSkipVerify, e.allowPrivateTargets)
updatedSite := site updatedSite := site
updatedSite.HasSSL = result.HasSSL updatedSite.HasSSL = result.HasSSL
updatedSite.CertExpiry = result.CertExpiry updatedSite.CertExpiry = result.CertExpiry
@@ -685,7 +685,7 @@ func (e *Engine) checkByID(id int) {
} }
} }
func (e *Engine) checkPush(site models.Site) { func (e *Engine) checkPush(_ context.Context, site models.Site) {
if site.Status == "PENDING" { if site.Status == "PENDING" {
return return
} }
@@ -875,7 +875,7 @@ func (e *Engine) handleStatusChange(snap models.Site, rawStatus string, code int
} }
func (e *Engine) triggerAlert(alertID int, title, message string) { func (e *Engine) triggerAlert(alertID int, title, message string) {
cfg, err := e.db.GetAlert(alertID) cfg, err := e.db.GetAlert(context.Background(), alertID)
if err != nil { if err != nil {
e.AddLog(fmt.Sprintf("Failed to load alert config %d: %v", alertID, err)) e.AddLog(fmt.Sprintf("Failed to load alert config %d: %v", alertID, err))
return return
@@ -928,7 +928,7 @@ func (e *Engine) GetAlertHealth(alertID int) AlertHealth {
} }
func (e *Engine) TestAlert(alertID int) error { func (e *Engine) TestAlert(alertID int) error {
cfg, err := e.db.GetAlert(alertID) cfg, err := e.db.GetAlert(context.Background(), alertID)
if err != nil { if err != nil {
return fmt.Errorf("failed to load alert: %w", err) return fmt.Errorf("failed to load alert: %w", err)
} }
@@ -954,8 +954,8 @@ func (e *Engine) isInMaintenance(monitorID int) bool {
return e.maintCache[monitorID] return e.maintCache[monitorID]
} }
func (e *Engine) refreshMaintenanceCache() { func (e *Engine) refreshMaintenanceCache(ctx context.Context) {
windows, err := e.db.GetActiveMaintenanceWindows() windows, err := e.db.GetActiveMaintenanceWindows(ctx)
if err != nil { if err != nil {
return return
} }
@@ -994,7 +994,7 @@ func (e *Engine) GetDisplayStatus(site models.Site) string {
return site.Status return site.Status
} }
func (e *Engine) checkGroup(site models.Site) { func (e *Engine) checkGroup(_ context.Context, site models.Site) {
e.mu.RLock() e.mu.RLock()
status := "UP" status := "UP"
hasChildren := false hasChildren := false
@@ -1095,7 +1095,7 @@ func (e *Engine) GetProbeResults(siteID int) map[string]NodeResult {
} }
func (e *Engine) GetStateChanges(siteID int, limit int) []models.StateChange { func (e *Engine) GetStateChanges(siteID int, limit int) []models.StateChange {
changes, err := e.db.GetStateChanges(siteID, limit) changes, err := e.db.GetStateChanges(context.Background(), siteID, limit)
if err != nil { if err != nil {
return nil return nil
} }
@@ -1103,7 +1103,7 @@ func (e *Engine) GetStateChanges(siteID int, limit int) []models.StateChange {
} }
func (e *Engine) GetStateChangesSince(siteID int, since time.Time) []models.StateChange { func (e *Engine) GetStateChangesSince(siteID int, since time.Time) []models.StateChange {
changes, err := e.db.GetStateChangesSince(siteID, since) changes, err := e.db.GetStateChangesSince(context.Background(), siteID, since)
if err != nil { if err != nil {
return nil return nil
} }
+75 -65
View File
@@ -38,37 +38,43 @@ func newMockStore() *mockStore {
} }
} }
func (m *mockStore) Init() error { return nil } func (m *mockStore) Init(context.Context) error { return nil }
func (m *mockStore) GetSites() ([]models.Site, error) { return m.sites, nil } func (m *mockStore) GetSites(context.Context) ([]models.Site, error) { return m.sites, nil }
func (m *mockStore) AddSite(models.Site) error { return nil } func (m *mockStore) AddSite(context.Context, models.Site) error { return nil }
func (m *mockStore) UpdateSite(models.Site) error { return nil } func (m *mockStore) UpdateSite(context.Context, models.Site) error { return nil }
func (m *mockStore) UpdateSitePaused(int, bool) error { return nil } func (m *mockStore) UpdateSitePaused(context.Context, int, bool) error { return nil }
func (m *mockStore) DeleteSite(int) error { return nil } func (m *mockStore) DeleteSite(context.Context, int) error { return nil }
func (m *mockStore) AddAlert(string, string, map[string]string) error { return nil } func (m *mockStore) AddAlert(context.Context, string, string, map[string]string) error { return nil }
func (m *mockStore) UpdateAlert(int, string, string, map[string]string) error { return nil } func (m *mockStore) UpdateAlert(context.Context, int, string, string, map[string]string) error {
func (m *mockStore) DeleteAlert(int) error { return nil } return nil
func (m *mockStore) GetAllUsers() ([]models.User, error) { return nil, nil } }
func (m *mockStore) AddUser(string, string, string) error { return nil } func (m *mockStore) DeleteAlert(context.Context, int) error { return nil }
func (m *mockStore) UpdateUser(int, string, string, string) error { return nil } func (m *mockStore) GetAllUsers(context.Context) ([]models.User, error) { return nil, nil }
func (m *mockStore) DeleteUser(int) error { return nil } func (m *mockStore) AddUser(context.Context, string, string, string) error { return nil }
func (m *mockStore) ExportData() (models.Backup, error) { return models.Backup{}, nil } func (m *mockStore) UpdateUser(context.Context, int, string, string, string) error { return nil }
func (m *mockStore) ImportData(models.Backup) error { return nil } func (m *mockStore) DeleteUser(context.Context, int) error { return nil }
func (m *mockStore) GetSiteByName(string) (models.Site, error) { return models.Site{}, nil } func (m *mockStore) ExportData(context.Context) (models.Backup, error) { return models.Backup{}, nil }
func (m *mockStore) AddSiteReturningID(models.Site) (int, error) { return 0, nil } func (m *mockStore) ImportData(context.Context, models.Backup) error { return nil }
func (m *mockStore) AddAlertReturningID(string, string, map[string]string) (int, error) { func (m *mockStore) GetSiteByName(context.Context, string) (models.Site, error) {
return models.Site{}, nil
}
func (m *mockStore) AddSiteReturningID(context.Context, models.Site) (int, error) { return 0, nil }
func (m *mockStore) AddAlertReturningID(context.Context, string, string, map[string]string) (int, error) {
return 0, nil return 0, nil
} }
func (m *mockStore) SaveCheckFromNode(int, string, int64, bool) error { return nil } func (m *mockStore) SaveCheckFromNode(context.Context, int, string, int64, bool) error { return nil }
func (m *mockStore) RegisterNode(models.ProbeNode) error { return nil } func (m *mockStore) RegisterNode(context.Context, models.ProbeNode) error { return nil }
func (m *mockStore) GetNode(string) (models.ProbeNode, error) { return models.ProbeNode{}, nil } func (m *mockStore) GetNode(context.Context, string) (models.ProbeNode, error) {
func (m *mockStore) GetAllNodes() ([]models.ProbeNode, error) { return nil, nil } return models.ProbeNode{}, nil
func (m *mockStore) UpdateNodeLastSeen(string) error { return nil } }
func (m *mockStore) DeleteNode(string) error { return nil } func (m *mockStore) GetAllNodes(context.Context) ([]models.ProbeNode, error) { return nil, nil }
func (m *mockStore) LoadAlertHealth() (map[int]models.AlertHealthRecord, error) { func (m *mockStore) UpdateNodeLastSeen(context.Context, string) error { return nil }
func (m *mockStore) DeleteNode(context.Context, string) error { return nil }
func (m *mockStore) LoadAlertHealth(context.Context) (map[int]models.AlertHealthRecord, error) {
return nil, nil return nil, nil
} }
func (m *mockStore) SaveAlertHealth(models.AlertHealthRecord) error { return nil } func (m *mockStore) SaveAlertHealth(context.Context, models.AlertHealthRecord) error { return nil }
func (m *mockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) { func (m *mockStore) GetActiveMaintenanceWindows(context.Context) ([]models.MaintenanceWindow, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
var windows []models.MaintenanceWindow var windows []models.MaintenanceWindow
@@ -77,23 +83,27 @@ func (m *mockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, e
} }
return windows, nil return windows, nil
} }
func (m *mockStore) GetAllMaintenanceWindows(int) ([]models.MaintenanceWindow, error) { func (m *mockStore) GetAllMaintenanceWindows(context.Context, int) ([]models.MaintenanceWindow, error) {
return nil, nil return nil, nil
} }
func (m *mockStore) AddMaintenanceWindow(models.MaintenanceWindow) error { return nil } func (m *mockStore) AddMaintenanceWindow(context.Context, models.MaintenanceWindow) error { return nil }
func (m *mockStore) EndMaintenanceWindow(int) error { return nil } func (m *mockStore) EndMaintenanceWindow(context.Context, int) error { return nil }
func (m *mockStore) DeleteMaintenanceWindow(int) error { return nil } func (m *mockStore) DeleteMaintenanceWindow(context.Context, int) error { return nil }
func (m *mockStore) PruneExpiredMaintenanceWindows(time.Duration) (int64, error) { return 0, nil } func (m *mockStore) PruneExpiredMaintenanceWindows(context.Context, time.Duration) (int64, error) {
func (m *mockStore) GetPreference(string) (string, error) { return "", nil } return 0, nil
func (m *mockStore) SetPreference(string, string) error { return nil } }
func (m *mockStore) SaveStateChange(int, string, string, string) error { return nil } func (m *mockStore) GetPreference(context.Context, string) (string, error) { return "", nil }
func (m *mockStore) GetStateChanges(int, int) ([]models.StateChange, error) { return nil, nil } func (m *mockStore) SetPreference(context.Context, string, string) error { return nil }
func (m *mockStore) GetStateChangesSince(int, time.Time) ([]models.StateChange, error) { func (m *mockStore) SaveStateChange(context.Context, int, string, string, string) error { return nil }
func (m *mockStore) GetStateChanges(context.Context, int, int) ([]models.StateChange, error) {
return nil, nil
}
func (m *mockStore) GetStateChangesSince(context.Context, int, time.Time) ([]models.StateChange, error) {
return nil, nil return nil, nil
} }
func (m *mockStore) Close() error { return nil } func (m *mockStore) Close() error { return nil }
func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { func (m *mockStore) GetAllAlerts(context.Context) ([]models.AlertConfig, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
var result []models.AlertConfig var result []models.AlertConfig
@@ -103,7 +113,7 @@ func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) {
return result, nil return result, nil
} }
func (m *mockStore) GetAlert(id int) (models.AlertConfig, error) { func (m *mockStore) GetAlert(_ context.Context, id int) (models.AlertConfig, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.getAlertCalls = append(m.getAlertCalls, id) m.getAlertCalls = append(m.getAlertCalls, id)
@@ -113,7 +123,7 @@ func (m *mockStore) GetAlert(id int) (models.AlertConfig, error) {
return models.AlertConfig{}, fmt.Errorf("alert %d not found", id) return models.AlertConfig{}, fmt.Errorf("alert %d not found", id)
} }
func (m *mockStore) GetAlertByName(name string) (models.AlertConfig, error) { func (m *mockStore) GetAlertByName(_ context.Context, name string) (models.AlertConfig, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
for _, a := range m.alerts { for _, a := range m.alerts {
@@ -124,37 +134,37 @@ func (m *mockStore) GetAlertByName(name string) (models.AlertConfig, error) {
return models.AlertConfig{}, fmt.Errorf("alert %q not found", name) return models.AlertConfig{}, fmt.Errorf("alert %q not found", name)
} }
func (m *mockStore) IsMonitorInMaintenance(id int) (bool, error) { func (m *mockStore) IsMonitorInMaintenance(_ context.Context, id int) (bool, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
return m.maintenance[id], nil return m.maintenance[id], nil
} }
func (m *mockStore) SaveCheck(siteID int, latencyNs int64, isUp bool) error { func (m *mockStore) SaveCheck(_ context.Context, siteID int, latencyNs int64, isUp bool) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.savedChecks = append(m.savedChecks, savedCheck{siteID, latencyNs, isUp}) m.savedChecks = append(m.savedChecks, savedCheck{siteID, latencyNs, isUp})
return nil return nil
} }
func (m *mockStore) SaveLog(msg string) error { func (m *mockStore) SaveLog(_ context.Context, msg string) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.savedLogs = append(m.savedLogs, msg) m.savedLogs = append(m.savedLogs, msg)
return nil return nil
} }
func (m *mockStore) LoadLogs(limit int) ([]string, error) { func (m *mockStore) LoadLogs(_ context.Context, limit int) ([]string, error) {
return m.logs, nil return m.logs, nil
} }
func (m *mockStore) LoadAllHistory(limit int) (map[int][]models.CheckRecord, error) { func (m *mockStore) LoadAllHistory(_ context.Context, limit int) (map[int][]models.CheckRecord, error) {
return m.history, nil return m.history, nil
} }
func (m *mockStore) PruneLogs() error { return nil } func (m *mockStore) PruneLogs(context.Context) error { return nil }
func (m *mockStore) PruneCheckHistory() error { return nil } func (m *mockStore) PruneCheckHistory(context.Context) error { return nil }
func (m *mockStore) PruneStateChanges() error { return nil } func (m *mockStore) PruneStateChanges(context.Context) error { return nil }
// --- Helpers --- // --- Helpers ---
@@ -336,7 +346,7 @@ func TestHandleStatusChange_AlertSuppressedMaintenance(t *testing.T) {
e := newTestEngine(ms) e := newTestEngine(ms)
site := models.Site{ID: 1, Name: "test", Status: "UP", MaxRetries: 0, AlertID: 1} site := models.Site{ID: 1, Name: "test", Status: "UP", MaxRetries: 0, AlertID: 1}
injectSite(e, site) injectSite(e, site)
e.refreshMaintenanceCache() e.refreshMaintenanceCache(context.Background())
e.handleStatusChange(site, "DOWN", 0, 0, "test error") e.handleStatusChange(site, "DOWN", 0, 0, "test error")
@@ -368,7 +378,7 @@ func TestHandleStatusChange_RecoverySuppressedMaintenance(t *testing.T) {
e := newTestEngine(ms) e := newTestEngine(ms)
site := models.Site{ID: 1, Name: "test", Status: "DOWN", AlertID: 1} site := models.Site{ID: 1, Name: "test", Status: "DOWN", AlertID: 1}
injectSite(e, site) injectSite(e, site)
e.refreshMaintenanceCache() e.refreshMaintenanceCache(context.Background())
e.handleStatusChange(site, "UP", 200, 0, "") e.handleStatusChange(site, "UP", 200, 0, "")
@@ -456,7 +466,7 @@ func TestHandleStatusChange_SSLWarningSuppressedMaint(t *testing.T) {
CertExpiry: time.Now().Add(15 * 24 * time.Hour), CertExpiry: time.Now().Add(15 * 24 * time.Hour),
} }
injectSite(e, site) injectSite(e, site)
e.refreshMaintenanceCache() e.refreshMaintenanceCache(context.Background())
e.handleStatusChange(site, "UP", 200, 0, "") e.handleStatusChange(site, "UP", 200, 0, "")
@@ -563,7 +573,7 @@ func TestCheckPush_DeadlineMissed(t *testing.T) {
} }
injectSite(e, site) injectSite(e, site)
e.checkPush(site) e.checkPush(context.Background(), site)
s, _ := getSite(e, 1) s, _ := getSite(e, 1)
if s.Status != "DOWN" { if s.Status != "DOWN" {
@@ -581,7 +591,7 @@ func TestCheckPush_OverdueBecomesLate(t *testing.T) {
} }
injectSite(e, site) injectSite(e, site)
e.checkPush(site) e.checkPush(context.Background(), site)
s, _ := getSite(e, 1) s, _ := getSite(e, 1)
if s.Status != "LATE" { if s.Status != "LATE" {
@@ -601,7 +611,7 @@ func TestCheckPush_OverdueBecomesStale(t *testing.T) {
} }
injectSite(e, site) injectSite(e, site)
e.checkPush(site) e.checkPush(context.Background(), site)
s, _ := getSite(e, 1) s, _ := getSite(e, 1)
if s.Status != "STALE" { if s.Status != "STALE" {
@@ -618,7 +628,7 @@ func TestCheckPush_WithinDeadline(t *testing.T) {
} }
injectSite(e, site) injectSite(e, site)
e.checkPush(site) e.checkPush(context.Background(), site)
s, _ := getSite(e, 1) s, _ := getSite(e, 1)
if s.Status != "UP" { if s.Status != "UP" {
@@ -635,7 +645,7 @@ func TestCheckPush_PendingStaysPending(t *testing.T) {
} }
injectSite(e, site) injectSite(e, site)
e.checkPush(site) e.checkPush(context.Background(), site)
s, _ := getSite(e, 1) s, _ := getSite(e, 1)
if s.Status != "PENDING" { if s.Status != "PENDING" {
@@ -655,7 +665,7 @@ func TestCheckGroup_AllChildrenUp(t *testing.T) {
injectSite(e, child1) injectSite(e, child1)
injectSite(e, child2) injectSite(e, child2)
e.checkGroup(group) e.checkGroup(context.Background(), group)
s, _ := getSite(e, 1) s, _ := getSite(e, 1)
if s.Status != "UP" { if s.Status != "UP" {
@@ -673,7 +683,7 @@ func TestCheckGroup_OneChildDown(t *testing.T) {
injectSite(e, child1) injectSite(e, child1)
injectSite(e, child2) injectSite(e, child2)
e.checkGroup(group) e.checkGroup(context.Background(), group)
s, _ := getSite(e, 1) s, _ := getSite(e, 1)
if s.Status != "DOWN" { if s.Status != "DOWN" {
@@ -691,7 +701,7 @@ func TestCheckGroup_PausedChildIgnored(t *testing.T) {
injectSite(e, child1) injectSite(e, child1)
injectSite(e, child2) injectSite(e, child2)
e.checkGroup(group) e.checkGroup(context.Background(), group)
s, _ := getSite(e, 1) s, _ := getSite(e, 1)
if s.Status != "UP" { if s.Status != "UP" {
@@ -709,9 +719,9 @@ func TestCheckGroup_MaintenanceChildIgnored(t *testing.T) {
injectSite(e, group) injectSite(e, group)
injectSite(e, child1) injectSite(e, child1)
injectSite(e, child2) injectSite(e, child2)
e.refreshMaintenanceCache() e.refreshMaintenanceCache(context.Background())
e.checkGroup(group) e.checkGroup(context.Background(), group)
s, _ := getSite(e, 1) s, _ := getSite(e, 1)
if s.Status != "UP" { if s.Status != "UP" {
@@ -725,7 +735,7 @@ func TestCheckGroup_NoChildren(t *testing.T) {
group := models.Site{ID: 1, Name: "group", Type: "group", Status: "UP"} group := models.Site{ID: 1, Name: "group", Type: "group", Status: "UP"}
injectSite(e, group) injectSite(e, group)
e.checkGroup(group) e.checkGroup(context.Background(), group)
s, _ := getSite(e, 1) s, _ := getSite(e, 1)
if s.Status != "PENDING" { if s.Status != "PENDING" {
@@ -1241,7 +1251,7 @@ func TestCheckGroup_AllPausedNoAutoFreeze(t *testing.T) {
injectSite(e, child1) injectSite(e, child1)
injectSite(e, child2) injectSite(e, child2)
e.checkGroup(group) e.checkGroup(context.Background(), group)
s, _ := getSite(e, 1) s, _ := getSite(e, 1)
if s.Paused { if s.Paused {
@@ -1361,7 +1371,7 @@ func TestIsInMaintenance_UsesCache(t *testing.T) {
child := models.Site{ID: 20, Name: "child", Type: "http", ParentID: 10, Status: "UP"} child := models.Site{ID: 20, Name: "child", Type: "http", ParentID: 10, Status: "UP"}
injectSite(e, group) injectSite(e, group)
injectSite(e, child) injectSite(e, child)
e.refreshMaintenanceCache() e.refreshMaintenanceCache(context.Background())
if !e.isInMaintenance(10) { if !e.isInMaintenance(10) {
t.Error("group should be in maintenance (direct)") t.Error("group should be in maintenance (direct)")
@@ -1381,7 +1391,7 @@ func TestIsInMaintenance_GlobalMaintenance(t *testing.T) {
e := newTestEngine(ms) e := newTestEngine(ms)
site := models.Site{ID: 1, Name: "test", Type: "http", Status: "UP"} site := models.Site{ID: 1, Name: "test", Type: "http", Status: "UP"}
injectSite(e, site) injectSite(e, site)
e.refreshMaintenanceCache() e.refreshMaintenanceCache(context.Background())
if !e.isInMaintenance(1) { if !e.isInMaintenance(1) {
t.Error("all monitors should be in maintenance during global window") t.Error("all monitors should be in maintenance during global window")
+7 -7
View File
@@ -255,7 +255,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server {
http.Error(w, "Unauthorized: UPTOP_CLUSTER_SECRET required", http.StatusUnauthorized) http.Error(w, "Unauthorized: UPTOP_CLUSTER_SECRET required", http.StatusUnauthorized)
return return
} }
data, err := s.ExportData() data, err := s.ExportData(r.Context())
if err != nil { if err != nil {
log.Printf("Export failed: %v", err) log.Printf("Export failed: %v", err)
http.Error(w, "Export failed", http.StatusInternalServerError) http.Error(w, "Export failed", http.StatusInternalServerError)
@@ -285,7 +285,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server {
http.Error(w, "Invalid JSON", http.StatusBadRequest) http.Error(w, "Invalid JSON", http.StatusBadRequest)
return return
} }
if err := s.ImportData(data); err != nil { if err := s.ImportData(r.Context(), data); err != nil {
log.Printf("Import failed: %v", err) log.Printf("Import failed: %v", err)
http.Error(w, "Import failed", http.StatusInternalServerError) http.Error(w, "Import failed", http.StatusInternalServerError)
return return
@@ -311,7 +311,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server {
return return
} }
backup := importer.ConvertKuma(&kb) backup := importer.ConvertKuma(&kb)
if err := s.ImportData(backup); err != nil { if err := s.ImportData(r.Context(), backup); err != nil {
log.Printf("Kuma import failed: %v", err) log.Printf("Kuma import failed: %v", err)
http.Error(w, "Import failed", http.StatusInternalServerError) http.Error(w, "Import failed", http.StatusInternalServerError)
return return
@@ -344,7 +344,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server {
http.Error(w, "id is required", http.StatusBadRequest) http.Error(w, "id is required", http.StatusBadRequest)
return return
} }
if err := s.RegisterNode(models.ProbeNode{ if err := s.RegisterNode(r.Context(), models.ProbeNode{
ID: req.ID, Name: req.Name, Region: req.Region, Version: req.Version, ID: req.ID, Name: req.Name, Region: req.Region, Version: req.Version,
}); err != nil { }); err != nil {
log.Printf("Probe register failed: %v", err) log.Printf("Probe register failed: %v", err)
@@ -367,7 +367,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server {
nodeID := r.URL.Query().Get("node_id") nodeID := r.URL.Query().Get("node_id")
var nodeRegion string var nodeRegion string
if nodeID != "" { if nodeID != "" {
if node, err := s.GetNode(nodeID); err == nil { if node, err := s.GetNode(r.Context(), nodeID); err == nil {
nodeRegion = node.Region nodeRegion = node.Region
} }
} }
@@ -427,7 +427,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server {
eng.EnqueueProbeCheck(result.SiteID, req.NodeID, result.LatencyNs, result.IsUp) eng.EnqueueProbeCheck(result.SiteID, req.NodeID, result.LatencyNs, result.IsUp)
eng.IngestProbeResult(req.NodeID, result.SiteID, result.LatencyNs, result.IsUp, result.ErrorReason) eng.IngestProbeResult(req.NodeID, result.SiteID, result.LatencyNs, result.IsUp, result.ErrorReason)
} }
if err := s.UpdateNodeLastSeen(req.NodeID); err != nil { if err := s.UpdateNodeLastSeen(r.Context(), req.NodeID); err != nil {
log.Printf("Failed to update node last seen: %v", err) log.Printf("Failed to update node last seen: %v", err)
} }
_ = json.NewEncoder(w).Encode(map[string]bool{"ok": true}) //nolint:errcheck _ = json.NewEncoder(w).Encode(map[string]bool{"ok": true}) //nolint:errcheck
@@ -453,7 +453,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server {
mux.HandleFunc("/status", RateLimit(statusRL, func(w http.ResponseWriter, r *http.Request) { renderStatusPage(w, cfg.Title, eng) })) mux.HandleFunc("/status", RateLimit(statusRL, func(w http.ResponseWriter, r *http.Request) { renderStatusPage(w, cfg.Title, eng) }))
mux.HandleFunc("/status/json", RateLimit(statusRL, func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/status/json", RateLimit(statusRL, func(w http.ResponseWriter, r *http.Request) {
state := eng.GetLiveState() state := eng.GetLiveState()
activeWindows, _ := s.GetActiveMaintenanceWindows() activeWindows, _ := s.GetActiveMaintenanceWindows(r.Context())
maintSet := make(map[int]bool) maintSet := make(map[int]bool)
allInMaint := false allInMaint := false
for _, mw := range activeWindows { for _, mw := range activeWindows {
+71 -51
View File
@@ -33,80 +33,100 @@ func newMockStore() *mockStore {
} }
} }
func (m *mockStore) Init() error { return nil } func (m *mockStore) Init(_ context.Context) error { return nil }
func (m *mockStore) GetSites() ([]models.Site, error) { return m.sites, nil } func (m *mockStore) GetSites(_ context.Context) ([]models.Site, error) { return m.sites, nil }
func (m *mockStore) AddSite(models.Site) error { return nil } func (m *mockStore) AddSite(_ context.Context, _ models.Site) error { return nil }
func (m *mockStore) UpdateSite(models.Site) error { return nil } func (m *mockStore) UpdateSite(_ context.Context, _ models.Site) error { return nil }
func (m *mockStore) UpdateSitePaused(int, bool) error { return nil } func (m *mockStore) UpdateSitePaused(_ context.Context, _ int, _ bool) error { return nil }
func (m *mockStore) DeleteSite(int) error { return nil } func (m *mockStore) DeleteSite(_ context.Context, _ int) error { return nil }
func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { return m.alerts, nil } func (m *mockStore) GetAllAlerts(_ context.Context) ([]models.AlertConfig, error) {
func (m *mockStore) GetAlert(int) (models.AlertConfig, error) { return models.AlertConfig{}, nil } return m.alerts, nil
func (m *mockStore) AddAlert(string, string, map[string]string) error { return nil }
func (m *mockStore) UpdateAlert(int, string, string, map[string]string) error { return nil }
func (m *mockStore) DeleteAlert(int) error { return nil }
func (m *mockStore) GetAllUsers() ([]models.User, error) { return nil, nil }
func (m *mockStore) AddUser(string, string, string) error { return nil }
func (m *mockStore) UpdateUser(int, string, string, string) error { return nil }
func (m *mockStore) DeleteUser(int) error { return nil }
func (m *mockStore) SaveCheck(int, int64, bool) error { return nil }
func (m *mockStore) SaveCheckFromNode(siteID int, nodeID string, latencyNs int64, isUp bool) error {
return nil
} }
func (m *mockStore) LoadAllHistory(int) (map[int][]models.CheckRecord, error) { func (m *mockStore) GetAlert(_ context.Context, _ int) (models.AlertConfig, error) {
return nil, nil
}
func (m *mockStore) GetSiteByName(string) (models.Site, error) { return models.Site{}, nil }
func (m *mockStore) GetAlertByName(string) (models.AlertConfig, error) {
return models.AlertConfig{}, nil return models.AlertConfig{}, nil
} }
func (m *mockStore) AddSiteReturningID(models.Site) (int, error) { return 0, nil } func (m *mockStore) AddAlert(_ context.Context, _ string, _ string, _ map[string]string) error {
func (m *mockStore) AddAlertReturningID(string, string, map[string]string) (int, error) { return nil
}
func (m *mockStore) UpdateAlert(_ context.Context, _ int, _ string, _ string, _ map[string]string) error {
return nil
}
func (m *mockStore) DeleteAlert(_ context.Context, _ int) error { return nil }
func (m *mockStore) GetAllUsers(_ context.Context) ([]models.User, error) { return nil, nil }
func (m *mockStore) AddUser(_ context.Context, _ string, _ string, _ string) error { return nil }
func (m *mockStore) UpdateUser(_ context.Context, _ int, _ string, _ string, _ string) error {
return nil
}
func (m *mockStore) DeleteUser(_ context.Context, _ int) error { return nil }
func (m *mockStore) SaveCheck(_ context.Context, _ int, _ int64, _ bool) error { return nil }
func (m *mockStore) SaveCheckFromNode(_ context.Context, siteID int, nodeID string, latencyNs int64, isUp bool) error {
return nil
}
func (m *mockStore) LoadAllHistory(_ context.Context, _ int) (map[int][]models.CheckRecord, error) {
return nil, nil
}
func (m *mockStore) GetSiteByName(_ context.Context, _ string) (models.Site, error) {
return models.Site{}, nil
}
func (m *mockStore) GetAlertByName(_ context.Context, _ string) (models.AlertConfig, error) {
return models.AlertConfig{}, nil
}
func (m *mockStore) AddSiteReturningID(_ context.Context, _ models.Site) (int, error) { return 0, nil }
func (m *mockStore) AddAlertReturningID(_ context.Context, _ string, _ string, _ map[string]string) (int, error) {
return 0, nil return 0, nil
} }
func (m *mockStore) GetAllNodes() ([]models.ProbeNode, error) { return nil, nil } func (m *mockStore) GetAllNodes(_ context.Context) ([]models.ProbeNode, error) { return nil, nil }
func (m *mockStore) UpdateNodeLastSeen(string) error { return nil } func (m *mockStore) UpdateNodeLastSeen(_ context.Context, _ string) error { return nil }
func (m *mockStore) DeleteNode(string) error { return nil } func (m *mockStore) DeleteNode(_ context.Context, _ string) error { return nil }
func (m *mockStore) LoadAlertHealth() (map[int]models.AlertHealthRecord, error) { func (m *mockStore) LoadAlertHealth(_ context.Context) (map[int]models.AlertHealthRecord, error) {
return nil, nil return nil, nil
} }
func (m *mockStore) SaveAlertHealth(models.AlertHealthRecord) error { return nil } func (m *mockStore) SaveAlertHealth(_ context.Context, _ models.AlertHealthRecord) error { return nil }
func (m *mockStore) SaveLog(string) error { return nil } func (m *mockStore) SaveLog(_ context.Context, _ string) error { return nil }
func (m *mockStore) PruneLogs() error { return nil } func (m *mockStore) PruneLogs(_ context.Context) error { return nil }
func (m *mockStore) PruneCheckHistory() error { return nil } func (m *mockStore) PruneCheckHistory(_ context.Context) error { return nil }
func (m *mockStore) PruneStateChanges() error { return nil } func (m *mockStore) PruneStateChanges(_ context.Context) error { return nil }
func (m *mockStore) LoadLogs(int) ([]string, error) { return nil, nil } func (m *mockStore) LoadLogs(_ context.Context, _ int) ([]string, error) { return nil, nil }
func (m *mockStore) GetAllMaintenanceWindows(int) ([]models.MaintenanceWindow, error) { func (m *mockStore) GetAllMaintenanceWindows(_ context.Context, _ int) ([]models.MaintenanceWindow, error) {
return nil, nil return nil, nil
} }
func (m *mockStore) AddMaintenanceWindow(models.MaintenanceWindow) error { return nil } func (m *mockStore) AddMaintenanceWindow(_ context.Context, _ models.MaintenanceWindow) error {
func (m *mockStore) EndMaintenanceWindow(int) error { return nil } return nil
func (m *mockStore) DeleteMaintenanceWindow(int) error { return nil } }
func (m *mockStore) PruneExpiredMaintenanceWindows(time.Duration) (int64, error) { return 0, nil } func (m *mockStore) EndMaintenanceWindow(_ context.Context, _ int) error { return nil }
func (m *mockStore) IsMonitorInMaintenance(int) (bool, error) { return false, nil } func (m *mockStore) DeleteMaintenanceWindow(_ context.Context, _ int) error { return nil }
func (m *mockStore) GetPreference(string) (string, error) { return "", nil } func (m *mockStore) PruneExpiredMaintenanceWindows(_ context.Context, _ time.Duration) (int64, error) {
func (m *mockStore) SetPreference(string, string) error { return nil } return 0, nil
func (m *mockStore) SaveStateChange(int, string, string, string) error { return nil } }
func (m *mockStore) GetStateChanges(int, int) ([]models.StateChange, error) { return nil, nil } func (m *mockStore) IsMonitorInMaintenance(_ context.Context, _ int) (bool, error) { return false, nil }
func (m *mockStore) GetStateChangesSince(int, time.Time) ([]models.StateChange, error) { func (m *mockStore) GetPreference(_ context.Context, _ string) (string, error) { return "", nil }
func (m *mockStore) SetPreference(_ context.Context, _ string, _ string) error { return nil }
func (m *mockStore) SaveStateChange(_ context.Context, _ int, _ string, _ string, _ string) error {
return nil
}
func (m *mockStore) GetStateChanges(_ context.Context, _ int, _ int) ([]models.StateChange, error) {
return nil, nil
}
func (m *mockStore) GetStateChangesSince(_ context.Context, _ int, _ time.Time) ([]models.StateChange, error) {
return nil, nil return nil, nil
} }
func (m *mockStore) Close() error { return nil } func (m *mockStore) Close() error { return nil }
func (m *mockStore) ExportData() (models.Backup, error) { func (m *mockStore) ExportData(_ context.Context) (models.Backup, error) {
return models.Backup{ return models.Backup{
Sites: m.sites, Sites: m.sites,
Alerts: m.alerts, Alerts: m.alerts,
}, nil }, nil
} }
func (m *mockStore) ImportData(data models.Backup) error { func (m *mockStore) ImportData(_ context.Context, data models.Backup) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.importedData = &data m.importedData = &data
return nil return nil
} }
func (m *mockStore) RegisterNode(node models.ProbeNode) error { func (m *mockStore) RegisterNode(_ context.Context, node models.ProbeNode) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.registeredNodes = append(m.registeredNodes, node) m.registeredNodes = append(m.registeredNodes, node)
@@ -114,7 +134,7 @@ func (m *mockStore) RegisterNode(node models.ProbeNode) error {
return nil return nil
} }
func (m *mockStore) GetNode(id string) (models.ProbeNode, error) { func (m *mockStore) GetNode(_ context.Context, id string) (models.ProbeNode, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if n, ok := m.nodes[id]; ok { if n, ok := m.nodes[id]; ok {
@@ -123,7 +143,7 @@ func (m *mockStore) GetNode(id string) (models.ProbeNode, error) {
return models.ProbeNode{}, fmt.Errorf("not found") return models.ProbeNode{}, fmt.Errorf("not found")
} }
func (m *mockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) { func (m *mockStore) GetActiveMaintenanceWindows(_ context.Context) ([]models.MaintenanceWindow, error) {
return m.maintWindows, nil return m.maintWindows, nil
} }
+110 -109
View File
@@ -1,6 +1,7 @@
package store package store
import ( import (
"context"
"crypto/rand" "crypto/rand"
"database/sql" "database/sql"
"encoding/hex" "encoding/hex"
@@ -73,14 +74,14 @@ func (s *SQLStore) Close() error {
return s.db.Close() return s.db.Close()
} }
func (s *SQLStore) Init() error { func (s *SQLStore) Init(ctx context.Context) error {
for _, stmt := range s.dialect.CreateTablesSQL() { for _, stmt := range s.dialect.CreateTablesSQL() {
if _, err := s.db.Exec(stmt); err != nil { if _, err := s.db.ExecContext(ctx, stmt); err != nil {
return err return err
} }
} }
for _, m := range s.dialect.MigrationsSQL() { for _, m := range s.dialect.MigrationsSQL() {
if _, err := s.db.Exec(m); err != nil { if _, err := s.db.ExecContext(ctx, m); err != nil {
errMsg := err.Error() errMsg := err.Error()
if strings.Contains(errMsg, "already exists") || strings.Contains(errMsg, "duplicate column") { if strings.Contains(errMsg, "already exists") || strings.Contains(errMsg, "duplicate column") {
continue continue
@@ -91,13 +92,13 @@ func (s *SQLStore) Init() error {
return nil return nil
} }
func (s *SQLStore) GetSites() ([]models.Site, error) { func (s *SQLStore) GetSites(ctx context.Context) ([]models.Site, error) {
bf := s.dialect.BoolFalse() bf := s.dialect.BoolFalse()
query := fmt.Sprintf( //nolint:gosec // bf is a dialect boolean literal, not user input query := fmt.Sprintf( //nolint:gosec // bf is a dialect boolean literal, not user input
"SELECT id, COALESCE(name, url), url, COALESCE(type, 'http'), COALESCE(token, ''), interval, alert_id, check_ssl, threshold, max_retries, COALESCE(hostname, ''), COALESCE(port, 0), COALESCE(timeout, 0), COALESCE(method, 'GET'), COALESCE(description, ''), COALESCE(parent_id, 0), COALESCE(accepted_codes, '200-299'), COALESCE(dns_resolve_type, ''), COALESCE(dns_server, ''), COALESCE(ignore_tls, %s), COALESCE(paused, %s), COALESCE(regions, '') FROM sites", "SELECT id, COALESCE(name, url), url, COALESCE(type, 'http'), COALESCE(token, ''), interval, alert_id, check_ssl, threshold, max_retries, COALESCE(hostname, ''), COALESCE(port, 0), COALESCE(timeout, 0), COALESCE(method, 'GET'), COALESCE(description, ''), COALESCE(parent_id, 0), COALESCE(accepted_codes, '200-299'), COALESCE(dns_resolve_type, ''), COALESCE(dns_server, ''), COALESCE(ignore_tls, %s), COALESCE(paused, %s), COALESCE(regions, '') FROM sites",
bf, bf, bf, bf,
) )
rows, err := s.db.Query(query) rows, err := s.db.QueryContext(ctx, query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -116,7 +117,7 @@ func (s *SQLStore) GetSites() ([]models.Site, error) {
return sites, rows.Err() return sites, rows.Err()
} }
func (s *SQLStore) AddSite(site models.Site) error { func (s *SQLStore) AddSite(ctx context.Context, site models.Site) error {
token := "" token := ""
if site.Type == "push" { if site.Type == "push" {
var err error var err error
@@ -125,15 +126,15 @@ func (s *SQLStore) AddSite(site models.Site) error {
return fmt.Errorf("generate push token: %w", err) return fmt.Errorf("generate push token: %w", err)
} }
} }
_, err := s.db.Exec(s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), _, err := s.db.ExecContext(ctx, s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"),
site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries,
site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.Regions) site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.Regions)
return err return err
} }
func (s *SQLStore) UpdateSite(site models.Site) error { func (s *SQLStore) UpdateSite(ctx context.Context, site models.Site) error {
var existingToken string var existingToken string
_ = s.db.QueryRow(s.q("SELECT token FROM sites WHERE id=?"), site.ID).Scan(&existingToken) //nolint:errcheck _ = s.db.QueryRowContext(ctx, s.q("SELECT token FROM sites WHERE id=?"), site.ID).Scan(&existingToken) //nolint:errcheck
if site.Type == "push" && existingToken == "" { if site.Type == "push" && existingToken == "" {
var err error var err error
existingToken, err = generateToken() existingToken, err = generateToken()
@@ -141,19 +142,19 @@ func (s *SQLStore) UpdateSite(site models.Site) error {
return fmt.Errorf("generate push token: %w", err) return fmt.Errorf("generate push token: %w", err)
} }
} }
_, err := s.db.Exec(s.q("UPDATE sites SET name=?, url=?, type=?, token=?, interval=?, alert_id=?, check_ssl=?, threshold=?, max_retries=?, hostname=?, port=?, timeout=?, method=?, description=?, parent_id=?, accepted_codes=?, dns_resolve_type=?, dns_server=?, ignore_tls=?, paused=?, regions=? WHERE id=?"), _, err := s.db.ExecContext(ctx, s.q("UPDATE sites SET name=?, url=?, type=?, token=?, interval=?, alert_id=?, check_ssl=?, threshold=?, max_retries=?, hostname=?, port=?, timeout=?, method=?, description=?, parent_id=?, accepted_codes=?, dns_resolve_type=?, dns_server=?, ignore_tls=?, paused=?, regions=? WHERE id=?"),
site.Name, site.URL, site.Type, existingToken, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, site.Name, site.URL, site.Type, existingToken, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries,
site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.Regions, site.ID) site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.Regions, site.ID)
return err return err
} }
func (s *SQLStore) UpdateSitePaused(id int, paused bool) error { func (s *SQLStore) UpdateSitePaused(ctx context.Context, id int, paused bool) error {
_, err := s.db.Exec(s.q("UPDATE sites SET paused=? WHERE id=?"), paused, id) _, err := s.db.ExecContext(ctx, s.q("UPDATE sites SET paused=? WHERE id=?"), paused, id)
return err return err
} }
func (s *SQLStore) DeleteSite(id int) error { func (s *SQLStore) DeleteSite(ctx context.Context, id int) error {
tx, err := s.db.Begin() tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -165,7 +166,7 @@ func (s *SQLStore) DeleteSite(id int) error {
"DELETE FROM state_changes WHERE site_id = ?", "DELETE FROM state_changes WHERE site_id = ?",
"DELETE FROM sites WHERE id = ?", "DELETE FROM sites WHERE id = ?",
} { } {
if _, err := tx.Exec(s.q(q), id); err != nil { if _, err := tx.ExecContext(ctx, s.q(q), id); err != nil {
return err return err
} }
} }
@@ -177,14 +178,14 @@ func (s *SQLStore) DeleteSite(id int) error {
return nil return nil
} }
func (s *SQLStore) GetSiteByName(name string) (models.Site, error) { func (s *SQLStore) GetSiteByName(ctx context.Context, name string) (models.Site, error) {
bf := s.dialect.BoolFalse() bf := s.dialect.BoolFalse()
query := fmt.Sprintf( //nolint:gosec // bf is a dialect boolean literal, not user input query := fmt.Sprintf( //nolint:gosec // bf is a dialect boolean literal, not user input
"SELECT id, COALESCE(name, url), url, COALESCE(type, 'http'), COALESCE(token, ''), interval, alert_id, check_ssl, threshold, max_retries, COALESCE(hostname, ''), COALESCE(port, 0), COALESCE(timeout, 0), COALESCE(method, 'GET'), COALESCE(description, ''), COALESCE(parent_id, 0), COALESCE(accepted_codes, '200-299'), COALESCE(dns_resolve_type, ''), COALESCE(dns_server, ''), COALESCE(ignore_tls, %s), COALESCE(paused, %s), COALESCE(regions, '') FROM sites WHERE name = %s", "SELECT id, COALESCE(name, url), url, COALESCE(type, 'http'), COALESCE(token, ''), interval, alert_id, check_ssl, threshold, max_retries, COALESCE(hostname, ''), COALESCE(port, 0), COALESCE(timeout, 0), COALESCE(method, 'GET'), COALESCE(description, ''), COALESCE(parent_id, 0), COALESCE(accepted_codes, '200-299'), COALESCE(dns_resolve_type, ''), COALESCE(dns_server, ''), COALESCE(ignore_tls, %s), COALESCE(paused, %s), COALESCE(regions, '') FROM sites WHERE name = %s",
bf, bf, s.q("?"), bf, bf, s.q("?"),
) )
var st models.Site var st models.Site
err := s.db.QueryRow(query, name).Scan(&st.ID, &st.Name, &st.URL, &st.Type, &st.Token, &st.Interval, &st.AlertID, err := s.db.QueryRowContext(ctx, query, name).Scan(&st.ID, &st.Name, &st.URL, &st.Type, &st.Token, &st.Interval, &st.AlertID,
&st.CheckSSL, &st.ExpiryThreshold, &st.MaxRetries, &st.Hostname, &st.Port, &st.Timeout, &st.CheckSSL, &st.ExpiryThreshold, &st.MaxRetries, &st.Hostname, &st.Port, &st.Timeout,
&st.Method, &st.Description, &st.ParentID, &st.AcceptedCodes, &st.DNSResolveType, &st.Method, &st.Description, &st.ParentID, &st.AcceptedCodes, &st.DNSResolveType,
&st.DNSServer, &st.IgnoreTLS, &st.Paused, &st.Regions) &st.DNSServer, &st.IgnoreTLS, &st.Paused, &st.Regions)
@@ -211,10 +212,10 @@ func (s *SQLStore) marshalSettings(settings map[string]string) (string, error) {
return s.encryptSettings(string(jsonBytes)) return s.encryptSettings(string(jsonBytes))
} }
func (s *SQLStore) GetAlertByName(name string) (models.AlertConfig, error) { func (s *SQLStore) GetAlertByName(ctx context.Context, name string) (models.AlertConfig, error) {
var a models.AlertConfig var a models.AlertConfig
var settingsRaw string var settingsRaw string
err := s.db.QueryRow(s.q("SELECT id, name, type, settings FROM alerts WHERE name = ?"), name).Scan(&a.ID, &a.Name, &a.Type, &settingsRaw) err := s.db.QueryRowContext(ctx, s.q("SELECT id, name, type, settings FROM alerts WHERE name = ?"), name).Scan(&a.ID, &a.Name, &a.Type, &settingsRaw)
if err != nil { if err != nil {
return a, err return a, err
} }
@@ -225,7 +226,7 @@ func (s *SQLStore) GetAlertByName(name string) (models.AlertConfig, error) {
return a, nil return a, nil
} }
func (s *SQLStore) AddSiteReturningID(site models.Site) (int, error) { func (s *SQLStore) AddSiteReturningID(ctx context.Context, site models.Site) (int, error) {
token := "" token := ""
if site.Type == "push" { if site.Type == "push" {
var err error var err error
@@ -236,12 +237,12 @@ func (s *SQLStore) AddSiteReturningID(site models.Site) (int, error) {
} }
if s.dollar { if s.dollar {
var id int var id int
err := s.db.QueryRow(s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) RETURNING id"), err := s.db.QueryRowContext(ctx, s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) RETURNING id"),
site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries,
site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.Regions).Scan(&id) site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.Regions).Scan(&id)
return id, err return id, err
} }
result, err := s.db.Exec(s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), result, err := s.db.ExecContext(ctx, s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"),
site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries,
site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.Regions) site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.Regions)
if err != nil { if err != nil {
@@ -251,17 +252,17 @@ func (s *SQLStore) AddSiteReturningID(site models.Site) (int, error) {
return int(id), err return int(id), err
} }
func (s *SQLStore) AddAlertReturningID(name, aType string, settings map[string]string) (int, error) { func (s *SQLStore) AddAlertReturningID(ctx context.Context, name, aType string, settings map[string]string) (int, error) {
stored, err := s.marshalSettings(settings) stored, err := s.marshalSettings(settings)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if s.dollar { if s.dollar {
var id int var id int
err := s.db.QueryRow(s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?) RETURNING id"), name, aType, stored).Scan(&id) err := s.db.QueryRowContext(ctx, s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?) RETURNING id"), name, aType, stored).Scan(&id)
return id, err return id, err
} }
result, err := s.db.Exec(s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?)"), name, aType, stored) result, err := s.db.ExecContext(ctx, s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?)"), name, aType, stored)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -269,8 +270,8 @@ func (s *SQLStore) AddAlertReturningID(name, aType string, settings map[string]s
return int(id), err return int(id), err
} }
func (s *SQLStore) GetAllAlerts() ([]models.AlertConfig, error) { func (s *SQLStore) GetAllAlerts(ctx context.Context) ([]models.AlertConfig, error) {
rows, err := s.db.Query("SELECT id, name, type, settings FROM alerts") rows, err := s.db.QueryContext(ctx, "SELECT id, name, type, settings FROM alerts")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -291,10 +292,10 @@ func (s *SQLStore) GetAllAlerts() ([]models.AlertConfig, error) {
return alerts, rows.Err() return alerts, rows.Err()
} }
func (s *SQLStore) GetAlert(id int) (models.AlertConfig, error) { func (s *SQLStore) GetAlert(ctx context.Context, id int) (models.AlertConfig, error) {
var a models.AlertConfig var a models.AlertConfig
var settingsRaw string var settingsRaw string
err := s.db.QueryRow(s.q("SELECT id, name, type, settings FROM alerts WHERE id = ?"), id).Scan(&a.ID, &a.Name, &a.Type, &settingsRaw) err := s.db.QueryRowContext(ctx, s.q("SELECT id, name, type, settings FROM alerts WHERE id = ?"), id).Scan(&a.ID, &a.Name, &a.Type, &settingsRaw)
if err != nil { if err != nil {
return a, err return a, err
} }
@@ -305,26 +306,26 @@ func (s *SQLStore) GetAlert(id int) (models.AlertConfig, error) {
return a, nil return a, nil
} }
func (s *SQLStore) AddAlert(name, aType string, settings map[string]string) error { func (s *SQLStore) AddAlert(ctx context.Context, name, aType string, settings map[string]string) error {
stored, err := s.marshalSettings(settings) stored, err := s.marshalSettings(settings)
if err != nil { if err != nil {
return err return err
} }
_, err = s.db.Exec(s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?)"), name, aType, stored) _, err = s.db.ExecContext(ctx, s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?)"), name, aType, stored)
return err return err
} }
func (s *SQLStore) UpdateAlert(id int, name, aType string, settings map[string]string) error { func (s *SQLStore) UpdateAlert(ctx context.Context, id int, name, aType string, settings map[string]string) error {
stored, err := s.marshalSettings(settings) stored, err := s.marshalSettings(settings)
if err != nil { if err != nil {
return err return err
} }
_, err = s.db.Exec(s.q("UPDATE alerts SET name=?, type=?, settings=? WHERE id=?"), name, aType, stored, id) _, err = s.db.ExecContext(ctx, s.q("UPDATE alerts SET name=?, type=?, settings=? WHERE id=?"), name, aType, stored, id)
return err return err
} }
func (s *SQLStore) DeleteAlert(id int) error { func (s *SQLStore) DeleteAlert(ctx context.Context, id int) error {
_, err := s.db.Exec(s.q("DELETE FROM alerts WHERE id=?"), id) _, err := s.db.ExecContext(ctx, s.q("DELETE FROM alerts WHERE id=?"), id)
if err != nil { if err != nil {
return err return err
} }
@@ -332,8 +333,8 @@ func (s *SQLStore) DeleteAlert(id int) error {
return nil return nil
} }
func (s *SQLStore) GetAllUsers() ([]models.User, error) { func (s *SQLStore) GetAllUsers(ctx context.Context) ([]models.User, error) {
rows, err := s.db.Query("SELECT id, username, public_key, role FROM users") rows, err := s.db.QueryContext(ctx, "SELECT id, username, public_key, role FROM users")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -349,29 +350,29 @@ func (s *SQLStore) GetAllUsers() ([]models.User, error) {
return users, rows.Err() return users, rows.Err()
} }
func (s *SQLStore) AddUser(username, publicKey, role string) error { func (s *SQLStore) AddUser(ctx context.Context, username, publicKey, role string) error {
_, err := s.db.Exec(s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), username, publicKey, role) _, err := s.db.ExecContext(ctx, s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), username, publicKey, role)
return err return err
} }
func (s *SQLStore) UpdateUser(id int, username, publicKey, role string) error { func (s *SQLStore) UpdateUser(ctx context.Context, id int, username, publicKey, role string) error {
_, err := s.db.Exec(s.q("UPDATE users SET username=?, public_key=?, role=? WHERE id=?"), username, publicKey, role, id) _, err := s.db.ExecContext(ctx, s.q("UPDATE users SET username=?, public_key=?, role=? WHERE id=?"), username, publicKey, role, id)
return err return err
} }
func (s *SQLStore) DeleteUser(id int) error { func (s *SQLStore) DeleteUser(ctx context.Context, id int) error {
_, err := s.db.Exec(s.q("DELETE FROM users WHERE id=?"), id) _, err := s.db.ExecContext(ctx, s.q("DELETE FROM users WHERE id=?"), id)
return err return err
} }
func (s *SQLStore) SaveStateChange(siteID int, fromStatus, toStatus, errorReason string) error { func (s *SQLStore) SaveStateChange(ctx context.Context, siteID int, fromStatus, toStatus, errorReason string) error {
_, err := s.db.Exec(s.q("INSERT INTO state_changes (site_id, from_status, to_status, error_reason) VALUES (?, ?, ?, ?)"), _, err := s.db.ExecContext(ctx, s.q("INSERT INTO state_changes (site_id, from_status, to_status, error_reason) VALUES (?, ?, ?, ?)"),
siteID, fromStatus, toStatus, errorReason) siteID, fromStatus, toStatus, errorReason)
return err return err
} }
func (s *SQLStore) GetStateChanges(siteID int, limit int) ([]models.StateChange, error) { func (s *SQLStore) GetStateChanges(ctx context.Context, siteID int, limit int) ([]models.StateChange, error) {
rows, err := s.db.Query(s.q("SELECT id, site_id, from_status, to_status, error_reason, changed_at FROM state_changes WHERE site_id = ? ORDER BY changed_at DESC LIMIT ?"), siteID, limit) rows, err := s.db.QueryContext(ctx, s.q("SELECT id, site_id, from_status, to_status, error_reason, changed_at FROM state_changes WHERE site_id = ? ORDER BY changed_at DESC LIMIT ?"), siteID, limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -387,8 +388,8 @@ func (s *SQLStore) GetStateChanges(siteID int, limit int) ([]models.StateChange,
return changes, rows.Err() return changes, rows.Err()
} }
func (s *SQLStore) GetStateChangesSince(siteID int, since time.Time) ([]models.StateChange, error) { func (s *SQLStore) GetStateChangesSince(ctx context.Context, siteID int, since time.Time) ([]models.StateChange, error) {
rows, err := s.db.Query(s.q("SELECT id, site_id, from_status, to_status, error_reason, changed_at FROM state_changes WHERE site_id = ? AND changed_at >= ? ORDER BY changed_at DESC"), siteID, since) rows, err := s.db.QueryContext(ctx, s.q("SELECT id, site_id, from_status, to_status, error_reason, changed_at FROM state_changes WHERE site_id = ? AND changed_at >= ? ORDER BY changed_at DESC"), siteID, since)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -404,59 +405,59 @@ func (s *SQLStore) GetStateChangesSince(siteID int, since time.Time) ([]models.S
return changes, rows.Err() return changes, rows.Err()
} }
func (s *SQLStore) SaveCheck(siteID int, latencyNs int64, isUp bool) error { func (s *SQLStore) SaveCheck(ctx context.Context, siteID int, latencyNs int64, isUp bool) error {
return s.SaveCheckFromNode(siteID, "", latencyNs, isUp) return s.SaveCheckFromNode(ctx, siteID, "", latencyNs, isUp)
} }
// SaveCheckFromNode inserts a single check row. Retention is handled out of // SaveCheckFromNode inserts a single check row. Retention is handled out of
// band by PruneCheckHistory on a timer, not per-insert, to keep the write hot // band by PruneCheckHistory on a timer, not per-insert, to keep the write hot
// path a plain INSERT. // path a plain INSERT.
func (s *SQLStore) SaveCheckFromNode(siteID int, nodeID string, latencyNs int64, isUp bool) error { func (s *SQLStore) SaveCheckFromNode(ctx context.Context, siteID int, nodeID string, latencyNs int64, isUp bool) error {
_, err := s.db.Exec(s.q("INSERT INTO check_history (site_id, node_id, latency_ns, is_up) VALUES (?, ?, ?, ?)"), siteID, nodeID, latencyNs, isUp) _, err := s.db.ExecContext(ctx, s.q("INSERT INTO check_history (site_id, node_id, latency_ns, is_up) VALUES (?, ?, ?, ?)"), siteID, nodeID, latencyNs, isUp)
return err return err
} }
// PruneCheckHistory trims check_history to the newest maxCheckHistory rows per // PruneCheckHistory trims check_history to the newest maxCheckHistory rows per
// site, across all sites, in one pass. Intended to run periodically. // site, across all sites, in one pass. Intended to run periodically.
func (s *SQLStore) PruneCheckHistory() error { func (s *SQLStore) PruneCheckHistory(ctx context.Context) error {
q := fmt.Sprintf(`DELETE FROM check_history WHERE id IN ( q := fmt.Sprintf(`DELETE FROM check_history WHERE id IN (
SELECT id FROM ( SELECT id FROM (
SELECT id, ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY checked_at DESC, id DESC) AS rn SELECT id, ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY checked_at DESC, id DESC) AS rn
FROM check_history FROM check_history
) ranked WHERE rn > %d ) ranked WHERE rn > %d
)`, maxCheckHistory) )`, maxCheckHistory)
_, err := s.db.Exec(s.q(q)) _, err := s.db.ExecContext(ctx, s.q(q))
return err return err
} }
// PruneStateChanges trims state_changes to the newest maxStateChangesPerSite // PruneStateChanges trims state_changes to the newest maxStateChangesPerSite
// rows per site. Generous so realistic SLA windows are unaffected; bounds the // rows per site. Generous so realistic SLA windows are unaffected; bounds the
// otherwise unbounded growth of a flapping monitor's history. // otherwise unbounded growth of a flapping monitor's history.
func (s *SQLStore) PruneStateChanges() error { func (s *SQLStore) PruneStateChanges(ctx context.Context) error {
q := fmt.Sprintf(`DELETE FROM state_changes WHERE id IN ( q := fmt.Sprintf(`DELETE FROM state_changes WHERE id IN (
SELECT id FROM ( SELECT id FROM (
SELECT id, ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY changed_at DESC, id DESC) AS rn SELECT id, ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY changed_at DESC, id DESC) AS rn
FROM state_changes FROM state_changes
) ranked WHERE rn > %d ) ranked WHERE rn > %d
)`, maxStateChangesPerSite) )`, maxStateChangesPerSite)
_, err := s.db.Exec(s.q(q)) _, err := s.db.ExecContext(ctx, s.q(q))
return err return err
} }
func (s *SQLStore) RegisterNode(node models.ProbeNode) error { func (s *SQLStore) RegisterNode(ctx context.Context, node models.ProbeNode) error {
_, err := s.db.Exec(s.dialect.UpsertNodeSQL(), node.ID, node.Name, node.Region, node.Version) _, err := s.db.ExecContext(ctx, s.dialect.UpsertNodeSQL(), node.ID, node.Name, node.Region, node.Version)
return err return err
} }
func (s *SQLStore) GetNode(id string) (models.ProbeNode, error) { func (s *SQLStore) GetNode(ctx context.Context, id string) (models.ProbeNode, error) {
var n models.ProbeNode var n models.ProbeNode
err := s.db.QueryRow(s.q("SELECT id, name, region, last_seen, version FROM nodes WHERE id = ?"), id). err := s.db.QueryRowContext(ctx, s.q("SELECT id, name, region, last_seen, version FROM nodes WHERE id = ?"), id).
Scan(&n.ID, &n.Name, &n.Region, &n.LastSeen, &n.Version) Scan(&n.ID, &n.Name, &n.Region, &n.LastSeen, &n.Version)
return n, err return n, err
} }
func (s *SQLStore) GetAllNodes() ([]models.ProbeNode, error) { func (s *SQLStore) GetAllNodes(ctx context.Context) ([]models.ProbeNode, error) {
rows, err := s.db.Query("SELECT id, name, region, last_seen, version FROM nodes ORDER BY region, name") rows, err := s.db.QueryContext(ctx, "SELECT id, name, region, last_seen, version FROM nodes ORDER BY region, name")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -472,18 +473,18 @@ func (s *SQLStore) GetAllNodes() ([]models.ProbeNode, error) {
return nodes, rows.Err() return nodes, rows.Err()
} }
func (s *SQLStore) UpdateNodeLastSeen(id string) error { func (s *SQLStore) UpdateNodeLastSeen(ctx context.Context, id string) error {
_, err := s.db.Exec(s.q("UPDATE nodes SET last_seen = CURRENT_TIMESTAMP WHERE id = ?"), id) _, err := s.db.ExecContext(ctx, s.q("UPDATE nodes SET last_seen = CURRENT_TIMESTAMP WHERE id = ?"), id)
return err return err
} }
func (s *SQLStore) DeleteNode(id string) error { func (s *SQLStore) DeleteNode(ctx context.Context, id string) error {
_, err := s.db.Exec(s.q("DELETE FROM nodes WHERE id = ?"), id) _, err := s.db.ExecContext(ctx, s.q("DELETE FROM nodes WHERE id = ?"), id)
return err return err
} }
func (s *SQLStore) LoadAlertHealth() (map[int]models.AlertHealthRecord, error) { func (s *SQLStore) LoadAlertHealth(ctx context.Context) (map[int]models.AlertHealthRecord, error) {
rows, err := s.db.Query("SELECT alert_id, last_send_at, last_send_ok, last_error, send_count, fail_count FROM alert_health") rows, err := s.db.QueryContext(ctx, "SELECT alert_id, last_send_at, last_send_ok, last_error, send_count, fail_count FROM alert_health")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -503,35 +504,35 @@ func (s *SQLStore) LoadAlertHealth() (map[int]models.AlertHealthRecord, error) {
return out, rows.Err() return out, rows.Err()
} }
func (s *SQLStore) SaveAlertHealth(h models.AlertHealthRecord) error { func (s *SQLStore) SaveAlertHealth(ctx context.Context, h models.AlertHealthRecord) error {
var lastSend interface{} var lastSend interface{}
if !h.LastSendAt.IsZero() { if !h.LastSendAt.IsZero() {
lastSend = h.LastSendAt lastSend = h.LastSendAt
} }
_, err := s.db.Exec(s.dialect.UpsertAlertHealthSQL(), _, err := s.db.ExecContext(ctx, s.dialect.UpsertAlertHealthSQL(),
h.AlertID, lastSend, h.LastSendOK, h.LastError, h.SendCount, h.FailCount) h.AlertID, lastSend, h.LastSendOK, h.LastError, h.SendCount, h.FailCount)
return err return err
} }
// SaveLog inserts a single log row. Retention is handled by PruneLogs on a // SaveLog inserts a single log row. Retention is handled by PruneLogs on a
// timer, not per-insert. // timer, not per-insert.
func (s *SQLStore) SaveLog(message string) error { func (s *SQLStore) SaveLog(ctx context.Context, message string) error {
_, err := s.db.Exec(s.q("INSERT INTO logs (message) VALUES (?)"), message) _, err := s.db.ExecContext(ctx, s.q("INSERT INTO logs (message) VALUES (?)"), message)
return err return err
} }
// PruneLogs trims the logs table to the newest maxLogRows rows. The id DESC // PruneLogs trims the logs table to the newest maxLogRows rows. The id DESC
// tiebreak keeps ordering deterministic when rows share a created_at second. // tiebreak keeps ordering deterministic when rows share a created_at second.
func (s *SQLStore) PruneLogs() error { func (s *SQLStore) PruneLogs(ctx context.Context) error {
q := fmt.Sprintf(`DELETE FROM logs WHERE id NOT IN ( q := fmt.Sprintf(`DELETE FROM logs WHERE id NOT IN (
SELECT id FROM logs ORDER BY created_at DESC, id DESC LIMIT %d SELECT id FROM logs ORDER BY created_at DESC, id DESC LIMIT %d
)`, maxLogRows) )`, maxLogRows)
_, err := s.db.Exec(s.q(q)) _, err := s.db.ExecContext(ctx, s.q(q))
return err return err
} }
func (s *SQLStore) LoadLogs(limit int) ([]string, error) { func (s *SQLStore) LoadLogs(ctx context.Context, limit int) ([]string, error) {
rows, err := s.db.Query(s.q("SELECT message FROM logs ORDER BY created_at DESC LIMIT ?"), limit) rows, err := s.db.QueryContext(ctx, s.q("SELECT message FROM logs ORDER BY created_at DESC LIMIT ?"), limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -547,9 +548,9 @@ func (s *SQLStore) LoadLogs(limit int) ([]string, error) {
return logs, rows.Err() return logs, rows.Err()
} }
func (s *SQLStore) LoadAllHistory(limit int) (map[int][]models.CheckRecord, error) { func (s *SQLStore) LoadAllHistory(ctx context.Context, limit int) (map[int][]models.CheckRecord, error) {
result := make(map[int][]models.CheckRecord) result := make(map[int][]models.CheckRecord)
rows, err := s.db.Query(s.q(` rows, err := s.db.QueryContext(ctx, s.q(`
SELECT site_id, latency_ns, is_up FROM ( SELECT site_id, latency_ns, is_up FROM (
SELECT site_id, latency_ns, is_up, SELECT site_id, latency_ns, is_up,
ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY checked_at DESC) AS rn ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY checked_at DESC) AS rn
@@ -587,8 +588,8 @@ func (s *SQLStore) scanMaintenanceWindow(rows *sql.Rows) (models.MaintenanceWind
return mw, nil return mw, nil
} }
func (s *SQLStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) { func (s *SQLStore) GetActiveMaintenanceWindows(ctx context.Context) ([]models.MaintenanceWindow, error) {
rows, err := s.db.Query(s.q("SELECT id, monitor_id, title, description, type, start_time, end_time, created_by, created_at FROM maintenance_windows WHERE start_time <= CURRENT_TIMESTAMP AND (end_time IS NULL OR end_time > CURRENT_TIMESTAMP) ORDER BY start_time DESC")) rows, err := s.db.QueryContext(ctx, s.q("SELECT id, monitor_id, title, description, type, start_time, end_time, created_by, created_at FROM maintenance_windows WHERE start_time <= CURRENT_TIMESTAMP AND (end_time IS NULL OR end_time > CURRENT_TIMESTAMP) ORDER BY start_time DESC"))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -604,8 +605,8 @@ func (s *SQLStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, er
return windows, rows.Err() return windows, rows.Err()
} }
func (s *SQLStore) GetAllMaintenanceWindows(limit int) ([]models.MaintenanceWindow, error) { func (s *SQLStore) GetAllMaintenanceWindows(ctx context.Context, limit int) ([]models.MaintenanceWindow, error) {
rows, err := s.db.Query(s.q("SELECT id, monitor_id, title, description, type, start_time, end_time, created_by, created_at FROM maintenance_windows ORDER BY created_at DESC LIMIT ?"), limit) rows, err := s.db.QueryContext(ctx, s.q("SELECT id, monitor_id, title, description, type, start_time, end_time, created_by, created_at FROM maintenance_windows ORDER BY created_at DESC LIMIT ?"), limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -621,22 +622,22 @@ func (s *SQLStore) GetAllMaintenanceWindows(limit int) ([]models.MaintenanceWind
return windows, rows.Err() return windows, rows.Err()
} }
func (s *SQLStore) AddMaintenanceWindow(mw models.MaintenanceWindow) error { func (s *SQLStore) AddMaintenanceWindow(ctx context.Context, mw models.MaintenanceWindow) error {
if mw.StartTime.IsZero() { if mw.StartTime.IsZero() {
mw.StartTime = time.Now() mw.StartTime = time.Now()
} }
_, err := s.db.Exec(s.q("INSERT INTO maintenance_windows (monitor_id, title, description, type, start_time, end_time, created_by) VALUES (?, ?, ?, ?, ?, ?, ?)"), _, err := s.db.ExecContext(ctx, s.q("INSERT INTO maintenance_windows (monitor_id, title, description, type, start_time, end_time, created_by) VALUES (?, ?, ?, ?, ?, ?, ?)"),
mw.MonitorID, mw.Title, mw.Description, mw.Type, mw.StartTime, sql.NullTime{Time: mw.EndTime, Valid: !mw.EndTime.IsZero()}, mw.CreatedBy) mw.MonitorID, mw.Title, mw.Description, mw.Type, mw.StartTime, sql.NullTime{Time: mw.EndTime, Valid: !mw.EndTime.IsZero()}, mw.CreatedBy)
return err return err
} }
func (s *SQLStore) EndMaintenanceWindow(id int) error { func (s *SQLStore) EndMaintenanceWindow(ctx context.Context, id int) error {
_, err := s.db.Exec(s.q("UPDATE maintenance_windows SET end_time = CURRENT_TIMESTAMP WHERE id = ?"), id) _, err := s.db.ExecContext(ctx, s.q("UPDATE maintenance_windows SET end_time = CURRENT_TIMESTAMP WHERE id = ?"), id)
return err return err
} }
func (s *SQLStore) DeleteMaintenanceWindow(id int) error { func (s *SQLStore) DeleteMaintenanceWindow(ctx context.Context, id int) error {
_, err := s.db.Exec(s.q("DELETE FROM maintenance_windows WHERE id = ?"), id) _, err := s.db.ExecContext(ctx, s.q("DELETE FROM maintenance_windows WHERE id = ?"), id)
if err != nil { if err != nil {
return err return err
} }
@@ -644,9 +645,9 @@ func (s *SQLStore) DeleteMaintenanceWindow(id int) error {
return nil return nil
} }
func (s *SQLStore) PruneExpiredMaintenanceWindows(retention time.Duration) (int64, error) { func (s *SQLStore) PruneExpiredMaintenanceWindows(ctx context.Context, retention time.Duration) (int64, error) {
cutoff := time.Now().Add(-retention) cutoff := time.Now().Add(-retention)
result, err := s.db.Exec( result, err := s.db.ExecContext(ctx,
s.q("DELETE FROM maintenance_windows WHERE end_time IS NOT NULL AND end_time < ?"), s.q("DELETE FROM maintenance_windows WHERE end_time IS NOT NULL AND end_time < ?"),
cutoff, cutoff,
) )
@@ -656,9 +657,9 @@ func (s *SQLStore) PruneExpiredMaintenanceWindows(retention time.Duration) (int6
return result.RowsAffected() return result.RowsAffected()
} }
func (s *SQLStore) IsMonitorInMaintenance(monitorID int) (bool, error) { func (s *SQLStore) IsMonitorInMaintenance(ctx context.Context, monitorID int) (bool, error) {
var count int var count int
err := s.db.QueryRow(s.q(`SELECT COUNT(*) FROM maintenance_windows err := s.db.QueryRowContext(ctx, s.q(`SELECT COUNT(*) FROM maintenance_windows
WHERE type = 'maintenance' WHERE type = 'maintenance'
AND start_time <= CURRENT_TIMESTAMP AND start_time <= CURRENT_TIMESTAMP
AND (end_time IS NULL OR end_time > CURRENT_TIMESTAMP) AND (end_time IS NULL OR end_time > CURRENT_TIMESTAMP)
@@ -671,46 +672,46 @@ func (s *SQLStore) IsMonitorInMaintenance(monitorID int) (bool, error) {
return count > 0, nil return count > 0, nil
} }
func (s *SQLStore) GetPreference(key string) (string, error) { func (s *SQLStore) GetPreference(ctx context.Context, key string) (string, error) {
var value string var value string
err := s.db.QueryRow(s.q("SELECT value FROM preferences WHERE key = ?"), key).Scan(&value) err := s.db.QueryRowContext(ctx, s.q("SELECT value FROM preferences WHERE key = ?"), key).Scan(&value)
if err != nil { if err != nil {
return "", err return "", err
} }
return value, nil return value, nil
} }
func (s *SQLStore) SetPreference(key, value string) error { func (s *SQLStore) SetPreference(ctx context.Context, key, value string) error {
if s.dollar { if s.dollar {
_, err := s.db.Exec(s.q("INSERT INTO preferences (key, value) VALUES (?, ?) ON CONFLICT (key) DO UPDATE SET value = ?"), key, value, value) _, err := s.db.ExecContext(ctx, s.q("INSERT INTO preferences (key, value) VALUES (?, ?) ON CONFLICT (key) DO UPDATE SET value = ?"), key, value, value)
return err return err
} }
_, err := s.db.Exec("INSERT OR REPLACE INTO preferences (key, value) VALUES (?, ?)", key, value) _, err := s.db.ExecContext(ctx, "INSERT OR REPLACE INTO preferences (key, value) VALUES (?, ?)", key, value)
return err return err
} }
func (s *SQLStore) ExportData() (models.Backup, error) { func (s *SQLStore) ExportData(ctx context.Context) (models.Backup, error) {
sites, err := s.GetSites() sites, err := s.GetSites(ctx)
if err != nil { if err != nil {
return models.Backup{}, err return models.Backup{}, err
} }
alerts, err := s.GetAllAlerts() alerts, err := s.GetAllAlerts(ctx)
if err != nil { if err != nil {
return models.Backup{}, err return models.Backup{}, err
} }
users, err := s.GetAllUsers() users, err := s.GetAllUsers(ctx)
if err != nil { if err != nil {
return models.Backup{}, err return models.Backup{}, err
} }
windows, err := s.GetAllMaintenanceWindows(maxMaintenanceExport) windows, err := s.GetAllMaintenanceWindows(ctx, maxMaintenanceExport)
if err != nil { if err != nil {
return models.Backup{}, err return models.Backup{}, err
} }
return models.Backup{Sites: sites, Alerts: alerts, Users: users, MaintenanceWindows: windows}, nil return models.Backup{Sites: sites, Alerts: alerts, Users: users, MaintenanceWindows: windows}, nil
} }
func (s *SQLStore) ImportData(data models.Backup) error { func (s *SQLStore) ImportData(ctx context.Context, data models.Backup) error {
tx, err := s.db.Begin() tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -719,7 +720,7 @@ func (s *SQLStore) ImportData(data models.Backup) error {
s.dialect.ImportWipe(tx) s.dialect.ImportWipe(tx)
for _, u := range data.Users { for _, u := range data.Users {
if _, err := tx.Exec(s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), u.Username, u.PublicKey, u.Role); err != nil { if _, err := tx.ExecContext(ctx, s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), u.Username, u.PublicKey, u.Role); err != nil {
return err return err
} }
} }
@@ -730,12 +731,12 @@ func (s *SQLStore) ImportData(data models.Backup) error {
if err != nil { if err != nil {
return err return err
} }
if _, err := tx.Exec(s.q("INSERT INTO alerts (id, name, type, settings) VALUES (?, ?, ?, ?)"), a.ID, a.Name, a.Type, settingsStr); err != nil { if _, err := tx.ExecContext(ctx, s.q("INSERT INTO alerts (id, name, type, settings) VALUES (?, ?, ?, ?)"), a.ID, a.Name, a.Type, settingsStr); err != nil {
return err return err
} }
} }
for _, st := range data.Sites { for _, st := range data.Sites {
if _, err := tx.Exec(s.q("INSERT INTO sites (id, name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), if _, err := tx.ExecContext(ctx, s.q("INSERT INTO sites (id, name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"),
st.ID, st.Name, st.URL, st.Type, st.Token, st.Interval, st.AlertID, st.CheckSSL, st.ExpiryThreshold, st.MaxRetries, st.ID, st.Name, st.URL, st.Type, st.Token, st.Interval, st.AlertID, st.CheckSSL, st.ExpiryThreshold, st.MaxRetries,
st.Hostname, st.Port, st.Timeout, st.Method, st.Description, st.ParentID, st.AcceptedCodes, st.DNSResolveType, st.DNSServer, st.IgnoreTLS, st.Paused, st.Regions); err != nil { st.Hostname, st.Port, st.Timeout, st.Method, st.Description, st.ParentID, st.AcceptedCodes, st.DNSResolveType, st.DNSServer, st.IgnoreTLS, st.Paused, st.Regions); err != nil {
return err return err
@@ -743,7 +744,7 @@ func (s *SQLStore) ImportData(data models.Backup) error {
} }
for _, mw := range data.MaintenanceWindows { for _, mw := range data.MaintenanceWindows {
if _, err := tx.Exec(s.q("INSERT INTO maintenance_windows (id, monitor_id, title, description, type, start_time, end_time, created_by) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"), if _, err := tx.ExecContext(ctx, s.q("INSERT INTO maintenance_windows (id, monitor_id, title, description, type, start_time, end_time, created_by) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"),
mw.ID, mw.MonitorID, mw.Title, mw.Description, mw.Type, mw.StartTime, sql.NullTime{Time: mw.EndTime, Valid: !mw.EndTime.IsZero()}, mw.CreatedBy); err != nil { mw.ID, mw.MonitorID, mw.Title, mw.Description, mw.Type, mw.StartTime, sql.NullTime{Time: mw.EndTime, Valid: !mw.EndTime.IsZero()}, mw.CreatedBy); err != nil {
return err return err
} }
+66 -65
View File
@@ -1,6 +1,7 @@
package store package store
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
"testing" "testing"
@@ -15,7 +16,7 @@ func newTestStore(t *testing.T) *SQLStore {
if err != nil { if err != nil {
t.Fatalf("NewSQLiteStore: %v", err) t.Fatalf("NewSQLiteStore: %v", err)
} }
if err := s.Init(); err != nil { if err := s.Init(context.Background()); err != nil {
t.Fatalf("Init: %v", err) t.Fatalf("Init: %v", err)
} }
return s return s
@@ -24,7 +25,7 @@ func newTestStore(t *testing.T) *SQLStore {
func TestSiteCRUD(t *testing.T) { func TestSiteCRUD(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
sites, err := s.GetSites() sites, err := s.GetSites(context.Background())
if err != nil { if err != nil {
t.Fatalf("GetSites: %v", err) t.Fatalf("GetSites: %v", err)
} }
@@ -32,11 +33,11 @@ func TestSiteCRUD(t *testing.T) {
t.Fatalf("expected 0 sites, got %d", len(sites)) t.Fatalf("expected 0 sites, got %d", len(sites))
} }
if err := s.AddSite(models.Site{Name: "Test", URL: "https://example.com", Type: "http", Interval: 30}); err != nil { if err := s.AddSite(context.Background(), models.Site{Name: "Test", URL: "https://example.com", Type: "http", Interval: 30}); err != nil {
t.Fatalf("AddSite: %v", err) t.Fatalf("AddSite: %v", err)
} }
sites, err = s.GetSites() sites, err = s.GetSites(context.Background())
if err != nil { if err != nil {
t.Fatalf("GetSites: %v", err) t.Fatalf("GetSites: %v", err)
} }
@@ -48,11 +49,11 @@ func TestSiteCRUD(t *testing.T) {
} }
sites[0].Name = "Updated" sites[0].Name = "Updated"
if err := s.UpdateSite(sites[0]); err != nil { if err := s.UpdateSite(context.Background(), sites[0]); err != nil {
t.Fatalf("UpdateSite: %v", err) t.Fatalf("UpdateSite: %v", err)
} }
sites, err = s.GetSites() sites, err = s.GetSites(context.Background())
if err != nil { if err != nil {
t.Fatalf("GetSites: %v", err) t.Fatalf("GetSites: %v", err)
} }
@@ -60,11 +61,11 @@ func TestSiteCRUD(t *testing.T) {
t.Errorf("expected name 'Updated', got '%s'", sites[0].Name) t.Errorf("expected name 'Updated', got '%s'", sites[0].Name)
} }
if err := s.DeleteSite(sites[0].ID); err != nil { if err := s.DeleteSite(context.Background(), sites[0].ID); err != nil {
t.Fatalf("DeleteSite: %v", err) t.Fatalf("DeleteSite: %v", err)
} }
sites, err = s.GetSites() sites, err = s.GetSites(context.Background())
if err != nil { if err != nil {
t.Fatalf("GetSites: %v", err) t.Fatalf("GetSites: %v", err)
} }
@@ -76,11 +77,11 @@ func TestSiteCRUD(t *testing.T) {
func TestAlertCRUD(t *testing.T) { func TestAlertCRUD(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
if err := s.AddAlert("Discord", "discord", map[string]string{"url": "https://example.com/hook"}); err != nil { if err := s.AddAlert(context.Background(), "Discord", "discord", map[string]string{"url": "https://example.com/hook"}); err != nil {
t.Fatalf("AddAlert: %v", err) t.Fatalf("AddAlert: %v", err)
} }
alerts, err := s.GetAllAlerts() alerts, err := s.GetAllAlerts(context.Background())
if err != nil { if err != nil {
t.Fatalf("GetAllAlerts: %v", err) t.Fatalf("GetAllAlerts: %v", err)
} }
@@ -94,7 +95,7 @@ func TestAlertCRUD(t *testing.T) {
t.Errorf("settings url mismatch") t.Errorf("settings url mismatch")
} }
a, err := s.GetAlert(alerts[0].ID) a, err := s.GetAlert(context.Background(), alerts[0].ID)
if err != nil { if err != nil {
t.Fatalf("GetAlert: %v", err) t.Fatalf("GetAlert: %v", err)
} }
@@ -102,11 +103,11 @@ func TestAlertCRUD(t *testing.T) {
t.Errorf("expected name 'Discord', got '%s'", a.Name) t.Errorf("expected name 'Discord', got '%s'", a.Name)
} }
if err := s.UpdateAlert(a.ID, "Slack", "slack", map[string]string{"url": "https://slack.com/hook"}); err != nil { if err := s.UpdateAlert(context.Background(), a.ID, "Slack", "slack", map[string]string{"url": "https://slack.com/hook"}); err != nil {
t.Fatalf("UpdateAlert: %v", err) t.Fatalf("UpdateAlert: %v", err)
} }
a, err = s.GetAlert(a.ID) a, err = s.GetAlert(context.Background(), a.ID)
if err != nil { if err != nil {
t.Fatalf("GetAlert: %v", err) t.Fatalf("GetAlert: %v", err)
} }
@@ -114,11 +115,11 @@ func TestAlertCRUD(t *testing.T) {
t.Errorf("expected type 'slack', got '%s'", a.Type) t.Errorf("expected type 'slack', got '%s'", a.Type)
} }
if err := s.DeleteAlert(a.ID); err != nil { if err := s.DeleteAlert(context.Background(), a.ID); err != nil {
t.Fatalf("DeleteAlert: %v", err) t.Fatalf("DeleteAlert: %v", err)
} }
alerts, err = s.GetAllAlerts() alerts, err = s.GetAllAlerts(context.Background())
if err != nil { if err != nil {
t.Fatalf("GetAllAlerts: %v", err) t.Fatalf("GetAllAlerts: %v", err)
} }
@@ -130,11 +131,11 @@ func TestAlertCRUD(t *testing.T) {
func TestUserCRUD(t *testing.T) { func TestUserCRUD(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
if err := s.AddUser("admin", "ssh-ed25519 AAAA...", "admin"); err != nil { if err := s.AddUser(context.Background(), "admin", "ssh-ed25519 AAAA...", "admin"); err != nil {
t.Fatalf("AddUser: %v", err) t.Fatalf("AddUser: %v", err)
} }
users, err := s.GetAllUsers() users, err := s.GetAllUsers(context.Background())
if err != nil { if err != nil {
t.Fatalf("GetAllUsers: %v", err) t.Fatalf("GetAllUsers: %v", err)
} }
@@ -145,11 +146,11 @@ func TestUserCRUD(t *testing.T) {
t.Errorf("expected username 'admin', got '%s'", users[0].Username) t.Errorf("expected username 'admin', got '%s'", users[0].Username)
} }
if err := s.UpdateUser(users[0].ID, "root", "ssh-ed25519 BBBB...", "admin"); err != nil { if err := s.UpdateUser(context.Background(), users[0].ID, "root", "ssh-ed25519 BBBB...", "admin"); err != nil {
t.Fatalf("UpdateUser: %v", err) t.Fatalf("UpdateUser: %v", err)
} }
users, err = s.GetAllUsers() users, err = s.GetAllUsers(context.Background())
if err != nil { if err != nil {
t.Fatalf("GetAllUsers: %v", err) t.Fatalf("GetAllUsers: %v", err)
} }
@@ -157,11 +158,11 @@ func TestUserCRUD(t *testing.T) {
t.Errorf("expected username 'root', got '%s'", users[0].Username) t.Errorf("expected username 'root', got '%s'", users[0].Username)
} }
if err := s.DeleteUser(users[0].ID); err != nil { if err := s.DeleteUser(context.Background(), users[0].ID); err != nil {
t.Fatalf("DeleteUser: %v", err) t.Fatalf("DeleteUser: %v", err)
} }
users, err = s.GetAllUsers() users, err = s.GetAllUsers(context.Background())
if err != nil { if err != nil {
t.Fatalf("GetAllUsers: %v", err) t.Fatalf("GetAllUsers: %v", err)
} }
@@ -173,11 +174,11 @@ func TestUserCRUD(t *testing.T) {
func TestPushTokenGeneration(t *testing.T) { func TestPushTokenGeneration(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
if err := s.AddSite(models.Site{Name: "Push Monitor", Type: "push", Interval: 60}); err != nil { if err := s.AddSite(context.Background(), models.Site{Name: "Push Monitor", Type: "push", Interval: 60}); err != nil {
t.Fatalf("AddSite: %v", err) t.Fatalf("AddSite: %v", err)
} }
sites, err := s.GetSites() sites, err := s.GetSites(context.Background())
if err != nil { if err != nil {
t.Fatalf("GetSites: %v", err) t.Fatalf("GetSites: %v", err)
} }
@@ -195,17 +196,17 @@ func TestPushTokenGeneration(t *testing.T) {
func TestImportExport(t *testing.T) { func TestImportExport(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
if err := s.AddAlert("Test Alert", "webhook", map[string]string{"url": "https://example.com"}); err != nil { if err := s.AddAlert(context.Background(), "Test Alert", "webhook", map[string]string{"url": "https://example.com"}); err != nil {
t.Fatalf("AddAlert: %v", err) t.Fatalf("AddAlert: %v", err)
} }
if err := s.AddSite(models.Site{Name: "Site1", URL: "https://example.com", Type: "http", Interval: 30}); err != nil { if err := s.AddSite(context.Background(), models.Site{Name: "Site1", URL: "https://example.com", Type: "http", Interval: 30}); err != nil {
t.Fatalf("AddSite: %v", err) t.Fatalf("AddSite: %v", err)
} }
if err := s.AddUser("user1", "ssh-ed25519 KEY", "user"); err != nil { if err := s.AddUser(context.Background(), "user1", "ssh-ed25519 KEY", "user"); err != nil {
t.Fatalf("AddUser: %v", err) t.Fatalf("AddUser: %v", err)
} }
backup, err := s.ExportData() backup, err := s.ExportData(context.Background())
if err != nil { if err != nil {
t.Fatalf("ExportData: %v", err) t.Fatalf("ExportData: %v", err)
} }
@@ -214,19 +215,19 @@ func TestImportExport(t *testing.T) {
} }
s2 := newTestStore(t) s2 := newTestStore(t)
if err := s2.ImportData(backup); err != nil { if err := s2.ImportData(context.Background(), backup); err != nil {
t.Fatalf("ImportData: %v", err) t.Fatalf("ImportData: %v", err)
} }
sites, err := s2.GetSites() sites, err := s2.GetSites(context.Background())
if err != nil { if err != nil {
t.Fatalf("GetSites: %v", err) t.Fatalf("GetSites: %v", err)
} }
alerts, err := s2.GetAllAlerts() alerts, err := s2.GetAllAlerts(context.Background())
if err != nil { if err != nil {
t.Fatalf("GetAllAlerts: %v", err) t.Fatalf("GetAllAlerts: %v", err)
} }
users, err := s2.GetAllUsers() users, err := s2.GetAllUsers(context.Background())
if err != nil { if err != nil {
t.Fatalf("GetAllUsers: %v", err) t.Fatalf("GetAllUsers: %v", err)
} }
@@ -238,27 +239,27 @@ func TestImportExport(t *testing.T) {
func TestImportData_WipesHistory(t *testing.T) { func TestImportData_WipesHistory(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
if err := s.AddSite(models.Site{Name: "OldSite", URL: "https://old.com", Type: "http", Interval: 30}); err != nil { if err := s.AddSite(context.Background(), models.Site{Name: "OldSite", URL: "https://old.com", Type: "http", Interval: 30}); err != nil {
t.Fatalf("AddSite: %v", err) t.Fatalf("AddSite: %v", err)
} }
if err := s.SaveCheck(1, 5000, true); err != nil { if err := s.SaveCheck(context.Background(), 1, 5000, true); err != nil {
t.Fatalf("SaveCheck: %v", err) t.Fatalf("SaveCheck: %v", err)
} }
if err := s.SaveStateChange(1, "UP", "DOWN", "timeout"); err != nil { if err := s.SaveStateChange(context.Background(), 1, "UP", "DOWN", "timeout"); err != nil {
t.Fatalf("SaveStateChange: %v", err) t.Fatalf("SaveStateChange: %v", err)
} }
if err := s.SaveAlertHealth(models.AlertHealthRecord{AlertID: 1, LastSendOK: true, SendCount: 1}); err != nil { if err := s.SaveAlertHealth(context.Background(), models.AlertHealthRecord{AlertID: 1, LastSendOK: true, SendCount: 1}); err != nil {
t.Fatalf("SaveAlertHealth: %v", err) t.Fatalf("SaveAlertHealth: %v", err)
} }
backup := models.Backup{ backup := models.Backup{
Sites: []models.Site{{ID: 1, Name: "NewSite", URL: "https://new.com", Type: "http", Interval: 60}}, Sites: []models.Site{{ID: 1, Name: "NewSite", URL: "https://new.com", Type: "http", Interval: 60}},
} }
if err := s.ImportData(backup); err != nil { if err := s.ImportData(context.Background(), backup); err != nil {
t.Fatalf("ImportData: %v", err) t.Fatalf("ImportData: %v", err)
} }
history, err := s.LoadAllHistory(100) history, err := s.LoadAllHistory(context.Background(), 100)
if err != nil { if err != nil {
t.Fatalf("LoadAllHistory: %v", err) t.Fatalf("LoadAllHistory: %v", err)
} }
@@ -266,7 +267,7 @@ func TestImportData_WipesHistory(t *testing.T) {
t.Errorf("expected empty check_history after import, got %d sites with history", len(history)) t.Errorf("expected empty check_history after import, got %d sites with history", len(history))
} }
changes, err := s.GetStateChanges(1, 100) changes, err := s.GetStateChanges(context.Background(), 1, 100)
if err != nil { if err != nil {
t.Fatalf("GetStateChanges: %v", err) t.Fatalf("GetStateChanges: %v", err)
} }
@@ -278,17 +279,17 @@ func TestImportData_WipesHistory(t *testing.T) {
func TestCheckHistory(t *testing.T) { func TestCheckHistory(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
if err := s.SaveCheck(1, 5000000, true); err != nil { if err := s.SaveCheck(context.Background(), 1, 5000000, true); err != nil {
t.Fatalf("SaveCheck: %v", err) t.Fatalf("SaveCheck: %v", err)
} }
if err := s.SaveCheck(1, 10000000, false); err != nil { if err := s.SaveCheck(context.Background(), 1, 10000000, false); err != nil {
t.Fatalf("SaveCheck: %v", err) t.Fatalf("SaveCheck: %v", err)
} }
if err := s.SaveCheck(2, 3000000, true); err != nil { if err := s.SaveCheck(context.Background(), 2, 3000000, true); err != nil {
t.Fatalf("SaveCheck site 2: %v", err) t.Fatalf("SaveCheck site 2: %v", err)
} }
history, err := s.LoadAllHistory(10) history, err := s.LoadAllHistory(context.Background(), 10)
if err != nil { if err != nil {
t.Fatalf("LoadAllHistory: %v", err) t.Fatalf("LoadAllHistory: %v", err)
} }
@@ -314,16 +315,16 @@ func TestDeleteSiteCascade(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
site := models.Site{Name: "Cascade Test", URL: "https://example.com", Interval: 30} site := models.Site{Name: "Cascade Test", URL: "https://example.com", Interval: 30}
if err := s.AddSite(site); err != nil { if err := s.AddSite(context.Background(), site); err != nil {
t.Fatalf("AddSite: %v", err) t.Fatalf("AddSite: %v", err)
} }
sites, _ := s.GetSites() sites, _ := s.GetSites(context.Background())
siteID := sites[0].ID siteID := sites[0].ID
if err := s.SaveCheck(siteID, 1000, true); err != nil { if err := s.SaveCheck(context.Background(), siteID, 1000, true); err != nil {
t.Fatalf("SaveCheck: %v", err) t.Fatalf("SaveCheck: %v", err)
} }
if err := s.SaveStateChange(siteID, "UP", "DOWN", "timeout"); err != nil { if err := s.SaveStateChange(context.Background(), siteID, "UP", "DOWN", "timeout"); err != nil {
t.Fatalf("SaveStateChange: %v", err) t.Fatalf("SaveStateChange: %v", err)
} }
mw := models.MaintenanceWindow{ mw := models.MaintenanceWindow{
@@ -332,25 +333,25 @@ func TestDeleteSiteCascade(t *testing.T) {
Type: "maintenance", Type: "maintenance",
StartTime: time.Now(), StartTime: time.Now(),
} }
if err := s.AddMaintenanceWindow(mw); err != nil { if err := s.AddMaintenanceWindow(context.Background(), mw); err != nil {
t.Fatalf("AddMaintenanceWindow: %v", err) t.Fatalf("AddMaintenanceWindow: %v", err)
} }
if err := s.DeleteSite(siteID); err != nil { if err := s.DeleteSite(context.Background(), siteID); err != nil {
t.Fatalf("DeleteSite: %v", err) t.Fatalf("DeleteSite: %v", err)
} }
history, _ := s.LoadAllHistory(100) history, _ := s.LoadAllHistory(context.Background(), 100)
if len(history[siteID]) != 0 { if len(history[siteID]) != 0 {
t.Errorf("expected 0 check_history rows, got %d", len(history[siteID])) t.Errorf("expected 0 check_history rows, got %d", len(history[siteID]))
} }
changes, _ := s.GetStateChanges(siteID, 100) changes, _ := s.GetStateChanges(context.Background(), siteID, 100)
if len(changes) != 0 { if len(changes) != 0 {
t.Errorf("expected 0 state_changes rows, got %d", len(changes)) t.Errorf("expected 0 state_changes rows, got %d", len(changes))
} }
windows, _ := s.GetActiveMaintenanceWindows() windows, _ := s.GetActiveMaintenanceWindows(context.Background())
for _, w := range windows { for _, w := range windows {
if w.MonitorID == siteID { if w.MonitorID == siteID {
t.Errorf("orphaned maintenance window found: id=%d", w.ID) t.Errorf("orphaned maintenance window found: id=%d", w.ID)
@@ -362,15 +363,15 @@ func TestPruneLogs(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
for i := 0; i < maxLogRows+50; i++ { for i := 0; i < maxLogRows+50; i++ {
if err := s.SaveLog(fmt.Sprintf("log %d", i)); err != nil { if err := s.SaveLog(context.Background(), fmt.Sprintf("log %d", i)); err != nil {
t.Fatalf("SaveLog: %v", err) t.Fatalf("SaveLog: %v", err)
} }
} }
if err := s.PruneLogs(); err != nil { if err := s.PruneLogs(context.Background()); err != nil {
t.Fatalf("PruneLogs: %v", err) t.Fatalf("PruneLogs: %v", err)
} }
logs, err := s.LoadLogs(maxLogRows * 2) logs, err := s.LoadLogs(context.Background(), maxLogRows*2)
if err != nil { if err != nil {
t.Fatalf("LoadLogs: %v", err) t.Fatalf("LoadLogs: %v", err)
} }
@@ -395,21 +396,21 @@ func TestPruneCheckHistory(t *testing.T) {
s := newTestStore(t) s := newTestStore(t)
for i := 0; i < maxCheckHistory+5; i++ { for i := 0; i < maxCheckHistory+5; i++ {
if err := s.SaveCheck(1, int64(i), true); err != nil { if err := s.SaveCheck(context.Background(), 1, int64(i), true); err != nil {
t.Fatalf("SaveCheck site 1: %v", err) t.Fatalf("SaveCheck site 1: %v", err)
} }
} }
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
if err := s.SaveCheck(2, int64(i), true); err != nil { if err := s.SaveCheck(context.Background(), 2, int64(i), true); err != nil {
t.Fatalf("SaveCheck site 2: %v", err) t.Fatalf("SaveCheck site 2: %v", err)
} }
} }
if err := s.PruneCheckHistory(); err != nil { if err := s.PruneCheckHistory(context.Background()); err != nil {
t.Fatalf("PruneCheckHistory: %v", err) t.Fatalf("PruneCheckHistory: %v", err)
} }
history, err := s.LoadAllHistory(maxCheckHistory * 2) history, err := s.LoadAllHistory(context.Background(), maxCheckHistory*2)
if err != nil { if err != nil {
t.Fatalf("LoadAllHistory: %v", err) t.Fatalf("LoadAllHistory: %v", err)
} }
@@ -434,7 +435,7 @@ func TestPruneExpiredMaintenanceWindows(t *testing.T) {
StartTime: now.Add(-11 * 24 * time.Hour), StartTime: now.Add(-11 * 24 * time.Hour),
EndTime: now.Add(-10 * 24 * time.Hour), EndTime: now.Add(-10 * 24 * time.Hour),
} }
if err := s.AddMaintenanceWindow(old); err != nil { if err := s.AddMaintenanceWindow(context.Background(), old); err != nil {
t.Fatalf("AddMaintenanceWindow (old): %v", err) t.Fatalf("AddMaintenanceWindow (old): %v", err)
} }
@@ -446,7 +447,7 @@ func TestPruneExpiredMaintenanceWindows(t *testing.T) {
StartTime: now.Add(-2 * 24 * time.Hour), StartTime: now.Add(-2 * 24 * time.Hour),
EndTime: now.Add(-1 * 24 * time.Hour), EndTime: now.Add(-1 * 24 * time.Hour),
} }
if err := s.AddMaintenanceWindow(recent); err != nil { if err := s.AddMaintenanceWindow(context.Background(), recent); err != nil {
t.Fatalf("AddMaintenanceWindow (recent): %v", err) t.Fatalf("AddMaintenanceWindow (recent): %v", err)
} }
@@ -457,11 +458,11 @@ func TestPruneExpiredMaintenanceWindows(t *testing.T) {
Type: "maintenance", Type: "maintenance",
StartTime: now.Add(-1 * time.Hour), StartTime: now.Add(-1 * time.Hour),
} }
if err := s.AddMaintenanceWindow(ongoing); err != nil { if err := s.AddMaintenanceWindow(context.Background(), ongoing); err != nil {
t.Fatalf("AddMaintenanceWindow (ongoing): %v", err) t.Fatalf("AddMaintenanceWindow (ongoing): %v", err)
} }
pruned, err := s.PruneExpiredMaintenanceWindows(7 * 24 * time.Hour) pruned, err := s.PruneExpiredMaintenanceWindows(context.Background(), 7*24*time.Hour)
if err != nil { if err != nil {
t.Fatalf("PruneExpiredMaintenanceWindows: %v", err) t.Fatalf("PruneExpiredMaintenanceWindows: %v", err)
} }
@@ -469,7 +470,7 @@ func TestPruneExpiredMaintenanceWindows(t *testing.T) {
t.Errorf("expected 1 pruned, got %d", pruned) t.Errorf("expected 1 pruned, got %d", pruned)
} }
all, err := s.GetAllMaintenanceWindows(100) all, err := s.GetAllMaintenanceWindows(context.Background(), 100)
if err != nil { if err != nil {
t.Fatalf("GetAllMaintenanceWindows: %v", err) t.Fatalf("GetAllMaintenanceWindows: %v", err)
} }
@@ -498,7 +499,7 @@ func TestImportData_EncryptsAlertSettings(t *testing.T) {
{ID: 1, Name: "tg", Type: "telegram", Settings: map[string]string{"token": "123:SECRET", "chat_id": "42"}}, {ID: 1, Name: "tg", Type: "telegram", Settings: map[string]string{"token": "123:SECRET", "chat_id": "42"}},
}, },
} }
if err := s.ImportData(backup); err != nil { if err := s.ImportData(context.Background(), backup); err != nil {
t.Fatalf("ImportData: %v", err) t.Fatalf("ImportData: %v", err)
} }
@@ -513,7 +514,7 @@ func TestImportData_EncryptsAlertSettings(t *testing.T) {
t.Errorf("plaintext secret found in stored column: %q", raw) t.Errorf("plaintext secret found in stored column: %q", raw)
} }
alerts, err := s.GetAllAlerts() alerts, err := s.GetAllAlerts(context.Background())
if err != nil { if err != nil {
t.Fatalf("GetAllAlerts: %v", err) t.Fatalf("GetAllAlerts: %v", err)
} }
+49 -48
View File
@@ -1,84 +1,85 @@
package store package store
import ( import (
"context"
"time" "time"
"gitea.lerkolabs.com/lerkolabs/uptop/internal/models" "gitea.lerkolabs.com/lerkolabs/uptop/internal/models"
) )
type Store interface { type Store interface {
Init() error Init(ctx context.Context) error
// Sites // Sites
GetSites() ([]models.Site, error) GetSites(ctx context.Context) ([]models.Site, error)
AddSite(site models.Site) error AddSite(ctx context.Context, site models.Site) error
UpdateSite(site models.Site) error UpdateSite(ctx context.Context, site models.Site) error
UpdateSitePaused(id int, paused bool) error UpdateSitePaused(ctx context.Context, id int, paused bool) error
DeleteSite(id int) error DeleteSite(ctx context.Context, id int) error
// Alerts // Alerts
GetAllAlerts() ([]models.AlertConfig, error) GetAllAlerts(ctx context.Context) ([]models.AlertConfig, error)
GetAlert(id int) (models.AlertConfig, error) GetAlert(ctx context.Context, id int) (models.AlertConfig, error)
AddAlert(name, aType string, settings map[string]string) error AddAlert(ctx context.Context, name, aType string, settings map[string]string) error
UpdateAlert(id int, name, aType string, settings map[string]string) error UpdateAlert(ctx context.Context, id int, name, aType string, settings map[string]string) error
DeleteAlert(id int) error DeleteAlert(ctx context.Context, id int) error
// Declarative config support // Declarative config support
GetSiteByName(name string) (models.Site, error) GetSiteByName(ctx context.Context, name string) (models.Site, error)
GetAlertByName(name string) (models.AlertConfig, error) GetAlertByName(ctx context.Context, name string) (models.AlertConfig, error)
AddSiteReturningID(site models.Site) (int, error) AddSiteReturningID(ctx context.Context, site models.Site) (int, error)
AddAlertReturningID(name, aType string, settings map[string]string) (int, error) AddAlertReturningID(ctx context.Context, name, aType string, settings map[string]string) (int, error)
// Users // Users
GetAllUsers() ([]models.User, error) GetAllUsers(ctx context.Context) ([]models.User, error)
AddUser(username, publicKey, role string) error AddUser(ctx context.Context, username, publicKey, role string) error
UpdateUser(id int, username, publicKey, role string) error UpdateUser(ctx context.Context, id int, username, publicKey, role string) error
DeleteUser(id int) error DeleteUser(ctx context.Context, id int) error
// History // History
SaveCheck(siteID int, latencyNs int64, isUp bool) error SaveCheck(ctx context.Context, siteID int, latencyNs int64, isUp bool) error
SaveCheckFromNode(siteID int, nodeID string, latencyNs int64, isUp bool) error SaveCheckFromNode(ctx context.Context, siteID int, nodeID string, latencyNs int64, isUp bool) error
LoadAllHistory(limit int) (map[int][]models.CheckRecord, error) LoadAllHistory(ctx context.Context, limit int) (map[int][]models.CheckRecord, error)
PruneCheckHistory() error PruneCheckHistory(ctx context.Context) error
// State Changes // State Changes
SaveStateChange(siteID int, fromStatus, toStatus, errorReason string) error SaveStateChange(ctx context.Context, siteID int, fromStatus, toStatus, errorReason string) error
GetStateChanges(siteID int, limit int) ([]models.StateChange, error) GetStateChanges(ctx context.Context, siteID int, limit int) ([]models.StateChange, error)
GetStateChangesSince(siteID int, since time.Time) ([]models.StateChange, error) GetStateChangesSince(ctx context.Context, siteID int, since time.Time) ([]models.StateChange, error)
PruneStateChanges() error PruneStateChanges(ctx context.Context) error
// Nodes // Nodes
RegisterNode(node models.ProbeNode) error RegisterNode(ctx context.Context, node models.ProbeNode) error
GetNode(id string) (models.ProbeNode, error) GetNode(ctx context.Context, id string) (models.ProbeNode, error)
GetAllNodes() ([]models.ProbeNode, error) GetAllNodes(ctx context.Context) ([]models.ProbeNode, error)
UpdateNodeLastSeen(id string) error UpdateNodeLastSeen(ctx context.Context, id string) error
DeleteNode(id string) error DeleteNode(ctx context.Context, id string) error
// Alert Health // Alert Health
LoadAlertHealth() (map[int]models.AlertHealthRecord, error) LoadAlertHealth(ctx context.Context) (map[int]models.AlertHealthRecord, error)
SaveAlertHealth(h models.AlertHealthRecord) error SaveAlertHealth(ctx context.Context, h models.AlertHealthRecord) error
// Logs // Logs
SaveLog(message string) error SaveLog(ctx context.Context, message string) error
LoadLogs(limit int) ([]string, error) LoadLogs(ctx context.Context, limit int) ([]string, error)
PruneLogs() error PruneLogs(ctx context.Context) error
// Maintenance Windows // Maintenance Windows
GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) GetActiveMaintenanceWindows(ctx context.Context) ([]models.MaintenanceWindow, error)
GetAllMaintenanceWindows(limit int) ([]models.MaintenanceWindow, error) GetAllMaintenanceWindows(ctx context.Context, limit int) ([]models.MaintenanceWindow, error)
AddMaintenanceWindow(mw models.MaintenanceWindow) error AddMaintenanceWindow(ctx context.Context, mw models.MaintenanceWindow) error
EndMaintenanceWindow(id int) error EndMaintenanceWindow(ctx context.Context, id int) error
DeleteMaintenanceWindow(id int) error DeleteMaintenanceWindow(ctx context.Context, id int) error
PruneExpiredMaintenanceWindows(retention time.Duration) (int64, error) PruneExpiredMaintenanceWindows(ctx context.Context, retention time.Duration) (int64, error)
IsMonitorInMaintenance(monitorID int) (bool, error) IsMonitorInMaintenance(ctx context.Context, monitorID int) (bool, error)
// Preferences // Preferences
GetPreference(key string) (string, error) GetPreference(ctx context.Context, key string) (string, error)
SetPreference(key, value string) error SetPreference(ctx context.Context, key, value string) error
// Backup & Restore // Backup & Restore
ExportData() (models.Backup, error) ExportData(ctx context.Context) (models.Backup, error)
ImportData(data models.Backup) error ImportData(ctx context.Context, data models.Backup) error
// Lifecycle // Lifecycle
Close() error Close() error
+7 -5
View File
@@ -1,6 +1,7 @@
package tui package tui
import ( import (
"context"
"encoding/json" "encoding/json"
"sort" "sort"
"strings" "strings"
@@ -13,7 +14,7 @@ import (
func loadCollapsed(s store.Store) map[int]bool { func loadCollapsed(s store.Store) map[int]bool {
m := make(map[int]bool) m := make(map[int]bool)
raw, err := s.GetPreference("collapsed_groups") raw, err := s.GetPreference(context.Background(), "collapsed_groups")
if err != nil || raw == "" { if err != nil || raw == "" {
return m return m
} }
@@ -130,21 +131,22 @@ func (m *Model) loadTabDataCmd() tea.Cmd {
st := m.store st := m.store
isAdmin := m.isAdmin isAdmin := m.isAdmin
return func() tea.Msg { return func() tea.Msg {
alerts, err := st.GetAllAlerts() ctx := context.Background()
alerts, err := st.GetAllAlerts(ctx)
if err != nil { if err != nil {
return tabDataMsg{seq: seq, err: err} return tabDataMsg{seq: seq, err: err}
} }
var users []models.User var users []models.User
if isAdmin { if isAdmin {
if users, err = st.GetAllUsers(); err != nil { if users, err = st.GetAllUsers(ctx); err != nil {
return tabDataMsg{seq: seq, err: err} return tabDataMsg{seq: seq, err: err}
} }
} }
nodes, err := st.GetAllNodes() nodes, err := st.GetAllNodes(ctx)
if err != nil { if err != nil {
return tabDataMsg{seq: seq, err: err} return tabDataMsg{seq: seq, err: err}
} }
maint, err := st.GetAllMaintenanceWindows(100) maint, err := st.GetAllMaintenanceWindows(ctx, 100)
if err != nil { if err != nil {
return tabDataMsg{seq: seq, err: err} return tabDataMsg{seq: seq, err: err}
} }
+3 -2
View File
@@ -1,6 +1,7 @@
package tui package tui
import ( import (
"context"
"fmt" "fmt"
neturl "net/url" neturl "net/url"
"sort" "sort"
@@ -528,10 +529,10 @@ func (m *Model) submitAlertForm() tea.Cmd {
m.state = stateDashboard m.state = stateDashboard
if id > 0 { if id > 0 {
return writeCmd("Update alert", func() error { return writeCmd("Update alert", func() error {
return st.UpdateAlert(id, name, aType, settings) return st.UpdateAlert(context.Background(), id, name, aType, settings)
}) })
} }
return writeCmd("Add alert", func() error { return writeCmd("Add alert", func() error {
return st.AddAlert(name, aType, settings) return st.AddAlert(context.Background(), name, aType, settings)
}) })
} }
+2 -1
View File
@@ -1,6 +1,7 @@
package tui package tui
import ( import (
"context"
"fmt" "fmt"
"strconv" "strconv"
"time" "time"
@@ -240,6 +241,6 @@ func (m *Model) submitMaintForm() tea.Cmd {
st := m.store st := m.store
m.state = stateDashboard m.state = stateDashboard
return writeCmd("Add maintenance window", func() error { return writeCmd("Add maintenance window", func() error {
return st.AddMaintenanceWindow(mw) return st.AddMaintenanceWindow(context.Background(), mw)
}) })
} }
+3 -2
View File
@@ -1,6 +1,7 @@
package tui package tui
import ( import (
"context"
"fmt" "fmt"
"net/url" "net/url"
"strconv" "strconv"
@@ -562,7 +563,7 @@ func (m *Model) submitSiteForm() tea.Cmd {
// follows in the Cmd. New sites enter the engine via its poll loop // follows in the Cmd. New sites enter the engine via its poll loop
// once the insert lands. // once the insert lands.
m.engine.UpdateSiteConfig(site) m.engine.UpdateSiteConfig(site)
return writeCmd("Update site", func() error { return st.UpdateSite(site) }) return writeCmd("Update site", func() error { return st.UpdateSite(context.Background(), site) })
} }
return writeCmd("Add site", func() error { return st.AddSite(site) }) return writeCmd("Add site", func() error { return st.AddSite(context.Background(), site) })
} }
+3 -2
View File
@@ -1,6 +1,7 @@
package tui package tui
import ( import (
"context"
"fmt" "fmt"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
@@ -118,10 +119,10 @@ func (m *Model) submitUserForm() tea.Cmd {
m.state = stateUsers m.state = stateUsers
if id > 0 { if id > 0 {
return writeCmd("Update user", func() error { return writeCmd("Update user", func() error {
return st.UpdateUser(id, username, key, role) return st.UpdateUser(context.Background(), id, username, key, role)
}) })
} }
return writeCmd("Add user", func() error { return writeCmd("Add user", func() error {
return st.AddUser(username, key, role) return st.AddUser(context.Background(), username, key, role)
}) })
} }
+2 -1
View File
@@ -1,6 +1,7 @@
package tui package tui
import ( import (
"context"
"os" "os"
"time" "time"
@@ -180,7 +181,7 @@ func InitialModel(isAdmin bool, s store.Store, eng *monitor.Engine, version stri
spring := harmonica.NewSpring(harmonica.FPS(10), 6.0, 0.4) spring := harmonica.NewSpring(harmonica.FPS(10), 6.0, 0.4)
collapsed := loadCollapsed(s) collapsed := loadCollapsed(s)
themeName, _ := s.GetPreference("theme") themeName, _ := s.GetPreference(context.Background(), "theme")
theme := themeByName(themeName) theme := themeByName(themeName)
themeIdx := 0 themeIdx := 0
for i, t := range themes { for i, t := range themes {
+9 -8
View File
@@ -1,6 +1,7 @@
package tui package tui
import ( import (
"context"
"fmt" "fmt"
"time" "time"
@@ -78,17 +79,17 @@ func (m *Model) handleConfirmDelete(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd var cmd tea.Cmd
switch m.deleteTab { switch m.deleteTab {
case 0: case 0:
cmd = writeCmd("Delete site", func() error { return st.DeleteSite(id) }) cmd = writeCmd("Delete site", func() error { return st.DeleteSite(context.Background(), id) })
m.engine.RemoveSite(id) m.engine.RemoveSite(id)
m.adjustCursor(len(m.sites) - 1) m.adjustCursor(len(m.sites) - 1)
case 1: case 1:
cmd = writeCmd("Delete alert", func() error { return st.DeleteAlert(id) }) cmd = writeCmd("Delete alert", func() error { return st.DeleteAlert(context.Background(), id) })
m.adjustCursor(len(m.alerts) - 1) m.adjustCursor(len(m.alerts) - 1)
case 4: case 4:
cmd = writeCmd("Delete maintenance window", func() error { return st.DeleteMaintenanceWindow(id) }) cmd = writeCmd("Delete maintenance window", func() error { return st.DeleteMaintenanceWindow(context.Background(), id) })
m.adjustCursor(len(m.maintenanceWindows) - 1) m.adjustCursor(len(m.maintenanceWindows) - 1)
case 5: case 5:
cmd = writeCmd("Delete user", func() error { return st.DeleteUser(id) }) cmd = writeCmd("Delete user", func() error { return st.DeleteUser(context.Background(), id) })
m.adjustCursor(len(m.users) - 1) m.adjustCursor(len(m.users) - 1)
} }
m.refreshLive() m.refreshLive()
@@ -566,7 +567,7 @@ func (m *Model) handleDashboardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
st := m.store st := m.store
m.refreshLive() m.refreshLive()
return m, writeCmd("Save collapsed groups", func() error { return m, writeCmd("Save collapsed groups", func() error {
return st.SetPreference("collapsed_groups", payload) return st.SetPreference(context.Background(), "collapsed_groups", payload)
}) })
} }
case "p": case "p":
@@ -576,7 +577,7 @@ func (m *Model) handleDashboardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
st := m.store st := m.store
m.refreshLive() m.refreshLive()
return m, writeCmd("Update pause state", func() error { return m, writeCmd("Update pause state", func() error {
return st.UpdateSitePaused(id, paused) return st.UpdateSitePaused(context.Background(), id, paused)
}) })
} }
case "i": case "i":
@@ -596,7 +597,7 @@ func (m *Model) handleDashboardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
id := mw.ID id := mw.ID
m.refreshLive() m.refreshLive()
return m, writeCmd("End maintenance", func() error { return m, writeCmd("End maintenance", func() error {
return st.EndMaintenanceWindow(id) return st.EndMaintenanceWindow(context.Background(), id)
}) })
} }
} }
@@ -607,7 +608,7 @@ func (m *Model) handleDashboardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
st := m.store st := m.store
name := m.theme.Name name := m.theme.Name
return m, writeCmd("Save theme", func() error { return m, writeCmd("Save theme", func() error {
return st.SetPreference("theme", name) return st.SetPreference(context.Background(), "theme", name)
}) })
case "d", "backspace": case "d", "backspace":
return m.handleDeleteItem() return m.handleDeleteItem()
+85 -56
View File
@@ -1,6 +1,7 @@
package tui package tui
import ( import (
"context"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -23,80 +24,108 @@ type tuiMockStore struct {
deleteSiteCalls int // counts DeleteSite hits (to prove writes run in Cmds) deleteSiteCalls int // counts DeleteSite hits (to prove writes run in Cmds)
} }
func (m *tuiMockStore) GetAllAlerts() ([]models.AlertConfig, error) { return m.alerts, nil } func (m *tuiMockStore) GetAllAlerts(_ context.Context) ([]models.AlertConfig, error) {
func (m *tuiMockStore) GetAllUsers() ([]models.User, error) { return m.users, nil } return m.alerts, nil
func (m *tuiMockStore) GetAllNodes() ([]models.ProbeNode, error) { return m.nodes, nil } }
func (m *tuiMockStore) GetStateChanges(int, int) ([]models.StateChange, error) { func (m *tuiMockStore) GetAllUsers(_ context.Context) ([]models.User, error) { return m.users, nil }
func (m *tuiMockStore) GetAllNodes(_ context.Context) ([]models.ProbeNode, error) {
return m.nodes, nil
}
func (m *tuiMockStore) GetStateChanges(_ context.Context, _ int, _ int) ([]models.StateChange, error) {
m.stateChangeCalls++ m.stateChangeCalls++
return m.stateChanges, nil return m.stateChanges, nil
} }
func (m *tuiMockStore) GetAllMaintenanceWindows(int) ([]models.MaintenanceWindow, error) { func (m *tuiMockStore) GetAllMaintenanceWindows(_ context.Context, _ int) ([]models.MaintenanceWindow, error) {
return m.maint, nil return m.maint, nil
} }
func (m *tuiMockStore) Init() error { return nil } func (m *tuiMockStore) Init(_ context.Context) error { return nil }
func (m *tuiMockStore) GetSites() ([]models.Site, error) { return nil, nil } func (m *tuiMockStore) GetSites(_ context.Context) ([]models.Site, error) { return nil, nil }
func (m *tuiMockStore) AddSite(models.Site) error { return nil } func (m *tuiMockStore) AddSite(_ context.Context, _ models.Site) error { return nil }
func (m *tuiMockStore) UpdateSite(models.Site) error { return nil } func (m *tuiMockStore) UpdateSite(_ context.Context, _ models.Site) error { return nil }
func (m *tuiMockStore) UpdateSitePaused(int, bool) error { return nil } func (m *tuiMockStore) UpdateSitePaused(_ context.Context, _ int, _ bool) error { return nil }
func (m *tuiMockStore) DeleteSite(int) error { func (m *tuiMockStore) DeleteSite(_ context.Context, _ int) error {
m.deleteSiteCalls++ m.deleteSiteCalls++
return nil return nil
} }
func (m *tuiMockStore) GetAlert(int) (models.AlertConfig, error) { return models.AlertConfig{}, nil } func (m *tuiMockStore) GetAlert(_ context.Context, _ int) (models.AlertConfig, error) {
func (m *tuiMockStore) AddAlert(string, string, map[string]string) error { return nil }
func (m *tuiMockStore) UpdateAlert(int, string, string, map[string]string) error { return nil }
func (m *tuiMockStore) DeleteAlert(int) error { return nil }
func (m *tuiMockStore) GetSiteByName(string) (models.Site, error) { return models.Site{}, nil }
func (m *tuiMockStore) GetAlertByName(string) (models.AlertConfig, error) {
return models.AlertConfig{}, nil return models.AlertConfig{}, nil
} }
func (m *tuiMockStore) AddSiteReturningID(models.Site) (int, error) { return 0, nil } func (m *tuiMockStore) AddAlert(_ context.Context, _ string, _ string, _ map[string]string) error {
func (m *tuiMockStore) AddAlertReturningID(string, string, map[string]string) (int, error) {
return 0, nil
}
func (m *tuiMockStore) AddUser(string, string, string) error { return nil }
func (m *tuiMockStore) UpdateUser(int, string, string, string) error { return nil }
func (m *tuiMockStore) DeleteUser(int) error { return nil }
func (m *tuiMockStore) SaveCheck(int, int64, bool) error { return nil }
func (m *tuiMockStore) SaveCheckFromNode(int, string, int64, bool) error {
return nil return nil
} }
func (m *tuiMockStore) LoadAllHistory(int) (map[int][]models.CheckRecord, error) { func (m *tuiMockStore) UpdateAlert(_ context.Context, _ int, _ string, _ string, _ map[string]string) error {
return nil, nil return nil
} }
func (m *tuiMockStore) PruneCheckHistory() error { return nil } func (m *tuiMockStore) DeleteAlert(_ context.Context, _ int) error { return nil }
func (m *tuiMockStore) SaveStateChange(int, string, string, string) error { return nil } func (m *tuiMockStore) GetSiteByName(_ context.Context, _ string) (models.Site, error) {
func (m *tuiMockStore) GetStateChangesSince(int, time.Time) ([]models.StateChange, error) { return models.Site{}, nil
return nil, nil
} }
func (m *tuiMockStore) PruneStateChanges() error { return nil } func (m *tuiMockStore) GetAlertByName(_ context.Context, _ string) (models.AlertConfig, error) {
func (m *tuiMockStore) RegisterNode(models.ProbeNode) error { return nil } return models.AlertConfig{}, nil
func (m *tuiMockStore) GetNode(string) (models.ProbeNode, error) { return models.ProbeNode{}, nil }
func (m *tuiMockStore) UpdateNodeLastSeen(string) error { return nil }
func (m *tuiMockStore) DeleteNode(string) error { return nil }
func (m *tuiMockStore) LoadAlertHealth() (map[int]models.AlertHealthRecord, error) {
return nil, nil
} }
func (m *tuiMockStore) SaveAlertHealth(models.AlertHealthRecord) error { return nil } func (m *tuiMockStore) AddSiteReturningID(_ context.Context, _ models.Site) (int, error) {
func (m *tuiMockStore) SaveLog(string) error { return nil }
func (m *tuiMockStore) LoadLogs(int) ([]string, error) { return nil, nil }
func (m *tuiMockStore) PruneLogs() error { return nil }
func (m *tuiMockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) {
return nil, nil
}
func (m *tuiMockStore) AddMaintenanceWindow(models.MaintenanceWindow) error { return nil }
func (m *tuiMockStore) EndMaintenanceWindow(int) error { return nil }
func (m *tuiMockStore) DeleteMaintenanceWindow(int) error { return nil }
func (m *tuiMockStore) PruneExpiredMaintenanceWindows(time.Duration) (int64, error) {
return 0, nil return 0, nil
} }
func (m *tuiMockStore) IsMonitorInMaintenance(int) (bool, error) { return false, nil } func (m *tuiMockStore) AddAlertReturningID(_ context.Context, _ string, _ string, _ map[string]string) (int, error) {
func (m *tuiMockStore) GetPreference(string) (string, error) { return "", nil } return 0, nil
func (m *tuiMockStore) SetPreference(string, string) error { return nil } }
func (m *tuiMockStore) ExportData() (models.Backup, error) { return models.Backup{}, nil } func (m *tuiMockStore) AddUser(_ context.Context, _ string, _ string, _ string) error { return nil }
func (m *tuiMockStore) ImportData(models.Backup) error { return nil } func (m *tuiMockStore) UpdateUser(_ context.Context, _ int, _ string, _ string, _ string) error {
func (m *tuiMockStore) Close() error { return nil } return nil
}
func (m *tuiMockStore) DeleteUser(_ context.Context, _ int) error { return nil }
func (m *tuiMockStore) SaveCheck(_ context.Context, _ int, _ int64, _ bool) error { return nil }
func (m *tuiMockStore) SaveCheckFromNode(_ context.Context, _ int, _ string, _ int64, _ bool) error {
return nil
}
func (m *tuiMockStore) LoadAllHistory(_ context.Context, _ int) (map[int][]models.CheckRecord, error) {
return nil, nil
}
func (m *tuiMockStore) PruneCheckHistory(_ context.Context) error { return nil }
func (m *tuiMockStore) SaveStateChange(_ context.Context, _ int, _ string, _ string, _ string) error {
return nil
}
func (m *tuiMockStore) GetStateChangesSince(_ context.Context, _ int, _ time.Time) ([]models.StateChange, error) {
return nil, nil
}
func (m *tuiMockStore) PruneStateChanges(_ context.Context) error { return nil }
func (m *tuiMockStore) RegisterNode(_ context.Context, _ models.ProbeNode) error { return nil }
func (m *tuiMockStore) GetNode(_ context.Context, _ string) (models.ProbeNode, error) {
return models.ProbeNode{}, nil
}
func (m *tuiMockStore) UpdateNodeLastSeen(_ context.Context, _ string) error { return nil }
func (m *tuiMockStore) DeleteNode(_ context.Context, _ string) error { return nil }
func (m *tuiMockStore) LoadAlertHealth(_ context.Context) (map[int]models.AlertHealthRecord, error) {
return nil, nil
}
func (m *tuiMockStore) SaveAlertHealth(_ context.Context, _ models.AlertHealthRecord) error {
return nil
}
func (m *tuiMockStore) SaveLog(_ context.Context, _ string) error { return nil }
func (m *tuiMockStore) LoadLogs(_ context.Context, _ int) ([]string, error) { return nil, nil }
func (m *tuiMockStore) PruneLogs(_ context.Context) error { return nil }
func (m *tuiMockStore) GetActiveMaintenanceWindows(_ context.Context) ([]models.MaintenanceWindow, error) {
return nil, nil
}
func (m *tuiMockStore) AddMaintenanceWindow(_ context.Context, _ models.MaintenanceWindow) error {
return nil
}
func (m *tuiMockStore) EndMaintenanceWindow(_ context.Context, _ int) error { return nil }
func (m *tuiMockStore) DeleteMaintenanceWindow(_ context.Context, _ int) error { return nil }
func (m *tuiMockStore) PruneExpiredMaintenanceWindows(_ context.Context, _ time.Duration) (int64, error) {
return 0, nil
}
func (m *tuiMockStore) IsMonitorInMaintenance(_ context.Context, _ int) (bool, error) {
return false, nil
}
func (m *tuiMockStore) GetPreference(_ context.Context, _ string) (string, error) { return "", nil }
func (m *tuiMockStore) SetPreference(_ context.Context, _ string, _ string) error { return nil }
func (m *tuiMockStore) ExportData(_ context.Context) (models.Backup, error) {
return models.Backup{}, nil
}
func (m *tuiMockStore) ImportData(_ context.Context, _ models.Backup) error { return nil }
func (m *tuiMockStore) Close() error { return nil }
func newTestModel(ms *tuiMockStore) Model { func newTestModel(ms *tuiMockStore) Model {
return Model{ return Model{