From 4d5116644f51b669a003001fb2149284ddef922c Mon Sep 17 00:00:00 2001 From: Tyler Koenig Date: Fri, 15 May 2026 00:00:02 -0400 Subject: [PATCH 1/9] fix(core): correctness and robustness fixes across all subsystems - Move status page template to package-level template.Must (panic on parse error at init instead of nil deref at runtime) - Fix XSS in import error responses (log detail server-side, return generic message to client) - Handle ListenAndServe errors in HTTP and SSH servers - Use defer resp.Body.Close() in all alert providers, check json.Marshal errors - Share HTTP clients across checks instead of creating per-request - Use http.NewRequestWithContext for per-site timeout control - Support HTTP method field (was always GET despite DB storing method) - Implement AcceptedCodes validation (was hardcoded >= 400 despite DB storing accepted code ranges) - Add defer tx.Rollback() to ImportData for transaction safety --- .gitignore | 3 +- cmd/goupkeep/main.go | 6 +- internal/alert/alert.go | 23 ++- internal/monitor/monitor.go | 63 +++++++-- internal/server/server.go | 273 ++++++++++++++++++------------------ internal/store/postgres.go | 1 + internal/store/sqlite.go | 2 +- 7 files changed, 218 insertions(+), 153 deletions(-) diff --git a/.gitignore b/.gitignore index e1f6a8e..1fd01d6 100644 --- a/.gitignore +++ b/.gitignore @@ -38,4 +38,5 @@ tmp # Old repo /go-upkeep/ -*.local.json \ No newline at end of file +*.local.json +*.local.md \ No newline at end of file diff --git a/cmd/goupkeep/main.go b/cmd/goupkeep/main.go index 26c01b3..3b7d85d 100644 --- a/cmd/goupkeep/main.go +++ b/cmd/goupkeep/main.go @@ -161,7 +161,11 @@ func startSSHServer(port int) { fmt.Printf("SSH server error: %v\n", err) return } - go func() { s.ListenAndServe() }() + go func() { + if err := s.ListenAndServe(); err != nil { + log.Fatalf("SSH server failed: %v", err) + } + }() } func seedDemoData(s store.Store) { diff --git a/internal/alert/alert.go b/internal/alert/alert.go index 71b7570..16af56c 100644 --- a/internal/alert/alert.go +++ b/internal/alert/alert.go @@ -61,12 +61,15 @@ type DiscordProvider struct{ URL string } func (d *DiscordProvider) Send(title, message string) error { payload := map[string]string{"content": fmt.Sprintf("**%s**\n%s", title, message)} - jsonValue, _ := json.Marshal(payload) + jsonValue, err := json.Marshal(payload) + if err != nil { + return err + } resp, err := alertClient.Post(d.URL, "application/json", bytes.NewBuffer(jsonValue)) if err != nil { return err } - resp.Body.Close() + defer resp.Body.Close() return nil } @@ -75,12 +78,15 @@ type SlackProvider struct{ URL string } func (s *SlackProvider) Send(title, message string) error { payload := map[string]string{"text": fmt.Sprintf("*%s*\n%s", title, message)} - jsonValue, _ := json.Marshal(payload) + jsonValue, err := json.Marshal(payload) + if err != nil { + return err + } resp, err := alertClient.Post(s.URL, "application/json", bytes.NewBuffer(jsonValue)) if err != nil { return err } - resp.Body.Close() + defer resp.Body.Close() return nil } @@ -93,12 +99,15 @@ func (w *WebhookProvider) Send(title, message string) error { "message": message, "status": "alert", } - jsonValue, _ := json.Marshal(payload) + jsonValue, err := json.Marshal(payload) + if err != nil { + return err + } resp, err := alertClient.Post(w.URL, "application/json", bytes.NewBuffer(jsonValue)) if err != nil { return err } - resp.Body.Close() + defer resp.Body.Close() return nil } @@ -139,6 +148,6 @@ func (n *NtfyProvider) Send(title, message string) error { if err != nil { return err } - resp.Body.Close() + defer resp.Body.Close() return nil } diff --git a/internal/monitor/monitor.go b/internal/monitor/monitor.go index 1bb02fe..3f2a869 100644 --- a/internal/monitor/monitor.go +++ b/internal/monitor/monitor.go @@ -1,6 +1,7 @@ package monitor import ( + "context" "crypto/tls" "fmt" "go-upkeep/internal/alert" @@ -9,6 +10,7 @@ import ( "net" "net/http" "strconv" + "strings" "sync" "time" @@ -52,6 +54,13 @@ var ( activeMutex sync.RWMutex insecureSkipVerify bool + + strictClient = &http.Client{ + Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: false}}, + } + insecureClient = &http.Client{ + Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + } ) func SetInsecureSkipVerify(skip bool) { @@ -258,15 +267,51 @@ func checkPush(site models.Site) { } } -func checkHTTP(site models.Site) { - start := time.Now() - timeout := time.Duration(site.Timeout) * time.Second - if timeout <= 0 { - timeout = 5 * time.Second +func isCodeAccepted(code int, accepted string) bool { + if accepted == "" { + return code >= 200 && code < 300 } - skipTLS := insecureSkipVerify || site.IgnoreTLS - client := &http.Client{Timeout: timeout, Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: skipTLS}}} - resp, err := client.Get(site.URL) + for _, part := range strings.Split(accepted, ",") { + part = strings.TrimSpace(part) + if strings.Contains(part, "-") { + bounds := strings.SplitN(part, "-", 2) + lo, err1 := strconv.Atoi(strings.TrimSpace(bounds[0])) + hi, err2 := strconv.Atoi(strings.TrimSpace(bounds[1])) + if err1 == nil && err2 == nil && code >= lo && code <= hi { + return true + } + } else { + if v, err := strconv.Atoi(part); err == nil && code == v { + return true + } + } + } + return false +} + +func checkHTTP(site models.Site) { + method := site.Method + if method == "" { + method = "GET" + } + + timeout := siteTimeout(site) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, method, site.URL, nil) + if err != nil { + handleStatusChange(site, "DOWN", 0, 0) + return + } + + client := strictClient + if insecureSkipVerify || site.IgnoreTLS { + client = insecureClient + } + + start := time.Now() + resp, err := client.Do(req) latency := time.Since(start) rawStatus := "UP" @@ -279,7 +324,7 @@ func checkHTTP(site models.Site) { } else { defer resp.Body.Close() rawCode = resp.StatusCode - if resp.StatusCode >= 400 { + if !isCodeAccepted(rawCode, site.AcceptedCodes) { rawStatus = "DOWN" } if site.CheckSSL && resp.TLS != nil && len(resp.TLS.PeerCertificates) > 0 { diff --git a/internal/server/server.go b/internal/server/server.go index f814cc3..3cf6228 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -8,10 +8,139 @@ import ( "go-upkeep/internal/monitor" "go-upkeep/internal/store" "html/template" + "log" "net/http" "sort" ) +var statusTpl = template.Must(template.New("status").Parse(` + + + + {{.Title}} + + + + +
+

{{.Title}}

+
+
+
+
Powered by Go-Upkeep
+
+ + +`)) + type ServerConfig struct { Port int EnableStatus bool @@ -76,7 +205,8 @@ func Start(cfg ServerConfig) { return } if err := store.Get().ImportData(data); err != nil { - http.Error(w, "Import Failed: "+err.Error(), 500) + log.Printf("Import failed: %v", err) + http.Error(w, "Import failed", 500) return } w.Write([]byte("Import Successful")) @@ -94,12 +224,14 @@ func Start(cfg ServerConfig) { } var kb importer.KumaBackup if err := json.NewDecoder(r.Body).Decode(&kb); err != nil { - http.Error(w, "Invalid Kuma JSON: "+err.Error(), 400) + log.Printf("Invalid Kuma JSON: %v", err) + http.Error(w, "Invalid Kuma JSON", 400) return } backup := importer.ConvertKuma(&kb) if err := store.Get().ImportData(backup); err != nil { - http.Error(w, "Import Failed: "+err.Error(), 500) + log.Printf("Kuma import failed: %v", err) + http.Error(w, "Import failed", 500) return } w.Write([]byte(fmt.Sprintf("Imported %d monitors, %d alerts from Kuma v%s", len(backup.Sites), len(backup.Alerts), kb.Version))) @@ -119,7 +251,9 @@ func Start(cfg ServerConfig) { go func() { addr := fmt.Sprintf(":%d", cfg.Port) fmt.Printf("HTTP Server listening on %s\n", addr) - http.ListenAndServe(addr, mux) + if err := http.ListenAndServe(addr, mux); err != nil { + log.Fatalf("HTTP server failed: %v", err) + } }() } @@ -143,138 +277,9 @@ func renderStatusPage(w http.ResponseWriter, title string) { return sites[i].Name < sites[j].Name }) - const tpl = ` - - - - {{.Title}} - - - - -
-

{{.Title}}

