refactor(store): propagate context.Context through all Store methods

Every Store interface method (except Close) now takes context.Context
as first parameter. All 54 db.Query/Exec/QueryRow calls in SQLStore
replaced with their *Context variants. DB operations now respect
cancellation and deadlines.

Context sources by caller:
- Engine dbWriter/poll/pruner: engine ctx from Start()
- HTTP handlers: r.Context()
- config.Apply/Export: caller-provided ctx
- TUI/main.go init: context.Background()

RunCheck and all sub-checks (HTTP/ping/port/DNS) accept parent ctx.
HTTP checks now inherit shutdown cancellation instead of rooting in
context.Background(). dbWrite.exec takes ctx so the writer goroutine
can cancel stuck DB operations.

DeleteSite/ImportData use BeginTx(ctx) instead of Begin().
This commit is contained in:
2026-06-11 14:40:30 -04:00
parent 5d5153351e
commit 70a83a1da9
28 changed files with 813 additions and 677 deletions
+66 -65
View File
@@ -1,6 +1,7 @@
package store
import (
"context"
"fmt"
"strings"
"testing"
@@ -15,7 +16,7 @@ func newTestStore(t *testing.T) *SQLStore {
if err != nil {
t.Fatalf("NewSQLiteStore: %v", err)
}
if err := s.Init(); err != nil {
if err := s.Init(context.Background()); err != nil {
t.Fatalf("Init: %v", err)
}
return s
@@ -24,7 +25,7 @@ func newTestStore(t *testing.T) *SQLStore {
func TestSiteCRUD(t *testing.T) {
s := newTestStore(t)
sites, err := s.GetSites()
sites, err := s.GetSites(context.Background())
if err != nil {
t.Fatalf("GetSites: %v", err)
}
@@ -32,11 +33,11 @@ func TestSiteCRUD(t *testing.T) {
t.Fatalf("expected 0 sites, got %d", len(sites))
}
if err := s.AddSite(models.Site{Name: "Test", URL: "https://example.com", Type: "http", Interval: 30}); err != nil {
if err := s.AddSite(context.Background(), models.Site{Name: "Test", URL: "https://example.com", Type: "http", Interval: 30}); err != nil {
t.Fatalf("AddSite: %v", err)
}
sites, err = s.GetSites()
sites, err = s.GetSites(context.Background())
if err != nil {
t.Fatalf("GetSites: %v", err)
}
@@ -48,11 +49,11 @@ func TestSiteCRUD(t *testing.T) {
}
sites[0].Name = "Updated"
if err := s.UpdateSite(sites[0]); err != nil {
if err := s.UpdateSite(context.Background(), sites[0]); err != nil {
t.Fatalf("UpdateSite: %v", err)
}
sites, err = s.GetSites()
sites, err = s.GetSites(context.Background())
if err != nil {
t.Fatalf("GetSites: %v", err)
}
@@ -60,11 +61,11 @@ func TestSiteCRUD(t *testing.T) {
t.Errorf("expected name 'Updated', got '%s'", sites[0].Name)
}
if err := s.DeleteSite(sites[0].ID); err != nil {
if err := s.DeleteSite(context.Background(), sites[0].ID); err != nil {
t.Fatalf("DeleteSite: %v", err)
}
sites, err = s.GetSites()
sites, err = s.GetSites(context.Background())
if err != nil {
t.Fatalf("GetSites: %v", err)
}
@@ -76,11 +77,11 @@ func TestSiteCRUD(t *testing.T) {
func TestAlertCRUD(t *testing.T) {
s := newTestStore(t)
if err := s.AddAlert("Discord", "discord", map[string]string{"url": "https://example.com/hook"}); err != nil {
if err := s.AddAlert(context.Background(), "Discord", "discord", map[string]string{"url": "https://example.com/hook"}); err != nil {
t.Fatalf("AddAlert: %v", err)
}
alerts, err := s.GetAllAlerts()
alerts, err := s.GetAllAlerts(context.Background())
if err != nil {
t.Fatalf("GetAllAlerts: %v", err)
}
@@ -94,7 +95,7 @@ func TestAlertCRUD(t *testing.T) {
t.Errorf("settings url mismatch")
}
a, err := s.GetAlert(alerts[0].ID)
a, err := s.GetAlert(context.Background(), alerts[0].ID)
if err != nil {
t.Fatalf("GetAlert: %v", err)
}
@@ -102,11 +103,11 @@ func TestAlertCRUD(t *testing.T) {
t.Errorf("expected name 'Discord', got '%s'", a.Name)
}
if err := s.UpdateAlert(a.ID, "Slack", "slack", map[string]string{"url": "https://slack.com/hook"}); err != nil {
if err := s.UpdateAlert(context.Background(), a.ID, "Slack", "slack", map[string]string{"url": "https://slack.com/hook"}); err != nil {
t.Fatalf("UpdateAlert: %v", err)
}
a, err = s.GetAlert(a.ID)
a, err = s.GetAlert(context.Background(), a.ID)
if err != nil {
t.Fatalf("GetAlert: %v", err)
}
@@ -114,11 +115,11 @@ func TestAlertCRUD(t *testing.T) {
t.Errorf("expected type 'slack', got '%s'", a.Type)
}
if err := s.DeleteAlert(a.ID); err != nil {
if err := s.DeleteAlert(context.Background(), a.ID); err != nil {
t.Fatalf("DeleteAlert: %v", err)
}
alerts, err = s.GetAllAlerts()
alerts, err = s.GetAllAlerts(context.Background())
if err != nil {
t.Fatalf("GetAllAlerts: %v", err)
}
@@ -130,11 +131,11 @@ func TestAlertCRUD(t *testing.T) {
func TestUserCRUD(t *testing.T) {
s := newTestStore(t)
if err := s.AddUser("admin", "ssh-ed25519 AAAA...", "admin"); err != nil {
if err := s.AddUser(context.Background(), "admin", "ssh-ed25519 AAAA...", "admin"); err != nil {
t.Fatalf("AddUser: %v", err)
}
users, err := s.GetAllUsers()
users, err := s.GetAllUsers(context.Background())
if err != nil {
t.Fatalf("GetAllUsers: %v", err)
}
@@ -145,11 +146,11 @@ func TestUserCRUD(t *testing.T) {
t.Errorf("expected username 'admin', got '%s'", users[0].Username)
}
if err := s.UpdateUser(users[0].ID, "root", "ssh-ed25519 BBBB...", "admin"); err != nil {
if err := s.UpdateUser(context.Background(), users[0].ID, "root", "ssh-ed25519 BBBB...", "admin"); err != nil {
t.Fatalf("UpdateUser: %v", err)
}
users, err = s.GetAllUsers()
users, err = s.GetAllUsers(context.Background())
if err != nil {
t.Fatalf("GetAllUsers: %v", err)
}
@@ -157,11 +158,11 @@ func TestUserCRUD(t *testing.T) {
t.Errorf("expected username 'root', got '%s'", users[0].Username)
}
if err := s.DeleteUser(users[0].ID); err != nil {
if err := s.DeleteUser(context.Background(), users[0].ID); err != nil {
t.Fatalf("DeleteUser: %v", err)
}
users, err = s.GetAllUsers()
users, err = s.GetAllUsers(context.Background())
if err != nil {
t.Fatalf("GetAllUsers: %v", err)
}
@@ -173,11 +174,11 @@ func TestUserCRUD(t *testing.T) {
func TestPushTokenGeneration(t *testing.T) {
s := newTestStore(t)
if err := s.AddSite(models.Site{Name: "Push Monitor", Type: "push", Interval: 60}); err != nil {
if err := s.AddSite(context.Background(), models.Site{Name: "Push Monitor", Type: "push", Interval: 60}); err != nil {
t.Fatalf("AddSite: %v", err)
}
sites, err := s.GetSites()
sites, err := s.GetSites(context.Background())
if err != nil {
t.Fatalf("GetSites: %v", err)
}
@@ -195,17 +196,17 @@ func TestPushTokenGeneration(t *testing.T) {
func TestImportExport(t *testing.T) {
s := newTestStore(t)
if err := s.AddAlert("Test Alert", "webhook", map[string]string{"url": "https://example.com"}); err != nil {
if err := s.AddAlert(context.Background(), "Test Alert", "webhook", map[string]string{"url": "https://example.com"}); err != nil {
t.Fatalf("AddAlert: %v", err)
}
if err := s.AddSite(models.Site{Name: "Site1", URL: "https://example.com", Type: "http", Interval: 30}); err != nil {
if err := s.AddSite(context.Background(), models.Site{Name: "Site1", URL: "https://example.com", Type: "http", Interval: 30}); err != nil {
t.Fatalf("AddSite: %v", err)
}
if err := s.AddUser("user1", "ssh-ed25519 KEY", "user"); err != nil {
if err := s.AddUser(context.Background(), "user1", "ssh-ed25519 KEY", "user"); err != nil {
t.Fatalf("AddUser: %v", err)
}
backup, err := s.ExportData()
backup, err := s.ExportData(context.Background())
if err != nil {
t.Fatalf("ExportData: %v", err)
}
@@ -214,19 +215,19 @@ func TestImportExport(t *testing.T) {
}
s2 := newTestStore(t)
if err := s2.ImportData(backup); err != nil {
if err := s2.ImportData(context.Background(), backup); err != nil {
t.Fatalf("ImportData: %v", err)
}
sites, err := s2.GetSites()
sites, err := s2.GetSites(context.Background())
if err != nil {
t.Fatalf("GetSites: %v", err)
}
alerts, err := s2.GetAllAlerts()
alerts, err := s2.GetAllAlerts(context.Background())
if err != nil {
t.Fatalf("GetAllAlerts: %v", err)
}
users, err := s2.GetAllUsers()
users, err := s2.GetAllUsers(context.Background())
if err != nil {
t.Fatalf("GetAllUsers: %v", err)
}
@@ -238,27 +239,27 @@ func TestImportExport(t *testing.T) {
func TestImportData_WipesHistory(t *testing.T) {
s := newTestStore(t)
if err := s.AddSite(models.Site{Name: "OldSite", URL: "https://old.com", Type: "http", Interval: 30}); err != nil {
if err := s.AddSite(context.Background(), models.Site{Name: "OldSite", URL: "https://old.com", Type: "http", Interval: 30}); err != nil {
t.Fatalf("AddSite: %v", err)
}
if err := s.SaveCheck(1, 5000, true); err != nil {
if err := s.SaveCheck(context.Background(), 1, 5000, true); err != nil {
t.Fatalf("SaveCheck: %v", err)
}
if err := s.SaveStateChange(1, "UP", "DOWN", "timeout"); err != nil {
if err := s.SaveStateChange(context.Background(), 1, "UP", "DOWN", "timeout"); err != nil {
t.Fatalf("SaveStateChange: %v", err)
}
if err := s.SaveAlertHealth(models.AlertHealthRecord{AlertID: 1, LastSendOK: true, SendCount: 1}); err != nil {
if err := s.SaveAlertHealth(context.Background(), models.AlertHealthRecord{AlertID: 1, LastSendOK: true, SendCount: 1}); err != nil {
t.Fatalf("SaveAlertHealth: %v", err)
}
backup := models.Backup{
Sites: []models.Site{{ID: 1, Name: "NewSite", URL: "https://new.com", Type: "http", Interval: 60}},
}
if err := s.ImportData(backup); err != nil {
if err := s.ImportData(context.Background(), backup); err != nil {
t.Fatalf("ImportData: %v", err)
}
history, err := s.LoadAllHistory(100)
history, err := s.LoadAllHistory(context.Background(), 100)
if err != nil {
t.Fatalf("LoadAllHistory: %v", err)
}
@@ -266,7 +267,7 @@ func TestImportData_WipesHistory(t *testing.T) {
t.Errorf("expected empty check_history after import, got %d sites with history", len(history))
}
changes, err := s.GetStateChanges(1, 100)
changes, err := s.GetStateChanges(context.Background(), 1, 100)
if err != nil {
t.Fatalf("GetStateChanges: %v", err)
}
@@ -278,17 +279,17 @@ func TestImportData_WipesHistory(t *testing.T) {
func TestCheckHistory(t *testing.T) {
s := newTestStore(t)
if err := s.SaveCheck(1, 5000000, true); err != nil {
if err := s.SaveCheck(context.Background(), 1, 5000000, true); err != nil {
t.Fatalf("SaveCheck: %v", err)
}
if err := s.SaveCheck(1, 10000000, false); err != nil {
if err := s.SaveCheck(context.Background(), 1, 10000000, false); err != nil {
t.Fatalf("SaveCheck: %v", err)
}
if err := s.SaveCheck(2, 3000000, true); err != nil {
if err := s.SaveCheck(context.Background(), 2, 3000000, true); err != nil {
t.Fatalf("SaveCheck site 2: %v", err)
}
history, err := s.LoadAllHistory(10)
history, err := s.LoadAllHistory(context.Background(), 10)
if err != nil {
t.Fatalf("LoadAllHistory: %v", err)
}
@@ -314,16 +315,16 @@ func TestDeleteSiteCascade(t *testing.T) {
s := newTestStore(t)
site := models.Site{Name: "Cascade Test", URL: "https://example.com", Interval: 30}
if err := s.AddSite(site); err != nil {
if err := s.AddSite(context.Background(), site); err != nil {
t.Fatalf("AddSite: %v", err)
}
sites, _ := s.GetSites()
sites, _ := s.GetSites(context.Background())
siteID := sites[0].ID
if err := s.SaveCheck(siteID, 1000, true); err != nil {
if err := s.SaveCheck(context.Background(), siteID, 1000, true); err != nil {
t.Fatalf("SaveCheck: %v", err)
}
if err := s.SaveStateChange(siteID, "UP", "DOWN", "timeout"); err != nil {
if err := s.SaveStateChange(context.Background(), siteID, "UP", "DOWN", "timeout"); err != nil {
t.Fatalf("SaveStateChange: %v", err)
}
mw := models.MaintenanceWindow{
@@ -332,25 +333,25 @@ func TestDeleteSiteCascade(t *testing.T) {
Type: "maintenance",
StartTime: time.Now(),
}
if err := s.AddMaintenanceWindow(mw); err != nil {
if err := s.AddMaintenanceWindow(context.Background(), mw); err != nil {
t.Fatalf("AddMaintenanceWindow: %v", err)
}
if err := s.DeleteSite(siteID); err != nil {
if err := s.DeleteSite(context.Background(), siteID); err != nil {
t.Fatalf("DeleteSite: %v", err)
}
history, _ := s.LoadAllHistory(100)
history, _ := s.LoadAllHistory(context.Background(), 100)
if len(history[siteID]) != 0 {
t.Errorf("expected 0 check_history rows, got %d", len(history[siteID]))
}
changes, _ := s.GetStateChanges(siteID, 100)
changes, _ := s.GetStateChanges(context.Background(), siteID, 100)
if len(changes) != 0 {
t.Errorf("expected 0 state_changes rows, got %d", len(changes))
}
windows, _ := s.GetActiveMaintenanceWindows()
windows, _ := s.GetActiveMaintenanceWindows(context.Background())
for _, w := range windows {
if w.MonitorID == siteID {
t.Errorf("orphaned maintenance window found: id=%d", w.ID)
@@ -362,15 +363,15 @@ func TestPruneLogs(t *testing.T) {
s := newTestStore(t)
for i := 0; i < maxLogRows+50; i++ {
if err := s.SaveLog(fmt.Sprintf("log %d", i)); err != nil {
if err := s.SaveLog(context.Background(), fmt.Sprintf("log %d", i)); err != nil {
t.Fatalf("SaveLog: %v", err)
}
}
if err := s.PruneLogs(); err != nil {
if err := s.PruneLogs(context.Background()); err != nil {
t.Fatalf("PruneLogs: %v", err)
}
logs, err := s.LoadLogs(maxLogRows * 2)
logs, err := s.LoadLogs(context.Background(), maxLogRows*2)
if err != nil {
t.Fatalf("LoadLogs: %v", err)
}
@@ -395,21 +396,21 @@ func TestPruneCheckHistory(t *testing.T) {
s := newTestStore(t)
for i := 0; i < maxCheckHistory+5; i++ {
if err := s.SaveCheck(1, int64(i), true); err != nil {
if err := s.SaveCheck(context.Background(), 1, int64(i), true); err != nil {
t.Fatalf("SaveCheck site 1: %v", err)
}
}
for i := 0; i < 3; i++ {
if err := s.SaveCheck(2, int64(i), true); err != nil {
if err := s.SaveCheck(context.Background(), 2, int64(i), true); err != nil {
t.Fatalf("SaveCheck site 2: %v", err)
}
}
if err := s.PruneCheckHistory(); err != nil {
if err := s.PruneCheckHistory(context.Background()); err != nil {
t.Fatalf("PruneCheckHistory: %v", err)
}
history, err := s.LoadAllHistory(maxCheckHistory * 2)
history, err := s.LoadAllHistory(context.Background(), maxCheckHistory*2)
if err != nil {
t.Fatalf("LoadAllHistory: %v", err)
}
@@ -434,7 +435,7 @@ func TestPruneExpiredMaintenanceWindows(t *testing.T) {
StartTime: now.Add(-11 * 24 * time.Hour),
EndTime: now.Add(-10 * 24 * time.Hour),
}
if err := s.AddMaintenanceWindow(old); err != nil {
if err := s.AddMaintenanceWindow(context.Background(), old); err != nil {
t.Fatalf("AddMaintenanceWindow (old): %v", err)
}
@@ -446,7 +447,7 @@ func TestPruneExpiredMaintenanceWindows(t *testing.T) {
StartTime: now.Add(-2 * 24 * time.Hour),
EndTime: now.Add(-1 * 24 * time.Hour),
}
if err := s.AddMaintenanceWindow(recent); err != nil {
if err := s.AddMaintenanceWindow(context.Background(), recent); err != nil {
t.Fatalf("AddMaintenanceWindow (recent): %v", err)
}
@@ -457,11 +458,11 @@ func TestPruneExpiredMaintenanceWindows(t *testing.T) {
Type: "maintenance",
StartTime: now.Add(-1 * time.Hour),
}
if err := s.AddMaintenanceWindow(ongoing); err != nil {
if err := s.AddMaintenanceWindow(context.Background(), ongoing); err != nil {
t.Fatalf("AddMaintenanceWindow (ongoing): %v", err)
}
pruned, err := s.PruneExpiredMaintenanceWindows(7 * 24 * time.Hour)
pruned, err := s.PruneExpiredMaintenanceWindows(context.Background(), 7*24*time.Hour)
if err != nil {
t.Fatalf("PruneExpiredMaintenanceWindows: %v", err)
}
@@ -469,7 +470,7 @@ func TestPruneExpiredMaintenanceWindows(t *testing.T) {
t.Errorf("expected 1 pruned, got %d", pruned)
}
all, err := s.GetAllMaintenanceWindows(100)
all, err := s.GetAllMaintenanceWindows(context.Background(), 100)
if err != nil {
t.Fatalf("GetAllMaintenanceWindows: %v", err)
}
@@ -498,7 +499,7 @@ func TestImportData_EncryptsAlertSettings(t *testing.T) {
{ID: 1, Name: "tg", Type: "telegram", Settings: map[string]string{"token": "123:SECRET", "chat_id": "42"}},
},
}
if err := s.ImportData(backup); err != nil {
if err := s.ImportData(context.Background(), backup); err != nil {
t.Fatalf("ImportData: %v", err)
}
@@ -513,7 +514,7 @@ func TestImportData_EncryptsAlertSettings(t *testing.T) {
t.Errorf("plaintext secret found in stored column: %q", raw)
}
alerts, err := s.GetAllAlerts()
alerts, err := s.GetAllAlerts(context.Background())
if err != nil {
t.Fatalf("GetAllAlerts: %v", err)
}