refactor(store): propagate context.Context through all Store methods
Every Store interface method (except Close) now takes context.Context as first parameter. All 54 db.Query/Exec/QueryRow calls in SQLStore replaced with their *Context variants. DB operations now respect cancellation and deadlines. Context sources by caller: - Engine dbWriter/poll/pruner: engine ctx from Start() - HTTP handlers: r.Context() - config.Apply/Export: caller-provided ctx - TUI/main.go init: context.Background() RunCheck and all sub-checks (HTTP/ping/port/DNS) accept parent ctx. HTTP checks now inherit shutdown cancellation instead of rooting in context.Background(). dbWrite.exec takes ctx so the writer goroutine can cancel stuck DB operations. DeleteSite/ImportData use BeginTx(ctx) instead of Begin().
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
@@ -22,8 +23,8 @@ type kcMockStore struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *kcMockStore) GetAllUsers() ([]models.User, error) { return m.users, m.err }
|
||||
func (m *kcMockStore) DeleteUser(int) error { return nil }
|
||||
func (m *kcMockStore) GetAllUsers(_ context.Context) ([]models.User, error) { return m.users, m.err }
|
||||
func (m *kcMockStore) DeleteUser(_ context.Context, _ int) error { return nil }
|
||||
|
||||
func testKey(t *testing.T) (string, ssh.PublicKey) {
|
||||
t.Helper()
|
||||
@@ -103,7 +104,7 @@ func TestUserInvalidatingStore_DeleteDropsKeyCache(t *testing.T) {
|
||||
|
||||
// Revoke the user; DB unreachable immediately after. The cached key must
|
||||
// be gone the moment the delete returns.
|
||||
if err := s.DeleteUser(1); err != nil {
|
||||
if err := s.DeleteUser(context.Background(), 1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ms.users = nil
|
||||
|
||||
+27
-25
@@ -141,7 +141,7 @@ func openStore(dbType, dsn string) store.Store {
|
||||
} else {
|
||||
fmt.Println("WARNING: No UPTOP_ENCRYPTION_KEY set. Alert credentials stored unencrypted.")
|
||||
}
|
||||
if err := ss.Init(); err != nil {
|
||||
if err := ss.Init(context.Background()); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "database init error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -171,7 +171,7 @@ func runApply(args []string) {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
changes, err := config.Apply(s, f, config.ApplyOpts{
|
||||
changes, err := config.Apply(context.Background(), s, f, config.ApplyOpts{
|
||||
DryRun: *dryRun,
|
||||
Prune: *prune,
|
||||
})
|
||||
@@ -192,7 +192,7 @@ func runExport(args []string) {
|
||||
|
||||
s := openStore(*dbType, *dsn)
|
||||
|
||||
f, err := config.Export(s)
|
||||
f, err := config.Export(context.Background(), s)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
os.Exit(1)
|
||||
@@ -231,12 +231,12 @@ func runMigrateSecrets(args []string) {
|
||||
fmt.Fprintf(os.Stderr, "database error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := ss.Init(); err != nil {
|
||||
if err := ss.Init(context.Background()); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "database init error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
alerts, err := ss.GetAllAlerts()
|
||||
alerts, err := ss.GetAllAlerts(context.Background())
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error loading alerts: %v\n", err)
|
||||
os.Exit(1)
|
||||
@@ -245,7 +245,7 @@ func runMigrateSecrets(args []string) {
|
||||
ss.SetEncryptor(enc)
|
||||
migrated := 0
|
||||
for _, a := range alerts {
|
||||
if err := ss.UpdateAlert(a.ID, a.Name, a.Type, a.Settings); err != nil {
|
||||
if err := ss.UpdateAlert(context.Background(), a.ID, a.Name, a.Type, a.Settings); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error migrating alert %q: %v\n", a.Name, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -378,7 +378,7 @@ func runServe(args []string) {
|
||||
|
||||
kc := newKeyCache(ss)
|
||||
var s store.Store = &userInvalidatingStore{Store: ss, kc: kc}
|
||||
if err := s.Init(); err != nil {
|
||||
if err := s.Init(context.Background()); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "database init error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -395,7 +395,7 @@ func runServe(args []string) {
|
||||
os.Exit(1)
|
||||
}
|
||||
backup := importer.ConvertKuma(kb)
|
||||
if err := s.ImportData(backup); err != nil {
|
||||
if err := s.ImportData(context.Background(), backup); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "import failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -515,21 +515,22 @@ func startSSHServer(port int, db store.Store, eng *monitor.Engine, kc *keyCache)
|
||||
}
|
||||
|
||||
func seedDemoData(s store.Store) {
|
||||
existing, _ := s.GetSites()
|
||||
ctx := context.Background()
|
||||
existing, _ := s.GetSites(ctx)
|
||||
if len(existing) > 0 {
|
||||
return
|
||||
}
|
||||
fmt.Println("Seeding demo data...")
|
||||
|
||||
if err := s.AddAlert("Discord Ops", "discord", map[string]string{"url": "https://discord.com/api/webhooks/demo/token"}); err != nil {
|
||||
if err := s.AddAlert(ctx, "Discord Ops", "discord", map[string]string{"url": "https://discord.com/api/webhooks/demo/token"}); err != nil {
|
||||
log.Printf("demo seed: add alert: %v", err)
|
||||
return
|
||||
}
|
||||
if err := s.AddAlert("Slack Infra", "slack", map[string]string{"url": "https://hooks.slack.com/services/DEMO/WEBHOOK"}); err != nil {
|
||||
if err := s.AddAlert(ctx, "Slack Infra", "slack", map[string]string{"url": "https://hooks.slack.com/services/DEMO/WEBHOOK"}); err != nil {
|
||||
log.Printf("demo seed: add alert: %v", err)
|
||||
return
|
||||
}
|
||||
if err := s.AddAlert("Email Oncall", "email", map[string]string{
|
||||
if err := s.AddAlert(ctx, "Email Oncall", "email", map[string]string{
|
||||
"host": "smtp.example.com", "port": "587",
|
||||
"user": "oncall@example.com", "pass": "replace-me",
|
||||
"from": "oncall@example.com", "to": "team@example.com",
|
||||
@@ -538,7 +539,7 @@ func seedDemoData(s store.Store) {
|
||||
return
|
||||
}
|
||||
|
||||
alerts, _ := s.GetAllAlerts()
|
||||
alerts, _ := s.GetAllAlerts(ctx)
|
||||
alertID := 0
|
||||
if len(alerts) > 0 {
|
||||
alertID = alerts[0].ID
|
||||
@@ -557,7 +558,7 @@ func seedDemoData(s store.Store) {
|
||||
{Name: "SSH Server", Type: "port", Interval: 60, AlertID: alertID, Hostname: "10.0.0.1", Port: 22, Timeout: 5, ExpiryThreshold: 7},
|
||||
}
|
||||
for _, site := range demoSites {
|
||||
if err := s.AddSite(site); err != nil {
|
||||
if err := s.AddSite(ctx, site); err != nil {
|
||||
log.Printf("demo seed: add site %q: %v", site.Name, err)
|
||||
}
|
||||
}
|
||||
@@ -576,7 +577,7 @@ func newKeyCache(db store.Store) *keyCache {
|
||||
}
|
||||
|
||||
func (c *keyCache) refresh() {
|
||||
users, err := c.db.GetAllUsers()
|
||||
users, err := c.db.GetAllUsers(context.Background())
|
||||
if err != nil {
|
||||
// Keep the previous key set: a transient DB error must not lock every
|
||||
// admin out. Revocation still fails closed because Invalidate clears
|
||||
@@ -637,31 +638,32 @@ type userInvalidatingStore struct {
|
||||
kc *keyCache
|
||||
}
|
||||
|
||||
func (s *userInvalidatingStore) AddUser(username, publicKey, role string) error {
|
||||
err := s.Store.AddUser(username, publicKey, role)
|
||||
func (s *userInvalidatingStore) AddUser(ctx context.Context, username, publicKey, role string) error {
|
||||
err := s.Store.AddUser(ctx, username, publicKey, role)
|
||||
s.kc.Invalidate()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *userInvalidatingStore) UpdateUser(id int, username, publicKey, role string) error {
|
||||
err := s.Store.UpdateUser(id, username, publicKey, role)
|
||||
func (s *userInvalidatingStore) UpdateUser(ctx context.Context, id int, username, publicKey, role string) error {
|
||||
err := s.Store.UpdateUser(ctx, id, username, publicKey, role)
|
||||
s.kc.Invalidate()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *userInvalidatingStore) DeleteUser(id int) error {
|
||||
err := s.Store.DeleteUser(id)
|
||||
func (s *userInvalidatingStore) DeleteUser(ctx context.Context, id int) error {
|
||||
err := s.Store.DeleteUser(ctx, id)
|
||||
s.kc.Invalidate()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *userInvalidatingStore) ImportData(data models.Backup) error {
|
||||
err := s.Store.ImportData(data)
|
||||
func (s *userInvalidatingStore) ImportData(ctx context.Context, data models.Backup) error {
|
||||
err := s.Store.ImportData(ctx, data)
|
||||
s.kc.Invalidate()
|
||||
return err
|
||||
}
|
||||
|
||||
func seedKeysFromEnv(s store.Store) {
|
||||
ctx := context.Background()
|
||||
var keys []string
|
||||
|
||||
if v := os.Getenv("UPTOP_ADMIN_KEY"); v != "" {
|
||||
@@ -687,7 +689,7 @@ func seedKeysFromEnv(s store.Store) {
|
||||
return
|
||||
}
|
||||
|
||||
existing, err := s.GetAllUsers()
|
||||
existing, err := s.GetAllUsers(ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: could not check existing users: %v\n", err)
|
||||
return
|
||||
@@ -705,7 +707,7 @@ func seedKeysFromEnv(s store.Store) {
|
||||
}
|
||||
|
||||
username := usernameFromKey(key, i, len(existing)+added)
|
||||
if err := s.AddUser(username, key, "admin"); err != nil {
|
||||
if err := s.AddUser(ctx, username, key, "admin"); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: failed to seed user %q: %v\n", username, err)
|
||||
continue
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user