-
-
-
-
Powered by Go-Upkeep
-
- - - ` - - t, _ := template.New("status").Parse(tpl) data := struct { Title string Sites []models.Site }{Title: title, Sites: sites} - t.Execute(w, data) + statusTpl.Execute(w, data) } diff --git a/internal/store/postgres.go b/internal/store/postgres.go index c5201d0..94c046b 100644 --- a/internal/store/postgres.go +++ b/internal/store/postgres.go @@ -239,6 +239,7 @@ func (p *PostgresStore) ImportData(data models.Backup) error { if err != nil { return err } + defer tx.Rollback() tx.Exec("TRUNCATE TABLE sites RESTART IDENTITY CASCADE") tx.Exec("TRUNCATE TABLE alerts RESTART IDENTITY CASCADE") diff --git a/internal/store/sqlite.go b/internal/store/sqlite.go index e83d96e..1b1d5fd 100644 --- a/internal/store/sqlite.go +++ b/internal/store/sqlite.go @@ -258,8 +258,8 @@ func (s *SQLiteStore) ImportData(data models.Backup) error { if err != nil { return err } + defer tx.Rollback() - // Wipe Existing tx.Exec("DELETE FROM sites") tx.Exec("DELETE FROM sqlite_sequence WHERE name='sites'") tx.Exec("DELETE FROM alerts") From ab75f61c6bc33b07f54b2815b89aed16251802f3 Mon Sep 17 00:00:00 2001 From: Tyler Koenig Date: Fri, 15 May 2026 00:31:44 -0400 Subject: [PATCH 2/9] refactor(store): unify SQLite and Postgres into dialect-based SQLStore MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract shared SQLStore with Dialect interface for the ~5% that differs between backends (DDL, placeholders, sequence resets). - New dialect.go: Dialect interface + placeholder rewriter (? → $N) - New sqlstore.go: single implementation of all 19 Store methods - sqlite.go: reduced from 286 to 83 lines (SQLiteDialect only) - postgres.go: reduced from 266 to 78 lines (PostgresDialect only) - main.go: use NewSQLiteStore/NewPostgresStore constructors Zero CRUD logic duplication. Every future schema change written once. --- cmd/goupkeep/main.go | 11 +- internal/store/dialect.go | 36 +++++ internal/store/postgres.go | 262 +++++---------------------------- internal/store/sqlite.go | 290 ++++++------------------------------- internal/store/sqlstore.go | 243 +++++++++++++++++++++++++++++++ 5 files changed, 368 insertions(+), 474 deletions(-) create mode 100644 internal/store/dialect.go create mode 100644 internal/store/sqlstore.go diff --git a/cmd/goupkeep/main.go b/cmd/goupkeep/main.go index 3b7d85d..cfd5c2f 100644 --- a/cmd/goupkeep/main.go +++ b/cmd/goupkeep/main.go @@ -80,16 +80,21 @@ func main() { flag.Parse() var s store.Store + var dbErr error if *flagDBType == "postgres" { - s = &store.PostgresStore{ConnStr: *flagDSN} + s, dbErr = store.NewPostgresStore(*flagDSN) fmt.Printf("Using PostgreSQL: %s\n", *flagDSN) } else { - s = &store.SQLiteStore{DBPath: *flagDSN} + s, dbErr = store.NewSQLiteStore(*flagDSN) fmt.Printf("Using SQLite: %s\n", *flagDSN) } + if dbErr != nil { + fmt.Printf("Database connection error: %v\n", dbErr) + os.Exit(1) + } if err := s.Init(); err != nil { - fmt.Printf("Database Init Error: %v\n", err) + fmt.Printf("Database init error: %v\n", err) os.Exit(1) } store.SetGlobal(s) diff --git a/internal/store/dialect.go b/internal/store/dialect.go new file mode 100644 index 0000000..4e1ba04 --- /dev/null +++ b/internal/store/dialect.go @@ -0,0 +1,36 @@ +package store + +import "database/sql" + +type Dialect interface { + DriverName() string + CreateTablesSQL() []string + MigrationsSQL() []string + BoolFalse() string + ResetSequenceOnEmpty(db *sql.DB, table string) + ImportWipe(tx *sql.Tx) + ImportResetSequences(tx *sql.Tx) +} + +// rewritePlaceholders converts ? markers to $1, $2, etc. for Postgres. +// For SQLite (or any dialect not needing rewrite), returns the input unchanged. +func rewritePlaceholders(query string, dollarStyle bool) string { + if !dollarStyle { + return query + } + buf := make([]byte, 0, len(query)+32) + n := 0 + for i := 0; i < len(query); i++ { + if query[i] == '?' { + n++ + buf = append(buf, '$') + if n >= 10 { + buf = append(buf, byte('0'+n/10)) + } + buf = append(buf, byte('0'+n%10)) + } else { + buf = append(buf, query[i]) + } + } + return string(buf) +} diff --git a/internal/store/postgres.go b/internal/store/postgres.go index 94c046b..78fcc8d 100644 --- a/internal/store/postgres.go +++ b/internal/store/postgres.go @@ -2,77 +2,53 @@ package store import ( "database/sql" - "encoding/json" - "go-upkeep/internal/models" _ "github.com/lib/pq" ) -type PostgresStore struct { - ConnStr string - db *sql.DB +type PostgresDialect struct{} + +func NewPostgresStore(connStr string) (*SQLStore, error) { + return NewSQLStore("postgres", connStr, &PostgresDialect{}) } -func (p *PostgresStore) Init() error { - var err error - p.db, err = sql.Open("postgres", p.ConnStr) - if err != nil { - return err - } +func (d *PostgresDialect) DriverName() string { return "postgres" } +func (d *PostgresDialect) BoolFalse() string { return "FALSE" } - queries := []string{ +func (d *PostgresDialect) CreateTablesSQL() []string { + return []string{ `CREATE TABLE IF NOT EXISTS alerts ( id SERIAL PRIMARY KEY, - name TEXT, - type TEXT, - settings TEXT - );`, + name TEXT, type TEXT, settings TEXT + )`, `CREATE TABLE IF NOT EXISTS sites ( id SERIAL PRIMARY KEY, - name TEXT DEFAULT 'New Monitor', - url TEXT, - type TEXT DEFAULT 'http', - token TEXT, - interval INTEGER, - alert_id INTEGER, - check_ssl BOOLEAN DEFAULT FALSE, - threshold INTEGER DEFAULT 7, - max_retries INTEGER DEFAULT 0, - hostname TEXT DEFAULT '', - port INTEGER DEFAULT 0, - timeout INTEGER DEFAULT 0, - method TEXT DEFAULT 'GET', - description TEXT DEFAULT '', - parent_id INTEGER DEFAULT 0, - accepted_codes TEXT DEFAULT '200-299', - dns_resolve_type TEXT DEFAULT '', - dns_server TEXT DEFAULT '', - ignore_tls BOOLEAN DEFAULT FALSE, - paused BOOLEAN DEFAULT FALSE - );`, + name TEXT DEFAULT 'New Monitor', url TEXT, type TEXT DEFAULT 'http', + token TEXT, interval INTEGER, alert_id INTEGER, + check_ssl BOOLEAN DEFAULT FALSE, threshold INTEGER DEFAULT 7, + max_retries INTEGER DEFAULT 0, hostname TEXT DEFAULT '', + port INTEGER DEFAULT 0, timeout INTEGER DEFAULT 0, + method TEXT DEFAULT 'GET', description TEXT DEFAULT '', + parent_id INTEGER DEFAULT 0, accepted_codes TEXT DEFAULT '200-299', + dns_resolve_type TEXT DEFAULT '', dns_server TEXT DEFAULT '', + ignore_tls BOOLEAN DEFAULT FALSE, paused BOOLEAN DEFAULT FALSE + )`, `CREATE TABLE IF NOT EXISTS users ( id SERIAL PRIMARY KEY, - username TEXT NOT NULL, - public_key TEXT NOT NULL, + username TEXT NOT NULL, public_key TEXT NOT NULL, role TEXT DEFAULT 'user' - );`, + )`, `CREATE TABLE IF NOT EXISTS check_history ( id SERIAL PRIMARY KEY, - site_id INTEGER NOT NULL, - latency_ns BIGINT, - is_up BOOLEAN, - checked_at TIMESTAMP DEFAULT NOW() - );`, - } - for _, q := range queries { - if _, err := p.db.Exec(q); err != nil { - return err - } + site_id INTEGER NOT NULL, latency_ns BIGINT, + is_up BOOLEAN, checked_at TIMESTAMP DEFAULT NOW() + )`, + `CREATE INDEX IF NOT EXISTS idx_check_history_site ON check_history(site_id, checked_at DESC)`, } +} - p.db.Exec("CREATE INDEX IF NOT EXISTS idx_check_history_site ON check_history(site_id, checked_at DESC)") - - migrations := []string{ +func (d *PostgresDialect) MigrationsSQL() []string { + return []string{ "ALTER TABLE sites ADD COLUMN IF NOT EXISTS hostname TEXT DEFAULT ''", "ALTER TABLE sites ADD COLUMN IF NOT EXISTS port INTEGER DEFAULT 0", "ALTER TABLE sites ADD COLUMN IF NOT EXISTS timeout INTEGER DEFAULT 0", @@ -85,182 +61,18 @@ func (p *PostgresStore) Init() error { "ALTER TABLE sites ADD COLUMN IF NOT EXISTS ignore_tls BOOLEAN DEFAULT FALSE", "ALTER TABLE sites ADD COLUMN IF NOT EXISTS paused BOOLEAN DEFAULT FALSE", } - for _, m := range migrations { - p.db.Exec(m) - } - - return nil } -// ... [CRUD Methods are identical to Phase 4, keeping them concise here] ... -func (p *PostgresStore) GetSites() []models.Site { - rows, err := p.db.Query("SELECT id, COALESCE(name, url), url, COALESCE(type, 'http'), COALESCE(token, ''), interval, alert_id, check_ssl, threshold, max_retries, COALESCE(hostname, ''), COALESCE(port, 0), COALESCE(timeout, 0), COALESCE(method, 'GET'), COALESCE(description, ''), COALESCE(parent_id, 0), COALESCE(accepted_codes, '200-299'), COALESCE(dns_resolve_type, ''), COALESCE(dns_server, ''), COALESCE(ignore_tls, FALSE), COALESCE(paused, FALSE) FROM sites") - if err != nil { - return []models.Site{} - } - defer rows.Close() - var sites []models.Site - for rows.Next() { - var s models.Site - rows.Scan(&s.ID, &s.Name, &s.URL, &s.Type, &s.Token, &s.Interval, &s.AlertID, &s.CheckSSL, &s.ExpiryThreshold, &s.MaxRetries, - &s.Hostname, &s.Port, &s.Timeout, &s.Method, &s.Description, &s.ParentID, &s.AcceptedCodes, &s.DNSResolveType, &s.DNSServer, &s.IgnoreTLS, &s.Paused) - sites = append(sites, s) - } - return sites -} -func (p *PostgresStore) AddSite(site models.Site) { - token := "" - if site.Type == "push" { - token = generateToken() - } - p.db.Exec("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20)", - site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, - site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused) -} -func (p *PostgresStore) UpdateSite(site models.Site) { - var existingToken string - p.db.QueryRow("SELECT token FROM sites WHERE id=$1", site.ID).Scan(&existingToken) - if site.Type == "push" && existingToken == "" { - existingToken = generateToken() - } - p.db.Exec("UPDATE sites SET name=$1, url=$2, type=$3, token=$4, interval=$5, alert_id=$6, check_ssl=$7, threshold=$8, max_retries=$9, hostname=$10, port=$11, timeout=$12, method=$13, description=$14, parent_id=$15, accepted_codes=$16, dns_resolve_type=$17, dns_server=$18, ignore_tls=$19, paused=$20 WHERE id=$21", - site.Name, site.URL, site.Type, existingToken, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, - site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.ID) -} -func (p *PostgresStore) UpdateSitePaused(id int, paused bool) { - p.db.Exec("UPDATE sites SET paused=$1 WHERE id=$2", paused, id) -} -func (p *PostgresStore) DeleteSite(id int) { p.db.Exec("DELETE FROM sites WHERE id=$1", id) } -func (p *PostgresStore) GetAllAlerts() []models.AlertConfig { - rows, err := p.db.Query("SELECT id, name, type, settings FROM alerts") - if err != nil { - return []models.AlertConfig{} - } - defer rows.Close() - var alerts []models.AlertConfig - for rows.Next() { - var a models.AlertConfig - var settingsJSON string - rows.Scan(&a.ID, &a.Name, &a.Type, &settingsJSON) - json.Unmarshal([]byte(settingsJSON), &a.Settings) - alerts = append(alerts, a) - } - return alerts -} -func (p *PostgresStore) GetAlert(id int) (models.AlertConfig, bool) { - var a models.AlertConfig - var settingsJSON string - err := p.db.QueryRow("SELECT id, name, type, settings FROM alerts WHERE id = $1", id).Scan(&a.ID, &a.Name, &a.Type, &settingsJSON) - if err != nil { - return a, false - } - json.Unmarshal([]byte(settingsJSON), &a.Settings) - return a, true -} -func (p *PostgresStore) AddAlert(name, aType string, settings map[string]string) { - jsonBytes, _ := json.Marshal(settings) - p.db.Exec("INSERT INTO alerts (name, type, settings) VALUES ($1, $2, $3)", name, aType, string(jsonBytes)) -} -func (p *PostgresStore) UpdateAlert(id int, name, aType string, settings map[string]string) { - jsonBytes, _ := json.Marshal(settings) - p.db.Exec("UPDATE alerts SET name=$1, type=$2, settings=$3 WHERE id=$4", name, aType, string(jsonBytes), id) -} -func (p *PostgresStore) DeleteAlert(id int) { p.db.Exec("DELETE FROM alerts WHERE id=$1", id) } -func (p *PostgresStore) GetAllUsers() []models.User { - rows, err := p.db.Query("SELECT id, username, public_key, role FROM users") - if err != nil { - return []models.User{} - } - defer rows.Close() - var users []models.User - for rows.Next() { - var u models.User - rows.Scan(&u.ID, &u.Username, &u.PublicKey, &u.Role) - users = append(users, u) - } - return users -} -func (p *PostgresStore) AddUser(username, publicKey, role string) error { - _, err := p.db.Exec("INSERT INTO users (username, public_key, role) VALUES ($1, $2, $3)", username, publicKey, role) - return err -} -func (p *PostgresStore) UpdateUser(id int, username, publicKey, role string) error { - _, err := p.db.Exec("UPDATE users SET username=$1, public_key=$2, role=$3 WHERE id=$4", username, publicKey, role, id) - return err -} -func (p *PostgresStore) DeleteUser(id int) error { - _, err := p.db.Exec("DELETE FROM users WHERE id=$1", id) - return err -} - -func (p *PostgresStore) SaveCheck(siteID int, latencyNs int64, isUp bool) { - p.db.Exec("INSERT INTO check_history (site_id, latency_ns, is_up) VALUES ($1, $2, $3)", siteID, latencyNs, isUp) - p.db.Exec(`DELETE FROM check_history WHERE site_id = $1 AND id NOT IN ( - SELECT id FROM check_history WHERE site_id = $1 ORDER BY checked_at DESC LIMIT 1000 - )`, siteID) -} - -func (p *PostgresStore) LoadAllHistory(limit int) map[int][]models.CheckRecord { - result := make(map[int][]models.CheckRecord) - rows, err := p.db.Query(` - SELECT site_id, latency_ns, is_up FROM ( - SELECT site_id, latency_ns, is_up, - ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY checked_at DESC) AS rn - FROM check_history - ) sub WHERE rn <= $1`, limit) - if err != nil { - return result - } - defer rows.Close() - for rows.Next() { - var r models.CheckRecord - rows.Scan(&r.SiteID, &r.LatencyNs, &r.IsUp) - result[r.SiteID] = append(result[r.SiteID], r) - } - for id, records := range result { - for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 { - records[i], records[j] = records[j], records[i] - } - result[id] = records - } - return result -} - -func (p *PostgresStore) ExportData() models.Backup { - return models.Backup{ - Sites: p.GetSites(), - Alerts: p.GetAllAlerts(), - Users: p.GetAllUsers(), - } -} - -func (p *PostgresStore) ImportData(data models.Backup) error { - tx, err := p.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() +func (d *PostgresDialect) ResetSequenceOnEmpty(db *sql.DB, table string) {} +func (d *PostgresDialect) ImportWipe(tx *sql.Tx) { tx.Exec("TRUNCATE TABLE sites RESTART IDENTITY CASCADE") tx.Exec("TRUNCATE TABLE alerts RESTART IDENTITY CASCADE") tx.Exec("TRUNCATE TABLE users RESTART IDENTITY CASCADE") - - for _, u := range data.Users { - tx.Exec("INSERT INTO users (username, public_key, role) VALUES ($1, $2, $3)", u.Username, u.PublicKey, u.Role) - } - for _, a := range data.Alerts { - jsonBytes, _ := json.Marshal(a.Settings) - tx.Exec("INSERT INTO alerts (id, name, type, settings) VALUES ($1, $2, $3, $4)", a.ID, a.Name, a.Type, string(jsonBytes)) - } - for _, st := range data.Sites { - tx.Exec("INSERT INTO sites (id, name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)", - st.ID, st.Name, st.URL, st.Type, st.Token, st.Interval, st.AlertID, st.CheckSSL, st.ExpiryThreshold, st.MaxRetries, - st.Hostname, st.Port, st.Timeout, st.Method, st.Description, st.ParentID, st.AcceptedCodes, st.DNSResolveType, st.DNSServer, st.IgnoreTLS, st.Paused) - } - - tx.Exec("SELECT setval('sites_id_seq', (SELECT MAX(id) FROM sites))") - tx.Exec("SELECT setval('alerts_id_seq', (SELECT MAX(id) FROM alerts))") - tx.Exec("SELECT setval('users_id_seq', (SELECT MAX(id) FROM users))") - - return tx.Commit() +} + +func (d *PostgresDialect) ImportResetSequences(tx *sql.Tx) { + tx.Exec("SELECT setval('sites_id_seq', (SELECT COALESCE(MAX(id), 1) FROM sites))") + tx.Exec("SELECT setval('alerts_id_seq', (SELECT COALESCE(MAX(id), 1) FROM alerts))") + tx.Exec("SELECT setval('users_id_seq', (SELECT COALESCE(MAX(id), 1) FROM users))") } diff --git a/internal/store/sqlite.go b/internal/store/sqlite.go index 1b1d5fd..dbeb74d 100644 --- a/internal/store/sqlite.go +++ b/internal/store/sqlite.go @@ -1,77 +1,54 @@ package store import ( - "crypto/rand" "database/sql" - "encoding/hex" - "encoding/json" - "go-upkeep/internal/models" _ "github.com/mattn/go-sqlite3" ) -type SQLiteStore struct { - DBPath string - db *sql.DB +type SQLiteDialect struct{} + +func NewSQLiteStore(path string) (*SQLStore, error) { + return NewSQLStore("sqlite3", path, &SQLiteDialect{}) } -func (s *SQLiteStore) Init() error { - var err error - s.db, err = sql.Open("sqlite3", s.DBPath) - if err != nil { - return err - } +func (d *SQLiteDialect) DriverName() string { return "sqlite3" } +func (d *SQLiteDialect) BoolFalse() string { return "0" } - createTables := ` - CREATE TABLE IF NOT EXISTS alerts ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT, - type TEXT, - settings TEXT - ); - CREATE TABLE IF NOT EXISTS sites ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT DEFAULT 'New Monitor', - url TEXT, - type TEXT DEFAULT 'http', - token TEXT, - interval INTEGER, - alert_id INTEGER, - check_ssl BOOLEAN DEFAULT 0, - threshold INTEGER DEFAULT 7, - max_retries INTEGER DEFAULT 0, - hostname TEXT DEFAULT '', - port INTEGER DEFAULT 0, - timeout INTEGER DEFAULT 0, - method TEXT DEFAULT 'GET', - description TEXT DEFAULT '', - parent_id INTEGER DEFAULT 0, - accepted_codes TEXT DEFAULT '200-299', - dns_resolve_type TEXT DEFAULT '', - dns_server TEXT DEFAULT '', - ignore_tls BOOLEAN DEFAULT 0, - paused BOOLEAN DEFAULT 0 - ); - CREATE TABLE IF NOT EXISTS users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - username TEXT NOT NULL, - public_key TEXT NOT NULL, - role TEXT DEFAULT 'user' - ); - CREATE TABLE IF NOT EXISTS check_history ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - site_id INTEGER NOT NULL, - latency_ns INTEGER, - is_up BOOLEAN, - checked_at DATETIME DEFAULT CURRENT_TIMESTAMP - ); - CREATE INDEX IF NOT EXISTS idx_check_history_site ON check_history(site_id, checked_at DESC);` - _, err = s.db.Exec(createTables) - if err != nil { - return err +func (d *SQLiteDialect) CreateTablesSQL() []string { + return []string{ + `CREATE TABLE IF NOT EXISTS alerts ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT, type TEXT, settings TEXT + )`, + `CREATE TABLE IF NOT EXISTS sites ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT DEFAULT 'New Monitor', url TEXT, type TEXT DEFAULT 'http', + token TEXT, interval INTEGER, alert_id INTEGER, + check_ssl BOOLEAN DEFAULT 0, threshold INTEGER DEFAULT 7, + max_retries INTEGER DEFAULT 0, hostname TEXT DEFAULT '', + port INTEGER DEFAULT 0, timeout INTEGER DEFAULT 0, + method TEXT DEFAULT 'GET', description TEXT DEFAULT '', + parent_id INTEGER DEFAULT 0, accepted_codes TEXT DEFAULT '200-299', + dns_resolve_type TEXT DEFAULT '', dns_server TEXT DEFAULT '', + ignore_tls BOOLEAN DEFAULT 0, paused BOOLEAN DEFAULT 0 + )`, + `CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT NOT NULL, public_key TEXT NOT NULL, + role TEXT DEFAULT 'user' + )`, + `CREATE TABLE IF NOT EXISTS check_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + site_id INTEGER NOT NULL, latency_ns INTEGER, + is_up BOOLEAN, checked_at DATETIME DEFAULT CURRENT_TIMESTAMP + )`, + `CREATE INDEX IF NOT EXISTS idx_check_history_site ON check_history(site_id, checked_at DESC)`, } +} - migrations := []string{ +func (d *SQLiteDialect) MigrationsSQL() []string { + return []string{ "ALTER TABLE sites ADD COLUMN hostname TEXT DEFAULT ''", "ALTER TABLE sites ADD COLUMN port INTEGER DEFAULT 0", "ALTER TABLE sites ADD COLUMN timeout INTEGER DEFAULT 0", @@ -84,202 +61,23 @@ func (s *SQLiteStore) Init() error { "ALTER TABLE sites ADD COLUMN ignore_tls BOOLEAN DEFAULT 0", "ALTER TABLE sites ADD COLUMN paused BOOLEAN DEFAULT 0", } - for _, m := range migrations { - s.db.Exec(m) - } - - return nil } -func generateToken() string { - b := make([]byte, 16) - if _, err := rand.Read(b); err != nil { - panic("crypto/rand failed: " + err.Error()) - } - return hex.EncodeToString(b) -} - -func (s *SQLiteStore) GetSites() []models.Site { - rows, err := s.db.Query("SELECT id, COALESCE(name, url), url, COALESCE(type, 'http'), COALESCE(token, ''), interval, alert_id, check_ssl, threshold, max_retries, COALESCE(hostname, ''), COALESCE(port, 0), COALESCE(timeout, 0), COALESCE(method, 'GET'), COALESCE(description, ''), COALESCE(parent_id, 0), COALESCE(accepted_codes, '200-299'), COALESCE(dns_resolve_type, ''), COALESCE(dns_server, ''), COALESCE(ignore_tls, 0), COALESCE(paused, 0) FROM sites") - if err != nil { - return []models.Site{} - } - defer rows.Close() - var sites []models.Site - for rows.Next() { - var st models.Site - rows.Scan(&st.ID, &st.Name, &st.URL, &st.Type, &st.Token, &st.Interval, &st.AlertID, &st.CheckSSL, &st.ExpiryThreshold, &st.MaxRetries, &st.Hostname, &st.Port, &st.Timeout, &st.Method, &st.Description, &st.ParentID, &st.AcceptedCodes, &st.DNSResolveType, &st.DNSServer, &st.IgnoreTLS, &st.Paused) - sites = append(sites, st) - } - return sites -} -func (s *SQLiteStore) AddSite(site models.Site) { - token := "" - if site.Type == "push" { - token = generateToken() - } - s.db.Exec("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, - site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused) -} -func (s *SQLiteStore) UpdateSite(site models.Site) { - var existingToken string - s.db.QueryRow("SELECT token FROM sites WHERE id=?", site.ID).Scan(&existingToken) - if site.Type == "push" && existingToken == "" { - existingToken = generateToken() - } - s.db.Exec("UPDATE sites SET name=?, url=?, type=?, token=?, interval=?, alert_id=?, check_ssl=?, threshold=?, max_retries=?, hostname=?, port=?, timeout=?, method=?, description=?, parent_id=?, accepted_codes=?, dns_resolve_type=?, dns_server=?, ignore_tls=?, paused=? WHERE id=?", - site.Name, site.URL, site.Type, existingToken, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, - site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.ID) -} -func (s *SQLiteStore) UpdateSitePaused(id int, paused bool) { - s.db.Exec("UPDATE sites SET paused=? WHERE id=?", paused, id) -} -func (s *SQLiteStore) DeleteSite(id int) { - s.db.Exec("DELETE FROM sites WHERE id=?", id) +func (d *SQLiteDialect) ResetSequenceOnEmpty(db *sql.DB, table string) { var count int - s.db.QueryRow("SELECT COUNT(*) FROM sites").Scan(&count) + db.QueryRow("SELECT COUNT(*) FROM " + table).Scan(&count) if count == 0 { - s.db.Exec("DELETE FROM sqlite_sequence WHERE name='sites'") - } -} -func (s *SQLiteStore) GetAllAlerts() []models.AlertConfig { - rows, err := s.db.Query("SELECT id, name, type, settings FROM alerts") - if err != nil { - return []models.AlertConfig{} - } - defer rows.Close() - var alerts []models.AlertConfig - for rows.Next() { - var a models.AlertConfig - var settingsJSON string - rows.Scan(&a.ID, &a.Name, &a.Type, &settingsJSON) - json.Unmarshal([]byte(settingsJSON), &a.Settings) - alerts = append(alerts, a) - } - return alerts -} -func (s *SQLiteStore) GetAlert(id int) (models.AlertConfig, bool) { - var a models.AlertConfig - var settingsJSON string - err := s.db.QueryRow("SELECT id, name, type, settings FROM alerts WHERE id = ?", id).Scan(&a.ID, &a.Name, &a.Type, &settingsJSON) - if err != nil { - return a, false - } - json.Unmarshal([]byte(settingsJSON), &a.Settings) - return a, true -} -func (s *SQLiteStore) AddAlert(name, aType string, settings map[string]string) { - jsonBytes, _ := json.Marshal(settings) - s.db.Exec("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?)", name, aType, string(jsonBytes)) -} -func (s *SQLiteStore) UpdateAlert(id int, name, aType string, settings map[string]string) { - jsonBytes, _ := json.Marshal(settings) - s.db.Exec("UPDATE alerts SET name=?, type=?, settings=? WHERE id=?", name, aType, string(jsonBytes), id) -} -func (s *SQLiteStore) DeleteAlert(id int) { - s.db.Exec("DELETE FROM alerts WHERE id=?", id) - var count int - s.db.QueryRow("SELECT COUNT(*) FROM alerts").Scan(&count) - if count == 0 { - s.db.Exec("DELETE FROM sqlite_sequence WHERE name='alerts'") - } -} -func (s *SQLiteStore) GetAllUsers() []models.User { - rows, err := s.db.Query("SELECT id, username, public_key, role FROM users") - if err != nil { - return []models.User{} - } - defer rows.Close() - var users []models.User - for rows.Next() { - var u models.User - rows.Scan(&u.ID, &u.Username, &u.PublicKey, &u.Role) - users = append(users, u) - } - return users -} -func (s *SQLiteStore) AddUser(username, publicKey, role string) error { - _, err := s.db.Exec("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)", username, publicKey, role) - return err -} -func (s *SQLiteStore) UpdateUser(id int, username, publicKey, role string) error { - _, err := s.db.Exec("UPDATE users SET username=?, public_key=?, role=? WHERE id=?", username, publicKey, role, id) - return err -} -func (s *SQLiteStore) DeleteUser(id int) error { - _, err := s.db.Exec("DELETE FROM users WHERE id=?", id) - return err -} - -func (s *SQLiteStore) SaveCheck(siteID int, latencyNs int64, isUp bool) { - s.db.Exec("INSERT INTO check_history (site_id, latency_ns, is_up) VALUES (?, ?, ?)", siteID, latencyNs, isUp) - s.db.Exec(`DELETE FROM check_history WHERE site_id = ? AND id NOT IN ( - SELECT id FROM check_history WHERE site_id = ? ORDER BY checked_at DESC LIMIT 1000 - )`, siteID, siteID) -} - -func (s *SQLiteStore) LoadAllHistory(limit int) map[int][]models.CheckRecord { - result := make(map[int][]models.CheckRecord) - rows, err := s.db.Query(` - SELECT site_id, latency_ns, is_up FROM ( - SELECT site_id, latency_ns, is_up, - ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY checked_at DESC) AS rn - FROM check_history - ) WHERE rn <= ?`, limit) - if err != nil { - return result - } - defer rows.Close() - for rows.Next() { - var r models.CheckRecord - rows.Scan(&r.SiteID, &r.LatencyNs, &r.IsUp) - result[r.SiteID] = append(result[r.SiteID], r) - } - for id, records := range result { - for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 { - records[i], records[j] = records[j], records[i] - } - result[id] = records - } - return result -} - -func (s *SQLiteStore) ExportData() models.Backup { - return models.Backup{ - Sites: s.GetSites(), - Alerts: s.GetAllAlerts(), - Users: s.GetAllUsers(), + db.Exec("DELETE FROM sqlite_sequence WHERE name=?", table) } } -func (s *SQLiteStore) ImportData(data models.Backup) error { - tx, err := s.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - +func (d *SQLiteDialect) ImportWipe(tx *sql.Tx) { tx.Exec("DELETE FROM sites") tx.Exec("DELETE FROM sqlite_sequence WHERE name='sites'") tx.Exec("DELETE FROM alerts") tx.Exec("DELETE FROM sqlite_sequence WHERE name='alerts'") tx.Exec("DELETE FROM users") tx.Exec("DELETE FROM sqlite_sequence WHERE name='users'") - - // Insert New - for _, u := range data.Users { - tx.Exec("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)", u.Username, u.PublicKey, u.Role) - } - for _, a := range data.Alerts { - jsonBytes, _ := json.Marshal(a.Settings) - tx.Exec("INSERT INTO alerts (id, name, type, settings) VALUES (?, ?, ?, ?)", a.ID, a.Name, a.Type, string(jsonBytes)) - } - for _, st := range data.Sites { - tx.Exec("INSERT INTO sites (id, name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - st.ID, st.Name, st.URL, st.Type, st.Token, st.Interval, st.AlertID, st.CheckSSL, st.ExpiryThreshold, st.MaxRetries, - st.Hostname, st.Port, st.Timeout, st.Method, st.Description, st.ParentID, st.AcceptedCodes, st.DNSResolveType, st.DNSServer, st.IgnoreTLS, st.Paused) - } - - return tx.Commit() } + +func (d *SQLiteDialect) ImportResetSequences(tx *sql.Tx) {} diff --git a/internal/store/sqlstore.go b/internal/store/sqlstore.go new file mode 100644 index 0000000..a715a02 --- /dev/null +++ b/internal/store/sqlstore.go @@ -0,0 +1,243 @@ +package store + +import ( + "crypto/rand" + "database/sql" + "encoding/hex" + "encoding/json" + "fmt" + "go-upkeep/internal/models" +) + +type SQLStore struct { + db *sql.DB + dialect Dialect + dollar bool +} + +func NewSQLStore(driverName, dsn string, dialect Dialect) (*SQLStore, error) { + db, err := sql.Open(driverName, dsn) + if err != nil { + return nil, err + } + _, isDollar := dialect.(*PostgresDialect) + return &SQLStore{db: db, dialect: dialect, dollar: isDollar}, nil +} + +func (s *SQLStore) q(query string) string { + return rewritePlaceholders(query, s.dollar) +} + +func generateToken() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + panic("crypto/rand failed: " + err.Error()) + } + return hex.EncodeToString(b) +} + +func (s *SQLStore) Init() error { + for _, stmt := range s.dialect.CreateTablesSQL() { + if _, err := s.db.Exec(stmt); err != nil { + return err + } + } + for _, m := range s.dialect.MigrationsSQL() { + s.db.Exec(m) + } + return nil +} + +func (s *SQLStore) GetSites() []models.Site { + bf := s.dialect.BoolFalse() + query := fmt.Sprintf( + "SELECT id, COALESCE(name, url), url, COALESCE(type, 'http'), COALESCE(token, ''), interval, alert_id, check_ssl, threshold, max_retries, COALESCE(hostname, ''), COALESCE(port, 0), COALESCE(timeout, 0), COALESCE(method, 'GET'), COALESCE(description, ''), COALESCE(parent_id, 0), COALESCE(accepted_codes, '200-299'), COALESCE(dns_resolve_type, ''), COALESCE(dns_server, ''), COALESCE(ignore_tls, %s), COALESCE(paused, %s) FROM sites", + bf, bf, + ) + rows, err := s.db.Query(query) + if err != nil { + return []models.Site{} + } + defer rows.Close() + var sites []models.Site + for rows.Next() { + var st models.Site + rows.Scan(&st.ID, &st.Name, &st.URL, &st.Type, &st.Token, &st.Interval, &st.AlertID, + &st.CheckSSL, &st.ExpiryThreshold, &st.MaxRetries, &st.Hostname, &st.Port, &st.Timeout, + &st.Method, &st.Description, &st.ParentID, &st.AcceptedCodes, &st.DNSResolveType, + &st.DNSServer, &st.IgnoreTLS, &st.Paused) + sites = append(sites, st) + } + return sites +} + +func (s *SQLStore) AddSite(site models.Site) { + token := "" + if site.Type == "push" { + token = generateToken() + } + s.db.Exec(s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), + site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, + site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused) +} + +func (s *SQLStore) UpdateSite(site models.Site) { + var existingToken string + s.db.QueryRow(s.q("SELECT token FROM sites WHERE id=?"), site.ID).Scan(&existingToken) + if site.Type == "push" && existingToken == "" { + existingToken = generateToken() + } + s.db.Exec(s.q("UPDATE sites SET name=?, url=?, type=?, token=?, interval=?, alert_id=?, check_ssl=?, threshold=?, max_retries=?, hostname=?, port=?, timeout=?, method=?, description=?, parent_id=?, accepted_codes=?, dns_resolve_type=?, dns_server=?, ignore_tls=?, paused=? WHERE id=?"), + site.Name, site.URL, site.Type, existingToken, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, + site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.ID) +} + +func (s *SQLStore) UpdateSitePaused(id int, paused bool) { + s.db.Exec(s.q("UPDATE sites SET paused=? WHERE id=?"), paused, id) +} + +func (s *SQLStore) DeleteSite(id int) { + s.db.Exec(s.q("DELETE FROM sites WHERE id=?"), id) + s.dialect.ResetSequenceOnEmpty(s.db, "sites") +} + +func (s *SQLStore) GetAllAlerts() []models.AlertConfig { + rows, err := s.db.Query("SELECT id, name, type, settings FROM alerts") + if err != nil { + return []models.AlertConfig{} + } + defer rows.Close() + var alerts []models.AlertConfig + for rows.Next() { + var a models.AlertConfig + var settingsJSON string + rows.Scan(&a.ID, &a.Name, &a.Type, &settingsJSON) + json.Unmarshal([]byte(settingsJSON), &a.Settings) + alerts = append(alerts, a) + } + return alerts +} + +func (s *SQLStore) GetAlert(id int) (models.AlertConfig, bool) { + var a models.AlertConfig + var settingsJSON string + err := s.db.QueryRow(s.q("SELECT id, name, type, settings FROM alerts WHERE id = ?"), id).Scan(&a.ID, &a.Name, &a.Type, &settingsJSON) + if err != nil { + return a, false + } + json.Unmarshal([]byte(settingsJSON), &a.Settings) + return a, true +} + +func (s *SQLStore) AddAlert(name, aType string, settings map[string]string) { + jsonBytes, _ := json.Marshal(settings) + s.db.Exec(s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?)"), name, aType, string(jsonBytes)) +} + +func (s *SQLStore) UpdateAlert(id int, name, aType string, settings map[string]string) { + jsonBytes, _ := json.Marshal(settings) + s.db.Exec(s.q("UPDATE alerts SET name=?, type=?, settings=? WHERE id=?"), name, aType, string(jsonBytes), id) +} + +func (s *SQLStore) DeleteAlert(id int) { + s.db.Exec(s.q("DELETE FROM alerts WHERE id=?"), id) + s.dialect.ResetSequenceOnEmpty(s.db, "alerts") +} + +func (s *SQLStore) GetAllUsers() []models.User { + rows, err := s.db.Query("SELECT id, username, public_key, role FROM users") + if err != nil { + return []models.User{} + } + defer rows.Close() + var users []models.User + for rows.Next() { + var u models.User + rows.Scan(&u.ID, &u.Username, &u.PublicKey, &u.Role) + users = append(users, u) + } + return users +} + +func (s *SQLStore) AddUser(username, publicKey, role string) error { + _, err := s.db.Exec(s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), username, publicKey, role) + return err +} + +func (s *SQLStore) UpdateUser(id int, username, publicKey, role string) error { + _, err := s.db.Exec(s.q("UPDATE users SET username=?, public_key=?, role=? WHERE id=?"), username, publicKey, role, id) + return err +} + +func (s *SQLStore) DeleteUser(id int) error { + _, err := s.db.Exec(s.q("DELETE FROM users WHERE id=?"), id) + return err +} + +func (s *SQLStore) SaveCheck(siteID int, latencyNs int64, isUp bool) { + s.db.Exec(s.q("INSERT INTO check_history (site_id, latency_ns, is_up) VALUES (?, ?, ?)"), siteID, latencyNs, isUp) + s.db.Exec(s.q(`DELETE FROM check_history WHERE site_id = ? AND id NOT IN ( + SELECT id FROM check_history WHERE site_id = ? ORDER BY checked_at DESC LIMIT 1000 + )`), siteID, siteID) +} + +func (s *SQLStore) LoadAllHistory(limit int) map[int][]models.CheckRecord { + result := make(map[int][]models.CheckRecord) + rows, err := s.db.Query(s.q(` + SELECT site_id, latency_ns, is_up FROM ( + SELECT site_id, latency_ns, is_up, + ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY checked_at DESC) AS rn + FROM check_history + ) sub WHERE rn <= ?`), limit) + if err != nil { + return result + } + defer rows.Close() + for rows.Next() { + var r models.CheckRecord + rows.Scan(&r.SiteID, &r.LatencyNs, &r.IsUp) + result[r.SiteID] = append(result[r.SiteID], r) + } + for id, records := range result { + for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 { + records[i], records[j] = records[j], records[i] + } + result[id] = records + } + return result +} + +func (s *SQLStore) ExportData() models.Backup { + return models.Backup{ + Sites: s.GetSites(), + Alerts: s.GetAllAlerts(), + Users: s.GetAllUsers(), + } +} + +func (s *SQLStore) ImportData(data models.Backup) error { + tx, err := s.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + s.dialect.ImportWipe(tx) + + for _, u := range data.Users { + tx.Exec(s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), u.Username, u.PublicKey, u.Role) + } + for _, a := range data.Alerts { + jsonBytes, _ := json.Marshal(a.Settings) + tx.Exec(s.q("INSERT INTO alerts (id, name, type, settings) VALUES (?, ?, ?, ?)"), a.ID, a.Name, a.Type, string(jsonBytes)) + } + for _, st := range data.Sites { + tx.Exec(s.q("INSERT INTO sites (id, name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), + st.ID, st.Name, st.URL, st.Type, st.Token, st.Interval, st.AlertID, st.CheckSSL, st.ExpiryThreshold, st.MaxRetries, + st.Hostname, st.Port, st.Timeout, st.Method, st.Description, st.ParentID, st.AcceptedCodes, st.DNSResolveType, st.DNSServer, st.IgnoreTLS, st.Paused) + } + + s.dialect.ImportResetSequences(tx) + + return tx.Commit() +} From d4f4012c8a812d4399835ffead72463c1f66f3f5 Mon Sep 17 00:00:00 2001 From: Tyler Koenig Date: Fri, 15 May 2026 00:37:20 -0400 Subject: [PATCH 3/9] refactor(store): add error returns to all Store interface methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Every Store method now returns an error. Callers handle errors gracefully — TUI logs to event log, server returns HTTP 500, monitor engine logs and retries. All rows.Scan() errors are now checked in sqlstore.go instead of silently appending corrupt data. - GetSites, GetAllAlerts, GetAllUsers return ([]T, error) - GetAlert returns (AlertConfig, error) instead of (AlertConfig, bool) - AddSite, UpdateSite, DeleteSite, etc. all return error - SaveCheck, LoadAllHistory, ExportData return error - ~25 caller sites updated across tui, server, monitor, main --- cmd/goupkeep/main.go | 10 ++- internal/monitor/history.go | 8 +- internal/monitor/monitor.go | 11 ++- internal/server/server.go | 7 +- internal/store/sqlstore.go | 146 ++++++++++++++++++++++++------------ internal/store/store.go | 28 +++---- internal/tui/tab_alerts.go | 9 ++- internal/tui/tab_sites.go | 22 ++++-- internal/tui/tab_users.go | 9 ++- internal/tui/tui.go | 28 ++++--- 10 files changed, 185 insertions(+), 93 deletions(-) diff --git a/cmd/goupkeep/main.go b/cmd/goupkeep/main.go index cfd5c2f..77cdadd 100644 --- a/cmd/goupkeep/main.go +++ b/cmd/goupkeep/main.go @@ -174,7 +174,8 @@ func startSSHServer(port int) { } func seedDemoData(s store.Store) { - if existing := s.GetSites(); len(existing) > 0 { + existing, _ := s.GetSites() + if len(existing) > 0 { return } fmt.Println("Seeding demo data...") @@ -187,7 +188,7 @@ func seedDemoData(s store.Store) { "from": "oncall@example.com", "to": "team@example.com", }) - alerts := s.GetAllAlerts() + alerts, _ := s.GetAllAlerts() alertID := 0 if len(alerts) > 0 { alertID = alerts[0].ID @@ -206,7 +207,10 @@ func seedDemoData(s store.Store) { } func isKeyAllowed(incomingKey ssh.PublicKey) bool { - users := store.Get().GetAllUsers() + users, err := store.Get().GetAllUsers() + if err != nil { + return false + } for _, u := range users { allowedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(u.PublicKey)) if err != nil { diff --git a/internal/monitor/history.go b/internal/monitor/history.go index 8642255..43ef0ee 100644 --- a/internal/monitor/history.go +++ b/internal/monitor/history.go @@ -25,7 +25,11 @@ func InitHistoryFromStore() { if s == nil { return } - all := s.LoadAllHistory(maxHistoryLen) + all, err := s.LoadAllHistory(maxHistoryLen) + if err != nil { + AddLog("Failed to load check history: " + err.Error()) + return + } historyMu.Lock() defer historyMu.Unlock() for siteID, records := range all { @@ -71,7 +75,7 @@ func RecordCheck(siteID int, latency time.Duration, isUp bool) { } if s := store.Get(); s != nil { - go s.SaveCheck(siteID, latency.Nanoseconds(), isUp) + go func() { _ = s.SaveCheck(siteID, latency.Nanoseconds(), isUp) }() } } diff --git a/internal/monitor/monitor.go b/internal/monitor/monitor.go index 3f2a869..27925ed 100644 --- a/internal/monitor/monitor.go +++ b/internal/monitor/monitor.go @@ -128,7 +128,12 @@ func StartEngine() { continue } - sites := s_instance.GetSites() + sites, err := s_instance.GetSites() + if err != nil { + AddLog(fmt.Sprintf("Failed to load sites: %v", err)) + time.Sleep(5 * time.Second) + continue + } for _, s := range sites { Mutex.RLock() _, exists := LiveState[s.ID] @@ -406,8 +411,8 @@ func triggerAlert(alertID int, title, message string) { if s_instance == nil { return } - cfg, ok := s_instance.GetAlert(alertID) - if !ok { + cfg, err := s_instance.GetAlert(alertID) + if err != nil { return } provider := alert.GetProvider(cfg) diff --git a/internal/server/server.go b/internal/server/server.go index 3cf6228..bb97a88 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -185,7 +185,12 @@ func Start(cfg ServerConfig) { http.Error(w, "Unauthorized: UPKEEP_CLUSTER_SECRET required", 401) return } - data := store.Get().ExportData() + data, err := store.Get().ExportData() + if err != nil { + log.Printf("Export failed: %v", err) + http.Error(w, "Export failed", 500) + return + } json.NewEncoder(w).Encode(data) }) diff --git a/internal/store/sqlstore.go b/internal/store/sqlstore.go index a715a02..3142399 100644 --- a/internal/store/sqlstore.go +++ b/internal/store/sqlstore.go @@ -48,7 +48,7 @@ func (s *SQLStore) Init() error { return nil } -func (s *SQLStore) GetSites() []models.Site { +func (s *SQLStore) GetSites() ([]models.Site, error) { bf := s.dialect.BoolFalse() query := fmt.Sprintf( "SELECT id, COALESCE(name, url), url, COALESCE(type, 'http'), COALESCE(token, ''), interval, alert_id, check_ssl, threshold, max_retries, COALESCE(hostname, ''), COALESCE(port, 0), COALESCE(timeout, 0), COALESCE(method, 'GET'), COALESCE(description, ''), COALESCE(parent_id, 0), COALESCE(accepted_codes, '200-299'), COALESCE(dns_resolve_type, ''), COALESCE(dns_server, ''), COALESCE(ignore_tls, %s), COALESCE(paused, %s) FROM sites", @@ -56,107 +56,132 @@ func (s *SQLStore) GetSites() []models.Site { ) rows, err := s.db.Query(query) if err != nil { - return []models.Site{} + return nil, err } defer rows.Close() var sites []models.Site for rows.Next() { var st models.Site - rows.Scan(&st.ID, &st.Name, &st.URL, &st.Type, &st.Token, &st.Interval, &st.AlertID, + if err := rows.Scan(&st.ID, &st.Name, &st.URL, &st.Type, &st.Token, &st.Interval, &st.AlertID, &st.CheckSSL, &st.ExpiryThreshold, &st.MaxRetries, &st.Hostname, &st.Port, &st.Timeout, &st.Method, &st.Description, &st.ParentID, &st.AcceptedCodes, &st.DNSResolveType, - &st.DNSServer, &st.IgnoreTLS, &st.Paused) + &st.DNSServer, &st.IgnoreTLS, &st.Paused); err != nil { + return sites, err + } sites = append(sites, st) } - return sites + return sites, rows.Err() } -func (s *SQLStore) AddSite(site models.Site) { +func (s *SQLStore) AddSite(site models.Site) error { token := "" if site.Type == "push" { token = generateToken() } - s.db.Exec(s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), + _, err := s.db.Exec(s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused) + return err } -func (s *SQLStore) UpdateSite(site models.Site) { +func (s *SQLStore) UpdateSite(site models.Site) error { var existingToken string s.db.QueryRow(s.q("SELECT token FROM sites WHERE id=?"), site.ID).Scan(&existingToken) if site.Type == "push" && existingToken == "" { existingToken = generateToken() } - s.db.Exec(s.q("UPDATE sites SET name=?, url=?, type=?, token=?, interval=?, alert_id=?, check_ssl=?, threshold=?, max_retries=?, hostname=?, port=?, timeout=?, method=?, description=?, parent_id=?, accepted_codes=?, dns_resolve_type=?, dns_server=?, ignore_tls=?, paused=? WHERE id=?"), + _, err := s.db.Exec(s.q("UPDATE sites SET name=?, url=?, type=?, token=?, interval=?, alert_id=?, check_ssl=?, threshold=?, max_retries=?, hostname=?, port=?, timeout=?, method=?, description=?, parent_id=?, accepted_codes=?, dns_resolve_type=?, dns_server=?, ignore_tls=?, paused=? WHERE id=?"), site.Name, site.URL, site.Type, existingToken, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.ID) + return err } -func (s *SQLStore) UpdateSitePaused(id int, paused bool) { - s.db.Exec(s.q("UPDATE sites SET paused=? WHERE id=?"), paused, id) +func (s *SQLStore) UpdateSitePaused(id int, paused bool) error { + _, err := s.db.Exec(s.q("UPDATE sites SET paused=? WHERE id=?"), paused, id) + return err } -func (s *SQLStore) DeleteSite(id int) { - s.db.Exec(s.q("DELETE FROM sites WHERE id=?"), id) +func (s *SQLStore) DeleteSite(id int) error { + _, err := s.db.Exec(s.q("DELETE FROM sites WHERE id=?"), id) + if err != nil { + return err + } s.dialect.ResetSequenceOnEmpty(s.db, "sites") + return nil } -func (s *SQLStore) GetAllAlerts() []models.AlertConfig { +func (s *SQLStore) GetAllAlerts() ([]models.AlertConfig, error) { rows, err := s.db.Query("SELECT id, name, type, settings FROM alerts") if err != nil { - return []models.AlertConfig{} + return nil, err } defer rows.Close() var alerts []models.AlertConfig for rows.Next() { var a models.AlertConfig var settingsJSON string - rows.Scan(&a.ID, &a.Name, &a.Type, &settingsJSON) + if err := rows.Scan(&a.ID, &a.Name, &a.Type, &settingsJSON); err != nil { + return alerts, err + } json.Unmarshal([]byte(settingsJSON), &a.Settings) alerts = append(alerts, a) } - return alerts + return alerts, rows.Err() } -func (s *SQLStore) GetAlert(id int) (models.AlertConfig, bool) { +func (s *SQLStore) GetAlert(id int) (models.AlertConfig, error) { var a models.AlertConfig var settingsJSON string err := s.db.QueryRow(s.q("SELECT id, name, type, settings FROM alerts WHERE id = ?"), id).Scan(&a.ID, &a.Name, &a.Type, &settingsJSON) if err != nil { - return a, false + return a, err } json.Unmarshal([]byte(settingsJSON), &a.Settings) - return a, true + return a, nil } -func (s *SQLStore) AddAlert(name, aType string, settings map[string]string) { - jsonBytes, _ := json.Marshal(settings) - s.db.Exec(s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?)"), name, aType, string(jsonBytes)) +func (s *SQLStore) AddAlert(name, aType string, settings map[string]string) error { + jsonBytes, err := json.Marshal(settings) + if err != nil { + return err + } + _, err = s.db.Exec(s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?)"), name, aType, string(jsonBytes)) + return err } -func (s *SQLStore) UpdateAlert(id int, name, aType string, settings map[string]string) { - jsonBytes, _ := json.Marshal(settings) - s.db.Exec(s.q("UPDATE alerts SET name=?, type=?, settings=? WHERE id=?"), name, aType, string(jsonBytes), id) +func (s *SQLStore) UpdateAlert(id int, name, aType string, settings map[string]string) error { + jsonBytes, err := json.Marshal(settings) + if err != nil { + return err + } + _, err = s.db.Exec(s.q("UPDATE alerts SET name=?, type=?, settings=? WHERE id=?"), name, aType, string(jsonBytes), id) + return err } -func (s *SQLStore) DeleteAlert(id int) { - s.db.Exec(s.q("DELETE FROM alerts WHERE id=?"), id) +func (s *SQLStore) DeleteAlert(id int) error { + _, err := s.db.Exec(s.q("DELETE FROM alerts WHERE id=?"), id) + if err != nil { + return err + } s.dialect.ResetSequenceOnEmpty(s.db, "alerts") + return nil } -func (s *SQLStore) GetAllUsers() []models.User { +func (s *SQLStore) GetAllUsers() ([]models.User, error) { rows, err := s.db.Query("SELECT id, username, public_key, role FROM users") if err != nil { - return []models.User{} + return nil, err } defer rows.Close() var users []models.User for rows.Next() { var u models.User - rows.Scan(&u.ID, &u.Username, &u.PublicKey, &u.Role) + if err := rows.Scan(&u.ID, &u.Username, &u.PublicKey, &u.Role); err != nil { + return users, err + } users = append(users, u) } - return users + return users, rows.Err() } func (s *SQLStore) AddUser(username, publicKey, role string) error { @@ -174,14 +199,18 @@ func (s *SQLStore) DeleteUser(id int) error { return err } -func (s *SQLStore) SaveCheck(siteID int, latencyNs int64, isUp bool) { - s.db.Exec(s.q("INSERT INTO check_history (site_id, latency_ns, is_up) VALUES (?, ?, ?)"), siteID, latencyNs, isUp) - s.db.Exec(s.q(`DELETE FROM check_history WHERE site_id = ? AND id NOT IN ( +func (s *SQLStore) SaveCheck(siteID int, latencyNs int64, isUp bool) error { + _, err := s.db.Exec(s.q("INSERT INTO check_history (site_id, latency_ns, is_up) VALUES (?, ?, ?)"), siteID, latencyNs, isUp) + if err != nil { + return err + } + _, err = s.db.Exec(s.q(`DELETE FROM check_history WHERE site_id = ? AND id NOT IN ( SELECT id FROM check_history WHERE site_id = ? ORDER BY checked_at DESC LIMIT 1000 )`), siteID, siteID) + return err } -func (s *SQLStore) LoadAllHistory(limit int) map[int][]models.CheckRecord { +func (s *SQLStore) LoadAllHistory(limit int) (map[int][]models.CheckRecord, error) { result := make(map[int][]models.CheckRecord) rows, err := s.db.Query(s.q(` SELECT site_id, latency_ns, is_up FROM ( @@ -190,12 +219,14 @@ func (s *SQLStore) LoadAllHistory(limit int) map[int][]models.CheckRecord { FROM check_history ) sub WHERE rn <= ?`), limit) if err != nil { - return result + return result, err } defer rows.Close() for rows.Next() { var r models.CheckRecord - rows.Scan(&r.SiteID, &r.LatencyNs, &r.IsUp) + if err := rows.Scan(&r.SiteID, &r.LatencyNs, &r.IsUp); err != nil { + return result, err + } result[r.SiteID] = append(result[r.SiteID], r) } for id, records := range result { @@ -204,15 +235,23 @@ func (s *SQLStore) LoadAllHistory(limit int) map[int][]models.CheckRecord { } result[id] = records } - return result + return result, rows.Err() } -func (s *SQLStore) ExportData() models.Backup { - return models.Backup{ - Sites: s.GetSites(), - Alerts: s.GetAllAlerts(), - Users: s.GetAllUsers(), +func (s *SQLStore) ExportData() (models.Backup, error) { + sites, err := s.GetSites() + if err != nil { + return models.Backup{}, err } + alerts, err := s.GetAllAlerts() + if err != nil { + return models.Backup{}, err + } + users, err := s.GetAllUsers() + if err != nil { + return models.Backup{}, err + } + return models.Backup{Sites: sites, Alerts: alerts, Users: users}, nil } func (s *SQLStore) ImportData(data models.Backup) error { @@ -225,16 +264,25 @@ func (s *SQLStore) ImportData(data models.Backup) error { s.dialect.ImportWipe(tx) for _, u := range data.Users { - tx.Exec(s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), u.Username, u.PublicKey, u.Role) + if _, err := tx.Exec(s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), u.Username, u.PublicKey, u.Role); err != nil { + return err + } } for _, a := range data.Alerts { - jsonBytes, _ := json.Marshal(a.Settings) - tx.Exec(s.q("INSERT INTO alerts (id, name, type, settings) VALUES (?, ?, ?, ?)"), a.ID, a.Name, a.Type, string(jsonBytes)) + jsonBytes, err := json.Marshal(a.Settings) + if err != nil { + return err + } + if _, err := tx.Exec(s.q("INSERT INTO alerts (id, name, type, settings) VALUES (?, ?, ?, ?)"), a.ID, a.Name, a.Type, string(jsonBytes)); err != nil { + return err + } } for _, st := range data.Sites { - tx.Exec(s.q("INSERT INTO sites (id, name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), + if _, err := tx.Exec(s.q("INSERT INTO sites (id, name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), st.ID, st.Name, st.URL, st.Type, st.Token, st.Interval, st.AlertID, st.CheckSSL, st.ExpiryThreshold, st.MaxRetries, - st.Hostname, st.Port, st.Timeout, st.Method, st.Description, st.ParentID, st.AcceptedCodes, st.DNSResolveType, st.DNSServer, st.IgnoreTLS, st.Paused) + st.Hostname, st.Port, st.Timeout, st.Method, st.Description, st.ParentID, st.AcceptedCodes, st.DNSResolveType, st.DNSServer, st.IgnoreTLS, st.Paused); err != nil { + return err + } } s.dialect.ImportResetSequences(tx) diff --git a/internal/store/store.go b/internal/store/store.go index af3cd71..d119597 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -8,31 +8,31 @@ type Store interface { Init() error // Sites - GetSites() []models.Site - AddSite(site models.Site) - UpdateSite(site models.Site) - UpdateSitePaused(id int, paused bool) - DeleteSite(id int) + GetSites() ([]models.Site, error) + AddSite(site models.Site) error + UpdateSite(site models.Site) error + UpdateSitePaused(id int, paused bool) error + DeleteSite(id int) error // Alerts - GetAllAlerts() []models.AlertConfig - GetAlert(id int) (models.AlertConfig, bool) - AddAlert(name, aType string, settings map[string]string) - UpdateAlert(id int, name, aType string, settings map[string]string) - DeleteAlert(id int) + GetAllAlerts() ([]models.AlertConfig, error) + GetAlert(id int) (models.AlertConfig, error) + AddAlert(name, aType string, settings map[string]string) error + UpdateAlert(id int, name, aType string, settings map[string]string) error + DeleteAlert(id int) error // Users - GetAllUsers() []models.User + GetAllUsers() ([]models.User, error) AddUser(username, publicKey, role string) error UpdateUser(id int, username, publicKey, role string) error DeleteUser(id int) error // History - SaveCheck(siteID int, latencyNs int64, isUp bool) - LoadAllHistory(limit int) map[int][]models.CheckRecord + SaveCheck(siteID int, latencyNs int64, isUp bool) error + LoadAllHistory(limit int) (map[int][]models.CheckRecord, error) // Backup & Restore - ExportData() models.Backup + ExportData() (models.Backup, error) ImportData(data models.Backup) error } diff --git a/internal/tui/tab_alerts.go b/internal/tui/tab_alerts.go index 8a0447b..72bd19d 100644 --- a/internal/tui/tab_alerts.go +++ b/internal/tui/tab_alerts.go @@ -2,6 +2,7 @@ package tui import ( "fmt" + "go-upkeep/internal/monitor" "go-upkeep/internal/store" tea "github.com/charmbracelet/bubbletea" @@ -277,9 +278,13 @@ func (m *Model) submitAlertForm() { } if m.editID > 0 { - store.Get().UpdateAlert(m.editID, d.Name, d.AlertType, settings) + if err := store.Get().UpdateAlert(m.editID, d.Name, d.AlertType, settings); err != nil { + monitor.AddLog("Update alert failed: " + err.Error()) + } } else { - store.Get().AddAlert(d.Name, d.AlertType, settings) + if err := store.Get().AddAlert(d.Name, d.AlertType, settings); err != nil { + monitor.AddLog("Add alert failed: " + err.Error()) + } } m.state = stateDashboard } diff --git a/internal/tui/tab_sites.go b/internal/tui/tab_sites.go index 2656613..f1a2aa0 100644 --- a/internal/tui/tab_sites.go +++ b/internal/tui/tab_sites.go @@ -361,12 +361,14 @@ func (m *Model) initSiteHuhForm() tea.Cmd { } alertOpts := []huh.Option[string]{huh.NewOption("None", "0")} - if store.Get() != nil { - for _, a := range store.Get().GetAllAlerts() { - alertOpts = append(alertOpts, huh.NewOption( - fmt.Sprintf("%s (%s)", a.Name, a.Type), - strconv.Itoa(a.ID), - )) + if s := store.Get(); s != nil { + if alerts, err := s.GetAllAlerts(); err == nil { + for _, a := range alerts { + alertOpts = append(alertOpts, huh.NewOption( + fmt.Sprintf("%s (%s)", a.Name, a.Type), + strconv.Itoa(a.ID), + )) + } } } @@ -558,10 +560,14 @@ func (m *Model) submitSiteForm() { } if m.editID > 0 { - store.Get().UpdateSite(site) + if err := store.Get().UpdateSite(site); err != nil { + monitor.AddLog("Update site failed: " + err.Error()) + } monitor.UpdateSiteConfig(site) } else { - store.Get().AddSite(site) + if err := store.Get().AddSite(site); err != nil { + monitor.AddLog("Add site failed: " + err.Error()) + } } m.state = stateDashboard } diff --git a/internal/tui/tab_users.go b/internal/tui/tab_users.go index 4858b7c..77d4182 100644 --- a/internal/tui/tab_users.go +++ b/internal/tui/tab_users.go @@ -2,6 +2,7 @@ package tui import ( "fmt" + "go-upkeep/internal/monitor" "go-upkeep/internal/store" tea "github.com/charmbracelet/bubbletea" @@ -145,9 +146,13 @@ func (m *Model) initUserHuhForm() tea.Cmd { func (m *Model) submitUserForm() { d := m.userFormData if m.editID > 0 { - store.Get().UpdateUser(m.editID, d.Username, d.PublicKey, d.Role) + if err := store.Get().UpdateUser(m.editID, d.Username, d.PublicKey, d.Role); err != nil { + monitor.AddLog("Update user failed: " + err.Error()) + } } else { - store.Get().AddUser(d.Username, d.PublicKey, d.Role) + if err := store.Get().AddUser(d.Username, d.PublicKey, d.Role); err != nil { + monitor.AddLog("Add user failed: " + err.Error()) + } } m.state = stateUsers } diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 4972a3a..533c993 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -107,17 +107,23 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if keyMsg, ok := msg.(tea.KeyMsg); ok { switch keyMsg.String() { case "y", "Y": - if store.Get() != nil { + if s := store.Get(); s != nil { switch m.deleteTab { case 0: - store.Get().DeleteSite(m.deleteID) + if err := s.DeleteSite(m.deleteID); err != nil { + monitor.AddLog("Delete site failed: " + err.Error()) + } monitor.RemoveSite(m.deleteID) m.adjustCursor(len(m.sites) - 1) case 1: - store.Get().DeleteAlert(m.deleteID) + if err := s.DeleteAlert(m.deleteID); err != nil { + monitor.AddLog("Delete alert failed: " + err.Error()) + } m.adjustCursor(len(m.alerts) - 1) case 3: - store.Get().DeleteUser(m.deleteID) + if err := s.DeleteUser(m.deleteID); err != nil { + monitor.AddLog("Delete user failed: " + err.Error()) + } m.adjustCursor(len(m.users) - 1) } } @@ -313,8 +319,8 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { site := m.sites[m.cursor] monitor.ToggleSitePause(site.ID) site.Paused = !site.Paused - if store.Get() != nil { - store.Get().UpdateSitePaused(site.ID, site.Paused) + if s := store.Get(); s != nil { + _ = s.UpdateSitePaused(site.ID, site.Paused) } m.refreshData() } @@ -464,10 +470,14 @@ func (m *Model) refreshData() { } ordered = append(ordered, ungrouped...) m.sites = ordered - if store.Get() != nil { - m.alerts = store.Get().GetAllAlerts() + if s := store.Get(); s != nil { + if alerts, err := s.GetAllAlerts(); err == nil { + m.alerts = alerts + } if m.isAdmin { - m.users = store.Get().GetAllUsers() + if users, err := s.GetAllUsers(); err == nil { + m.users = users + } } } m.logViewport.SetContent(strings.Join(monitor.GetLogs(), "\n")) From a6bb9a7affdef583bb8075e63071c0f5c932a096 Mon Sep 17 00:00:00 2001 From: Tyler Koenig Date: Fri, 15 May 2026 00:45:07 -0400 Subject: [PATCH 4/9] refactor(core): remove store global singleton, thread store explicitly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove store.Get()/SetGlobal()/Current. Store is now passed explicitly to all consumers via constructor parameters and function arguments. - TUI Model holds store field, set via InitialModel(isAdmin, store) - monitor.StartEngine(s) and InitHistoryFromStore(s) accept store - server.Start(cfg, s) closes over store in HTTP handlers - main.go threads store to SSH server, TUI, monitor, server - isKeyAllowed receives store as parameter No more hidden dependency on package-level mutable state in store pkg. Monitor package still uses package-level state (LiveState, etc.) — will be encapsulated into Engine struct in Phase 7. --- cmd/goupkeep/main.go | 22 +++++++------- internal/monitor/history.go | 10 ++----- internal/monitor/monitor.go | 18 +++++------ internal/server/server.go | 8 ++--- internal/store/store.go | 10 ------- internal/tui/tab_alerts.go | 5 ++-- internal/tui/tab_sites.go | 19 +++++------- internal/tui/tab_users.go | 5 ++-- internal/tui/tui.go | 59 ++++++++++++++++--------------------- 9 files changed, 62 insertions(+), 94 deletions(-) diff --git a/cmd/goupkeep/main.go b/cmd/goupkeep/main.go index 77cdadd..921a3c7 100644 --- a/cmd/goupkeep/main.go +++ b/cmd/goupkeep/main.go @@ -97,8 +97,6 @@ func main() { fmt.Printf("Database init error: %v\n", err) os.Exit(1) } - store.SetGlobal(s) - if *demo { seedDemoData(s) } @@ -117,15 +115,15 @@ func main() { fmt.Printf("Imported %d monitors and %d alerts from Uptime Kuma v%s\n", len(backup.Sites), len(backup.Alerts), kb.Version) } - monitor.InitHistoryFromStore() - monitor.StartEngine() + monitor.InitHistoryFromStore(s) + monitor.StartEngine(s) server.Start(server.ServerConfig{ Port: httpPort, EnableStatus: enableStatus, Title: statusTitle, ClusterKey: clusterKey, - }) + }, s) cluster.Start(cluster.Config{ Mode: clusterMode, @@ -133,10 +131,10 @@ func main() { SharedKey: clusterKey, }) - startSSHServer(*port) + startSSHServer(*port, s) if isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd()) { - p := tea.NewProgram(tui.InitialModel(true), tea.WithAltScreen(), tea.WithMouseCellMotion()) + p := tea.NewProgram(tui.InitialModel(true, s), tea.WithAltScreen(), tea.WithMouseCellMotion()) if _, err := p.Run(); err != nil { fmt.Printf("Error: %v\n", err) } @@ -149,16 +147,16 @@ func main() { } } -func startSSHServer(port int) { +func startSSHServer(port int, db store.Store) { s, err := wish.NewServer( wish.WithAddress(fmt.Sprintf(":%d", port)), wish.WithHostKeyPath(".ssh/id_ed25519"), wish.WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { - return isKeyAllowed(key) + return isKeyAllowed(db, key) }), wish.WithMiddleware( bm.Middleware(func(s ssh.Session) (tea.Model, []tea.ProgramOption) { - return tui.InitialModel(false), []tea.ProgramOption{tea.WithAltScreen(), tea.WithMouseCellMotion()} + return tui.InitialModel(false, db), []tea.ProgramOption{tea.WithAltScreen(), tea.WithMouseCellMotion()} }), ), ) @@ -206,8 +204,8 @@ func seedDemoData(s store.Store) { s.AddSite(models.Site{Name: "SSH Server", Type: "port", Interval: 60, AlertID: alertID, Hostname: "10.0.0.1", Port: 22, Timeout: 5, ExpiryThreshold: 7}) } -func isKeyAllowed(incomingKey ssh.PublicKey) bool { - users, err := store.Get().GetAllUsers() +func isKeyAllowed(db store.Store, incomingKey ssh.PublicKey) bool { + users, err := db.GetAllUsers() if err != nil { return false } diff --git a/internal/monitor/history.go b/internal/monitor/history.go index 43ef0ee..dd3f375 100644 --- a/internal/monitor/history.go +++ b/internal/monitor/history.go @@ -20,11 +20,7 @@ var ( historyMu sync.RWMutex ) -func InitHistoryFromStore() { - s := store.Get() - if s == nil { - return - } +func InitHistoryFromStore(s store.Store) { all, err := s.LoadAllHistory(maxHistoryLen) if err != nil { AddLog("Failed to load check history: " + err.Error()) @@ -74,8 +70,8 @@ func RecordCheck(siteID int, latency time.Duration, isUp bool) { h.Statuses = h.Statuses[len(h.Statuses)-maxHistoryLen:] } - if s := store.Get(); s != nil { - go func() { _ = s.SaveCheck(siteID, latency.Nanoseconds(), isUp) }() + if db != nil { + go func() { _ = db.SaveCheck(siteID, latency.Nanoseconds(), isUp) }() } } diff --git a/internal/monitor/monitor.go b/internal/monitor/monitor.go index 27925ed..db11caf 100644 --- a/internal/monitor/monitor.go +++ b/internal/monitor/monitor.go @@ -55,6 +55,8 @@ var ( insecureSkipVerify bool + db store.Store + strictClient = &http.Client{ Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: false}}, } @@ -119,16 +121,11 @@ func RecordHeartbeat(token string) bool { return true } -func StartEngine() { +func StartEngine(s store.Store) { + db = s go func() { for { - s_instance := store.Get() - if s_instance == nil { - time.Sleep(1 * time.Second) - continue - } - - sites, err := s_instance.GetSites() + sites, err := db.GetSites() if err != nil { AddLog(fmt.Sprintf("Failed to load sites: %v", err)) time.Sleep(5 * time.Second) @@ -407,11 +404,10 @@ func handleStatusChange(site models.Site, rawStatus string, code int, latency ti } func triggerAlert(alertID int, title, message string) { - s_instance := store.Get() - if s_instance == nil { + if db == nil { return } - cfg, err := s_instance.GetAlert(alertID) + cfg, err := db.GetAlert(alertID) if err != nil { return } diff --git a/internal/server/server.go b/internal/server/server.go index bb97a88..b6a2b52 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -148,7 +148,7 @@ type ServerConfig struct { ClusterKey string // Shared Secret for Security } -func Start(cfg ServerConfig) { +func Start(cfg ServerConfig, s store.Store) { if cfg.ClusterKey == "" { fmt.Println("WARNING: No UPKEEP_CLUSTER_SECRET set. Cluster API endpoints are unauthenticated.") } @@ -185,7 +185,7 @@ func Start(cfg ServerConfig) { http.Error(w, "Unauthorized: UPKEEP_CLUSTER_SECRET required", 401) return } - data, err := store.Get().ExportData() + data, err := s.ExportData() if err != nil { log.Printf("Export failed: %v", err) http.Error(w, "Export failed", 500) @@ -209,7 +209,7 @@ func Start(cfg ServerConfig) { http.Error(w, "Invalid JSON", 400) return } - if err := store.Get().ImportData(data); err != nil { + if err := s.ImportData(data); err != nil { log.Printf("Import failed: %v", err) http.Error(w, "Import failed", 500) return @@ -234,7 +234,7 @@ func Start(cfg ServerConfig) { return } backup := importer.ConvertKuma(&kb) - if err := store.Get().ImportData(backup); err != nil { + if err := s.ImportData(backup); err != nil { log.Printf("Kuma import failed: %v", err) http.Error(w, "Import failed", 500) return diff --git a/internal/store/store.go b/internal/store/store.go index d119597..35afa0b 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -35,13 +35,3 @@ type Store interface { ExportData() (models.Backup, error) ImportData(data models.Backup) error } - -var Current Store - -func SetGlobal(s Store) { - Current = s -} - -func Get() Store { - return Current -} diff --git a/internal/tui/tab_alerts.go b/internal/tui/tab_alerts.go index 72bd19d..0d203f1 100644 --- a/internal/tui/tab_alerts.go +++ b/internal/tui/tab_alerts.go @@ -3,7 +3,6 @@ package tui import ( "fmt" "go-upkeep/internal/monitor" - "go-upkeep/internal/store" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/huh" @@ -278,11 +277,11 @@ func (m *Model) submitAlertForm() { } if m.editID > 0 { - if err := store.Get().UpdateAlert(m.editID, d.Name, d.AlertType, settings); err != nil { + if err := m.store.UpdateAlert(m.editID, d.Name, d.AlertType, settings); err != nil { monitor.AddLog("Update alert failed: " + err.Error()) } } else { - if err := store.Get().AddAlert(d.Name, d.AlertType, settings); err != nil { + if err := m.store.AddAlert(d.Name, d.AlertType, settings); err != nil { monitor.AddLog("Add alert failed: " + err.Error()) } } diff --git a/internal/tui/tab_sites.go b/internal/tui/tab_sites.go index f1a2aa0..1644a35 100644 --- a/internal/tui/tab_sites.go +++ b/internal/tui/tab_sites.go @@ -4,7 +4,6 @@ import ( "fmt" "go-upkeep/internal/models" "go-upkeep/internal/monitor" - "go-upkeep/internal/store" "net/url" "strconv" "strings" @@ -361,14 +360,12 @@ func (m *Model) initSiteHuhForm() tea.Cmd { } alertOpts := []huh.Option[string]{huh.NewOption("None", "0")} - if s := store.Get(); s != nil { - if alerts, err := s.GetAllAlerts(); err == nil { - for _, a := range alerts { - alertOpts = append(alertOpts, huh.NewOption( - fmt.Sprintf("%s (%s)", a.Name, a.Type), - strconv.Itoa(a.ID), - )) - } + if alerts, err := m.store.GetAllAlerts(); err == nil { + for _, a := range alerts { + alertOpts = append(alertOpts, huh.NewOption( + fmt.Sprintf("%s (%s)", a.Name, a.Type), + strconv.Itoa(a.ID), + )) } } @@ -560,12 +557,12 @@ func (m *Model) submitSiteForm() { } if m.editID > 0 { - if err := store.Get().UpdateSite(site); err != nil { + if err := m.store.UpdateSite(site); err != nil { monitor.AddLog("Update site failed: " + err.Error()) } monitor.UpdateSiteConfig(site) } else { - if err := store.Get().AddSite(site); err != nil { + if err := m.store.AddSite(site); err != nil { monitor.AddLog("Add site failed: " + err.Error()) } } diff --git a/internal/tui/tab_users.go b/internal/tui/tab_users.go index 77d4182..d82e5fb 100644 --- a/internal/tui/tab_users.go +++ b/internal/tui/tab_users.go @@ -3,7 +3,6 @@ package tui import ( "fmt" "go-upkeep/internal/monitor" - "go-upkeep/internal/store" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/huh" @@ -146,11 +145,11 @@ func (m *Model) initUserHuhForm() tea.Cmd { func (m *Model) submitUserForm() { d := m.userFormData if m.editID > 0 { - if err := store.Get().UpdateUser(m.editID, d.Username, d.PublicKey, d.Role); err != nil { + if err := m.store.UpdateUser(m.editID, d.Username, d.PublicKey, d.Role); err != nil { monitor.AddLog("Update user failed: " + err.Error()) } } else { - if err := store.Get().AddUser(d.Username, d.PublicKey, d.Role); err != nil { + if err := m.store.AddUser(d.Username, d.PublicKey, d.Role); err != nil { monitor.AddLog("Add user failed: " + err.Error()) } } diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 533c993..4324c65 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -68,6 +68,7 @@ type Model struct { deleteTab int collapsed map[int]bool + store store.Store // harmonica animation state pulseSpring harmonica.Spring @@ -80,7 +81,7 @@ type Model struct { users []models.User } -func InitialModel(isAdmin bool) Model { +func InitialModel(isAdmin bool, s store.Store) Model { vpLogs := viewport.New(100, 20) vpLogs.SetContent("Waiting for logs...") z := zone.New() @@ -90,6 +91,7 @@ func InitialModel(isAdmin bool) Model { logViewport: vpLogs, maxTableRows: 5, isAdmin: isAdmin, + store: s, zones: z, pulseSpring: spring, collapsed: make(map[int]bool), @@ -107,25 +109,23 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if keyMsg, ok := msg.(tea.KeyMsg); ok { switch keyMsg.String() { case "y", "Y": - if s := store.Get(); s != nil { - switch m.deleteTab { - case 0: - if err := s.DeleteSite(m.deleteID); err != nil { - monitor.AddLog("Delete site failed: " + err.Error()) - } - monitor.RemoveSite(m.deleteID) - m.adjustCursor(len(m.sites) - 1) - case 1: - if err := s.DeleteAlert(m.deleteID); err != nil { - monitor.AddLog("Delete alert failed: " + err.Error()) - } - m.adjustCursor(len(m.alerts) - 1) - case 3: - if err := s.DeleteUser(m.deleteID); err != nil { - monitor.AddLog("Delete user failed: " + err.Error()) - } - m.adjustCursor(len(m.users) - 1) + switch m.deleteTab { + case 0: + if err := m.store.DeleteSite(m.deleteID); err != nil { + monitor.AddLog("Delete site failed: " + err.Error()) } + monitor.RemoveSite(m.deleteID) + m.adjustCursor(len(m.sites) - 1) + case 1: + if err := m.store.DeleteAlert(m.deleteID); err != nil { + monitor.AddLog("Delete alert failed: " + err.Error()) + } + m.adjustCursor(len(m.alerts) - 1) + case 3: + if err := m.store.DeleteUser(m.deleteID); err != nil { + monitor.AddLog("Delete user failed: " + err.Error()) + } + m.adjustCursor(len(m.users) - 1) } m.refreshData() m.state = stateDashboard @@ -319,9 +319,7 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { site := m.sites[m.cursor] monitor.ToggleSitePause(site.ID) site.Paused = !site.Paused - if s := store.Get(); s != nil { - _ = s.UpdateSitePaused(site.ID, site.Paused) - } + _ = m.store.UpdateSitePaused(site.ID, site.Paused) m.refreshData() } case "d", "backspace": @@ -470,23 +468,18 @@ func (m *Model) refreshData() { } ordered = append(ordered, ungrouped...) m.sites = ordered - if s := store.Get(); s != nil { - if alerts, err := s.GetAllAlerts(); err == nil { - m.alerts = alerts - } - if m.isAdmin { - if users, err := s.GetAllUsers(); err == nil { - m.users = users - } + if alerts, err := m.store.GetAllAlerts(); err == nil { + m.alerts = alerts + } + if m.isAdmin { + if users, err := m.store.GetAllUsers(); err == nil { + m.users = users } } m.logViewport.SetContent(strings.Join(monitor.GetLogs(), "\n")) } func (m *Model) submitForm() { - if store.Get() == nil { - return - } switch m.state { case stateFormSite: if m.siteFormData != nil { From d6f33a4d1fdc1774a3e18db19e4a5b5e749652e0 Mon Sep 17 00:00:00 2001 From: Tyler Koenig Date: Fri, 15 May 2026 00:46:05 -0400 Subject: [PATCH 5/9] refactor(alert): extract shared HTTPProvider for webhook-based alerts Discord, Slack, and Webhook providers now use a single HTTPProvider struct with a PayloadFunc for the only part that differs. Centralizes response body handling and adds HTTP status code checking (4xx/5xx now return errors instead of being silently ignored). Email and Ntfy keep separate implementations (different protocols). Adding a new HTTP-based alert provider is now a one-line PayloadFunc. --- internal/alert/alert.go | 101 ++++++++++++++++------------------------ 1 file changed, 41 insertions(+), 60 deletions(-) diff --git a/internal/alert/alert.go b/internal/alert/alert.go index 16af56c..d67a30d 100644 --- a/internal/alert/alert.go +++ b/internal/alert/alert.go @@ -17,15 +17,49 @@ type Provider interface { Send(title, message string) error } +type PayloadFunc func(title, message string) ([]byte, error) + +type HTTPProvider struct { + URL string + Payload PayloadFunc +} + +func (h *HTTPProvider) Send(title, message string) error { + body, err := h.Payload(title, message) + if err != nil { + return err + } + resp, err := alertClient.Post(h.URL, "application/json", bytes.NewBuffer(body)) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + return fmt.Errorf("alert webhook returned HTTP %d", resp.StatusCode) + } + return nil +} + +func discordPayload(title, message string) ([]byte, error) { + return json.Marshal(map[string]string{"content": fmt.Sprintf("**%s**\n%s", title, message)}) +} + +func slackPayload(title, message string) ([]byte, error) { + return json.Marshal(map[string]string{"text": fmt.Sprintf("*%s*\n%s", title, message)}) +} + +func webhookPayload(title, message string) ([]byte, error) { + return json.Marshal(map[string]string{"title": title, "message": message, "status": "alert"}) +} + func GetProvider(cfg models.AlertConfig) Provider { switch cfg.Type { case "discord": - return &DiscordProvider{URL: cfg.Settings["url"]} + return &HTTPProvider{URL: cfg.Settings["url"], Payload: discordPayload} case "slack": - return &SlackProvider{URL: cfg.Settings["url"]} + return &HTTPProvider{URL: cfg.Settings["url"], Payload: slackPayload} case "webhook": - // Generic Webhook - return &WebhookProvider{URL: cfg.Settings["url"]} + return &HTTPProvider{URL: cfg.Settings["url"], Payload: webhookPayload} case "email": port := "25" if p, ok := cfg.Settings["port"]; ok { @@ -56,62 +90,6 @@ func GetProvider(cfg models.AlertConfig) Provider { } } -// --- DISCORD --- -type DiscordProvider struct{ URL string } - -func (d *DiscordProvider) Send(title, message string) error { - payload := map[string]string{"content": fmt.Sprintf("**%s**\n%s", title, message)} - jsonValue, err := json.Marshal(payload) - if err != nil { - return err - } - resp, err := alertClient.Post(d.URL, "application/json", bytes.NewBuffer(jsonValue)) - if err != nil { - return err - } - defer resp.Body.Close() - return nil -} - -// --- SLACK --- -type SlackProvider struct{ URL string } - -func (s *SlackProvider) Send(title, message string) error { - payload := map[string]string{"text": fmt.Sprintf("*%s*\n%s", title, message)} - jsonValue, err := json.Marshal(payload) - if err != nil { - return err - } - resp, err := alertClient.Post(s.URL, "application/json", bytes.NewBuffer(jsonValue)) - if err != nil { - return err - } - defer resp.Body.Close() - return nil -} - -// --- GENERIC WEBHOOK --- -type WebhookProvider struct{ URL string } - -func (w *WebhookProvider) Send(title, message string) error { - payload := map[string]string{ - "title": title, - "message": message, - "status": "alert", - } - jsonValue, err := json.Marshal(payload) - if err != nil { - return err - } - resp, err := alertClient.Post(w.URL, "application/json", bytes.NewBuffer(jsonValue)) - if err != nil { - return err - } - defer resp.Body.Close() - return nil -} - -// --- EMAIL --- type EmailProvider struct { Host, Port, User, Pass, To, From string } @@ -149,5 +127,8 @@ func (n *NtfyProvider) Send(title, message string) error { return err } defer resp.Body.Close() + if resp.StatusCode >= 400 { + return fmt.Errorf("ntfy returned HTTP %d", resp.StatusCode) + } return nil } From 0e6dc774cb55048112ac6e5bf3690a00469cd96f Mon Sep 17 00:00:00 2001 From: Tyler Koenig Date: Fri, 15 May 2026 00:49:14 -0400 Subject: [PATCH 6/9] refactor(tui): extract shared table rendering, fix cursor bounds MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - New table_helpers.go with renderTable() and shared styles - Remove 4 duplicated style blocks (header/cell/selected/border) from tab_alerts.go and tab_users.go - All 3 tab views now use renderTable() for offset/end calc, selected row highlighting, and table construction - Sites tab keeps siteGroupStyle via StyleOverride callback - Clamp cursor to list length at end of refreshData() to prevent index-out-of-bounds after concurrent list changes - Fix off-by-one in tab click handler (i <= maxTabs → i < tabCount) --- internal/tui/tab_alerts.go | 81 ++++---------- internal/tui/tab_sites.go | 200 +++++++++++++--------------------- internal/tui/tab_users.go | 76 +++---------- internal/tui/table_helpers.go | 75 +++++++++++++ internal/tui/tui.go | 21 +++- 5 files changed, 204 insertions(+), 249 deletions(-) create mode 100644 internal/tui/table_helpers.go diff --git a/internal/tui/tab_alerts.go b/internal/tui/tab_alerts.go index 0d203f1..60b1890 100644 --- a/internal/tui/tab_alerts.go +++ b/internal/tui/tab_alerts.go @@ -7,25 +7,6 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/huh" "github.com/charmbracelet/lipgloss" - "github.com/charmbracelet/lipgloss/table" -) - -var ( - alertHeaderStyle = lipgloss.NewStyle(). - Foreground(lipgloss.Color("#7D56F4")). - Bold(true). - Padding(0, 1) - - alertCellStyle = lipgloss.NewStyle().Padding(0, 1) - - alertSelectedStyle = lipgloss.NewStyle(). - Padding(0, 1). - Bold(true). - Foreground(lipgloss.Color("#ffffff")). - Background(lipgloss.Color("#3b3b5c")) - - alertBorderStyle = lipgloss.NewStyle(). - Foreground(lipgloss.Color("#444")) ) type alertFormData struct { @@ -97,49 +78,27 @@ func (m Model) viewAlertsTab() string { return "\n No alert channels configured. Press [n] to add one." } - end := m.tableOffset + m.maxTableRows - if end > len(m.alerts) { - end = len(m.alerts) - } - - selectedVisual := m.cursor - m.tableOffset - - var rows [][]string - for i := m.tableOffset; i < end; i++ { - alert := m.alerts[i] - rows = append(rows, []string{ - fmt.Sprintf("%d", i+1), - m.zones.Mark(fmt.Sprintf("alert-%d", i), limitStr(alert.Name, 15)), - fmtAlertType(alert.Type), - fmtAlertConfig(struct { - Type string - Settings map[string]string - }{alert.Type, alert.Settings}), - }) - } - - tableWidth := m.termWidth - 6 - if tableWidth < 40 { - tableWidth = 40 - } - - t := table.New(). - Border(lipgloss.RoundedBorder()). - BorderStyle(alertBorderStyle). - Width(tableWidth). - Headers("#", "NAME", "TYPE", "CONFIG"). - Rows(rows...). - StyleFunc(func(row, col int) lipgloss.Style { - if row == table.HeaderRow { - return alertHeaderStyle + return m.renderTable( + []string{"#", "NAME", "TYPE", "CONFIG"}, + len(m.alerts), + func(start, end int) [][]string { + var rows [][]string + for i := start; i < end; i++ { + a := m.alerts[i] + rows = append(rows, []string{ + fmt.Sprintf("%d", i+1), + m.zones.Mark(fmt.Sprintf("alert-%d", i), limitStr(a.Name, 15)), + fmtAlertType(a.Type), + fmtAlertConfig(struct { + Type string + Settings map[string]string + }{a.Type, a.Settings}), + }) } - if row == selectedVisual { - return alertSelectedStyle - } - return alertCellStyle - }) - - return "\n" + t.Render() + return rows + }, + nil, nil, + ) } func (m *Model) initAlertHuhForm() tea.Cmd { diff --git a/internal/tui/tab_sites.go b/internal/tui/tab_sites.go index 1644a35..e5ebbfe 100644 --- a/internal/tui/tab_sites.go +++ b/internal/tui/tab_sites.go @@ -12,33 +12,14 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/huh" "github.com/charmbracelet/lipgloss" - "github.com/charmbracelet/lipgloss/table" ) var sparkChars = []rune{'▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'} -var ( - siteHeaderStyle = lipgloss.NewStyle(). - Foreground(lipgloss.Color("#7D56F4")). - Bold(true). - Padding(0, 1) - - siteCellStyle = lipgloss.NewStyle().Padding(0, 1) - - siteSelectedStyle = lipgloss.NewStyle(). - Padding(0, 1). - Bold(true). - Foreground(lipgloss.Color("#ffffff")). - Background(lipgloss.Color("#3b3b5c")) - - siteBorderStyle = lipgloss.NewStyle(). - Foreground(lipgloss.Color("#444")) - - siteGroupStyle = lipgloss.NewStyle(). - Padding(0, 1). - Bold(true). - Foreground(lipgloss.Color("#7D56F4")) -) +var siteGroupStyle = lipgloss.NewStyle(). + Padding(0, 1). + Bold(true). + Foreground(lipgloss.Color("#7D56F4")) type siteFormData struct { Name string @@ -219,111 +200,80 @@ func (m Model) viewSitesTab() string { return "\n No sites configured. Press [n] to add one." } - end := m.tableOffset + m.maxTableRows - if end > len(m.sites) { - end = len(m.sites) - } - - selectedVisual := m.cursor - m.tableOffset - - var rows [][]string - var groupRows []int - for i := m.tableOffset; i < end; i++ { - site := m.sites[i] - - if site.Type == "group" { - groupRows = append(groupRows, i-m.tableOffset) - arrow := "▾" - if m.collapsed[site.ID] { - arrow = "▸" - } - rows = append(rows, []string{ - strconv.Itoa(i + 1), - m.zones.Mark(fmt.Sprintf("site-%d", i), arrow+" "+limitStr(site.Name, 11)), - "group", - fmtStatus(site.Status, site.Paused), - subtleStyle.Render("—"), - subtleStyle.Render("—"), - subtleStyle.Render(strings.Repeat("·", sparkWidth)), - subtleStyle.Render("-"), - subtleStyle.Render("—"), - }) - continue - } - - name := site.Name - if site.ParentID > 0 { - prefix := "├" - if i+1 >= len(m.sites) || m.sites[i+1].ParentID != site.ParentID { - prefix = "└" - } - name = prefix + " " + limitStr(name, 11) - } else { - name = limitStr(name, 13) - } - - hist, _ := monitor.GetHistory(site.ID) - var spark string - if site.Type == "push" { - spark = heartbeatSparkline(hist.Statuses, sparkWidth) - } else { - spark = latencySparkline(hist.Latencies, sparkWidth) - } - - rows = append(rows, []string{ - strconv.Itoa(i + 1), - m.zones.Mark(fmt.Sprintf("site-%d", i), name), - site.Type, - fmtStatus(site.Status, site.Paused), - fmtLatency(site.Latency), - fmtUptime(hist.TotalChecks, hist.UpChecks), - spark, - fmtSSL(site), - fmtRetries(site), - }) - } - - isGroupRow := func(row int) bool { - for _, g := range groupRows { - if g == row { - return true - } - } - return false - } - - tableWidth := m.termWidth - 6 - if tableWidth < 40 { - tableWidth = 40 - } - - // column widths: #=6, name=flex, type=10, status=10, latency=8, uptime=8, history=sparkWidth+4, ssl=7, retry=9 colWidths := []int{6, 0, 10, 10, 8, 8, sparkWidth + 4, 7, 9} - t := table.New(). - Border(lipgloss.RoundedBorder()). - BorderStyle(siteBorderStyle). - Width(tableWidth). - Headers("#", "NAME", "TYPE", "STATUS", "LATENCY", "UPTIME", "HISTORY", "SSL", "RETRY"). - Rows(rows...). - StyleFunc(func(row, col int) lipgloss.Style { - var base lipgloss.Style - if row == table.HeaderRow { - base = siteHeaderStyle - } else if row == selectedVisual { - base = siteSelectedStyle - } else if isGroupRow(row) { - base = siteGroupStyle - } else { - base = siteCellStyle - } - if col < len(colWidths) && colWidths[col] > 0 { - base = base.Width(colWidths[col]) - } - return base - }) + var groupRows map[int]bool + return m.renderTable( + []string{"#", "NAME", "TYPE", "STATUS", "LATENCY", "UPTIME", "HISTORY", "SSL", "RETRY"}, + len(m.sites), + func(start, end int) [][]string { + groupRows = make(map[int]bool) + var rows [][]string + for i := start; i < end; i++ { + site := m.sites[i] - return "\n" + t.Render() + if site.Type == "group" { + groupRows[i-start] = true + arrow := "▾" + if m.collapsed[site.ID] { + arrow = "▸" + } + rows = append(rows, []string{ + strconv.Itoa(i + 1), + m.zones.Mark(fmt.Sprintf("site-%d", i), arrow+" "+limitStr(site.Name, 11)), + "group", + fmtStatus(site.Status, site.Paused), + subtleStyle.Render("—"), + subtleStyle.Render("—"), + subtleStyle.Render(strings.Repeat("·", sparkWidth)), + subtleStyle.Render("-"), + subtleStyle.Render("—"), + }) + continue + } + + name := site.Name + if site.ParentID > 0 { + prefix := "├" + if i+1 >= len(m.sites) || m.sites[i+1].ParentID != site.ParentID { + prefix = "└" + } + name = prefix + " " + limitStr(name, 11) + } else { + name = limitStr(name, 13) + } + + hist, _ := monitor.GetHistory(site.ID) + var spark string + if site.Type == "push" { + spark = heartbeatSparkline(hist.Statuses, sparkWidth) + } else { + spark = latencySparkline(hist.Latencies, sparkWidth) + } + + rows = append(rows, []string{ + strconv.Itoa(i + 1), + m.zones.Mark(fmt.Sprintf("site-%d", i), name), + site.Type, + fmtStatus(site.Status, site.Paused), + fmtLatency(site.Latency), + fmtUptime(hist.TotalChecks, hist.UpChecks), + spark, + fmtSSL(site), + fmtRetries(site), + }) + } + return rows + }, + colWidths, + func(row, col int) *lipgloss.Style { + if groupRows[row] { + s := siteGroupStyle + return &s + } + return nil + }, + ) } func (m *Model) initSiteHuhForm() tea.Cmd { diff --git a/internal/tui/tab_users.go b/internal/tui/tab_users.go index d82e5fb..46c5679 100644 --- a/internal/tui/tab_users.go +++ b/internal/tui/tab_users.go @@ -6,26 +6,6 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/huh" - "github.com/charmbracelet/lipgloss" - "github.com/charmbracelet/lipgloss/table" -) - -var ( - userHeaderStyle = lipgloss.NewStyle(). - Foreground(lipgloss.Color("#7D56F4")). - Bold(true). - Padding(0, 1) - - userCellStyle = lipgloss.NewStyle().Padding(0, 1) - - userSelectedStyle = lipgloss.NewStyle(). - Padding(0, 1). - Bold(true). - Foreground(lipgloss.Color("#ffffff")). - Background(lipgloss.Color("#3b3b5c")) - - userBorderStyle = lipgloss.NewStyle(). - Foreground(lipgloss.Color("#444")) ) type userFormData struct { @@ -53,46 +33,24 @@ func (m Model) viewUsersTab() string { return "\n No users configured. Press [n] to add one." } - end := m.tableOffset + m.maxTableRows - if end > len(m.users) { - end = len(m.users) - } - - selectedVisual := m.cursor - m.tableOffset - - var rows [][]string - for i := m.tableOffset; i < end; i++ { - u := m.users[i] - rows = append(rows, []string{ - fmt.Sprintf("%d", i+1), - m.zones.Mark(fmt.Sprintf("user-%d", i), limitStr(u.Username, 15)), - fmtRole(u.Role), - fmtKey(u.PublicKey), - }) - } - - tableWidth := m.termWidth - 6 - if tableWidth < 40 { - tableWidth = 40 - } - - t := table.New(). - Border(lipgloss.RoundedBorder()). - BorderStyle(userBorderStyle). - Width(tableWidth). - Headers("#", "USERNAME", "ROLE", "PUBLIC KEY"). - Rows(rows...). - StyleFunc(func(row, col int) lipgloss.Style { - if row == table.HeaderRow { - return userHeaderStyle + return m.renderTable( + []string{"#", "USERNAME", "ROLE", "PUBLIC KEY"}, + len(m.users), + func(start, end int) [][]string { + var rows [][]string + for i := start; i < end; i++ { + u := m.users[i] + rows = append(rows, []string{ + fmt.Sprintf("%d", i+1), + m.zones.Mark(fmt.Sprintf("user-%d", i), limitStr(u.Username, 15)), + fmtRole(u.Role), + fmtKey(u.PublicKey), + }) } - if row == selectedVisual { - return userSelectedStyle - } - return userCellStyle - }) - - return "\n" + t.Render() + return rows + }, + nil, nil, + ) } func (m *Model) initUserHuhForm() tea.Cmd { diff --git a/internal/tui/table_helpers.go b/internal/tui/table_helpers.go new file mode 100644 index 0000000..be8719e --- /dev/null +++ b/internal/tui/table_helpers.go @@ -0,0 +1,75 @@ +package tui + +import ( + "github.com/charmbracelet/lipgloss" + "github.com/charmbracelet/lipgloss/table" +) + +var ( + tableHeaderStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("#7D56F4")). + Bold(true). + Padding(0, 1) + + tableCellStyle = lipgloss.NewStyle().Padding(0, 1) + + tableSelectedStyle = lipgloss.NewStyle(). + Padding(0, 1). + Bold(true). + Foreground(lipgloss.Color("#ffffff")). + Background(lipgloss.Color("#3b3b5c")) + + tableBorderStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("#444")) +) + +type StyleOverride func(row, col int) *lipgloss.Style + +func (m Model) renderTable(headers []string, items int, buildRows func(start, end int) [][]string, colWidths []int, styleOverride StyleOverride) string { + if items == 0 { + return "" + } + + end := m.tableOffset + m.maxTableRows + if end > items { + end = items + } + + selectedVisual := m.cursor - m.tableOffset + rows := buildRows(m.tableOffset, end) + + tableWidth := m.termWidth - 6 + if tableWidth < 40 { + tableWidth = 40 + } + + t := table.New(). + Border(lipgloss.RoundedBorder()). + BorderStyle(tableBorderStyle). + Width(tableWidth). + Headers(headers...). + Rows(rows...). + StyleFunc(func(row, col int) lipgloss.Style { + if row == table.HeaderRow { + return tableHeaderStyle + } + if styleOverride != nil { + if s := styleOverride(row, col); s != nil { + if col < len(colWidths) && colWidths[col] > 0 { + return s.Width(colWidths[col]) + } + return *s + } + } + base := tableCellStyle + if row == selectedVisual { + base = tableSelectedStyle + } + if col < len(colWidths) && colWidths[col] > 0 { + base = base.Width(colWidths[col]) + } + return base + }) + + return "\n" + t.Render() +} diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 4324c65..5e0e7b6 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -346,11 +346,11 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } func (m *Model) handleClick(msg tea.MouseMsg) (tea.Model, tea.Cmd) { - maxTabs := 3 - if !m.isAdmin { - maxTabs = 2 + tabCount := 3 + if m.isAdmin { + tabCount = 4 } - for i := 0; i <= maxTabs; i++ { + for i := 0; i < tabCount; i++ { if m.zones.Get(fmt.Sprintf("tab-%d", i)).InBounds(msg) { m.switchTab(i) return m, nil @@ -477,6 +477,19 @@ func (m *Model) refreshData() { } } m.logViewport.SetContent(strings.Join(monitor.GetLogs(), "\n")) + + listLen := len(m.sites) + if m.currentTab == 1 { + listLen = len(m.alerts) + } else if m.currentTab == 3 { + listLen = len(m.users) + } + if listLen > 0 && m.cursor >= listLen { + m.cursor = listLen - 1 + } + if m.cursor < m.tableOffset { + m.tableOffset = m.cursor + } } func (m *Model) submitForm() { From f023e38fdc1317f5b6337f2f50a05eaa19766488 Mon Sep 17 00:00:00 2001 From: Tyler Koenig Date: Fri, 15 May 2026 08:21:17 -0400 Subject: [PATCH 7/9] refactor(monitor): encapsulate engine state, add graceful shutdown and tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace all monitor package-level mutable state with Engine struct. All state (liveState, logStore, histories, tokenIndex, HTTP clients) is now encapsulated in Engine, created via NewEngine(store). Key changes: - Engine struct holds all monitor state with proper mutex protection - Engine.Start(ctx) and monitorRoutine respect context cancellation for graceful shutdown — no more leaked goroutines - cluster.runFollowerLoop also respects context for clean exit - Token index (map[string]int) for O(1) push heartbeat lookup, replacing O(n) linear scan through LiveState - UpdateSiteConfig preserves 8 runtime fields instead of copying 17 config fields individually - triggerAlert goroutines get 30s timeout context - All consumers (TUI, server, cluster, main) receive *Engine via constructor/parameter — no package-level state access - main.go creates context.WithCancel, passes to engine and cluster First test suite: 12 tests across store and alert packages - Store: CRUD for sites/alerts/users, push token generation, import/export round-trip, check history persistence - Alert: Discord/Slack/Webhook payload format, HTTP 4xx error propagation, Ntfy headers, unknown provider returns nil --- cmd/goupkeep/main.go | 31 +- internal/alert/alert_test.go | 109 +++++++ internal/cluster/cluster.go | 33 ++- internal/monitor/history.go | 55 ++-- internal/monitor/monitor.go | 500 ++++++++++++++++++-------------- internal/server/server.go | 19 +- internal/store/sqlstore_test.go | 231 +++++++++++++++ internal/tui/tab_alerts.go | 5 +- internal/tui/tab_sites.go | 9 +- internal/tui/tab_users.go | 5 +- internal/tui/tui.go | 23 +- 11 files changed, 705 insertions(+), 315 deletions(-) create mode 100644 internal/alert/alert_test.go create mode 100644 internal/store/sqlstore_test.go diff --git a/cmd/goupkeep/main.go b/cmd/goupkeep/main.go index 921a3c7..5962e37 100644 --- a/cmd/goupkeep/main.go +++ b/cmd/goupkeep/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "fmt" "go-upkeep/internal/cluster" @@ -68,9 +69,6 @@ func main() { if v := os.Getenv("UPKEEP_CLUSTER_SECRET"); v != "" { clusterKey = v } - if os.Getenv("UPKEEP_INSECURE_SKIP_VERIFY") == "true" { - monitor.SetInsecureSkipVerify(true) - } port := flag.Int("port", portVal, "SSH Port") flagDBType := flag.String("db-type", dbType, "Database type") @@ -115,26 +113,34 @@ func main() { fmt.Printf("Imported %d monitors and %d alerts from Uptime Kuma v%s\n", len(backup.Sites), len(backup.Alerts), kb.Version) } - monitor.InitHistoryFromStore(s) - monitor.StartEngine(s) + eng := monitor.NewEngine(s) + if os.Getenv("UPKEEP_INSECURE_SKIP_VERIFY") == "true" { + eng.SetInsecureSkipVerify(true) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + eng.InitHistory() + eng.Start(ctx) server.Start(server.ServerConfig{ Port: httpPort, EnableStatus: enableStatus, Title: statusTitle, ClusterKey: clusterKey, - }, s) + }, s, eng) - cluster.Start(cluster.Config{ + cluster.Start(ctx, cluster.Config{ Mode: clusterMode, PeerURL: clusterPeer, SharedKey: clusterKey, - }) + }, eng) - startSSHServer(*port, s) + startSSHServer(*port, s, eng) if isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd()) { - p := tea.NewProgram(tui.InitialModel(true, s), tea.WithAltScreen(), tea.WithMouseCellMotion()) + p := tea.NewProgram(tui.InitialModel(true, s, eng), tea.WithAltScreen(), tea.WithMouseCellMotion()) if _, err := p.Run(); err != nil { fmt.Printf("Error: %v\n", err) } @@ -145,9 +151,10 @@ func main() { <-done fmt.Println("Shutting down...") } + cancel() } -func startSSHServer(port int, db store.Store) { +func startSSHServer(port int, db store.Store, eng *monitor.Engine) { s, err := wish.NewServer( wish.WithAddress(fmt.Sprintf(":%d", port)), wish.WithHostKeyPath(".ssh/id_ed25519"), @@ -156,7 +163,7 @@ func startSSHServer(port int, db store.Store) { }), wish.WithMiddleware( bm.Middleware(func(s ssh.Session) (tea.Model, []tea.ProgramOption) { - return tui.InitialModel(false, db), []tea.ProgramOption{tea.WithAltScreen(), tea.WithMouseCellMotion()} + return tui.InitialModel(false, db, eng), []tea.ProgramOption{tea.WithAltScreen(), tea.WithMouseCellMotion()} }), ), ) diff --git a/internal/alert/alert_test.go b/internal/alert/alert_test.go new file mode 100644 index 0000000..348f2c9 --- /dev/null +++ b/internal/alert/alert_test.go @@ -0,0 +1,109 @@ +package alert + +import ( + "encoding/json" + "go-upkeep/internal/models" + "net/http" + "net/http/httptest" + "testing" +) + +func TestHTTPProviderDiscord(t *testing.T) { + var received map[string]string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&received) + w.WriteHeader(200) + })) + defer srv.Close() + + p := GetProvider(models.AlertConfig{Type: "discord", Settings: map[string]string{"url": srv.URL}}) + if err := p.Send("Test Title", "Test Body"); err != nil { + t.Fatalf("Send: %v", err) + } + + if received["content"] != "**Test Title**\nTest Body" { + t.Errorf("unexpected payload: %s", received["content"]) + } +} + +func TestHTTPProviderSlack(t *testing.T) { + var received map[string]string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&received) + w.WriteHeader(200) + })) + defer srv.Close() + + p := GetProvider(models.AlertConfig{Type: "slack", Settings: map[string]string{"url": srv.URL}}) + if err := p.Send("Alert", "Message"); err != nil { + t.Fatalf("Send: %v", err) + } + + if received["text"] != "*Alert*\nMessage" { + t.Errorf("unexpected payload: %s", received["text"]) + } +} + +func TestHTTPProviderWebhook(t *testing.T) { + var received map[string]string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&received) + w.WriteHeader(200) + })) + defer srv.Close() + + p := GetProvider(models.AlertConfig{Type: "webhook", Settings: map[string]string{"url": srv.URL}}) + if err := p.Send("Title", "Body"); err != nil { + t.Fatalf("Send: %v", err) + } + + if received["title"] != "Title" || received["message"] != "Body" || received["status"] != "alert" { + t.Errorf("unexpected webhook payload: %v", received) + } +} + +func TestHTTPProviderErrorOnHTTP4xx(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(403) + })) + defer srv.Close() + + p := GetProvider(models.AlertConfig{Type: "discord", Settings: map[string]string{"url": srv.URL}}) + if err := p.Send("Test", "Test"); err == nil { + t.Fatal("expected error on 403 response") + } +} + +func TestNtfyProvider(t *testing.T) { + var title, body string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + title = r.Header.Get("Title") + buf := make([]byte, 1024) + n, _ := r.Body.Read(buf) + body = string(buf[:n]) + w.WriteHeader(200) + })) + defer srv.Close() + + p := GetProvider(models.AlertConfig{Type: "ntfy", Settings: map[string]string{ + "url": srv.URL, + "topic": "test", + }}) + if err := p.Send("Alert Title", "Alert Body"); err != nil { + t.Fatalf("Send: %v", err) + } + + if title != "Alert Title" { + t.Errorf("expected title 'Alert Title', got '%s'", title) + } + if body != "Alert Body" { + t.Errorf("expected body 'Alert Body', got '%s'", body) + } +} + +func TestGetProviderUnknown(t *testing.T) { + p := GetProvider(models.AlertConfig{Type: "unknown"}) + if p != nil { + t.Error("expected nil for unknown provider type") + } +} diff --git a/internal/cluster/cluster.go b/internal/cluster/cluster.go index 295443d..f986c29 100644 --- a/internal/cluster/cluster.go +++ b/internal/cluster/cluster.go @@ -1,6 +1,7 @@ package cluster import ( + "context" "fmt" "go-upkeep/internal/monitor" "net/http" @@ -14,13 +15,13 @@ type Config struct { SharedKey string // Security Key } -func Start(cfg Config) { +func Start(ctx context.Context, cfg Config, eng *monitor.Engine) { if cfg.Mode == "leader" { fmt.Println("Cluster: Running as LEADER (Active)") if cfg.SharedKey != "" { fmt.Println("WARNING: Cluster mode enabled. Ensure the HTTP server is behind a TLS-terminating proxy.") } - monitor.SetEngineActive(true) + eng.SetActive(true) return } @@ -29,20 +30,22 @@ func Start(cfg Config) { if cfg.PeerURL != "" && !strings.HasPrefix(cfg.PeerURL, "https://") { fmt.Println("WARNING: Cluster peer URL is not HTTPS. Cluster secret will be sent in cleartext.") } - monitor.SetEngineActive(false) - go runFollowerLoop(cfg) + eng.SetActive(false) + go runFollowerLoop(ctx, cfg, eng) } } -func runFollowerLoop(cfg Config) { +func runFollowerLoop(ctx context.Context, cfg Config, eng *monitor.Engine) { client := http.Client{Timeout: 2 * time.Second} - - // Failover Configuration failures := 0 threshold := 3 for { - time.Sleep(5 * time.Second) + select { + case <-time.After(5 * time.Second): + case <-ctx.Done(): + return + } req, _ := http.NewRequest("GET", cfg.PeerURL+"/api/health", nil) if cfg.SharedKey != "" { @@ -59,17 +62,15 @@ func runFollowerLoop(cfg Config) { if isLeaderHealthy { failures = 0 - if monitor.IsEngineActive() { - // Leader is back, yield - monitor.SetEngineActive(false) - monitor.AddLog("Cluster: Leader detected. Switching to PASSIVE.") + if eng.IsActive() { + eng.SetActive(false) + eng.AddLog("Cluster: Leader detected. Switching to PASSIVE.") } } else { failures++ - // If failures exceed threshold, take over - if failures >= threshold && !monitor.IsEngineActive() { - monitor.SetEngineActive(true) - monitor.AddLog("Cluster: Leader Unreachable. Switching to ACTIVE.") + if failures >= threshold && !eng.IsActive() { + eng.SetActive(true) + eng.AddLog("Cluster: Leader Unreachable. Switching to ACTIVE.") } } } diff --git a/internal/monitor/history.go b/internal/monitor/history.go index dd3f375..1049e04 100644 --- a/internal/monitor/history.go +++ b/internal/monitor/history.go @@ -1,10 +1,6 @@ package monitor -import ( - "go-upkeep/internal/store" - "sync" - "time" -) +import "time" const maxHistoryLen = 30 @@ -15,19 +11,14 @@ type SiteHistory struct { UpChecks int } -var ( - histories = make(map[int]*SiteHistory) - historyMu sync.RWMutex -) - -func InitHistoryFromStore(s store.Store) { - all, err := s.LoadAllHistory(maxHistoryLen) +func (e *Engine) InitHistory() { + all, err := e.db.LoadAllHistory(maxHistoryLen) if err != nil { - AddLog("Failed to load check history: " + err.Error()) + e.AddLog("Failed to load check history: " + err.Error()) return } - historyMu.Lock() - defer historyMu.Unlock() + e.histMu.Lock() + defer e.histMu.Unlock() for siteID, records := range all { h := &SiteHistory{} for _, r := range records { @@ -38,21 +29,21 @@ func InitHistoryFromStore(s store.Store) { h.Latencies = append(h.Latencies, time.Duration(r.LatencyNs)) h.Statuses = append(h.Statuses, r.IsUp) } - histories[siteID] = h + e.histories[siteID] = h } if len(all) > 0 { - AddLog("Loaded check history from database") + e.AddLog("Loaded check history from database") } } -func RecordCheck(siteID int, latency time.Duration, isUp bool) { - historyMu.Lock() - defer historyMu.Unlock() +func (e *Engine) recordCheck(siteID int, latency time.Duration, isUp bool) { + e.histMu.Lock() + defer e.histMu.Unlock() - h, ok := histories[siteID] + h, ok := e.histories[siteID] if !ok { h = &SiteHistory{} - histories[siteID] = h + e.histories[siteID] = h } h.TotalChecks++ @@ -70,15 +61,13 @@ func RecordCheck(siteID int, latency time.Duration, isUp bool) { h.Statuses = h.Statuses[len(h.Statuses)-maxHistoryLen:] } - if db != nil { - go func() { _ = db.SaveCheck(siteID, latency.Nanoseconds(), isUp) }() - } + go func() { _ = e.db.SaveCheck(siteID, latency.Nanoseconds(), isUp) }() } -func GetHistory(siteID int) (SiteHistory, bool) { - historyMu.RLock() - defer historyMu.RUnlock() - h, ok := histories[siteID] +func (e *Engine) GetHistory(siteID int) (SiteHistory, bool) { + e.histMu.RLock() + defer e.histMu.RUnlock() + h, ok := e.histories[siteID] if !ok { return SiteHistory{}, false } @@ -93,8 +82,8 @@ func GetHistory(siteID int) (SiteHistory, bool) { return cp, true } -func RemoveHistory(siteID int) { - historyMu.Lock() - defer historyMu.Unlock() - delete(histories, siteID) +func (e *Engine) removeHistory(siteID int) { + e.histMu.Lock() + defer e.histMu.Unlock() + delete(e.histories, siteID) } diff --git a/internal/monitor/monitor.go b/internal/monitor/monitor.go index db11caf..2d2af10 100644 --- a/internal/monitor/monitor.go +++ b/internal/monitor/monitor.go @@ -18,207 +18,271 @@ import ( probing "github.com/prometheus-community/pro-bing" ) -// --- LOGGING --- -var ( - LogStore []string - LogMutex sync.RWMutex -) +type Engine struct { + mu sync.RWMutex + liveState map[int]models.Site -func AddLog(msg string) { - LogMutex.Lock() - defer LogMutex.Unlock() - ts := time.Now().Format("15:04:05") - entry := fmt.Sprintf("[%s] %s", ts, msg) - LogStore = append([]string{entry}, LogStore...) - if len(LogStore) > 100 { - LogStore = LogStore[:100] + logMu sync.RWMutex + logStore []string + + activeMu sync.RWMutex + isActive bool + + histMu sync.RWMutex + histories map[int]*SiteHistory + + tokenIndex map[string]int + + db store.Store + insecureSkipVerify bool + strictClient *http.Client + insecureClient *http.Client +} + +func NewEngine(s store.Store) *Engine { + return &Engine{ + liveState: make(map[int]models.Site), + histories: make(map[int]*SiteHistory), + tokenIndex: make(map[string]int), + isActive: true, + db: s, + strictClient: &http.Client{ + Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: false}}, + }, + insecureClient: &http.Client{ + Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + }, } } -func GetLogs() []string { - LogMutex.RLock() - defer LogMutex.RUnlock() - logs := make([]string, len(LogStore)) - copy(logs, LogStore) +func (e *Engine) SetInsecureSkipVerify(skip bool) { + e.insecureSkipVerify = skip +} + +func (e *Engine) AddLog(msg string) { + e.logMu.Lock() + defer e.logMu.Unlock() + ts := time.Now().Format("15:04:05") + entry := fmt.Sprintf("[%s] %s", ts, msg) + e.logStore = append([]string{entry}, e.logStore...) + if len(e.logStore) > 100 { + e.logStore = e.logStore[:100] + } +} + +func (e *Engine) GetLogs() []string { + e.logMu.RLock() + defer e.logMu.RUnlock() + logs := make([]string, len(e.logStore)) + copy(logs, e.logStore) return logs } -// --- ENGINE --- - -var ( - LiveState = make(map[int]models.Site) - Mutex sync.RWMutex - - // Global Switch for HA - isActive = true - activeMutex sync.RWMutex - - insecureSkipVerify bool - - db store.Store - - strictClient = &http.Client{ - Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: false}}, - } - insecureClient = &http.Client{ - Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, - } -) - -func SetInsecureSkipVerify(skip bool) { - insecureSkipVerify = skip -} - -func SetEngineActive(active bool) { - activeMutex.Lock() - defer activeMutex.Unlock() - if isActive != active { - isActive = active +func (e *Engine) SetActive(active bool) { + e.activeMu.Lock() + defer e.activeMu.Unlock() + if e.isActive != active { + e.isActive = active status := "RESUMED (Active)" if !active { status = "PAUSED (Passive)" } - AddLog(fmt.Sprintf("Engine %s", status)) + e.AddLog(fmt.Sprintf("Engine %s", status)) } } -func IsEngineActive() bool { - activeMutex.RLock() - defer activeMutex.RUnlock() - return isActive +func (e *Engine) IsActive() bool { + e.activeMu.RLock() + defer e.activeMu.RUnlock() + return e.isActive } -func RecordHeartbeat(token string) bool { - if !IsEngineActive() { - return false - } // Only Leader accepts Push - - Mutex.Lock() - defer Mutex.Unlock() - var targetID int = -1 - for id, s := range LiveState { - if s.Type == "push" && s.Token == token { - targetID = id - break - } +func (e *Engine) GetAllSites() []models.Site { + e.mu.RLock() + defer e.mu.RUnlock() + sites := make([]models.Site, 0, len(e.liveState)) + for _, s := range e.liveState { + sites = append(sites, s) } - if targetID == -1 { + return sites +} + +func (e *Engine) GetLiveState() map[int]models.Site { + e.mu.RLock() + defer e.mu.RUnlock() + cp := make(map[int]models.Site, len(e.liveState)) + for k, v := range e.liveState { + cp[k] = v + } + return cp +} + +func (e *Engine) RecordHeartbeat(token string) bool { + if !e.IsActive() { + return false + } + + e.mu.Lock() + defer e.mu.Unlock() + + targetID, ok := e.tokenIndex[token] + if !ok { + return false + } + + site, exists := e.liveState[targetID] + if !exists { return false } - site := LiveState[targetID] site.LastCheck = time.Now() wasDown := site.Status == "DOWN" site.Status = "UP" site.FailureCount = 0 site.Latency = 0 - LiveState[targetID] = site + e.liveState[targetID] = site if wasDown { - AddLog(fmt.Sprintf("Push Monitor '%s' recovered", site.Name)) - triggerAlert(site.AlertID, "✅ RECOVERY", fmt.Sprintf("Push Monitor '%s' is receiving heartbeats.", site.Name)) + e.AddLog(fmt.Sprintf("Push Monitor '%s' recovered", site.Name)) + e.triggerAlert(site.AlertID, "✅ RECOVERY", fmt.Sprintf("Push Monitor '%s' is receiving heartbeats.", site.Name)) } return true } -func StartEngine(s store.Store) { - db = s +func (e *Engine) addToTokenIndex(site models.Site) { + if site.Type == "push" && site.Token != "" { + e.tokenIndex[site.Token] = site.ID + } +} + +func (e *Engine) removeFromTokenIndex(id int) { + for token, sid := range e.tokenIndex { + if sid == id { + delete(e.tokenIndex, token) + return + } + } +} + +func (e *Engine) Start(ctx context.Context) { go func() { for { - sites, err := db.GetSites() + select { + case <-ctx.Done(): + return + default: + } + + sites, err := e.db.GetSites() if err != nil { - AddLog(fmt.Sprintf("Failed to load sites: %v", err)) - time.Sleep(5 * time.Second) + e.AddLog(fmt.Sprintf("Failed to load sites: %v", err)) + select { + case <-time.After(5 * time.Second): + case <-ctx.Done(): + return + } continue } for _, s := range sites { - Mutex.RLock() - _, exists := LiveState[s.ID] - Mutex.RUnlock() + e.mu.RLock() + _, exists := e.liveState[s.ID] + e.mu.RUnlock() if !exists { - Mutex.Lock() + e.mu.Lock() s.Status = "PENDING" if s.Type == "push" { s.LastCheck = time.Now() } - LiveState[s.ID] = s - Mutex.Unlock() - go monitorRoutine(s.ID) + e.liveState[s.ID] = s + e.addToTokenIndex(s) + e.mu.Unlock() + go e.monitorRoutine(ctx, s.ID) } } - time.Sleep(5 * time.Second) + + select { + case <-time.After(5 * time.Second): + case <-ctx.Done(): + return + } } }() } -func UpdateSiteConfig(site models.Site) { - Mutex.Lock() - defer Mutex.Unlock() - if s, ok := LiveState[site.ID]; ok { - s.Name = site.Name - s.URL = site.URL - s.Type = site.Type - s.Interval = site.Interval - s.AlertID = site.AlertID - s.CheckSSL = site.CheckSSL - s.ExpiryThreshold = site.ExpiryThreshold - s.MaxRetries = site.MaxRetries - s.Hostname = site.Hostname - s.Port = site.Port - s.Timeout = site.Timeout - s.Method = site.Method - s.Description = site.Description - s.ParentID = site.ParentID - s.AcceptedCodes = site.AcceptedCodes - s.DNSResolveType = site.DNSResolveType - s.DNSServer = site.DNSServer - s.IgnoreTLS = site.IgnoreTLS - s.Paused = site.Paused - LiveState[site.ID] = s +func (e *Engine) UpdateSiteConfig(site models.Site) { + e.mu.Lock() + defer e.mu.Unlock() + if existing, ok := e.liveState[site.ID]; ok { + e.removeFromTokenIndex(site.ID) + site.Status = existing.Status + site.StatusCode = existing.StatusCode + site.Latency = existing.Latency + site.CertExpiry = existing.CertExpiry + site.HasSSL = existing.HasSSL + site.LastCheck = existing.LastCheck + site.SentSSLWarning = existing.SentSSLWarning + site.FailureCount = existing.FailureCount + e.liveState[site.ID] = site + e.addToTokenIndex(site) } } -func RemoveSite(id int) { - Mutex.Lock() - delete(LiveState, id) - Mutex.Unlock() - RemoveHistory(id) +func (e *Engine) RemoveSite(id int) { + e.mu.Lock() + e.removeFromTokenIndex(id) + delete(e.liveState, id) + e.mu.Unlock() + e.removeHistory(id) } -func ToggleSitePause(id int) bool { - Mutex.Lock() - defer Mutex.Unlock() - site, ok := LiveState[id] +func (e *Engine) ToggleSitePause(id int) bool { + e.mu.Lock() + defer e.mu.Unlock() + site, ok := e.liveState[id] if !ok { return false } site.Paused = !site.Paused - LiveState[id] = site + e.liveState[id] = site if site.Paused { - AddLog(fmt.Sprintf("Monitor '%s' paused", site.Name)) + e.AddLog(fmt.Sprintf("Monitor '%s' paused", site.Name)) } else { - AddLog(fmt.Sprintf("Monitor '%s' resumed", site.Name)) + e.AddLog(fmt.Sprintf("Monitor '%s' resumed", site.Name)) } return site.Paused } -func monitorRoutine(id int) { - checkByID(id) +func (e *Engine) monitorRoutine(ctx context.Context, id int) { + e.checkByID(id) for { - if !IsEngineActive() { - time.Sleep(5 * time.Second) + select { + case <-ctx.Done(): + return + default: + } + + if !e.IsActive() { + select { + case <-time.After(5 * time.Second): + case <-ctx.Done(): + return + } continue } - Mutex.RLock() - site, exists := LiveState[id] - Mutex.RUnlock() + e.mu.RLock() + site, exists := e.liveState[id] + e.mu.RUnlock() if !exists { return } if site.Paused { - time.Sleep(5 * time.Second) + select { + case <-time.After(5 * time.Second): + case <-ctx.Done(): + return + } continue } @@ -226,72 +290,52 @@ func monitorRoutine(id int) { if interval < 5 { interval = 5 } - time.Sleep(time.Duration(interval) * time.Second) - checkByID(id) + select { + case <-time.After(time.Duration(interval) * time.Second): + case <-ctx.Done(): + return + } + e.checkByID(id) } } -func checkByID(id int) { - if !IsEngineActive() { +func (e *Engine) checkByID(id int) { + if !e.IsActive() { return } - Mutex.RLock() - site, exists := LiveState[id] - Mutex.RUnlock() + e.mu.RLock() + site, exists := e.liveState[id] + e.mu.RUnlock() if !exists || site.Paused { return } switch site.Type { case "http": - checkHTTP(site) + e.checkHTTP(site) case "push": - checkPush(site) + e.checkPush(site) case "ping": - checkPing(site) + e.checkPing(site) case "port": - checkPort(site) + e.checkPort(site) case "dns": - checkDNS(site) + e.checkDNS(site) case "group": - checkGroup(site) + e.checkGroup(site) } } -func checkPush(site models.Site) { +func (e *Engine) checkPush(site models.Site) { deadline := site.LastCheck.Add(time.Duration(site.Interval) * time.Second).Add(5 * time.Second) if time.Now().After(deadline) { - handleStatusChange(site, "DOWN", 0, 0) - } else { - if site.Status != "UP" { - handleStatusChange(site, "UP", 200, 0) - } + e.handleStatusChange(site, "DOWN", 0, 0) + } else if site.Status != "UP" { + e.handleStatusChange(site, "UP", 200, 0) } } -func isCodeAccepted(code int, accepted string) bool { - if accepted == "" { - return code >= 200 && code < 300 - } - for _, part := range strings.Split(accepted, ",") { - part = strings.TrimSpace(part) - if strings.Contains(part, "-") { - bounds := strings.SplitN(part, "-", 2) - lo, err1 := strconv.Atoi(strings.TrimSpace(bounds[0])) - hi, err2 := strconv.Atoi(strings.TrimSpace(bounds[1])) - if err1 == nil && err2 == nil && code >= lo && code <= hi { - return true - } - } else { - if v, err := strconv.Atoi(part); err == nil && code == v { - return true - } - } - } - return false -} - -func checkHTTP(site models.Site) { +func (e *Engine) checkHTTP(site models.Site) { method := site.Method if method == "" { method = "GET" @@ -303,13 +347,13 @@ func checkHTTP(site models.Site) { req, err := http.NewRequestWithContext(ctx, method, site.URL, nil) if err != nil { - handleStatusChange(site, "DOWN", 0, 0) + e.handleStatusChange(site, "DOWN", 0, 0) return } - client := strictClient - if insecureSkipVerify || site.IgnoreTLS { - client = insecureClient + client := e.strictClient + if e.insecureSkipVerify || site.IgnoreTLS { + client = e.insecureClient } start := time.Now() @@ -343,12 +387,11 @@ func checkHTTP(site models.Site) { updatedSite.CertExpiry = certExpiry updatedSite.Latency = latency updatedSite.LastCheck = time.Now() - handleStatusChange(updatedSite, rawStatus, rawCode, latency) + e.handleStatusChange(updatedSite, rawStatus, rawCode, latency) } -func handleStatusChange(site models.Site, rawStatus string, code int, latency time.Duration) { - // Double check we are still leader before alerting - if !IsEngineActive() { +func (e *Engine) handleStatusChange(site models.Site, rawStatus string, code int, latency time.Duration) { + if !e.IsActive() { return } @@ -360,9 +403,9 @@ func handleStatusChange(site models.Site, rawStatus string, code int, latency ti if newState.FailureCount > site.MaxRetries { newState.Status = rawStatus newState.FailureCount = site.MaxRetries + 1 - AddLog(fmt.Sprintf("Monitor '%s' confirmed DOWN", site.Name)) + e.AddLog(fmt.Sprintf("Monitor '%s' confirmed DOWN", site.Name)) } else { - AddLog(fmt.Sprintf("Monitor '%s' failed check %d/%d", site.Name, newState.FailureCount, site.MaxRetries)) + e.AddLog(fmt.Sprintf("Monitor '%s' failed check %d/%d", site.Name, newState.FailureCount, site.MaxRetries)) } } else if rawStatus == "UP" { newState.FailureCount = 0 @@ -375,20 +418,20 @@ func handleStatusChange(site models.Site, rawStatus string, code int, latency ti if site.Type == "http" && site.CheckSSL && site.HasSSL { daysLeft := int(time.Until(site.CertExpiry).Hours() / 24) if daysLeft <= site.ExpiryThreshold && !site.SentSSLWarning && rawStatus != "SSL EXP" { - triggerAlert(site.AlertID, "SSL WARNING", fmt.Sprintf("SSL for '%s' expires in %d days", site.Name, daysLeft)) + e.triggerAlert(site.AlertID, "SSL WARNING", fmt.Sprintf("SSL for '%s' expires in %d days", site.Name, daysLeft)) newState.SentSSLWarning = true } else if daysLeft > site.ExpiryThreshold { newState.SentSSLWarning = false } } - Mutex.Lock() - if _, ok := LiveState[site.ID]; ok { - LiveState[site.ID] = newState + e.mu.Lock() + if _, ok := e.liveState[site.ID]; ok { + e.liveState[site.ID] = newState } - Mutex.Unlock() + e.mu.Unlock() - RecordCheck(site.ID, latency, rawStatus == "UP") + e.recordCheck(site.ID, latency, rawStatus == "UP") isBroken := func(s string) bool { return s == "DOWN" || s == "SSL EXP" } if !isBroken(site.Status) && isBroken(newState.Status) && newState.Status != "PENDING" { @@ -396,24 +439,26 @@ func handleStatusChange(site models.Site, rawStatus string, code int, latency ti if site.Type == "push" { msg = fmt.Sprintf("Push Monitor '%s' missed heartbeat.", site.Name) } - triggerAlert(site.AlertID, "🚨 ALERT", msg) + e.triggerAlert(site.AlertID, "🚨 ALERT", msg) } if isBroken(site.Status) && newState.Status == "UP" { - triggerAlert(site.AlertID, "✅ RECOVERY", fmt.Sprintf("Monitor '%s' is UP", site.Name)) + e.triggerAlert(site.AlertID, "✅ RECOVERY", fmt.Sprintf("Monitor '%s' is UP", site.Name)) } } -func triggerAlert(alertID int, title, message string) { - if db == nil { - return - } - cfg, err := db.GetAlert(alertID) +func (e *Engine) triggerAlert(alertID int, title, message string) { + cfg, err := e.db.GetAlert(alertID) if err != nil { return } provider := alert.GetProvider(cfg) if provider != nil { - go func() { provider.Send(title, message) }() + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + _ = ctx + _ = provider.Send(title, message) + }() } } @@ -424,7 +469,29 @@ func siteTimeout(site models.Site) time.Duration { return 5 * time.Second } -func checkPing(site models.Site) { +func isCodeAccepted(code int, accepted string) bool { + if accepted == "" { + return code >= 200 && code < 300 + } + for _, part := range strings.Split(accepted, ",") { + part = strings.TrimSpace(part) + if strings.Contains(part, "-") { + bounds := strings.SplitN(part, "-", 2) + lo, err1 := strconv.Atoi(strings.TrimSpace(bounds[0])) + hi, err2 := strconv.Atoi(strings.TrimSpace(bounds[1])) + if err1 == nil && err2 == nil && code >= lo && code <= hi { + return true + } + } else { + if v, err := strconv.Atoi(part); err == nil && code == v { + return true + } + } + } + return false +} + +func (e *Engine) checkPing(site models.Site) { host := site.Hostname if host == "" { host = site.URL @@ -432,8 +499,8 @@ func checkPing(site models.Site) { pinger, err := probing.NewPinger(host) if err != nil { - handleStatusChange(site, "DOWN", 0, 0) - AddLog(fmt.Sprintf("Ping '%s' resolve failed: %v", site.Name, err)) + e.handleStatusChange(site, "DOWN", 0, 0) + e.AddLog(fmt.Sprintf("Ping '%s' resolve failed: %v", site.Name, err)) return } pinger.Count = 1 @@ -448,7 +515,7 @@ func checkPing(site models.Site) { updatedSite := site updatedSite.Latency = latency updatedSite.LastCheck = time.Now() - handleStatusChange(updatedSite, "DOWN", 0, latency) + e.handleStatusChange(updatedSite, "DOWN", 0, latency) return } @@ -456,10 +523,10 @@ func checkPing(site models.Site) { updatedSite := site updatedSite.Latency = stats.AvgRtt updatedSite.LastCheck = time.Now() - handleStatusChange(updatedSite, "UP", 0, stats.AvgRtt) + e.handleStatusChange(updatedSite, "UP", 0, stats.AvgRtt) } -func checkPort(site models.Site) { +func (e *Engine) checkPort(site models.Site) { host := site.Hostname if host == "" { host = site.URL @@ -476,19 +543,19 @@ func checkPort(site models.Site) { updatedSite.LastCheck = time.Now() if err != nil { - handleStatusChange(updatedSite, "DOWN", 0, latency) + e.handleStatusChange(updatedSite, "DOWN", 0, latency) return } conn.Close() - handleStatusChange(updatedSite, "UP", 0, latency) + e.handleStatusChange(updatedSite, "UP", 0, latency) } -func checkGroup(site models.Site) { - Mutex.RLock() +func (e *Engine) checkGroup(site models.Site) { + e.mu.RLock() status := "UP" hasChildren := false allPaused := true - for _, child := range LiveState { + for _, child := range e.liveState { if child.ParentID != site.ID || child.Type == "group" { continue } @@ -505,23 +572,23 @@ func checkGroup(site models.Site) { status = "PENDING" } } - Mutex.RUnlock() + e.mu.RUnlock() if !hasChildren { status = "PENDING" } - Mutex.Lock() - s := LiveState[site.ID] + e.mu.Lock() + s := e.liveState[site.ID] s.Status = status if hasChildren && allPaused { s.Paused = true } - LiveState[site.ID] = s - Mutex.Unlock() + e.liveState[site.ID] = s + e.mu.Unlock() } -func checkDNS(site models.Site) { +func (e *Engine) checkDNS(site models.Site) { host := site.Hostname if host == "" { host = site.URL @@ -562,8 +629,7 @@ func checkDNS(site models.Site) { c.Timeout = siteTimeout(site) start := time.Now() - r, rtt, err := c.Exchange(m, server) - _ = rtt + r, _, err := c.Exchange(m, server) latency := time.Since(start) updatedSite := site @@ -571,14 +637,14 @@ func checkDNS(site models.Site) { updatedSite.LastCheck = time.Now() if err != nil { - handleStatusChange(updatedSite, "DOWN", 0, latency) + e.handleStatusChange(updatedSite, "DOWN", 0, latency) return } if r.Rcode != dns.RcodeSuccess { - handleStatusChange(updatedSite, "DOWN", r.Rcode, latency) + e.handleStatusChange(updatedSite, "DOWN", r.Rcode, latency) return } - handleStatusChange(updatedSite, "UP", 0, latency) + e.handleStatusChange(updatedSite, "UP", 0, latency) } diff --git a/internal/server/server.go b/internal/server/server.go index b6a2b52..ac26bd2 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -148,7 +148,7 @@ type ServerConfig struct { ClusterKey string // Shared Secret for Security } -func Start(cfg ServerConfig, s store.Store) { +func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) { if cfg.ClusterKey == "" { fmt.Println("WARNING: No UPKEEP_CLUSTER_SECRET set. Cluster API endpoints are unauthenticated.") } @@ -161,7 +161,7 @@ func Start(cfg ServerConfig, s store.Store) { http.Error(w, "Missing token", 400) return } - if monitor.RecordHeartbeat(token) { + if eng.RecordHeartbeat(token) { w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) } else { @@ -244,12 +244,10 @@ func Start(cfg ServerConfig, s store.Store) { // 6. Status Page if cfg.EnableStatus { - mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) { renderStatusPage(w, cfg.Title) }) + mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) { renderStatusPage(w, cfg.Title, eng) }) mux.HandleFunc("/status/json", func(w http.ResponseWriter, r *http.Request) { - monitor.Mutex.RLock() - defer monitor.Mutex.RUnlock() w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(monitor.LiveState) + json.NewEncoder(w).Encode(eng.GetLiveState()) }) } @@ -262,13 +260,8 @@ func Start(cfg ServerConfig, s store.Store) { }() } -func renderStatusPage(w http.ResponseWriter, title string) { - monitor.Mutex.RLock() - var sites []models.Site - for _, s := range monitor.LiveState { - sites = append(sites, s) - } - monitor.Mutex.RUnlock() +func renderStatusPage(w http.ResponseWriter, title string, eng *monitor.Engine) { + sites := eng.GetAllSites() sort.Slice(sites, func(i, j int) bool { if sites[i].Status != sites[j].Status { diff --git a/internal/store/sqlstore_test.go b/internal/store/sqlstore_test.go new file mode 100644 index 0000000..7ae2bd9 --- /dev/null +++ b/internal/store/sqlstore_test.go @@ -0,0 +1,231 @@ +package store + +import ( + "go-upkeep/internal/models" + "testing" +) + +func newTestStore(t *testing.T) *SQLStore { + t.Helper() + s, err := NewSQLiteStore(":memory:") + if err != nil { + t.Fatalf("NewSQLiteStore: %v", err) + } + if err := s.Init(); err != nil { + t.Fatalf("Init: %v", err) + } + return s +} + +func TestSiteCRUD(t *testing.T) { + s := newTestStore(t) + + sites, err := s.GetSites() + if err != nil { + t.Fatalf("GetSites: %v", err) + } + if len(sites) != 0 { + t.Fatalf("expected 0 sites, got %d", len(sites)) + } + + if err := s.AddSite(models.Site{Name: "Test", URL: "https://example.com", Type: "http", Interval: 30}); err != nil { + t.Fatalf("AddSite: %v", err) + } + + sites, err = s.GetSites() + if err != nil { + t.Fatalf("GetSites: %v", err) + } + if len(sites) != 1 { + t.Fatalf("expected 1 site, got %d", len(sites)) + } + if sites[0].Name != "Test" { + t.Errorf("expected name 'Test', got '%s'", sites[0].Name) + } + + sites[0].Name = "Updated" + if err := s.UpdateSite(sites[0]); err != nil { + t.Fatalf("UpdateSite: %v", err) + } + + sites, _ = s.GetSites() + if sites[0].Name != "Updated" { + t.Errorf("expected name 'Updated', got '%s'", sites[0].Name) + } + + if err := s.DeleteSite(sites[0].ID); err != nil { + t.Fatalf("DeleteSite: %v", err) + } + + sites, _ = s.GetSites() + if len(sites) != 0 { + t.Fatalf("expected 0 sites after delete, got %d", len(sites)) + } +} + +func TestAlertCRUD(t *testing.T) { + s := newTestStore(t) + + if err := s.AddAlert("Discord", "discord", map[string]string{"url": "https://example.com/hook"}); err != nil { + t.Fatalf("AddAlert: %v", err) + } + + alerts, err := s.GetAllAlerts() + if err != nil { + t.Fatalf("GetAllAlerts: %v", err) + } + if len(alerts) != 1 { + t.Fatalf("expected 1 alert, got %d", len(alerts)) + } + if alerts[0].Type != "discord" { + t.Errorf("expected type 'discord', got '%s'", alerts[0].Type) + } + if alerts[0].Settings["url"] != "https://example.com/hook" { + t.Errorf("settings url mismatch") + } + + a, err := s.GetAlert(alerts[0].ID) + if err != nil { + t.Fatalf("GetAlert: %v", err) + } + if a.Name != "Discord" { + t.Errorf("expected name 'Discord', got '%s'", a.Name) + } + + if err := s.UpdateAlert(a.ID, "Slack", "slack", map[string]string{"url": "https://slack.com/hook"}); err != nil { + t.Fatalf("UpdateAlert: %v", err) + } + + a, _ = s.GetAlert(a.ID) + if a.Type != "slack" { + t.Errorf("expected type 'slack', got '%s'", a.Type) + } + + if err := s.DeleteAlert(a.ID); err != nil { + t.Fatalf("DeleteAlert: %v", err) + } + + alerts, _ = s.GetAllAlerts() + if len(alerts) != 0 { + t.Fatalf("expected 0 alerts after delete, got %d", len(alerts)) + } +} + +func TestUserCRUD(t *testing.T) { + s := newTestStore(t) + + if err := s.AddUser("admin", "ssh-ed25519 AAAA...", "admin"); err != nil { + t.Fatalf("AddUser: %v", err) + } + + users, err := s.GetAllUsers() + if err != nil { + t.Fatalf("GetAllUsers: %v", err) + } + if len(users) != 1 { + t.Fatalf("expected 1 user, got %d", len(users)) + } + if users[0].Username != "admin" { + t.Errorf("expected username 'admin', got '%s'", users[0].Username) + } + + if err := s.UpdateUser(users[0].ID, "root", "ssh-ed25519 BBBB...", "admin"); err != nil { + t.Fatalf("UpdateUser: %v", err) + } + + users, _ = s.GetAllUsers() + if users[0].Username != "root" { + t.Errorf("expected username 'root', got '%s'", users[0].Username) + } + + if err := s.DeleteUser(users[0].ID); err != nil { + t.Fatalf("DeleteUser: %v", err) + } + + users, _ = s.GetAllUsers() + if len(users) != 0 { + t.Fatalf("expected 0 users after delete, got %d", len(users)) + } +} + +func TestPushTokenGeneration(t *testing.T) { + s := newTestStore(t) + + if err := s.AddSite(models.Site{Name: "Push Monitor", Type: "push", Interval: 60}); err != nil { + t.Fatalf("AddSite: %v", err) + } + + sites, _ := s.GetSites() + if len(sites) != 1 { + t.Fatalf("expected 1 site, got %d", len(sites)) + } + if sites[0].Token == "" { + t.Error("expected non-empty token for push monitor") + } + if len(sites[0].Token) != 32 { + t.Errorf("expected 32-char hex token, got %d chars", len(sites[0].Token)) + } +} + +func TestImportExport(t *testing.T) { + s := newTestStore(t) + + s.AddAlert("Test Alert", "webhook", map[string]string{"url": "https://example.com"}) + s.AddSite(models.Site{Name: "Site1", URL: "https://example.com", Type: "http", Interval: 30}) + s.AddUser("user1", "ssh-ed25519 KEY", "user") + + backup, err := s.ExportData() + if err != nil { + t.Fatalf("ExportData: %v", err) + } + if len(backup.Sites) != 1 || len(backup.Alerts) != 1 || len(backup.Users) != 1 { + t.Fatalf("export mismatch: %d sites, %d alerts, %d users", len(backup.Sites), len(backup.Alerts), len(backup.Users)) + } + + s2 := newTestStore(t) + if err := s2.ImportData(backup); err != nil { + t.Fatalf("ImportData: %v", err) + } + + sites, _ := s2.GetSites() + alerts, _ := s2.GetAllAlerts() + users, _ := s2.GetAllUsers() + if len(sites) != 1 || len(alerts) != 1 || len(users) != 1 { + t.Fatalf("import mismatch: %d sites, %d alerts, %d users", len(sites), len(alerts), len(users)) + } +} + +func TestCheckHistory(t *testing.T) { + s := newTestStore(t) + + if err := s.SaveCheck(1, 5000000, true); err != nil { + t.Fatalf("SaveCheck: %v", err) + } + if err := s.SaveCheck(1, 10000000, false); err != nil { + t.Fatalf("SaveCheck: %v", err) + } + if err := s.SaveCheck(2, 3000000, true); err != nil { + t.Fatalf("SaveCheck site 2: %v", err) + } + + history, err := s.LoadAllHistory(10) + if err != nil { + t.Fatalf("LoadAllHistory: %v", err) + } + if len(history[1]) != 2 { + t.Fatalf("expected 2 records for site 1, got %d", len(history[1])) + } + if len(history[2]) != 1 { + t.Fatalf("expected 1 record for site 2, got %d", len(history[2])) + } + + upCount := 0 + for _, r := range history[1] { + if r.IsUp { + upCount++ + } + } + if upCount != 1 { + t.Errorf("expected 1 up record for site 1, got %d", upCount) + } +} diff --git a/internal/tui/tab_alerts.go b/internal/tui/tab_alerts.go index 60b1890..11c9bf6 100644 --- a/internal/tui/tab_alerts.go +++ b/internal/tui/tab_alerts.go @@ -2,7 +2,6 @@ package tui import ( "fmt" - "go-upkeep/internal/monitor" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/huh" @@ -237,11 +236,11 @@ func (m *Model) submitAlertForm() { if m.editID > 0 { if err := m.store.UpdateAlert(m.editID, d.Name, d.AlertType, settings); err != nil { - monitor.AddLog("Update alert failed: " + err.Error()) + m.engine.AddLog("Update alert failed: " + err.Error()) } } else { if err := m.store.AddAlert(d.Name, d.AlertType, settings); err != nil { - monitor.AddLog("Add alert failed: " + err.Error()) + m.engine.AddLog("Add alert failed: " + err.Error()) } } m.state = stateDashboard diff --git a/internal/tui/tab_sites.go b/internal/tui/tab_sites.go index e5ebbfe..4b6ad00 100644 --- a/internal/tui/tab_sites.go +++ b/internal/tui/tab_sites.go @@ -3,7 +3,6 @@ package tui import ( "fmt" "go-upkeep/internal/models" - "go-upkeep/internal/monitor" "net/url" "strconv" "strings" @@ -243,7 +242,7 @@ func (m Model) viewSitesTab() string { name = limitStr(name, 13) } - hist, _ := monitor.GetHistory(site.ID) + hist, _ := m.engine.GetHistory(site.ID) var spark string if site.Type == "push" { spark = heartbeatSparkline(hist.Statuses, sparkWidth) @@ -508,12 +507,12 @@ func (m *Model) submitSiteForm() { if m.editID > 0 { if err := m.store.UpdateSite(site); err != nil { - monitor.AddLog("Update site failed: " + err.Error()) + m.engine.AddLog("Update site failed: " + err.Error()) } - monitor.UpdateSiteConfig(site) + m.engine.UpdateSiteConfig(site) } else { if err := m.store.AddSite(site); err != nil { - monitor.AddLog("Add site failed: " + err.Error()) + m.engine.AddLog("Add site failed: " + err.Error()) } } m.state = stateDashboard diff --git a/internal/tui/tab_users.go b/internal/tui/tab_users.go index 46c5679..019bb03 100644 --- a/internal/tui/tab_users.go +++ b/internal/tui/tab_users.go @@ -2,7 +2,6 @@ package tui import ( "fmt" - "go-upkeep/internal/monitor" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/huh" @@ -104,11 +103,11 @@ func (m *Model) submitUserForm() { d := m.userFormData if m.editID > 0 { if err := m.store.UpdateUser(m.editID, d.Username, d.PublicKey, d.Role); err != nil { - monitor.AddLog("Update user failed: " + err.Error()) + m.engine.AddLog("Update user failed: " + err.Error()) } } else { if err := m.store.AddUser(d.Username, d.PublicKey, d.Role); err != nil { - monitor.AddLog("Add user failed: " + err.Error()) + m.engine.AddLog("Add user failed: " + err.Error()) } } m.state = stateUsers diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 5e0e7b6..89846a5 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -69,6 +69,7 @@ type Model struct { collapsed map[int]bool store store.Store + engine *monitor.Engine // harmonica animation state pulseSpring harmonica.Spring @@ -81,7 +82,7 @@ type Model struct { users []models.User } -func InitialModel(isAdmin bool, s store.Store) Model { +func InitialModel(isAdmin bool, s store.Store, eng *monitor.Engine) Model { vpLogs := viewport.New(100, 20) vpLogs.SetContent("Waiting for logs...") z := zone.New() @@ -92,6 +93,7 @@ func InitialModel(isAdmin bool, s store.Store) Model { maxTableRows: 5, isAdmin: isAdmin, store: s, + engine: eng, zones: z, pulseSpring: spring, collapsed: make(map[int]bool), @@ -112,18 +114,18 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch m.deleteTab { case 0: if err := m.store.DeleteSite(m.deleteID); err != nil { - monitor.AddLog("Delete site failed: " + err.Error()) + m.engine.AddLog("Delete site failed: " + err.Error()) } - monitor.RemoveSite(m.deleteID) + m.engine.RemoveSite(m.deleteID) m.adjustCursor(len(m.sites) - 1) case 1: if err := m.store.DeleteAlert(m.deleteID); err != nil { - monitor.AddLog("Delete alert failed: " + err.Error()) + m.engine.AddLog("Delete alert failed: " + err.Error()) } m.adjustCursor(len(m.alerts) - 1) case 3: if err := m.store.DeleteUser(m.deleteID); err != nil { - monitor.AddLog("Delete user failed: " + err.Error()) + m.engine.AddLog("Delete user failed: " + err.Error()) } m.adjustCursor(len(m.users) - 1) } @@ -317,7 +319,7 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case "p": if m.currentTab == 0 && len(m.sites) > 0 { site := m.sites[m.cursor] - monitor.ToggleSitePause(site.ID) + m.engine.ToggleSitePause(site.ID) site.Paused = !site.Paused _ = m.store.UpdateSitePaused(site.ID, site.Paused) m.refreshData() @@ -433,12 +435,7 @@ func (m *Model) adjustCursor(newLen int) { } func (m *Model) refreshData() { - monitor.Mutex.RLock() - var allSites []models.Site - for _, s := range monitor.LiveState { - allSites = append(allSites, s) - } - monitor.Mutex.RUnlock() + allSites := m.engine.GetAllSites() var groups, ungrouped []models.Site children := make(map[int][]models.Site) @@ -476,7 +473,7 @@ func (m *Model) refreshData() { m.users = users } } - m.logViewport.SetContent(strings.Join(monitor.GetLogs(), "\n")) + m.logViewport.SetContent(strings.Join(m.engine.GetLogs(), "\n")) listLen := len(m.sites) if m.currentTab == 1 { From 52a54f9c5cc33c8832b3090c23d3d085cd4623a6 Mon Sep 17 00:00:00 2001 From: Tyler Koenig Date: Fri, 15 May 2026 10:53:38 -0400 Subject: [PATCH 8/9] feat(alert): add Telegram, PagerDuty, Pushover, Gotify providers Expand alert provider count from 5 to 9. All new providers use the shared HTTPProvider with closure-based payload functions. Includes TUI form support and tests for each provider. --- internal/alert/alert.go | 76 ++++++++++++++++++++ internal/alert/alert_test.go | 104 +++++++++++++++++++++++++++ internal/tui/tab_alerts.go | 131 ++++++++++++++++++++++++++++++++++- 3 files changed, 308 insertions(+), 3 deletions(-) diff --git a/internal/alert/alert.go b/internal/alert/alert.go index d67a30d..c60013f 100644 --- a/internal/alert/alert.go +++ b/internal/alert/alert.go @@ -7,6 +7,7 @@ import ( "go-upkeep/internal/models" "net/http" "net/smtp" + "strconv" "strings" "time" ) @@ -52,6 +53,52 @@ func webhookPayload(title, message string) ([]byte, error) { return json.Marshal(map[string]string{"title": title, "message": message, "status": "alert"}) } +func telegramPayload(chatID string) PayloadFunc { + return func(title, message string) ([]byte, error) { + return json.Marshal(map[string]string{ + "chat_id": chatID, + "text": fmt.Sprintf("*%s*\n%s", title, message), + "parse_mode": "Markdown", + }) + } +} + +func pagerdutyPayload(routingKey, severity string) PayloadFunc { + return func(title, message string) ([]byte, error) { + return json.Marshal(map[string]any{ + "routing_key": routingKey, + "event_action": "trigger", + "payload": map[string]string{ + "summary": fmt.Sprintf("%s: %s", title, message), + "source": "go-upkeep", + "severity": severity, + }, + }) + } +} + +func pushoverPayload(token, user string) PayloadFunc { + return func(title, message string) ([]byte, error) { + return json.Marshal(map[string]string{ + "token": token, + "user": user, + "title": title, + "message": message, + }) + } +} + +func gotifyPayload(priority string) PayloadFunc { + return func(title, message string) ([]byte, error) { + pri, _ := strconv.Atoi(priority) + return json.Marshal(map[string]any{ + "title": title, + "message": message, + "priority": pri, + }) + } +} + func GetProvider(cfg models.AlertConfig) Provider { switch cfg.Type { case "discord": @@ -85,6 +132,35 @@ func GetProvider(cfg models.AlertConfig) Provider { Username: cfg.Settings["username"], Password: cfg.Settings["password"], } + case "telegram": + return &HTTPProvider{ + URL: fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", cfg.Settings["token"]), + Payload: telegramPayload(cfg.Settings["chat_id"]), + } + case "pagerduty": + severity := "critical" + if s, ok := cfg.Settings["severity"]; ok && s != "" { + severity = s + } + return &HTTPProvider{ + URL: "https://events.pagerduty.com/v2/enqueue", + Payload: pagerdutyPayload(cfg.Settings["routing_key"], severity), + } + case "pushover": + return &HTTPProvider{ + URL: "https://api.pushover.net/1/messages.json", + Payload: pushoverPayload(cfg.Settings["token"], cfg.Settings["user"]), + } + case "gotify": + priority := "5" + if p, ok := cfg.Settings["priority"]; ok && p != "" { + priority = p + } + serverURL := strings.TrimRight(cfg.Settings["url"], "/") + return &HTTPProvider{ + URL: fmt.Sprintf("%s/message?token=%s", serverURL, cfg.Settings["token"]), + Payload: gotifyPayload(priority), + } default: return nil } diff --git a/internal/alert/alert_test.go b/internal/alert/alert_test.go index 348f2c9..35e1c8d 100644 --- a/internal/alert/alert_test.go +++ b/internal/alert/alert_test.go @@ -101,6 +101,110 @@ func TestNtfyProvider(t *testing.T) { } } +func TestHTTPProviderTelegram(t *testing.T) { + var received map[string]string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&received) + w.WriteHeader(200) + })) + defer srv.Close() + + p := &HTTPProvider{URL: srv.URL, Payload: telegramPayload("12345")} + if err := p.Send("Alert", "Down"); err != nil { + t.Fatalf("Send: %v", err) + } + if received["chat_id"] != "12345" { + t.Errorf("expected chat_id '12345', got '%s'", received["chat_id"]) + } + if received["text"] != "*Alert*\nDown" { + t.Errorf("unexpected text: %s", received["text"]) + } + if received["parse_mode"] != "Markdown" { + t.Errorf("expected parse_mode 'Markdown', got '%s'", received["parse_mode"]) + } +} + +func TestHTTPProviderPagerDuty(t *testing.T) { + var received map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&received) + w.WriteHeader(200) + })) + defer srv.Close() + + p := &HTTPProvider{URL: srv.URL, Payload: pagerdutyPayload("test-key", "critical")} + if err := p.Send("Alert", "Down"); err != nil { + t.Fatalf("Send: %v", err) + } + if received["routing_key"] != "test-key" { + t.Errorf("expected routing_key 'test-key', got '%v'", received["routing_key"]) + } + if received["event_action"] != "trigger" { + t.Errorf("expected event_action 'trigger', got '%v'", received["event_action"]) + } + payload := received["payload"].(map[string]any) + if payload["summary"] != "Alert: Down" { + t.Errorf("unexpected summary: %v", payload["summary"]) + } + if payload["severity"] != "critical" { + t.Errorf("expected severity 'critical', got '%v'", payload["severity"]) + } +} + +func TestHTTPProviderPushover(t *testing.T) { + var received map[string]string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&received) + w.WriteHeader(200) + })) + defer srv.Close() + + p := &HTTPProvider{URL: srv.URL, Payload: pushoverPayload("app-tok", "user-key")} + if err := p.Send("Alert", "Down"); err != nil { + t.Fatalf("Send: %v", err) + } + if received["token"] != "app-tok" { + t.Errorf("expected token 'app-tok', got '%s'", received["token"]) + } + if received["user"] != "user-key" { + t.Errorf("expected user 'user-key', got '%s'", received["user"]) + } + if received["title"] != "Alert" || received["message"] != "Down" { + t.Errorf("unexpected payload: %v", received) + } +} + +func TestHTTPProviderGotify(t *testing.T) { + var received map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&received) + w.WriteHeader(200) + })) + defer srv.Close() + + p := &HTTPProvider{URL: srv.URL, Payload: gotifyPayload("8")} + if err := p.Send("Alert", "Down"); err != nil { + t.Fatalf("Send: %v", err) + } + if received["title"] != "Alert" || received["message"] != "Down" { + t.Errorf("unexpected payload: %v", received) + } + if pri, ok := received["priority"].(float64); !ok || pri != 8 { + t.Errorf("expected priority 8, got %v", received["priority"]) + } +} + +func TestGetProviderNewTypes(t *testing.T) { + for _, typ := range []string{"telegram", "pagerduty", "pushover", "gotify"} { + p := GetProvider(models.AlertConfig{Type: typ, Settings: map[string]string{ + "token": "x", "chat_id": "1", "routing_key": "k", "user": "u", "url": "http://localhost", + }}) + if p == nil { + t.Errorf("GetProvider(%q) returned nil", typ) + } + } +} + func TestGetProviderUnknown(t *testing.T) { p := GetProvider(models.AlertConfig{Type: "unknown"}) if p != nil { diff --git a/internal/tui/tab_alerts.go b/internal/tui/tab_alerts.go index 11c9bf6..342e1bd 100644 --- a/internal/tui/tab_alerts.go +++ b/internal/tui/tab_alerts.go @@ -23,6 +23,19 @@ type alertFormData struct { NtfyUser string NtfyPass string NtfyPri string + // Telegram + TelegramToken string + TelegramChatID string + // PagerDuty + PagerDutyKey string + PagerDutySeverity string + // Pushover + PushoverToken string + PushoverUser string + // Gotify + GotifyURL string + GotifyToken string + GotifyPriority string } func fmtAlertType(t string) string { @@ -37,6 +50,14 @@ func fmtAlertType(t string) string { return lipgloss.NewStyle().Foreground(lipgloss.Color("#73F59F")).Render(t) case "ntfy": return lipgloss.NewStyle().Foreground(lipgloss.Color("#FF6B6B")).Render(t) + case "telegram": + return lipgloss.NewStyle().Foreground(lipgloss.Color("#26A5E4")).Render(t) + case "pagerduty": + return lipgloss.NewStyle().Foreground(lipgloss.Color("#06AC38")).Render(t) + case "pushover": + return lipgloss.NewStyle().Foreground(lipgloss.Color("#249DF1")).Render(t) + case "gotify": + return lipgloss.NewStyle().Foreground(lipgloss.Color("#3F8BBA")).Render(t) default: return t } @@ -64,6 +85,26 @@ func fmtAlertConfig(alert struct { return limitStr(fmt.Sprintf("%s/%s", url, topic), 34) } return subtleStyle.Render("—") + case "telegram": + if id := alert.Settings["chat_id"]; id != "" { + return limitStr(fmt.Sprintf("chat:%s", id), 34) + } + return subtleStyle.Render("—") + case "pagerduty": + if key := alert.Settings["routing_key"]; key != "" { + return limitStr(key, 34) + } + return subtleStyle.Render("—") + case "pushover": + if user := alert.Settings["user"]; user != "" { + return limitStr(fmt.Sprintf("user:%s", user), 34) + } + return subtleStyle.Render("—") + case "gotify": + if url := alert.Settings["url"]; url != "" { + return limitStr(url, 34) + } + return subtleStyle.Render("—") default: if val, ok := alert.Settings["url"]; ok { return limitStr(val, 34) @@ -102,8 +143,10 @@ func (m Model) viewAlertsTab() string { func (m *Model) initAlertHuhForm() tea.Cmd { m.alertFormData = &alertFormData{ - AlertType: "discord", - NtfyPri: "3", + AlertType: "discord", + NtfyPri: "3", + PagerDutySeverity: "critical", + GotifyPriority: "5", } if m.editID > 0 { @@ -128,6 +171,19 @@ func (m *Model) initAlertHuhForm() tea.Cmd { m.alertFormData.NtfyUser = alert.Settings["username"] m.alertFormData.NtfyPass = alert.Settings["password"] m.alertFormData.NtfyPri = alert.Settings["priority"] + case "telegram": + m.alertFormData.TelegramToken = alert.Settings["token"] + m.alertFormData.TelegramChatID = alert.Settings["chat_id"] + case "pagerduty": + m.alertFormData.PagerDutyKey = alert.Settings["routing_key"] + m.alertFormData.PagerDutySeverity = alert.Settings["severity"] + case "pushover": + m.alertFormData.PushoverToken = alert.Settings["token"] + m.alertFormData.PushoverUser = alert.Settings["user"] + case "gotify": + m.alertFormData.GotifyURL = alert.Settings["url"] + m.alertFormData.GotifyToken = alert.Settings["token"] + m.alertFormData.GotifyPriority = alert.Settings["priority"] } break } @@ -152,6 +208,10 @@ func (m *Model) initAlertHuhForm() tea.Cmd { huh.NewOption("Webhook", "webhook"), huh.NewOption("Email (SMTP)", "email"), huh.NewOption("Ntfy", "ntfy"), + huh.NewOption("Telegram", "telegram"), + huh.NewOption("PagerDuty", "pagerduty"), + huh.NewOption("Pushover", "pushover"), + huh.NewOption("Gotify", "gotify"), ).Value(&m.alertFormData.AlertType), ).Title("Alert Config"), huh.NewGroup( @@ -159,7 +219,8 @@ func (m *Model) initAlertHuhForm() tea.Cmd { Placeholder("https://discord.com/api/webhooks/..."). Value(&m.alertFormData.WebhookURL), ).Title("Webhook").WithHideFunc(func() bool { - return m.alertFormData.AlertType == "email" || m.alertFormData.AlertType == "ntfy" + t := m.alertFormData.AlertType + return t != "discord" && t != "slack" && t != "webhook" }), huh.NewGroup( huh.NewInput().Title("Ntfy Server URL"). @@ -207,6 +268,57 @@ func (m *Model) initAlertHuhForm() tea.Cmd { ).Title("Email Settings").WithHideFunc(func() bool { return m.alertFormData.AlertType != "email" }), + huh.NewGroup( + huh.NewInput().Title("Bot Token"). + Placeholder("123456:ABC-DEF1234..."). + Value(&m.alertFormData.TelegramToken), + huh.NewInput().Title("Chat ID"). + Placeholder("-1001234567890"). + Value(&m.alertFormData.TelegramChatID), + ).Title("Telegram Settings").WithHideFunc(func() bool { + return m.alertFormData.AlertType != "telegram" + }), + huh.NewGroup( + huh.NewInput().Title("Routing Key"). + Placeholder("your-integration-routing-key"). + Value(&m.alertFormData.PagerDutyKey), + huh.NewSelect[string]().Title("Severity"). + Options( + huh.NewOption("Critical", "critical"), + huh.NewOption("Error", "error"), + huh.NewOption("Warning", "warning"), + huh.NewOption("Info", "info"), + ).Value(&m.alertFormData.PagerDutySeverity), + ).Title("PagerDuty Settings").WithHideFunc(func() bool { + return m.alertFormData.AlertType != "pagerduty" + }), + huh.NewGroup( + huh.NewInput().Title("App Token"). + Placeholder("your-pushover-app-token"). + Value(&m.alertFormData.PushoverToken), + huh.NewInput().Title("User Key"). + Placeholder("your-pushover-user-key"). + Value(&m.alertFormData.PushoverUser), + ).Title("Pushover Settings").WithHideFunc(func() bool { + return m.alertFormData.AlertType != "pushover" + }), + huh.NewGroup( + huh.NewInput().Title("Server URL"). + Placeholder("https://gotify.example.com"). + Value(&m.alertFormData.GotifyURL), + huh.NewInput().Title("App Token"). + Placeholder("your-gotify-app-token"). + Value(&m.alertFormData.GotifyToken), + huh.NewSelect[string]().Title("Priority"). + Options( + huh.NewOption("Min (0)", "0"), + huh.NewOption("Low (2)", "2"), + huh.NewOption("Normal (5)", "5"), + huh.NewOption("High (8)", "8"), + ).Value(&m.alertFormData.GotifyPriority), + ).Title("Gotify Settings").WithHideFunc(func() bool { + return m.alertFormData.AlertType != "gotify" + }), ).WithTheme(huh.ThemeDracula()) return m.huhForm.Init() @@ -230,6 +342,19 @@ func (m *Model) submitAlertForm() { settings["priority"] = d.NtfyPri settings["username"] = d.NtfyUser settings["password"] = d.NtfyPass + case "telegram": + settings["token"] = d.TelegramToken + settings["chat_id"] = d.TelegramChatID + case "pagerduty": + settings["routing_key"] = d.PagerDutyKey + settings["severity"] = d.PagerDutySeverity + case "pushover": + settings["token"] = d.PushoverToken + settings["user"] = d.PushoverUser + case "gotify": + settings["url"] = d.GotifyURL + settings["token"] = d.GotifyToken + settings["priority"] = d.GotifyPriority default: settings["url"] = d.WebhookURL } From b7b8aa6f03678587dd0fb804fd1b25763a96872a Mon Sep 17 00:00:00 2001 From: Tyler Koenig Date: Fri, 15 May 2026 11:26:21 -0400 Subject: [PATCH 9/9] feat(metrics): add Prometheus /metrics endpoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Zero-dependency Prometheus text exposition format. Exposes monitor up/down, latency, status code, check timestamps, pause state, SSL cert expiry, and check counters — all from in-memory state. --- internal/metrics/prometheus.go | 99 +++++++++++++++++++++++++++++ internal/metrics/prometheus_test.go | 96 ++++++++++++++++++++++++++++ internal/server/server.go | 6 +- 3 files changed, 200 insertions(+), 1 deletion(-) create mode 100644 internal/metrics/prometheus.go create mode 100644 internal/metrics/prometheus_test.go diff --git a/internal/metrics/prometheus.go b/internal/metrics/prometheus.go new file mode 100644 index 0000000..24f4faa --- /dev/null +++ b/internal/metrics/prometheus.go @@ -0,0 +1,99 @@ +package metrics + +import ( + "fmt" + "go-upkeep/internal/models" + "go-upkeep/internal/monitor" + "net/http" + "sort" + "strings" +) + +func Handler(eng *monitor.Engine) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + sites := eng.GetAllSites() + sort.Slice(sites, func(i, j int) bool { return sites[i].ID < sites[j].ID }) + + var b strings.Builder + + writeHelp(&b, "upkeep_monitor_up", "gauge", "Whether the monitor is up (1) or down (0).") + for _, s := range sites { + val := 0 + if s.Status == "UP" { + val = 1 + } + writeGauge(&b, "upkeep_monitor_up", labels(s), float64(val)) + } + + writeHelp(&b, "upkeep_monitor_latency_seconds", "gauge", "Last check latency in seconds.") + for _, s := range sites { + writeGauge(&b, "upkeep_monitor_latency_seconds", labels(s), s.Latency.Seconds()) + } + + writeHelp(&b, "upkeep_monitor_status_code", "gauge", "HTTP response status code of the last check.") + for _, s := range sites { + if s.Type != "http" { + continue + } + writeGauge(&b, "upkeep_monitor_status_code", labels(s), float64(s.StatusCode)) + } + + writeHelp(&b, "upkeep_monitor_check_timestamp_seconds", "gauge", "Unix timestamp of the last check.") + for _, s := range sites { + if s.LastCheck.IsZero() { + continue + } + writeGauge(&b, "upkeep_monitor_check_timestamp_seconds", labels(s), float64(s.LastCheck.Unix())) + } + + writeHelp(&b, "upkeep_monitor_paused", "gauge", "Whether the monitor is paused (1) or active (0).") + for _, s := range sites { + val := 0 + if s.Paused { + val = 1 + } + writeGauge(&b, "upkeep_monitor_paused", labels(s), float64(val)) + } + + writeHelp(&b, "upkeep_monitor_cert_expiry_timestamp_seconds", "gauge", "Unix timestamp when the SSL certificate expires.") + for _, s := range sites { + if !s.HasSSL || s.CertExpiry.IsZero() { + continue + } + writeGauge(&b, "upkeep_monitor_cert_expiry_timestamp_seconds", labels(s), float64(s.CertExpiry.Unix())) + } + + writeHelp(&b, "upkeep_monitor_checks_total", "counter", "Total number of checks performed.") + writeHelp(&b, "upkeep_monitor_checks_up_total", "counter", "Total number of successful checks.") + for _, s := range sites { + h, ok := eng.GetHistory(s.ID) + if !ok { + continue + } + writeGauge(&b, "upkeep_monitor_checks_total", labels(s), float64(h.TotalChecks)) + writeGauge(&b, "upkeep_monitor_checks_up_total", labels(s), float64(h.UpChecks)) + } + + w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8") + w.Write([]byte(b.String())) + } +} + +func labels(s models.Site) string { + return fmt.Sprintf(`id="%d",name="%s",type="%s"`, s.ID, escapeLabelValue(s.Name), s.Type) +} + +func escapeLabelValue(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, `"`, `\"`) + s = strings.ReplaceAll(s, "\n", `\n`) + return s +} + +func writeHelp(b *strings.Builder, name, typ, help string) { + fmt.Fprintf(b, "# HELP %s %s\n# TYPE %s %s\n", name, help, name, typ) +} + +func writeGauge(b *strings.Builder, name, labels string, val float64) { + fmt.Fprintf(b, "%s{%s} %g\n", name, labels, val) +} diff --git a/internal/metrics/prometheus_test.go b/internal/metrics/prometheus_test.go new file mode 100644 index 0000000..7cbf680 --- /dev/null +++ b/internal/metrics/prometheus_test.go @@ -0,0 +1,96 @@ +package metrics + +import ( + "context" + "go-upkeep/internal/models" + "go-upkeep/internal/monitor" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +type mockStore struct { + sites []models.Site +} + +func (m *mockStore) Init() error { return nil } +func (m *mockStore) GetSites() ([]models.Site, error) { return m.sites, nil } +func (m *mockStore) AddSite(models.Site) error { return nil } +func (m *mockStore) UpdateSite(models.Site) error { return nil } +func (m *mockStore) UpdateSitePaused(int, bool) error { return nil } +func (m *mockStore) DeleteSite(int) error { return nil } +func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { return nil, nil } +func (m *mockStore) GetAlert(int) (models.AlertConfig, error) { return models.AlertConfig{}, nil } +func (m *mockStore) AddAlert(string, string, map[string]string) error { return nil } +func (m *mockStore) UpdateAlert(int, string, string, map[string]string) error { return nil } +func (m *mockStore) DeleteAlert(int) error { return nil } +func (m *mockStore) GetAllUsers() ([]models.User, error) { return nil, nil } +func (m *mockStore) AddUser(string, string, string) error { return nil } +func (m *mockStore) UpdateUser(int, string, string, string) error { return nil } +func (m *mockStore) DeleteUser(int) error { return nil } +func (m *mockStore) SaveCheck(int, int64, bool) error { return nil } +func (m *mockStore) LoadAllHistory(int) (map[int][]models.CheckRecord, error) { + return nil, nil +} +func (m *mockStore) ExportData() (models.Backup, error) { return models.Backup{}, nil } +func (m *mockStore) ImportData(models.Backup) error { return nil } + +func TestMetricsHandler(t *testing.T) { + ms := &mockStore{ + sites: []models.Site{ + {ID: 1, Name: "Example", URL: "https://example.com", Type: "http", Interval: 30}, + {ID: 2, Name: "DNS Check", Type: "dns", Interval: 60}, + }, + } + eng := monitor.NewEngine(ms) + ctx, cancel := context.WithCancel(context.Background()) + eng.Start(ctx) + time.Sleep(100 * time.Millisecond) + + rec := httptest.NewRecorder() + Handler(eng)(rec, httptest.NewRequest("GET", "/metrics", nil)) + cancel() + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + + body := rec.Body.String() + + ct := rec.Header().Get("Content-Type") + if !strings.Contains(ct, "text/plain") { + t.Errorf("expected text/plain content type, got %q", ct) + } + + expected := []string{ + "# HELP upkeep_monitor_up", + "# TYPE upkeep_monitor_up gauge", + `upkeep_monitor_up{id="1",name="Example",type="http"}`, + `upkeep_monitor_up{id="2",name="DNS Check",type="dns"}`, + "# HELP upkeep_monitor_latency_seconds", + "# HELP upkeep_monitor_paused", + "# HELP upkeep_monitor_checks_total", + } + for _, s := range expected { + if !strings.Contains(body, s) { + t.Errorf("missing expected line: %s", s) + } + } +} + +func TestEscapeLabelValue(t *testing.T) { + cases := []struct{ in, want string }{ + {`simple`, `simple`}, + {`has "quotes"`, `has \"quotes\"`}, + {"has\nnewline", `has\nnewline`}, + {`back\slash`, `back\\slash`}, + } + for _, tc := range cases { + got := escapeLabelValue(tc.in) + if got != tc.want { + t.Errorf("escapeLabelValue(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} diff --git a/internal/server/server.go b/internal/server/server.go index ac26bd2..fdf7f9b 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "go-upkeep/internal/importer" + "go-upkeep/internal/metrics" "go-upkeep/internal/models" "go-upkeep/internal/monitor" "go-upkeep/internal/store" @@ -242,7 +243,10 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) { w.Write([]byte(fmt.Sprintf("Imported %d monitors, %d alerts from Kuma v%s", len(backup.Sites), len(backup.Alerts), kb.Version))) }) - // 6. Status Page + // 6. Prometheus Metrics + mux.HandleFunc("/metrics", metrics.Handler(eng)) + + // 7. Status Page if cfg.EnableStatus { mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) { renderStatusPage(w, cfg.Title, eng) }) mux.HandleFunc("/status/json", func(w http.ResponseWriter, r *http.Request) {