diff --git a/.gitignore b/.gitignore
index e1f6a8e..1fd01d6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -38,4 +38,5 @@ tmp
# Old repo
/go-upkeep/
-*.local.json
\ No newline at end of file
+*.local.json
+*.local.md
\ No newline at end of file
diff --git a/cmd/goupkeep/main.go b/cmd/goupkeep/main.go
index 26c01b3..5962e37 100644
--- a/cmd/goupkeep/main.go
+++ b/cmd/goupkeep/main.go
@@ -1,6 +1,7 @@
package main
import (
+ "context"
"flag"
"fmt"
"go-upkeep/internal/cluster"
@@ -68,9 +69,6 @@ func main() {
if v := os.Getenv("UPKEEP_CLUSTER_SECRET"); v != "" {
clusterKey = v
}
- if os.Getenv("UPKEEP_INSECURE_SKIP_VERIFY") == "true" {
- monitor.SetInsecureSkipVerify(true)
- }
port := flag.Int("port", portVal, "SSH Port")
flagDBType := flag.String("db-type", dbType, "Database type")
@@ -80,20 +78,23 @@ func main() {
flag.Parse()
var s store.Store
+ var dbErr error
if *flagDBType == "postgres" {
- s = &store.PostgresStore{ConnStr: *flagDSN}
+ s, dbErr = store.NewPostgresStore(*flagDSN)
fmt.Printf("Using PostgreSQL: %s\n", *flagDSN)
} else {
- s = &store.SQLiteStore{DBPath: *flagDSN}
+ s, dbErr = store.NewSQLiteStore(*flagDSN)
fmt.Printf("Using SQLite: %s\n", *flagDSN)
}
+ if dbErr != nil {
+ fmt.Printf("Database connection error: %v\n", dbErr)
+ os.Exit(1)
+ }
if err := s.Init(); err != nil {
- fmt.Printf("Database Init Error: %v\n", err)
+ fmt.Printf("Database init error: %v\n", err)
os.Exit(1)
}
- store.SetGlobal(s)
-
if *demo {
seedDemoData(s)
}
@@ -112,26 +113,34 @@ func main() {
fmt.Printf("Imported %d monitors and %d alerts from Uptime Kuma v%s\n", len(backup.Sites), len(backup.Alerts), kb.Version)
}
- monitor.InitHistoryFromStore()
- monitor.StartEngine()
+ eng := monitor.NewEngine(s)
+ if os.Getenv("UPKEEP_INSECURE_SKIP_VERIFY") == "true" {
+ eng.SetInsecureSkipVerify(true)
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ eng.InitHistory()
+ eng.Start(ctx)
server.Start(server.ServerConfig{
Port: httpPort,
EnableStatus: enableStatus,
Title: statusTitle,
ClusterKey: clusterKey,
- })
+ }, s, eng)
- cluster.Start(cluster.Config{
+ cluster.Start(ctx, cluster.Config{
Mode: clusterMode,
PeerURL: clusterPeer,
SharedKey: clusterKey,
- })
+ }, eng)
- startSSHServer(*port)
+ startSSHServer(*port, s, eng)
if isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd()) {
- p := tea.NewProgram(tui.InitialModel(true), tea.WithAltScreen(), tea.WithMouseCellMotion())
+ p := tea.NewProgram(tui.InitialModel(true, s, eng), tea.WithAltScreen(), tea.WithMouseCellMotion())
if _, err := p.Run(); err != nil {
fmt.Printf("Error: %v\n", err)
}
@@ -142,18 +151,19 @@ func main() {
<-done
fmt.Println("Shutting down...")
}
+ cancel()
}
-func startSSHServer(port int) {
+func startSSHServer(port int, db store.Store, eng *monitor.Engine) {
s, err := wish.NewServer(
wish.WithAddress(fmt.Sprintf(":%d", port)),
wish.WithHostKeyPath(".ssh/id_ed25519"),
wish.WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
- return isKeyAllowed(key)
+ return isKeyAllowed(db, key)
}),
wish.WithMiddleware(
bm.Middleware(func(s ssh.Session) (tea.Model, []tea.ProgramOption) {
- return tui.InitialModel(false), []tea.ProgramOption{tea.WithAltScreen(), tea.WithMouseCellMotion()}
+ return tui.InitialModel(false, db, eng), []tea.ProgramOption{tea.WithAltScreen(), tea.WithMouseCellMotion()}
}),
),
)
@@ -161,11 +171,16 @@ func startSSHServer(port int) {
fmt.Printf("SSH server error: %v\n", err)
return
}
- go func() { s.ListenAndServe() }()
+ go func() {
+ if err := s.ListenAndServe(); err != nil {
+ log.Fatalf("SSH server failed: %v", err)
+ }
+ }()
}
func seedDemoData(s store.Store) {
- if existing := s.GetSites(); len(existing) > 0 {
+ existing, _ := s.GetSites()
+ if len(existing) > 0 {
return
}
fmt.Println("Seeding demo data...")
@@ -178,7 +193,7 @@ func seedDemoData(s store.Store) {
"from": "oncall@example.com", "to": "team@example.com",
})
- alerts := s.GetAllAlerts()
+ alerts, _ := s.GetAllAlerts()
alertID := 0
if len(alerts) > 0 {
alertID = alerts[0].ID
@@ -196,8 +211,11 @@ func seedDemoData(s store.Store) {
s.AddSite(models.Site{Name: "SSH Server", Type: "port", Interval: 60, AlertID: alertID, Hostname: "10.0.0.1", Port: 22, Timeout: 5, ExpiryThreshold: 7})
}
-func isKeyAllowed(incomingKey ssh.PublicKey) bool {
- users := store.Get().GetAllUsers()
+func isKeyAllowed(db store.Store, incomingKey ssh.PublicKey) bool {
+ users, err := db.GetAllUsers()
+ if err != nil {
+ return false
+ }
for _, u := range users {
allowedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(u.PublicKey))
if err != nil {
diff --git a/internal/alert/alert.go b/internal/alert/alert.go
index 71b7570..c60013f 100644
--- a/internal/alert/alert.go
+++ b/internal/alert/alert.go
@@ -7,6 +7,7 @@ import (
"go-upkeep/internal/models"
"net/http"
"net/smtp"
+ "strconv"
"strings"
"time"
)
@@ -17,15 +18,95 @@ type Provider interface {
Send(title, message string) error
}
+type PayloadFunc func(title, message string) ([]byte, error)
+
+type HTTPProvider struct {
+ URL string
+ Payload PayloadFunc
+}
+
+func (h *HTTPProvider) Send(title, message string) error {
+ body, err := h.Payload(title, message)
+ if err != nil {
+ return err
+ }
+ resp, err := alertClient.Post(h.URL, "application/json", bytes.NewBuffer(body))
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode >= 400 {
+ return fmt.Errorf("alert webhook returned HTTP %d", resp.StatusCode)
+ }
+ return nil
+}
+
+func discordPayload(title, message string) ([]byte, error) {
+ return json.Marshal(map[string]string{"content": fmt.Sprintf("**%s**\n%s", title, message)})
+}
+
+func slackPayload(title, message string) ([]byte, error) {
+ return json.Marshal(map[string]string{"text": fmt.Sprintf("*%s*\n%s", title, message)})
+}
+
+func webhookPayload(title, message string) ([]byte, error) {
+ return json.Marshal(map[string]string{"title": title, "message": message, "status": "alert"})
+}
+
+func telegramPayload(chatID string) PayloadFunc {
+ return func(title, message string) ([]byte, error) {
+ return json.Marshal(map[string]string{
+ "chat_id": chatID,
+ "text": fmt.Sprintf("*%s*\n%s", title, message),
+ "parse_mode": "Markdown",
+ })
+ }
+}
+
+func pagerdutyPayload(routingKey, severity string) PayloadFunc {
+ return func(title, message string) ([]byte, error) {
+ return json.Marshal(map[string]any{
+ "routing_key": routingKey,
+ "event_action": "trigger",
+ "payload": map[string]string{
+ "summary": fmt.Sprintf("%s: %s", title, message),
+ "source": "go-upkeep",
+ "severity": severity,
+ },
+ })
+ }
+}
+
+func pushoverPayload(token, user string) PayloadFunc {
+ return func(title, message string) ([]byte, error) {
+ return json.Marshal(map[string]string{
+ "token": token,
+ "user": user,
+ "title": title,
+ "message": message,
+ })
+ }
+}
+
+func gotifyPayload(priority string) PayloadFunc {
+ return func(title, message string) ([]byte, error) {
+ pri, _ := strconv.Atoi(priority)
+ return json.Marshal(map[string]any{
+ "title": title,
+ "message": message,
+ "priority": pri,
+ })
+ }
+}
+
func GetProvider(cfg models.AlertConfig) Provider {
switch cfg.Type {
case "discord":
- return &DiscordProvider{URL: cfg.Settings["url"]}
+ return &HTTPProvider{URL: cfg.Settings["url"], Payload: discordPayload}
case "slack":
- return &SlackProvider{URL: cfg.Settings["url"]}
+ return &HTTPProvider{URL: cfg.Settings["url"], Payload: slackPayload}
case "webhook":
- // Generic Webhook
- return &WebhookProvider{URL: cfg.Settings["url"]}
+ return &HTTPProvider{URL: cfg.Settings["url"], Payload: webhookPayload}
case "email":
port := "25"
if p, ok := cfg.Settings["port"]; ok {
@@ -51,58 +132,40 @@ func GetProvider(cfg models.AlertConfig) Provider {
Username: cfg.Settings["username"],
Password: cfg.Settings["password"],
}
+ case "telegram":
+ return &HTTPProvider{
+ URL: fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", cfg.Settings["token"]),
+ Payload: telegramPayload(cfg.Settings["chat_id"]),
+ }
+ case "pagerduty":
+ severity := "critical"
+ if s, ok := cfg.Settings["severity"]; ok && s != "" {
+ severity = s
+ }
+ return &HTTPProvider{
+ URL: "https://events.pagerduty.com/v2/enqueue",
+ Payload: pagerdutyPayload(cfg.Settings["routing_key"], severity),
+ }
+ case "pushover":
+ return &HTTPProvider{
+ URL: "https://api.pushover.net/1/messages.json",
+ Payload: pushoverPayload(cfg.Settings["token"], cfg.Settings["user"]),
+ }
+ case "gotify":
+ priority := "5"
+ if p, ok := cfg.Settings["priority"]; ok && p != "" {
+ priority = p
+ }
+ serverURL := strings.TrimRight(cfg.Settings["url"], "/")
+ return &HTTPProvider{
+ URL: fmt.Sprintf("%s/message?token=%s", serverURL, cfg.Settings["token"]),
+ Payload: gotifyPayload(priority),
+ }
default:
return nil
}
}
-// --- DISCORD ---
-type DiscordProvider struct{ URL string }
-
-func (d *DiscordProvider) Send(title, message string) error {
- payload := map[string]string{"content": fmt.Sprintf("**%s**\n%s", title, message)}
- jsonValue, _ := json.Marshal(payload)
- resp, err := alertClient.Post(d.URL, "application/json", bytes.NewBuffer(jsonValue))
- if err != nil {
- return err
- }
- resp.Body.Close()
- return nil
-}
-
-// --- SLACK ---
-type SlackProvider struct{ URL string }
-
-func (s *SlackProvider) Send(title, message string) error {
- payload := map[string]string{"text": fmt.Sprintf("*%s*\n%s", title, message)}
- jsonValue, _ := json.Marshal(payload)
- resp, err := alertClient.Post(s.URL, "application/json", bytes.NewBuffer(jsonValue))
- if err != nil {
- return err
- }
- resp.Body.Close()
- return nil
-}
-
-// --- GENERIC WEBHOOK ---
-type WebhookProvider struct{ URL string }
-
-func (w *WebhookProvider) Send(title, message string) error {
- payload := map[string]string{
- "title": title,
- "message": message,
- "status": "alert",
- }
- jsonValue, _ := json.Marshal(payload)
- resp, err := alertClient.Post(w.URL, "application/json", bytes.NewBuffer(jsonValue))
- if err != nil {
- return err
- }
- resp.Body.Close()
- return nil
-}
-
-// --- EMAIL ---
type EmailProvider struct {
Host, Port, User, Pass, To, From string
}
@@ -139,6 +202,9 @@ func (n *NtfyProvider) Send(title, message string) error {
if err != nil {
return err
}
- resp.Body.Close()
+ defer resp.Body.Close()
+ if resp.StatusCode >= 400 {
+ return fmt.Errorf("ntfy returned HTTP %d", resp.StatusCode)
+ }
return nil
}
diff --git a/internal/alert/alert_test.go b/internal/alert/alert_test.go
new file mode 100644
index 0000000..35e1c8d
--- /dev/null
+++ b/internal/alert/alert_test.go
@@ -0,0 +1,213 @@
+package alert
+
+import (
+ "encoding/json"
+ "go-upkeep/internal/models"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestHTTPProviderDiscord(t *testing.T) {
+ var received map[string]string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewDecoder(r.Body).Decode(&received)
+ w.WriteHeader(200)
+ }))
+ defer srv.Close()
+
+ p := GetProvider(models.AlertConfig{Type: "discord", Settings: map[string]string{"url": srv.URL}})
+ if err := p.Send("Test Title", "Test Body"); err != nil {
+ t.Fatalf("Send: %v", err)
+ }
+
+ if received["content"] != "**Test Title**\nTest Body" {
+ t.Errorf("unexpected payload: %s", received["content"])
+ }
+}
+
+func TestHTTPProviderSlack(t *testing.T) {
+ var received map[string]string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewDecoder(r.Body).Decode(&received)
+ w.WriteHeader(200)
+ }))
+ defer srv.Close()
+
+ p := GetProvider(models.AlertConfig{Type: "slack", Settings: map[string]string{"url": srv.URL}})
+ if err := p.Send("Alert", "Message"); err != nil {
+ t.Fatalf("Send: %v", err)
+ }
+
+ if received["text"] != "*Alert*\nMessage" {
+ t.Errorf("unexpected payload: %s", received["text"])
+ }
+}
+
+func TestHTTPProviderWebhook(t *testing.T) {
+ var received map[string]string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewDecoder(r.Body).Decode(&received)
+ w.WriteHeader(200)
+ }))
+ defer srv.Close()
+
+ p := GetProvider(models.AlertConfig{Type: "webhook", Settings: map[string]string{"url": srv.URL}})
+ if err := p.Send("Title", "Body"); err != nil {
+ t.Fatalf("Send: %v", err)
+ }
+
+ if received["title"] != "Title" || received["message"] != "Body" || received["status"] != "alert" {
+ t.Errorf("unexpected webhook payload: %v", received)
+ }
+}
+
+func TestHTTPProviderErrorOnHTTP4xx(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(403)
+ }))
+ defer srv.Close()
+
+ p := GetProvider(models.AlertConfig{Type: "discord", Settings: map[string]string{"url": srv.URL}})
+ if err := p.Send("Test", "Test"); err == nil {
+ t.Fatal("expected error on 403 response")
+ }
+}
+
+func TestNtfyProvider(t *testing.T) {
+ var title, body string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ title = r.Header.Get("Title")
+ buf := make([]byte, 1024)
+ n, _ := r.Body.Read(buf)
+ body = string(buf[:n])
+ w.WriteHeader(200)
+ }))
+ defer srv.Close()
+
+ p := GetProvider(models.AlertConfig{Type: "ntfy", Settings: map[string]string{
+ "url": srv.URL,
+ "topic": "test",
+ }})
+ if err := p.Send("Alert Title", "Alert Body"); err != nil {
+ t.Fatalf("Send: %v", err)
+ }
+
+ if title != "Alert Title" {
+ t.Errorf("expected title 'Alert Title', got '%s'", title)
+ }
+ if body != "Alert Body" {
+ t.Errorf("expected body 'Alert Body', got '%s'", body)
+ }
+}
+
+func TestHTTPProviderTelegram(t *testing.T) {
+ var received map[string]string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewDecoder(r.Body).Decode(&received)
+ w.WriteHeader(200)
+ }))
+ defer srv.Close()
+
+ p := &HTTPProvider{URL: srv.URL, Payload: telegramPayload("12345")}
+ if err := p.Send("Alert", "Down"); err != nil {
+ t.Fatalf("Send: %v", err)
+ }
+ if received["chat_id"] != "12345" {
+ t.Errorf("expected chat_id '12345', got '%s'", received["chat_id"])
+ }
+ if received["text"] != "*Alert*\nDown" {
+ t.Errorf("unexpected text: %s", received["text"])
+ }
+ if received["parse_mode"] != "Markdown" {
+ t.Errorf("expected parse_mode 'Markdown', got '%s'", received["parse_mode"])
+ }
+}
+
+func TestHTTPProviderPagerDuty(t *testing.T) {
+ var received map[string]any
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewDecoder(r.Body).Decode(&received)
+ w.WriteHeader(200)
+ }))
+ defer srv.Close()
+
+ p := &HTTPProvider{URL: srv.URL, Payload: pagerdutyPayload("test-key", "critical")}
+ if err := p.Send("Alert", "Down"); err != nil {
+ t.Fatalf("Send: %v", err)
+ }
+ if received["routing_key"] != "test-key" {
+ t.Errorf("expected routing_key 'test-key', got '%v'", received["routing_key"])
+ }
+ if received["event_action"] != "trigger" {
+ t.Errorf("expected event_action 'trigger', got '%v'", received["event_action"])
+ }
+ payload := received["payload"].(map[string]any)
+ if payload["summary"] != "Alert: Down" {
+ t.Errorf("unexpected summary: %v", payload["summary"])
+ }
+ if payload["severity"] != "critical" {
+ t.Errorf("expected severity 'critical', got '%v'", payload["severity"])
+ }
+}
+
+func TestHTTPProviderPushover(t *testing.T) {
+ var received map[string]string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewDecoder(r.Body).Decode(&received)
+ w.WriteHeader(200)
+ }))
+ defer srv.Close()
+
+ p := &HTTPProvider{URL: srv.URL, Payload: pushoverPayload("app-tok", "user-key")}
+ if err := p.Send("Alert", "Down"); err != nil {
+ t.Fatalf("Send: %v", err)
+ }
+ if received["token"] != "app-tok" {
+ t.Errorf("expected token 'app-tok', got '%s'", received["token"])
+ }
+ if received["user"] != "user-key" {
+ t.Errorf("expected user 'user-key', got '%s'", received["user"])
+ }
+ if received["title"] != "Alert" || received["message"] != "Down" {
+ t.Errorf("unexpected payload: %v", received)
+ }
+}
+
+func TestHTTPProviderGotify(t *testing.T) {
+ var received map[string]any
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ json.NewDecoder(r.Body).Decode(&received)
+ w.WriteHeader(200)
+ }))
+ defer srv.Close()
+
+ p := &HTTPProvider{URL: srv.URL, Payload: gotifyPayload("8")}
+ if err := p.Send("Alert", "Down"); err != nil {
+ t.Fatalf("Send: %v", err)
+ }
+ if received["title"] != "Alert" || received["message"] != "Down" {
+ t.Errorf("unexpected payload: %v", received)
+ }
+ if pri, ok := received["priority"].(float64); !ok || pri != 8 {
+ t.Errorf("expected priority 8, got %v", received["priority"])
+ }
+}
+
+func TestGetProviderNewTypes(t *testing.T) {
+ for _, typ := range []string{"telegram", "pagerduty", "pushover", "gotify"} {
+ p := GetProvider(models.AlertConfig{Type: typ, Settings: map[string]string{
+ "token": "x", "chat_id": "1", "routing_key": "k", "user": "u", "url": "http://localhost",
+ }})
+ if p == nil {
+ t.Errorf("GetProvider(%q) returned nil", typ)
+ }
+ }
+}
+
+func TestGetProviderUnknown(t *testing.T) {
+ p := GetProvider(models.AlertConfig{Type: "unknown"})
+ if p != nil {
+ t.Error("expected nil for unknown provider type")
+ }
+}
diff --git a/internal/cluster/cluster.go b/internal/cluster/cluster.go
index 295443d..f986c29 100644
--- a/internal/cluster/cluster.go
+++ b/internal/cluster/cluster.go
@@ -1,6 +1,7 @@
package cluster
import (
+ "context"
"fmt"
"go-upkeep/internal/monitor"
"net/http"
@@ -14,13 +15,13 @@ type Config struct {
SharedKey string // Security Key
}
-func Start(cfg Config) {
+func Start(ctx context.Context, cfg Config, eng *monitor.Engine) {
if cfg.Mode == "leader" {
fmt.Println("Cluster: Running as LEADER (Active)")
if cfg.SharedKey != "" {
fmt.Println("WARNING: Cluster mode enabled. Ensure the HTTP server is behind a TLS-terminating proxy.")
}
- monitor.SetEngineActive(true)
+ eng.SetActive(true)
return
}
@@ -29,20 +30,22 @@ func Start(cfg Config) {
if cfg.PeerURL != "" && !strings.HasPrefix(cfg.PeerURL, "https://") {
fmt.Println("WARNING: Cluster peer URL is not HTTPS. Cluster secret will be sent in cleartext.")
}
- monitor.SetEngineActive(false)
- go runFollowerLoop(cfg)
+ eng.SetActive(false)
+ go runFollowerLoop(ctx, cfg, eng)
}
}
-func runFollowerLoop(cfg Config) {
+func runFollowerLoop(ctx context.Context, cfg Config, eng *monitor.Engine) {
client := http.Client{Timeout: 2 * time.Second}
-
- // Failover Configuration
failures := 0
threshold := 3
for {
- time.Sleep(5 * time.Second)
+ select {
+ case <-time.After(5 * time.Second):
+ case <-ctx.Done():
+ return
+ }
req, _ := http.NewRequest("GET", cfg.PeerURL+"/api/health", nil)
if cfg.SharedKey != "" {
@@ -59,17 +62,15 @@ func runFollowerLoop(cfg Config) {
if isLeaderHealthy {
failures = 0
- if monitor.IsEngineActive() {
- // Leader is back, yield
- monitor.SetEngineActive(false)
- monitor.AddLog("Cluster: Leader detected. Switching to PASSIVE.")
+ if eng.IsActive() {
+ eng.SetActive(false)
+ eng.AddLog("Cluster: Leader detected. Switching to PASSIVE.")
}
} else {
failures++
- // If failures exceed threshold, take over
- if failures >= threshold && !monitor.IsEngineActive() {
- monitor.SetEngineActive(true)
- monitor.AddLog("Cluster: Leader Unreachable. Switching to ACTIVE.")
+ if failures >= threshold && !eng.IsActive() {
+ eng.SetActive(true)
+ eng.AddLog("Cluster: Leader Unreachable. Switching to ACTIVE.")
}
}
}
diff --git a/internal/metrics/prometheus.go b/internal/metrics/prometheus.go
new file mode 100644
index 0000000..24f4faa
--- /dev/null
+++ b/internal/metrics/prometheus.go
@@ -0,0 +1,99 @@
+package metrics
+
+import (
+ "fmt"
+ "go-upkeep/internal/models"
+ "go-upkeep/internal/monitor"
+ "net/http"
+ "sort"
+ "strings"
+)
+
+func Handler(eng *monitor.Engine) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ sites := eng.GetAllSites()
+ sort.Slice(sites, func(i, j int) bool { return sites[i].ID < sites[j].ID })
+
+ var b strings.Builder
+
+ writeHelp(&b, "upkeep_monitor_up", "gauge", "Whether the monitor is up (1) or down (0).")
+ for _, s := range sites {
+ val := 0
+ if s.Status == "UP" {
+ val = 1
+ }
+ writeGauge(&b, "upkeep_monitor_up", labels(s), float64(val))
+ }
+
+ writeHelp(&b, "upkeep_monitor_latency_seconds", "gauge", "Last check latency in seconds.")
+ for _, s := range sites {
+ writeGauge(&b, "upkeep_monitor_latency_seconds", labels(s), s.Latency.Seconds())
+ }
+
+ writeHelp(&b, "upkeep_monitor_status_code", "gauge", "HTTP response status code of the last check.")
+ for _, s := range sites {
+ if s.Type != "http" {
+ continue
+ }
+ writeGauge(&b, "upkeep_monitor_status_code", labels(s), float64(s.StatusCode))
+ }
+
+ writeHelp(&b, "upkeep_monitor_check_timestamp_seconds", "gauge", "Unix timestamp of the last check.")
+ for _, s := range sites {
+ if s.LastCheck.IsZero() {
+ continue
+ }
+ writeGauge(&b, "upkeep_monitor_check_timestamp_seconds", labels(s), float64(s.LastCheck.Unix()))
+ }
+
+ writeHelp(&b, "upkeep_monitor_paused", "gauge", "Whether the monitor is paused (1) or active (0).")
+ for _, s := range sites {
+ val := 0
+ if s.Paused {
+ val = 1
+ }
+ writeGauge(&b, "upkeep_monitor_paused", labels(s), float64(val))
+ }
+
+ writeHelp(&b, "upkeep_monitor_cert_expiry_timestamp_seconds", "gauge", "Unix timestamp when the SSL certificate expires.")
+ for _, s := range sites {
+ if !s.HasSSL || s.CertExpiry.IsZero() {
+ continue
+ }
+ writeGauge(&b, "upkeep_monitor_cert_expiry_timestamp_seconds", labels(s), float64(s.CertExpiry.Unix()))
+ }
+
+ writeHelp(&b, "upkeep_monitor_checks_total", "counter", "Total number of checks performed.")
+ writeHelp(&b, "upkeep_monitor_checks_up_total", "counter", "Total number of successful checks.")
+ for _, s := range sites {
+ h, ok := eng.GetHistory(s.ID)
+ if !ok {
+ continue
+ }
+ writeGauge(&b, "upkeep_monitor_checks_total", labels(s), float64(h.TotalChecks))
+ writeGauge(&b, "upkeep_monitor_checks_up_total", labels(s), float64(h.UpChecks))
+ }
+
+ w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
+ w.Write([]byte(b.String()))
+ }
+}
+
+func labels(s models.Site) string {
+ return fmt.Sprintf(`id="%d",name="%s",type="%s"`, s.ID, escapeLabelValue(s.Name), s.Type)
+}
+
+func escapeLabelValue(s string) string {
+ s = strings.ReplaceAll(s, `\`, `\\`)
+ s = strings.ReplaceAll(s, `"`, `\"`)
+ s = strings.ReplaceAll(s, "\n", `\n`)
+ return s
+}
+
+func writeHelp(b *strings.Builder, name, typ, help string) {
+ fmt.Fprintf(b, "# HELP %s %s\n# TYPE %s %s\n", name, help, name, typ)
+}
+
+func writeGauge(b *strings.Builder, name, labels string, val float64) {
+ fmt.Fprintf(b, "%s{%s} %g\n", name, labels, val)
+}
diff --git a/internal/metrics/prometheus_test.go b/internal/metrics/prometheus_test.go
new file mode 100644
index 0000000..7cbf680
--- /dev/null
+++ b/internal/metrics/prometheus_test.go
@@ -0,0 +1,96 @@
+package metrics
+
+import (
+ "context"
+ "go-upkeep/internal/models"
+ "go-upkeep/internal/monitor"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+)
+
+type mockStore struct {
+ sites []models.Site
+}
+
+func (m *mockStore) Init() error { return nil }
+func (m *mockStore) GetSites() ([]models.Site, error) { return m.sites, nil }
+func (m *mockStore) AddSite(models.Site) error { return nil }
+func (m *mockStore) UpdateSite(models.Site) error { return nil }
+func (m *mockStore) UpdateSitePaused(int, bool) error { return nil }
+func (m *mockStore) DeleteSite(int) error { return nil }
+func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { return nil, nil }
+func (m *mockStore) GetAlert(int) (models.AlertConfig, error) { return models.AlertConfig{}, nil }
+func (m *mockStore) AddAlert(string, string, map[string]string) error { return nil }
+func (m *mockStore) UpdateAlert(int, string, string, map[string]string) error { return nil }
+func (m *mockStore) DeleteAlert(int) error { return nil }
+func (m *mockStore) GetAllUsers() ([]models.User, error) { return nil, nil }
+func (m *mockStore) AddUser(string, string, string) error { return nil }
+func (m *mockStore) UpdateUser(int, string, string, string) error { return nil }
+func (m *mockStore) DeleteUser(int) error { return nil }
+func (m *mockStore) SaveCheck(int, int64, bool) error { return nil }
+func (m *mockStore) LoadAllHistory(int) (map[int][]models.CheckRecord, error) {
+ return nil, nil
+}
+func (m *mockStore) ExportData() (models.Backup, error) { return models.Backup{}, nil }
+func (m *mockStore) ImportData(models.Backup) error { return nil }
+
+func TestMetricsHandler(t *testing.T) {
+ ms := &mockStore{
+ sites: []models.Site{
+ {ID: 1, Name: "Example", URL: "https://example.com", Type: "http", Interval: 30},
+ {ID: 2, Name: "DNS Check", Type: "dns", Interval: 60},
+ },
+ }
+ eng := monitor.NewEngine(ms)
+ ctx, cancel := context.WithCancel(context.Background())
+ eng.Start(ctx)
+ time.Sleep(100 * time.Millisecond)
+
+ rec := httptest.NewRecorder()
+ Handler(eng)(rec, httptest.NewRequest("GET", "/metrics", nil))
+ cancel()
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d", rec.Code)
+ }
+
+ body := rec.Body.String()
+
+ ct := rec.Header().Get("Content-Type")
+ if !strings.Contains(ct, "text/plain") {
+ t.Errorf("expected text/plain content type, got %q", ct)
+ }
+
+ expected := []string{
+ "# HELP upkeep_monitor_up",
+ "# TYPE upkeep_monitor_up gauge",
+ `upkeep_monitor_up{id="1",name="Example",type="http"}`,
+ `upkeep_monitor_up{id="2",name="DNS Check",type="dns"}`,
+ "# HELP upkeep_monitor_latency_seconds",
+ "# HELP upkeep_monitor_paused",
+ "# HELP upkeep_monitor_checks_total",
+ }
+ for _, s := range expected {
+ if !strings.Contains(body, s) {
+ t.Errorf("missing expected line: %s", s)
+ }
+ }
+}
+
+func TestEscapeLabelValue(t *testing.T) {
+ cases := []struct{ in, want string }{
+ {`simple`, `simple`},
+ {`has "quotes"`, `has \"quotes\"`},
+ {"has\nnewline", `has\nnewline`},
+ {`back\slash`, `back\\slash`},
+ }
+ for _, tc := range cases {
+ got := escapeLabelValue(tc.in)
+ if got != tc.want {
+ t.Errorf("escapeLabelValue(%q) = %q, want %q", tc.in, got, tc.want)
+ }
+ }
+}
diff --git a/internal/monitor/history.go b/internal/monitor/history.go
index 8642255..1049e04 100644
--- a/internal/monitor/history.go
+++ b/internal/monitor/history.go
@@ -1,10 +1,6 @@
package monitor
-import (
- "go-upkeep/internal/store"
- "sync"
- "time"
-)
+import "time"
const maxHistoryLen = 30
@@ -15,19 +11,14 @@ type SiteHistory struct {
UpChecks int
}
-var (
- histories = make(map[int]*SiteHistory)
- historyMu sync.RWMutex
-)
-
-func InitHistoryFromStore() {
- s := store.Get()
- if s == nil {
+func (e *Engine) InitHistory() {
+ all, err := e.db.LoadAllHistory(maxHistoryLen)
+ if err != nil {
+ e.AddLog("Failed to load check history: " + err.Error())
return
}
- all := s.LoadAllHistory(maxHistoryLen)
- historyMu.Lock()
- defer historyMu.Unlock()
+ e.histMu.Lock()
+ defer e.histMu.Unlock()
for siteID, records := range all {
h := &SiteHistory{}
for _, r := range records {
@@ -38,21 +29,21 @@ func InitHistoryFromStore() {
h.Latencies = append(h.Latencies, time.Duration(r.LatencyNs))
h.Statuses = append(h.Statuses, r.IsUp)
}
- histories[siteID] = h
+ e.histories[siteID] = h
}
if len(all) > 0 {
- AddLog("Loaded check history from database")
+ e.AddLog("Loaded check history from database")
}
}
-func RecordCheck(siteID int, latency time.Duration, isUp bool) {
- historyMu.Lock()
- defer historyMu.Unlock()
+func (e *Engine) recordCheck(siteID int, latency time.Duration, isUp bool) {
+ e.histMu.Lock()
+ defer e.histMu.Unlock()
- h, ok := histories[siteID]
+ h, ok := e.histories[siteID]
if !ok {
h = &SiteHistory{}
- histories[siteID] = h
+ e.histories[siteID] = h
}
h.TotalChecks++
@@ -70,15 +61,13 @@ func RecordCheck(siteID int, latency time.Duration, isUp bool) {
h.Statuses = h.Statuses[len(h.Statuses)-maxHistoryLen:]
}
- if s := store.Get(); s != nil {
- go s.SaveCheck(siteID, latency.Nanoseconds(), isUp)
- }
+ go func() { _ = e.db.SaveCheck(siteID, latency.Nanoseconds(), isUp) }()
}
-func GetHistory(siteID int) (SiteHistory, bool) {
- historyMu.RLock()
- defer historyMu.RUnlock()
- h, ok := histories[siteID]
+func (e *Engine) GetHistory(siteID int) (SiteHistory, bool) {
+ e.histMu.RLock()
+ defer e.histMu.RUnlock()
+ h, ok := e.histories[siteID]
if !ok {
return SiteHistory{}, false
}
@@ -93,8 +82,8 @@ func GetHistory(siteID int) (SiteHistory, bool) {
return cp, true
}
-func RemoveHistory(siteID int) {
- historyMu.Lock()
- defer historyMu.Unlock()
- delete(histories, siteID)
+func (e *Engine) removeHistory(siteID int) {
+ e.histMu.Lock()
+ defer e.histMu.Unlock()
+ delete(e.histories, siteID)
}
diff --git a/internal/monitor/monitor.go b/internal/monitor/monitor.go
index 1bb02fe..2d2af10 100644
--- a/internal/monitor/monitor.go
+++ b/internal/monitor/monitor.go
@@ -1,6 +1,7 @@
package monitor
import (
+ "context"
"crypto/tls"
"fmt"
"go-upkeep/internal/alert"
@@ -9,6 +10,7 @@ import (
"net"
"net/http"
"strconv"
+ "strings"
"sync"
"time"
@@ -16,198 +18,271 @@ import (
probing "github.com/prometheus-community/pro-bing"
)
-// --- LOGGING ---
-var (
- LogStore []string
- LogMutex sync.RWMutex
-)
+type Engine struct {
+ mu sync.RWMutex
+ liveState map[int]models.Site
-func AddLog(msg string) {
- LogMutex.Lock()
- defer LogMutex.Unlock()
- ts := time.Now().Format("15:04:05")
- entry := fmt.Sprintf("[%s] %s", ts, msg)
- LogStore = append([]string{entry}, LogStore...)
- if len(LogStore) > 100 {
- LogStore = LogStore[:100]
+ logMu sync.RWMutex
+ logStore []string
+
+ activeMu sync.RWMutex
+ isActive bool
+
+ histMu sync.RWMutex
+ histories map[int]*SiteHistory
+
+ tokenIndex map[string]int
+
+ db store.Store
+ insecureSkipVerify bool
+ strictClient *http.Client
+ insecureClient *http.Client
+}
+
+func NewEngine(s store.Store) *Engine {
+ return &Engine{
+ liveState: make(map[int]models.Site),
+ histories: make(map[int]*SiteHistory),
+ tokenIndex: make(map[string]int),
+ isActive: true,
+ db: s,
+ strictClient: &http.Client{
+ Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: false}},
+ },
+ insecureClient: &http.Client{
+ Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
+ },
}
}
-func GetLogs() []string {
- LogMutex.RLock()
- defer LogMutex.RUnlock()
- logs := make([]string, len(LogStore))
- copy(logs, LogStore)
+func (e *Engine) SetInsecureSkipVerify(skip bool) {
+ e.insecureSkipVerify = skip
+}
+
+func (e *Engine) AddLog(msg string) {
+ e.logMu.Lock()
+ defer e.logMu.Unlock()
+ ts := time.Now().Format("15:04:05")
+ entry := fmt.Sprintf("[%s] %s", ts, msg)
+ e.logStore = append([]string{entry}, e.logStore...)
+ if len(e.logStore) > 100 {
+ e.logStore = e.logStore[:100]
+ }
+}
+
+func (e *Engine) GetLogs() []string {
+ e.logMu.RLock()
+ defer e.logMu.RUnlock()
+ logs := make([]string, len(e.logStore))
+ copy(logs, e.logStore)
return logs
}
-// --- ENGINE ---
-
-var (
- LiveState = make(map[int]models.Site)
- Mutex sync.RWMutex
-
- // Global Switch for HA
- isActive = true
- activeMutex sync.RWMutex
-
- insecureSkipVerify bool
-)
-
-func SetInsecureSkipVerify(skip bool) {
- insecureSkipVerify = skip
-}
-
-func SetEngineActive(active bool) {
- activeMutex.Lock()
- defer activeMutex.Unlock()
- if isActive != active {
- isActive = active
+func (e *Engine) SetActive(active bool) {
+ e.activeMu.Lock()
+ defer e.activeMu.Unlock()
+ if e.isActive != active {
+ e.isActive = active
status := "RESUMED (Active)"
if !active {
status = "PAUSED (Passive)"
}
- AddLog(fmt.Sprintf("Engine %s", status))
+ e.AddLog(fmt.Sprintf("Engine %s", status))
}
}
-func IsEngineActive() bool {
- activeMutex.RLock()
- defer activeMutex.RUnlock()
- return isActive
+func (e *Engine) IsActive() bool {
+ e.activeMu.RLock()
+ defer e.activeMu.RUnlock()
+ return e.isActive
}
-func RecordHeartbeat(token string) bool {
- if !IsEngineActive() {
- return false
- } // Only Leader accepts Push
-
- Mutex.Lock()
- defer Mutex.Unlock()
- var targetID int = -1
- for id, s := range LiveState {
- if s.Type == "push" && s.Token == token {
- targetID = id
- break
- }
+func (e *Engine) GetAllSites() []models.Site {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ sites := make([]models.Site, 0, len(e.liveState))
+ for _, s := range e.liveState {
+ sites = append(sites, s)
}
- if targetID == -1 {
+ return sites
+}
+
+func (e *Engine) GetLiveState() map[int]models.Site {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ cp := make(map[int]models.Site, len(e.liveState))
+ for k, v := range e.liveState {
+ cp[k] = v
+ }
+ return cp
+}
+
+func (e *Engine) RecordHeartbeat(token string) bool {
+ if !e.IsActive() {
+ return false
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ targetID, ok := e.tokenIndex[token]
+ if !ok {
+ return false
+ }
+
+ site, exists := e.liveState[targetID]
+ if !exists {
return false
}
- site := LiveState[targetID]
site.LastCheck = time.Now()
wasDown := site.Status == "DOWN"
site.Status = "UP"
site.FailureCount = 0
site.Latency = 0
- LiveState[targetID] = site
+ e.liveState[targetID] = site
if wasDown {
- AddLog(fmt.Sprintf("Push Monitor '%s' recovered", site.Name))
- triggerAlert(site.AlertID, "โ
RECOVERY", fmt.Sprintf("Push Monitor '%s' is receiving heartbeats.", site.Name))
+ e.AddLog(fmt.Sprintf("Push Monitor '%s' recovered", site.Name))
+ e.triggerAlert(site.AlertID, "โ
RECOVERY", fmt.Sprintf("Push Monitor '%s' is receiving heartbeats.", site.Name))
}
return true
}
-func StartEngine() {
+func (e *Engine) addToTokenIndex(site models.Site) {
+ if site.Type == "push" && site.Token != "" {
+ e.tokenIndex[site.Token] = site.ID
+ }
+}
+
+func (e *Engine) removeFromTokenIndex(id int) {
+ for token, sid := range e.tokenIndex {
+ if sid == id {
+ delete(e.tokenIndex, token)
+ return
+ }
+ }
+}
+
+func (e *Engine) Start(ctx context.Context) {
go func() {
for {
- s_instance := store.Get()
- if s_instance == nil {
- time.Sleep(1 * time.Second)
- continue
+ select {
+ case <-ctx.Done():
+ return
+ default:
}
- sites := s_instance.GetSites()
+ sites, err := e.db.GetSites()
+ if err != nil {
+ e.AddLog(fmt.Sprintf("Failed to load sites: %v", err))
+ select {
+ case <-time.After(5 * time.Second):
+ case <-ctx.Done():
+ return
+ }
+ continue
+ }
for _, s := range sites {
- Mutex.RLock()
- _, exists := LiveState[s.ID]
- Mutex.RUnlock()
+ e.mu.RLock()
+ _, exists := e.liveState[s.ID]
+ e.mu.RUnlock()
if !exists {
- Mutex.Lock()
+ e.mu.Lock()
s.Status = "PENDING"
if s.Type == "push" {
s.LastCheck = time.Now()
}
- LiveState[s.ID] = s
- Mutex.Unlock()
- go monitorRoutine(s.ID)
+ e.liveState[s.ID] = s
+ e.addToTokenIndex(s)
+ e.mu.Unlock()
+ go e.monitorRoutine(ctx, s.ID)
}
}
- time.Sleep(5 * time.Second)
+
+ select {
+ case <-time.After(5 * time.Second):
+ case <-ctx.Done():
+ return
+ }
}
}()
}
-func UpdateSiteConfig(site models.Site) {
- Mutex.Lock()
- defer Mutex.Unlock()
- if s, ok := LiveState[site.ID]; ok {
- s.Name = site.Name
- s.URL = site.URL
- s.Type = site.Type
- s.Interval = site.Interval
- s.AlertID = site.AlertID
- s.CheckSSL = site.CheckSSL
- s.ExpiryThreshold = site.ExpiryThreshold
- s.MaxRetries = site.MaxRetries
- s.Hostname = site.Hostname
- s.Port = site.Port
- s.Timeout = site.Timeout
- s.Method = site.Method
- s.Description = site.Description
- s.ParentID = site.ParentID
- s.AcceptedCodes = site.AcceptedCodes
- s.DNSResolveType = site.DNSResolveType
- s.DNSServer = site.DNSServer
- s.IgnoreTLS = site.IgnoreTLS
- s.Paused = site.Paused
- LiveState[site.ID] = s
+func (e *Engine) UpdateSiteConfig(site models.Site) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if existing, ok := e.liveState[site.ID]; ok {
+ e.removeFromTokenIndex(site.ID)
+ site.Status = existing.Status
+ site.StatusCode = existing.StatusCode
+ site.Latency = existing.Latency
+ site.CertExpiry = existing.CertExpiry
+ site.HasSSL = existing.HasSSL
+ site.LastCheck = existing.LastCheck
+ site.SentSSLWarning = existing.SentSSLWarning
+ site.FailureCount = existing.FailureCount
+ e.liveState[site.ID] = site
+ e.addToTokenIndex(site)
}
}
-func RemoveSite(id int) {
- Mutex.Lock()
- delete(LiveState, id)
- Mutex.Unlock()
- RemoveHistory(id)
+func (e *Engine) RemoveSite(id int) {
+ e.mu.Lock()
+ e.removeFromTokenIndex(id)
+ delete(e.liveState, id)
+ e.mu.Unlock()
+ e.removeHistory(id)
}
-func ToggleSitePause(id int) bool {
- Mutex.Lock()
- defer Mutex.Unlock()
- site, ok := LiveState[id]
+func (e *Engine) ToggleSitePause(id int) bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ site, ok := e.liveState[id]
if !ok {
return false
}
site.Paused = !site.Paused
- LiveState[id] = site
+ e.liveState[id] = site
if site.Paused {
- AddLog(fmt.Sprintf("Monitor '%s' paused", site.Name))
+ e.AddLog(fmt.Sprintf("Monitor '%s' paused", site.Name))
} else {
- AddLog(fmt.Sprintf("Monitor '%s' resumed", site.Name))
+ e.AddLog(fmt.Sprintf("Monitor '%s' resumed", site.Name))
}
return site.Paused
}
-func monitorRoutine(id int) {
- checkByID(id)
+func (e *Engine) monitorRoutine(ctx context.Context, id int) {
+ e.checkByID(id)
for {
- if !IsEngineActive() {
- time.Sleep(5 * time.Second)
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
+
+ if !e.IsActive() {
+ select {
+ case <-time.After(5 * time.Second):
+ case <-ctx.Done():
+ return
+ }
continue
}
- Mutex.RLock()
- site, exists := LiveState[id]
- Mutex.RUnlock()
+ e.mu.RLock()
+ site, exists := e.liveState[id]
+ e.mu.RUnlock()
if !exists {
return
}
if site.Paused {
- time.Sleep(5 * time.Second)
+ select {
+ case <-time.After(5 * time.Second):
+ case <-ctx.Done():
+ return
+ }
continue
}
@@ -215,58 +290,74 @@ func monitorRoutine(id int) {
if interval < 5 {
interval = 5
}
- time.Sleep(time.Duration(interval) * time.Second)
- checkByID(id)
+ select {
+ case <-time.After(time.Duration(interval) * time.Second):
+ case <-ctx.Done():
+ return
+ }
+ e.checkByID(id)
}
}
-func checkByID(id int) {
- if !IsEngineActive() {
+func (e *Engine) checkByID(id int) {
+ if !e.IsActive() {
return
}
- Mutex.RLock()
- site, exists := LiveState[id]
- Mutex.RUnlock()
+ e.mu.RLock()
+ site, exists := e.liveState[id]
+ e.mu.RUnlock()
if !exists || site.Paused {
return
}
switch site.Type {
case "http":
- checkHTTP(site)
+ e.checkHTTP(site)
case "push":
- checkPush(site)
+ e.checkPush(site)
case "ping":
- checkPing(site)
+ e.checkPing(site)
case "port":
- checkPort(site)
+ e.checkPort(site)
case "dns":
- checkDNS(site)
+ e.checkDNS(site)
case "group":
- checkGroup(site)
+ e.checkGroup(site)
}
}
-func checkPush(site models.Site) {
+func (e *Engine) checkPush(site models.Site) {
deadline := site.LastCheck.Add(time.Duration(site.Interval) * time.Second).Add(5 * time.Second)
if time.Now().After(deadline) {
- handleStatusChange(site, "DOWN", 0, 0)
- } else {
- if site.Status != "UP" {
- handleStatusChange(site, "UP", 200, 0)
- }
+ e.handleStatusChange(site, "DOWN", 0, 0)
+ } else if site.Status != "UP" {
+ e.handleStatusChange(site, "UP", 200, 0)
}
}
-func checkHTTP(site models.Site) {
- start := time.Now()
- timeout := time.Duration(site.Timeout) * time.Second
- if timeout <= 0 {
- timeout = 5 * time.Second
+func (e *Engine) checkHTTP(site models.Site) {
+ method := site.Method
+ if method == "" {
+ method = "GET"
}
- skipTLS := insecureSkipVerify || site.IgnoreTLS
- client := &http.Client{Timeout: timeout, Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: skipTLS}}}
- resp, err := client.Get(site.URL)
+
+ timeout := siteTimeout(site)
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, method, site.URL, nil)
+ if err != nil {
+ e.handleStatusChange(site, "DOWN", 0, 0)
+ return
+ }
+
+ client := e.strictClient
+ if e.insecureSkipVerify || site.IgnoreTLS {
+ client = e.insecureClient
+ }
+
+ start := time.Now()
+ resp, err := client.Do(req)
latency := time.Since(start)
rawStatus := "UP"
@@ -279,7 +370,7 @@ func checkHTTP(site models.Site) {
} else {
defer resp.Body.Close()
rawCode = resp.StatusCode
- if resp.StatusCode >= 400 {
+ if !isCodeAccepted(rawCode, site.AcceptedCodes) {
rawStatus = "DOWN"
}
if site.CheckSSL && resp.TLS != nil && len(resp.TLS.PeerCertificates) > 0 {
@@ -296,12 +387,11 @@ func checkHTTP(site models.Site) {
updatedSite.CertExpiry = certExpiry
updatedSite.Latency = latency
updatedSite.LastCheck = time.Now()
- handleStatusChange(updatedSite, rawStatus, rawCode, latency)
+ e.handleStatusChange(updatedSite, rawStatus, rawCode, latency)
}
-func handleStatusChange(site models.Site, rawStatus string, code int, latency time.Duration) {
- // Double check we are still leader before alerting
- if !IsEngineActive() {
+func (e *Engine) handleStatusChange(site models.Site, rawStatus string, code int, latency time.Duration) {
+ if !e.IsActive() {
return
}
@@ -313,9 +403,9 @@ func handleStatusChange(site models.Site, rawStatus string, code int, latency ti
if newState.FailureCount > site.MaxRetries {
newState.Status = rawStatus
newState.FailureCount = site.MaxRetries + 1
- AddLog(fmt.Sprintf("Monitor '%s' confirmed DOWN", site.Name))
+ e.AddLog(fmt.Sprintf("Monitor '%s' confirmed DOWN", site.Name))
} else {
- AddLog(fmt.Sprintf("Monitor '%s' failed check %d/%d", site.Name, newState.FailureCount, site.MaxRetries))
+ e.AddLog(fmt.Sprintf("Monitor '%s' failed check %d/%d", site.Name, newState.FailureCount, site.MaxRetries))
}
} else if rawStatus == "UP" {
newState.FailureCount = 0
@@ -328,20 +418,20 @@ func handleStatusChange(site models.Site, rawStatus string, code int, latency ti
if site.Type == "http" && site.CheckSSL && site.HasSSL {
daysLeft := int(time.Until(site.CertExpiry).Hours() / 24)
if daysLeft <= site.ExpiryThreshold && !site.SentSSLWarning && rawStatus != "SSL EXP" {
- triggerAlert(site.AlertID, "SSL WARNING", fmt.Sprintf("SSL for '%s' expires in %d days", site.Name, daysLeft))
+ e.triggerAlert(site.AlertID, "SSL WARNING", fmt.Sprintf("SSL for '%s' expires in %d days", site.Name, daysLeft))
newState.SentSSLWarning = true
} else if daysLeft > site.ExpiryThreshold {
newState.SentSSLWarning = false
}
}
- Mutex.Lock()
- if _, ok := LiveState[site.ID]; ok {
- LiveState[site.ID] = newState
+ e.mu.Lock()
+ if _, ok := e.liveState[site.ID]; ok {
+ e.liveState[site.ID] = newState
}
- Mutex.Unlock()
+ e.mu.Unlock()
- RecordCheck(site.ID, latency, rawStatus == "UP")
+ e.recordCheck(site.ID, latency, rawStatus == "UP")
isBroken := func(s string) bool { return s == "DOWN" || s == "SSL EXP" }
if !isBroken(site.Status) && isBroken(newState.Status) && newState.Status != "PENDING" {
@@ -349,25 +439,26 @@ func handleStatusChange(site models.Site, rawStatus string, code int, latency ti
if site.Type == "push" {
msg = fmt.Sprintf("Push Monitor '%s' missed heartbeat.", site.Name)
}
- triggerAlert(site.AlertID, "๐จ ALERT", msg)
+ e.triggerAlert(site.AlertID, "๐จ ALERT", msg)
}
if isBroken(site.Status) && newState.Status == "UP" {
- triggerAlert(site.AlertID, "โ
RECOVERY", fmt.Sprintf("Monitor '%s' is UP", site.Name))
+ e.triggerAlert(site.AlertID, "โ
RECOVERY", fmt.Sprintf("Monitor '%s' is UP", site.Name))
}
}
-func triggerAlert(alertID int, title, message string) {
- s_instance := store.Get()
- if s_instance == nil {
- return
- }
- cfg, ok := s_instance.GetAlert(alertID)
- if !ok {
+func (e *Engine) triggerAlert(alertID int, title, message string) {
+ cfg, err := e.db.GetAlert(alertID)
+ if err != nil {
return
}
provider := alert.GetProvider(cfg)
if provider != nil {
- go func() { provider.Send(title, message) }()
+ go func() {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+ _ = ctx
+ _ = provider.Send(title, message)
+ }()
}
}
@@ -378,7 +469,29 @@ func siteTimeout(site models.Site) time.Duration {
return 5 * time.Second
}
-func checkPing(site models.Site) {
+func isCodeAccepted(code int, accepted string) bool {
+ if accepted == "" {
+ return code >= 200 && code < 300
+ }
+ for _, part := range strings.Split(accepted, ",") {
+ part = strings.TrimSpace(part)
+ if strings.Contains(part, "-") {
+ bounds := strings.SplitN(part, "-", 2)
+ lo, err1 := strconv.Atoi(strings.TrimSpace(bounds[0]))
+ hi, err2 := strconv.Atoi(strings.TrimSpace(bounds[1]))
+ if err1 == nil && err2 == nil && code >= lo && code <= hi {
+ return true
+ }
+ } else {
+ if v, err := strconv.Atoi(part); err == nil && code == v {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+func (e *Engine) checkPing(site models.Site) {
host := site.Hostname
if host == "" {
host = site.URL
@@ -386,8 +499,8 @@ func checkPing(site models.Site) {
pinger, err := probing.NewPinger(host)
if err != nil {
- handleStatusChange(site, "DOWN", 0, 0)
- AddLog(fmt.Sprintf("Ping '%s' resolve failed: %v", site.Name, err))
+ e.handleStatusChange(site, "DOWN", 0, 0)
+ e.AddLog(fmt.Sprintf("Ping '%s' resolve failed: %v", site.Name, err))
return
}
pinger.Count = 1
@@ -402,7 +515,7 @@ func checkPing(site models.Site) {
updatedSite := site
updatedSite.Latency = latency
updatedSite.LastCheck = time.Now()
- handleStatusChange(updatedSite, "DOWN", 0, latency)
+ e.handleStatusChange(updatedSite, "DOWN", 0, latency)
return
}
@@ -410,10 +523,10 @@ func checkPing(site models.Site) {
updatedSite := site
updatedSite.Latency = stats.AvgRtt
updatedSite.LastCheck = time.Now()
- handleStatusChange(updatedSite, "UP", 0, stats.AvgRtt)
+ e.handleStatusChange(updatedSite, "UP", 0, stats.AvgRtt)
}
-func checkPort(site models.Site) {
+func (e *Engine) checkPort(site models.Site) {
host := site.Hostname
if host == "" {
host = site.URL
@@ -430,19 +543,19 @@ func checkPort(site models.Site) {
updatedSite.LastCheck = time.Now()
if err != nil {
- handleStatusChange(updatedSite, "DOWN", 0, latency)
+ e.handleStatusChange(updatedSite, "DOWN", 0, latency)
return
}
conn.Close()
- handleStatusChange(updatedSite, "UP", 0, latency)
+ e.handleStatusChange(updatedSite, "UP", 0, latency)
}
-func checkGroup(site models.Site) {
- Mutex.RLock()
+func (e *Engine) checkGroup(site models.Site) {
+ e.mu.RLock()
status := "UP"
hasChildren := false
allPaused := true
- for _, child := range LiveState {
+ for _, child := range e.liveState {
if child.ParentID != site.ID || child.Type == "group" {
continue
}
@@ -459,23 +572,23 @@ func checkGroup(site models.Site) {
status = "PENDING"
}
}
- Mutex.RUnlock()
+ e.mu.RUnlock()
if !hasChildren {
status = "PENDING"
}
- Mutex.Lock()
- s := LiveState[site.ID]
+ e.mu.Lock()
+ s := e.liveState[site.ID]
s.Status = status
if hasChildren && allPaused {
s.Paused = true
}
- LiveState[site.ID] = s
- Mutex.Unlock()
+ e.liveState[site.ID] = s
+ e.mu.Unlock()
}
-func checkDNS(site models.Site) {
+func (e *Engine) checkDNS(site models.Site) {
host := site.Hostname
if host == "" {
host = site.URL
@@ -516,8 +629,7 @@ func checkDNS(site models.Site) {
c.Timeout = siteTimeout(site)
start := time.Now()
- r, rtt, err := c.Exchange(m, server)
- _ = rtt
+ r, _, err := c.Exchange(m, server)
latency := time.Since(start)
updatedSite := site
@@ -525,14 +637,14 @@ func checkDNS(site models.Site) {
updatedSite.LastCheck = time.Now()
if err != nil {
- handleStatusChange(updatedSite, "DOWN", 0, latency)
+ e.handleStatusChange(updatedSite, "DOWN", 0, latency)
return
}
if r.Rcode != dns.RcodeSuccess {
- handleStatusChange(updatedSite, "DOWN", r.Rcode, latency)
+ e.handleStatusChange(updatedSite, "DOWN", r.Rcode, latency)
return
}
- handleStatusChange(updatedSite, "UP", 0, latency)
+ e.handleStatusChange(updatedSite, "UP", 0, latency)
}
diff --git a/internal/server/server.go b/internal/server/server.go
index f814cc3..fdf7f9b 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -4,14 +4,144 @@ import (
"encoding/json"
"fmt"
"go-upkeep/internal/importer"
+ "go-upkeep/internal/metrics"
"go-upkeep/internal/models"
"go-upkeep/internal/monitor"
"go-upkeep/internal/store"
"html/template"
+ "log"
"net/http"
"sort"
)
+var statusTpl = template.Must(template.New("status").Parse(`
+
+
+
+ {{.Title}}
+
+
+
+
+
+
{{.Title}}
+
+
+
+
Powered by Go-Upkeep
+
+
+
+`))
+
type ServerConfig struct {
Port int
EnableStatus bool
@@ -19,7 +149,7 @@ type ServerConfig struct {
ClusterKey string // Shared Secret for Security
}
-func Start(cfg ServerConfig) {
+func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) {
if cfg.ClusterKey == "" {
fmt.Println("WARNING: No UPKEEP_CLUSTER_SECRET set. Cluster API endpoints are unauthenticated.")
}
@@ -32,7 +162,7 @@ func Start(cfg ServerConfig) {
http.Error(w, "Missing token", 400)
return
}
- if monitor.RecordHeartbeat(token) {
+ if eng.RecordHeartbeat(token) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
} else {
@@ -56,7 +186,12 @@ func Start(cfg ServerConfig) {
http.Error(w, "Unauthorized: UPKEEP_CLUSTER_SECRET required", 401)
return
}
- data := store.Get().ExportData()
+ data, err := s.ExportData()
+ if err != nil {
+ log.Printf("Export failed: %v", err)
+ http.Error(w, "Export failed", 500)
+ return
+ }
json.NewEncoder(w).Encode(data)
})
@@ -75,8 +210,9 @@ func Start(cfg ServerConfig) {
http.Error(w, "Invalid JSON", 400)
return
}
- if err := store.Get().ImportData(data); err != nil {
- http.Error(w, "Import Failed: "+err.Error(), 500)
+ if err := s.ImportData(data); err != nil {
+ log.Printf("Import failed: %v", err)
+ http.Error(w, "Import failed", 500)
return
}
w.Write([]byte("Import Successful"))
@@ -94,42 +230,42 @@ func Start(cfg ServerConfig) {
}
var kb importer.KumaBackup
if err := json.NewDecoder(r.Body).Decode(&kb); err != nil {
- http.Error(w, "Invalid Kuma JSON: "+err.Error(), 400)
+ log.Printf("Invalid Kuma JSON: %v", err)
+ http.Error(w, "Invalid Kuma JSON", 400)
return
}
backup := importer.ConvertKuma(&kb)
- if err := store.Get().ImportData(backup); err != nil {
- http.Error(w, "Import Failed: "+err.Error(), 500)
+ if err := s.ImportData(backup); err != nil {
+ log.Printf("Kuma import failed: %v", err)
+ http.Error(w, "Import failed", 500)
return
}
w.Write([]byte(fmt.Sprintf("Imported %d monitors, %d alerts from Kuma v%s", len(backup.Sites), len(backup.Alerts), kb.Version)))
})
- // 6. Status Page
+ // 6. Prometheus Metrics
+ mux.HandleFunc("/metrics", metrics.Handler(eng))
+
+ // 7. Status Page
if cfg.EnableStatus {
- mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) { renderStatusPage(w, cfg.Title) })
+ mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) { renderStatusPage(w, cfg.Title, eng) })
mux.HandleFunc("/status/json", func(w http.ResponseWriter, r *http.Request) {
- monitor.Mutex.RLock()
- defer monitor.Mutex.RUnlock()
w.Header().Set("Content-Type", "application/json")
- json.NewEncoder(w).Encode(monitor.LiveState)
+ json.NewEncoder(w).Encode(eng.GetLiveState())
})
}
go func() {
addr := fmt.Sprintf(":%d", cfg.Port)
fmt.Printf("HTTP Server listening on %s\n", addr)
- http.ListenAndServe(addr, mux)
+ if err := http.ListenAndServe(addr, mux); err != nil {
+ log.Fatalf("HTTP server failed: %v", err)
+ }
}()
}
-func renderStatusPage(w http.ResponseWriter, title string) {
- monitor.Mutex.RLock()
- var sites []models.Site
- for _, s := range monitor.LiveState {
- sites = append(sites, s)
- }
- monitor.Mutex.RUnlock()
+func renderStatusPage(w http.ResponseWriter, title string, eng *monitor.Engine) {
+ sites := eng.GetAllSites()
sort.Slice(sites, func(i, j int) bool {
if sites[i].Status != sites[j].Status {
@@ -143,138 +279,9 @@ func renderStatusPage(w http.ResponseWriter, title string) {
return sites[i].Name < sites[j].Name
})
- const tpl = `
-
-
-
- {{.Title}}
-
-
-
-
-
-
{{.Title}}
-
-
-
-
Powered by Go-Upkeep
-
-
-
- `
-
- t, _ := template.New("status").Parse(tpl)
data := struct {
Title string
Sites []models.Site
}{Title: title, Sites: sites}
- t.Execute(w, data)
+ statusTpl.Execute(w, data)
}
diff --git a/internal/store/dialect.go b/internal/store/dialect.go
new file mode 100644
index 0000000..4e1ba04
--- /dev/null
+++ b/internal/store/dialect.go
@@ -0,0 +1,36 @@
+package store
+
+import "database/sql"
+
+type Dialect interface {
+ DriverName() string
+ CreateTablesSQL() []string
+ MigrationsSQL() []string
+ BoolFalse() string
+ ResetSequenceOnEmpty(db *sql.DB, table string)
+ ImportWipe(tx *sql.Tx)
+ ImportResetSequences(tx *sql.Tx)
+}
+
+// rewritePlaceholders converts ? markers to $1, $2, etc. for Postgres.
+// For SQLite (or any dialect not needing rewrite), returns the input unchanged.
+func rewritePlaceholders(query string, dollarStyle bool) string {
+ if !dollarStyle {
+ return query
+ }
+ buf := make([]byte, 0, len(query)+32)
+ n := 0
+ for i := 0; i < len(query); i++ {
+ if query[i] == '?' {
+ n++
+ buf = append(buf, '$')
+ if n >= 10 {
+ buf = append(buf, byte('0'+n/10))
+ }
+ buf = append(buf, byte('0'+n%10))
+ } else {
+ buf = append(buf, query[i])
+ }
+ }
+ return string(buf)
+}
diff --git a/internal/store/postgres.go b/internal/store/postgres.go
index c5201d0..78fcc8d 100644
--- a/internal/store/postgres.go
+++ b/internal/store/postgres.go
@@ -2,77 +2,53 @@ package store
import (
"database/sql"
- "encoding/json"
- "go-upkeep/internal/models"
_ "github.com/lib/pq"
)
-type PostgresStore struct {
- ConnStr string
- db *sql.DB
+type PostgresDialect struct{}
+
+func NewPostgresStore(connStr string) (*SQLStore, error) {
+ return NewSQLStore("postgres", connStr, &PostgresDialect{})
}
-func (p *PostgresStore) Init() error {
- var err error
- p.db, err = sql.Open("postgres", p.ConnStr)
- if err != nil {
- return err
- }
+func (d *PostgresDialect) DriverName() string { return "postgres" }
+func (d *PostgresDialect) BoolFalse() string { return "FALSE" }
- queries := []string{
+func (d *PostgresDialect) CreateTablesSQL() []string {
+ return []string{
`CREATE TABLE IF NOT EXISTS alerts (
id SERIAL PRIMARY KEY,
- name TEXT,
- type TEXT,
- settings TEXT
- );`,
+ name TEXT, type TEXT, settings TEXT
+ )`,
`CREATE TABLE IF NOT EXISTS sites (
id SERIAL PRIMARY KEY,
- name TEXT DEFAULT 'New Monitor',
- url TEXT,
- type TEXT DEFAULT 'http',
- token TEXT,
- interval INTEGER,
- alert_id INTEGER,
- check_ssl BOOLEAN DEFAULT FALSE,
- threshold INTEGER DEFAULT 7,
- max_retries INTEGER DEFAULT 0,
- hostname TEXT DEFAULT '',
- port INTEGER DEFAULT 0,
- timeout INTEGER DEFAULT 0,
- method TEXT DEFAULT 'GET',
- description TEXT DEFAULT '',
- parent_id INTEGER DEFAULT 0,
- accepted_codes TEXT DEFAULT '200-299',
- dns_resolve_type TEXT DEFAULT '',
- dns_server TEXT DEFAULT '',
- ignore_tls BOOLEAN DEFAULT FALSE,
- paused BOOLEAN DEFAULT FALSE
- );`,
+ name TEXT DEFAULT 'New Monitor', url TEXT, type TEXT DEFAULT 'http',
+ token TEXT, interval INTEGER, alert_id INTEGER,
+ check_ssl BOOLEAN DEFAULT FALSE, threshold INTEGER DEFAULT 7,
+ max_retries INTEGER DEFAULT 0, hostname TEXT DEFAULT '',
+ port INTEGER DEFAULT 0, timeout INTEGER DEFAULT 0,
+ method TEXT DEFAULT 'GET', description TEXT DEFAULT '',
+ parent_id INTEGER DEFAULT 0, accepted_codes TEXT DEFAULT '200-299',
+ dns_resolve_type TEXT DEFAULT '', dns_server TEXT DEFAULT '',
+ ignore_tls BOOLEAN DEFAULT FALSE, paused BOOLEAN DEFAULT FALSE
+ )`,
`CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
- username TEXT NOT NULL,
- public_key TEXT NOT NULL,
+ username TEXT NOT NULL, public_key TEXT NOT NULL,
role TEXT DEFAULT 'user'
- );`,
+ )`,
`CREATE TABLE IF NOT EXISTS check_history (
id SERIAL PRIMARY KEY,
- site_id INTEGER NOT NULL,
- latency_ns BIGINT,
- is_up BOOLEAN,
- checked_at TIMESTAMP DEFAULT NOW()
- );`,
- }
- for _, q := range queries {
- if _, err := p.db.Exec(q); err != nil {
- return err
- }
+ site_id INTEGER NOT NULL, latency_ns BIGINT,
+ is_up BOOLEAN, checked_at TIMESTAMP DEFAULT NOW()
+ )`,
+ `CREATE INDEX IF NOT EXISTS idx_check_history_site ON check_history(site_id, checked_at DESC)`,
}
+}
- p.db.Exec("CREATE INDEX IF NOT EXISTS idx_check_history_site ON check_history(site_id, checked_at DESC)")
-
- migrations := []string{
+func (d *PostgresDialect) MigrationsSQL() []string {
+ return []string{
"ALTER TABLE sites ADD COLUMN IF NOT EXISTS hostname TEXT DEFAULT ''",
"ALTER TABLE sites ADD COLUMN IF NOT EXISTS port INTEGER DEFAULT 0",
"ALTER TABLE sites ADD COLUMN IF NOT EXISTS timeout INTEGER DEFAULT 0",
@@ -85,181 +61,18 @@ func (p *PostgresStore) Init() error {
"ALTER TABLE sites ADD COLUMN IF NOT EXISTS ignore_tls BOOLEAN DEFAULT FALSE",
"ALTER TABLE sites ADD COLUMN IF NOT EXISTS paused BOOLEAN DEFAULT FALSE",
}
- for _, m := range migrations {
- p.db.Exec(m)
- }
-
- return nil
}
-// ... [CRUD Methods are identical to Phase 4, keeping them concise here] ...
-func (p *PostgresStore) GetSites() []models.Site {
- rows, err := p.db.Query("SELECT id, COALESCE(name, url), url, COALESCE(type, 'http'), COALESCE(token, ''), interval, alert_id, check_ssl, threshold, max_retries, COALESCE(hostname, ''), COALESCE(port, 0), COALESCE(timeout, 0), COALESCE(method, 'GET'), COALESCE(description, ''), COALESCE(parent_id, 0), COALESCE(accepted_codes, '200-299'), COALESCE(dns_resolve_type, ''), COALESCE(dns_server, ''), COALESCE(ignore_tls, FALSE), COALESCE(paused, FALSE) FROM sites")
- if err != nil {
- return []models.Site{}
- }
- defer rows.Close()
- var sites []models.Site
- for rows.Next() {
- var s models.Site
- rows.Scan(&s.ID, &s.Name, &s.URL, &s.Type, &s.Token, &s.Interval, &s.AlertID, &s.CheckSSL, &s.ExpiryThreshold, &s.MaxRetries,
- &s.Hostname, &s.Port, &s.Timeout, &s.Method, &s.Description, &s.ParentID, &s.AcceptedCodes, &s.DNSResolveType, &s.DNSServer, &s.IgnoreTLS, &s.Paused)
- sites = append(sites, s)
- }
- return sites
-}
-func (p *PostgresStore) AddSite(site models.Site) {
- token := ""
- if site.Type == "push" {
- token = generateToken()
- }
- p.db.Exec("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20)",
- site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries,
- site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused)
-}
-func (p *PostgresStore) UpdateSite(site models.Site) {
- var existingToken string
- p.db.QueryRow("SELECT token FROM sites WHERE id=$1", site.ID).Scan(&existingToken)
- if site.Type == "push" && existingToken == "" {
- existingToken = generateToken()
- }
- p.db.Exec("UPDATE sites SET name=$1, url=$2, type=$3, token=$4, interval=$5, alert_id=$6, check_ssl=$7, threshold=$8, max_retries=$9, hostname=$10, port=$11, timeout=$12, method=$13, description=$14, parent_id=$15, accepted_codes=$16, dns_resolve_type=$17, dns_server=$18, ignore_tls=$19, paused=$20 WHERE id=$21",
- site.Name, site.URL, site.Type, existingToken, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries,
- site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.ID)
-}
-func (p *PostgresStore) UpdateSitePaused(id int, paused bool) {
- p.db.Exec("UPDATE sites SET paused=$1 WHERE id=$2", paused, id)
-}
-func (p *PostgresStore) DeleteSite(id int) { p.db.Exec("DELETE FROM sites WHERE id=$1", id) }
-func (p *PostgresStore) GetAllAlerts() []models.AlertConfig {
- rows, err := p.db.Query("SELECT id, name, type, settings FROM alerts")
- if err != nil {
- return []models.AlertConfig{}
- }
- defer rows.Close()
- var alerts []models.AlertConfig
- for rows.Next() {
- var a models.AlertConfig
- var settingsJSON string
- rows.Scan(&a.ID, &a.Name, &a.Type, &settingsJSON)
- json.Unmarshal([]byte(settingsJSON), &a.Settings)
- alerts = append(alerts, a)
- }
- return alerts
-}
-func (p *PostgresStore) GetAlert(id int) (models.AlertConfig, bool) {
- var a models.AlertConfig
- var settingsJSON string
- err := p.db.QueryRow("SELECT id, name, type, settings FROM alerts WHERE id = $1", id).Scan(&a.ID, &a.Name, &a.Type, &settingsJSON)
- if err != nil {
- return a, false
- }
- json.Unmarshal([]byte(settingsJSON), &a.Settings)
- return a, true
-}
-func (p *PostgresStore) AddAlert(name, aType string, settings map[string]string) {
- jsonBytes, _ := json.Marshal(settings)
- p.db.Exec("INSERT INTO alerts (name, type, settings) VALUES ($1, $2, $3)", name, aType, string(jsonBytes))
-}
-func (p *PostgresStore) UpdateAlert(id int, name, aType string, settings map[string]string) {
- jsonBytes, _ := json.Marshal(settings)
- p.db.Exec("UPDATE alerts SET name=$1, type=$2, settings=$3 WHERE id=$4", name, aType, string(jsonBytes), id)
-}
-func (p *PostgresStore) DeleteAlert(id int) { p.db.Exec("DELETE FROM alerts WHERE id=$1", id) }
-func (p *PostgresStore) GetAllUsers() []models.User {
- rows, err := p.db.Query("SELECT id, username, public_key, role FROM users")
- if err != nil {
- return []models.User{}
- }
- defer rows.Close()
- var users []models.User
- for rows.Next() {
- var u models.User
- rows.Scan(&u.ID, &u.Username, &u.PublicKey, &u.Role)
- users = append(users, u)
- }
- return users
-}
-func (p *PostgresStore) AddUser(username, publicKey, role string) error {
- _, err := p.db.Exec("INSERT INTO users (username, public_key, role) VALUES ($1, $2, $3)", username, publicKey, role)
- return err
-}
-func (p *PostgresStore) UpdateUser(id int, username, publicKey, role string) error {
- _, err := p.db.Exec("UPDATE users SET username=$1, public_key=$2, role=$3 WHERE id=$4", username, publicKey, role, id)
- return err
-}
-func (p *PostgresStore) DeleteUser(id int) error {
- _, err := p.db.Exec("DELETE FROM users WHERE id=$1", id)
- return err
-}
-
-func (p *PostgresStore) SaveCheck(siteID int, latencyNs int64, isUp bool) {
- p.db.Exec("INSERT INTO check_history (site_id, latency_ns, is_up) VALUES ($1, $2, $3)", siteID, latencyNs, isUp)
- p.db.Exec(`DELETE FROM check_history WHERE site_id = $1 AND id NOT IN (
- SELECT id FROM check_history WHERE site_id = $1 ORDER BY checked_at DESC LIMIT 1000
- )`, siteID)
-}
-
-func (p *PostgresStore) LoadAllHistory(limit int) map[int][]models.CheckRecord {
- result := make(map[int][]models.CheckRecord)
- rows, err := p.db.Query(`
- SELECT site_id, latency_ns, is_up FROM (
- SELECT site_id, latency_ns, is_up,
- ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY checked_at DESC) AS rn
- FROM check_history
- ) sub WHERE rn <= $1`, limit)
- if err != nil {
- return result
- }
- defer rows.Close()
- for rows.Next() {
- var r models.CheckRecord
- rows.Scan(&r.SiteID, &r.LatencyNs, &r.IsUp)
- result[r.SiteID] = append(result[r.SiteID], r)
- }
- for id, records := range result {
- for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 {
- records[i], records[j] = records[j], records[i]
- }
- result[id] = records
- }
- return result
-}
-
-func (p *PostgresStore) ExportData() models.Backup {
- return models.Backup{
- Sites: p.GetSites(),
- Alerts: p.GetAllAlerts(),
- Users: p.GetAllUsers(),
- }
-}
-
-func (p *PostgresStore) ImportData(data models.Backup) error {
- tx, err := p.db.Begin()
- if err != nil {
- return err
- }
+func (d *PostgresDialect) ResetSequenceOnEmpty(db *sql.DB, table string) {}
+func (d *PostgresDialect) ImportWipe(tx *sql.Tx) {
tx.Exec("TRUNCATE TABLE sites RESTART IDENTITY CASCADE")
tx.Exec("TRUNCATE TABLE alerts RESTART IDENTITY CASCADE")
tx.Exec("TRUNCATE TABLE users RESTART IDENTITY CASCADE")
-
- for _, u := range data.Users {
- tx.Exec("INSERT INTO users (username, public_key, role) VALUES ($1, $2, $3)", u.Username, u.PublicKey, u.Role)
- }
- for _, a := range data.Alerts {
- jsonBytes, _ := json.Marshal(a.Settings)
- tx.Exec("INSERT INTO alerts (id, name, type, settings) VALUES ($1, $2, $3, $4)", a.ID, a.Name, a.Type, string(jsonBytes))
- }
- for _, st := range data.Sites {
- tx.Exec("INSERT INTO sites (id, name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)",
- st.ID, st.Name, st.URL, st.Type, st.Token, st.Interval, st.AlertID, st.CheckSSL, st.ExpiryThreshold, st.MaxRetries,
- st.Hostname, st.Port, st.Timeout, st.Method, st.Description, st.ParentID, st.AcceptedCodes, st.DNSResolveType, st.DNSServer, st.IgnoreTLS, st.Paused)
- }
-
- tx.Exec("SELECT setval('sites_id_seq', (SELECT MAX(id) FROM sites))")
- tx.Exec("SELECT setval('alerts_id_seq', (SELECT MAX(id) FROM alerts))")
- tx.Exec("SELECT setval('users_id_seq', (SELECT MAX(id) FROM users))")
-
- return tx.Commit()
+}
+
+func (d *PostgresDialect) ImportResetSequences(tx *sql.Tx) {
+ tx.Exec("SELECT setval('sites_id_seq', (SELECT COALESCE(MAX(id), 1) FROM sites))")
+ tx.Exec("SELECT setval('alerts_id_seq', (SELECT COALESCE(MAX(id), 1) FROM alerts))")
+ tx.Exec("SELECT setval('users_id_seq', (SELECT COALESCE(MAX(id), 1) FROM users))")
}
diff --git a/internal/store/sqlite.go b/internal/store/sqlite.go
index e83d96e..dbeb74d 100644
--- a/internal/store/sqlite.go
+++ b/internal/store/sqlite.go
@@ -1,77 +1,54 @@
package store
import (
- "crypto/rand"
"database/sql"
- "encoding/hex"
- "encoding/json"
- "go-upkeep/internal/models"
_ "github.com/mattn/go-sqlite3"
)
-type SQLiteStore struct {
- DBPath string
- db *sql.DB
+type SQLiteDialect struct{}
+
+func NewSQLiteStore(path string) (*SQLStore, error) {
+ return NewSQLStore("sqlite3", path, &SQLiteDialect{})
}
-func (s *SQLiteStore) Init() error {
- var err error
- s.db, err = sql.Open("sqlite3", s.DBPath)
- if err != nil {
- return err
- }
+func (d *SQLiteDialect) DriverName() string { return "sqlite3" }
+func (d *SQLiteDialect) BoolFalse() string { return "0" }
- createTables := `
- CREATE TABLE IF NOT EXISTS alerts (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- name TEXT,
- type TEXT,
- settings TEXT
- );
- CREATE TABLE IF NOT EXISTS sites (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- name TEXT DEFAULT 'New Monitor',
- url TEXT,
- type TEXT DEFAULT 'http',
- token TEXT,
- interval INTEGER,
- alert_id INTEGER,
- check_ssl BOOLEAN DEFAULT 0,
- threshold INTEGER DEFAULT 7,
- max_retries INTEGER DEFAULT 0,
- hostname TEXT DEFAULT '',
- port INTEGER DEFAULT 0,
- timeout INTEGER DEFAULT 0,
- method TEXT DEFAULT 'GET',
- description TEXT DEFAULT '',
- parent_id INTEGER DEFAULT 0,
- accepted_codes TEXT DEFAULT '200-299',
- dns_resolve_type TEXT DEFAULT '',
- dns_server TEXT DEFAULT '',
- ignore_tls BOOLEAN DEFAULT 0,
- paused BOOLEAN DEFAULT 0
- );
- CREATE TABLE IF NOT EXISTS users (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- username TEXT NOT NULL,
- public_key TEXT NOT NULL,
- role TEXT DEFAULT 'user'
- );
- CREATE TABLE IF NOT EXISTS check_history (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- site_id INTEGER NOT NULL,
- latency_ns INTEGER,
- is_up BOOLEAN,
- checked_at DATETIME DEFAULT CURRENT_TIMESTAMP
- );
- CREATE INDEX IF NOT EXISTS idx_check_history_site ON check_history(site_id, checked_at DESC);`
- _, err = s.db.Exec(createTables)
- if err != nil {
- return err
+func (d *SQLiteDialect) CreateTablesSQL() []string {
+ return []string{
+ `CREATE TABLE IF NOT EXISTS alerts (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT, type TEXT, settings TEXT
+ )`,
+ `CREATE TABLE IF NOT EXISTS sites (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT DEFAULT 'New Monitor', url TEXT, type TEXT DEFAULT 'http',
+ token TEXT, interval INTEGER, alert_id INTEGER,
+ check_ssl BOOLEAN DEFAULT 0, threshold INTEGER DEFAULT 7,
+ max_retries INTEGER DEFAULT 0, hostname TEXT DEFAULT '',
+ port INTEGER DEFAULT 0, timeout INTEGER DEFAULT 0,
+ method TEXT DEFAULT 'GET', description TEXT DEFAULT '',
+ parent_id INTEGER DEFAULT 0, accepted_codes TEXT DEFAULT '200-299',
+ dns_resolve_type TEXT DEFAULT '', dns_server TEXT DEFAULT '',
+ ignore_tls BOOLEAN DEFAULT 0, paused BOOLEAN DEFAULT 0
+ )`,
+ `CREATE TABLE IF NOT EXISTS users (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ username TEXT NOT NULL, public_key TEXT NOT NULL,
+ role TEXT DEFAULT 'user'
+ )`,
+ `CREATE TABLE IF NOT EXISTS check_history (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ site_id INTEGER NOT NULL, latency_ns INTEGER,
+ is_up BOOLEAN, checked_at DATETIME DEFAULT CURRENT_TIMESTAMP
+ )`,
+ `CREATE INDEX IF NOT EXISTS idx_check_history_site ON check_history(site_id, checked_at DESC)`,
}
+}
- migrations := []string{
+func (d *SQLiteDialect) MigrationsSQL() []string {
+ return []string{
"ALTER TABLE sites ADD COLUMN hostname TEXT DEFAULT ''",
"ALTER TABLE sites ADD COLUMN port INTEGER DEFAULT 0",
"ALTER TABLE sites ADD COLUMN timeout INTEGER DEFAULT 0",
@@ -84,202 +61,23 @@ func (s *SQLiteStore) Init() error {
"ALTER TABLE sites ADD COLUMN ignore_tls BOOLEAN DEFAULT 0",
"ALTER TABLE sites ADD COLUMN paused BOOLEAN DEFAULT 0",
}
- for _, m := range migrations {
- s.db.Exec(m)
- }
-
- return nil
}
-func generateToken() string {
- b := make([]byte, 16)
- if _, err := rand.Read(b); err != nil {
- panic("crypto/rand failed: " + err.Error())
- }
- return hex.EncodeToString(b)
-}
-
-func (s *SQLiteStore) GetSites() []models.Site {
- rows, err := s.db.Query("SELECT id, COALESCE(name, url), url, COALESCE(type, 'http'), COALESCE(token, ''), interval, alert_id, check_ssl, threshold, max_retries, COALESCE(hostname, ''), COALESCE(port, 0), COALESCE(timeout, 0), COALESCE(method, 'GET'), COALESCE(description, ''), COALESCE(parent_id, 0), COALESCE(accepted_codes, '200-299'), COALESCE(dns_resolve_type, ''), COALESCE(dns_server, ''), COALESCE(ignore_tls, 0), COALESCE(paused, 0) FROM sites")
- if err != nil {
- return []models.Site{}
- }
- defer rows.Close()
- var sites []models.Site
- for rows.Next() {
- var st models.Site
- rows.Scan(&st.ID, &st.Name, &st.URL, &st.Type, &st.Token, &st.Interval, &st.AlertID, &st.CheckSSL, &st.ExpiryThreshold, &st.MaxRetries, &st.Hostname, &st.Port, &st.Timeout, &st.Method, &st.Description, &st.ParentID, &st.AcceptedCodes, &st.DNSResolveType, &st.DNSServer, &st.IgnoreTLS, &st.Paused)
- sites = append(sites, st)
- }
- return sites
-}
-func (s *SQLiteStore) AddSite(site models.Site) {
- token := ""
- if site.Type == "push" {
- token = generateToken()
- }
- s.db.Exec("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
- site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries,
- site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused)
-}
-func (s *SQLiteStore) UpdateSite(site models.Site) {
- var existingToken string
- s.db.QueryRow("SELECT token FROM sites WHERE id=?", site.ID).Scan(&existingToken)
- if site.Type == "push" && existingToken == "" {
- existingToken = generateToken()
- }
- s.db.Exec("UPDATE sites SET name=?, url=?, type=?, token=?, interval=?, alert_id=?, check_ssl=?, threshold=?, max_retries=?, hostname=?, port=?, timeout=?, method=?, description=?, parent_id=?, accepted_codes=?, dns_resolve_type=?, dns_server=?, ignore_tls=?, paused=? WHERE id=?",
- site.Name, site.URL, site.Type, existingToken, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries,
- site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.ID)
-}
-func (s *SQLiteStore) UpdateSitePaused(id int, paused bool) {
- s.db.Exec("UPDATE sites SET paused=? WHERE id=?", paused, id)
-}
-func (s *SQLiteStore) DeleteSite(id int) {
- s.db.Exec("DELETE FROM sites WHERE id=?", id)
+func (d *SQLiteDialect) ResetSequenceOnEmpty(db *sql.DB, table string) {
var count int
- s.db.QueryRow("SELECT COUNT(*) FROM sites").Scan(&count)
+ db.QueryRow("SELECT COUNT(*) FROM " + table).Scan(&count)
if count == 0 {
- s.db.Exec("DELETE FROM sqlite_sequence WHERE name='sites'")
- }
-}
-func (s *SQLiteStore) GetAllAlerts() []models.AlertConfig {
- rows, err := s.db.Query("SELECT id, name, type, settings FROM alerts")
- if err != nil {
- return []models.AlertConfig{}
- }
- defer rows.Close()
- var alerts []models.AlertConfig
- for rows.Next() {
- var a models.AlertConfig
- var settingsJSON string
- rows.Scan(&a.ID, &a.Name, &a.Type, &settingsJSON)
- json.Unmarshal([]byte(settingsJSON), &a.Settings)
- alerts = append(alerts, a)
- }
- return alerts
-}
-func (s *SQLiteStore) GetAlert(id int) (models.AlertConfig, bool) {
- var a models.AlertConfig
- var settingsJSON string
- err := s.db.QueryRow("SELECT id, name, type, settings FROM alerts WHERE id = ?", id).Scan(&a.ID, &a.Name, &a.Type, &settingsJSON)
- if err != nil {
- return a, false
- }
- json.Unmarshal([]byte(settingsJSON), &a.Settings)
- return a, true
-}
-func (s *SQLiteStore) AddAlert(name, aType string, settings map[string]string) {
- jsonBytes, _ := json.Marshal(settings)
- s.db.Exec("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?)", name, aType, string(jsonBytes))
-}
-func (s *SQLiteStore) UpdateAlert(id int, name, aType string, settings map[string]string) {
- jsonBytes, _ := json.Marshal(settings)
- s.db.Exec("UPDATE alerts SET name=?, type=?, settings=? WHERE id=?", name, aType, string(jsonBytes), id)
-}
-func (s *SQLiteStore) DeleteAlert(id int) {
- s.db.Exec("DELETE FROM alerts WHERE id=?", id)
- var count int
- s.db.QueryRow("SELECT COUNT(*) FROM alerts").Scan(&count)
- if count == 0 {
- s.db.Exec("DELETE FROM sqlite_sequence WHERE name='alerts'")
- }
-}
-func (s *SQLiteStore) GetAllUsers() []models.User {
- rows, err := s.db.Query("SELECT id, username, public_key, role FROM users")
- if err != nil {
- return []models.User{}
- }
- defer rows.Close()
- var users []models.User
- for rows.Next() {
- var u models.User
- rows.Scan(&u.ID, &u.Username, &u.PublicKey, &u.Role)
- users = append(users, u)
- }
- return users
-}
-func (s *SQLiteStore) AddUser(username, publicKey, role string) error {
- _, err := s.db.Exec("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)", username, publicKey, role)
- return err
-}
-func (s *SQLiteStore) UpdateUser(id int, username, publicKey, role string) error {
- _, err := s.db.Exec("UPDATE users SET username=?, public_key=?, role=? WHERE id=?", username, publicKey, role, id)
- return err
-}
-func (s *SQLiteStore) DeleteUser(id int) error {
- _, err := s.db.Exec("DELETE FROM users WHERE id=?", id)
- return err
-}
-
-func (s *SQLiteStore) SaveCheck(siteID int, latencyNs int64, isUp bool) {
- s.db.Exec("INSERT INTO check_history (site_id, latency_ns, is_up) VALUES (?, ?, ?)", siteID, latencyNs, isUp)
- s.db.Exec(`DELETE FROM check_history WHERE site_id = ? AND id NOT IN (
- SELECT id FROM check_history WHERE site_id = ? ORDER BY checked_at DESC LIMIT 1000
- )`, siteID, siteID)
-}
-
-func (s *SQLiteStore) LoadAllHistory(limit int) map[int][]models.CheckRecord {
- result := make(map[int][]models.CheckRecord)
- rows, err := s.db.Query(`
- SELECT site_id, latency_ns, is_up FROM (
- SELECT site_id, latency_ns, is_up,
- ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY checked_at DESC) AS rn
- FROM check_history
- ) WHERE rn <= ?`, limit)
- if err != nil {
- return result
- }
- defer rows.Close()
- for rows.Next() {
- var r models.CheckRecord
- rows.Scan(&r.SiteID, &r.LatencyNs, &r.IsUp)
- result[r.SiteID] = append(result[r.SiteID], r)
- }
- for id, records := range result {
- for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 {
- records[i], records[j] = records[j], records[i]
- }
- result[id] = records
- }
- return result
-}
-
-func (s *SQLiteStore) ExportData() models.Backup {
- return models.Backup{
- Sites: s.GetSites(),
- Alerts: s.GetAllAlerts(),
- Users: s.GetAllUsers(),
+ db.Exec("DELETE FROM sqlite_sequence WHERE name=?", table)
}
}
-func (s *SQLiteStore) ImportData(data models.Backup) error {
- tx, err := s.db.Begin()
- if err != nil {
- return err
- }
-
- // Wipe Existing
+func (d *SQLiteDialect) ImportWipe(tx *sql.Tx) {
tx.Exec("DELETE FROM sites")
tx.Exec("DELETE FROM sqlite_sequence WHERE name='sites'")
tx.Exec("DELETE FROM alerts")
tx.Exec("DELETE FROM sqlite_sequence WHERE name='alerts'")
tx.Exec("DELETE FROM users")
tx.Exec("DELETE FROM sqlite_sequence WHERE name='users'")
-
- // Insert New
- for _, u := range data.Users {
- tx.Exec("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)", u.Username, u.PublicKey, u.Role)
- }
- for _, a := range data.Alerts {
- jsonBytes, _ := json.Marshal(a.Settings)
- tx.Exec("INSERT INTO alerts (id, name, type, settings) VALUES (?, ?, ?, ?)", a.ID, a.Name, a.Type, string(jsonBytes))
- }
- for _, st := range data.Sites {
- tx.Exec("INSERT INTO sites (id, name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
- st.ID, st.Name, st.URL, st.Type, st.Token, st.Interval, st.AlertID, st.CheckSSL, st.ExpiryThreshold, st.MaxRetries,
- st.Hostname, st.Port, st.Timeout, st.Method, st.Description, st.ParentID, st.AcceptedCodes, st.DNSResolveType, st.DNSServer, st.IgnoreTLS, st.Paused)
- }
-
- return tx.Commit()
}
+
+func (d *SQLiteDialect) ImportResetSequences(tx *sql.Tx) {}
diff --git a/internal/store/sqlstore.go b/internal/store/sqlstore.go
new file mode 100644
index 0000000..3142399
--- /dev/null
+++ b/internal/store/sqlstore.go
@@ -0,0 +1,291 @@
+package store
+
+import (
+ "crypto/rand"
+ "database/sql"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "go-upkeep/internal/models"
+)
+
+type SQLStore struct {
+ db *sql.DB
+ dialect Dialect
+ dollar bool
+}
+
+func NewSQLStore(driverName, dsn string, dialect Dialect) (*SQLStore, error) {
+ db, err := sql.Open(driverName, dsn)
+ if err != nil {
+ return nil, err
+ }
+ _, isDollar := dialect.(*PostgresDialect)
+ return &SQLStore{db: db, dialect: dialect, dollar: isDollar}, nil
+}
+
+func (s *SQLStore) q(query string) string {
+ return rewritePlaceholders(query, s.dollar)
+}
+
+func generateToken() string {
+ b := make([]byte, 16)
+ if _, err := rand.Read(b); err != nil {
+ panic("crypto/rand failed: " + err.Error())
+ }
+ return hex.EncodeToString(b)
+}
+
+func (s *SQLStore) Init() error {
+ for _, stmt := range s.dialect.CreateTablesSQL() {
+ if _, err := s.db.Exec(stmt); err != nil {
+ return err
+ }
+ }
+ for _, m := range s.dialect.MigrationsSQL() {
+ s.db.Exec(m)
+ }
+ return nil
+}
+
+func (s *SQLStore) GetSites() ([]models.Site, error) {
+ bf := s.dialect.BoolFalse()
+ query := fmt.Sprintf(
+ "SELECT id, COALESCE(name, url), url, COALESCE(type, 'http'), COALESCE(token, ''), interval, alert_id, check_ssl, threshold, max_retries, COALESCE(hostname, ''), COALESCE(port, 0), COALESCE(timeout, 0), COALESCE(method, 'GET'), COALESCE(description, ''), COALESCE(parent_id, 0), COALESCE(accepted_codes, '200-299'), COALESCE(dns_resolve_type, ''), COALESCE(dns_server, ''), COALESCE(ignore_tls, %s), COALESCE(paused, %s) FROM sites",
+ bf, bf,
+ )
+ rows, err := s.db.Query(query)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var sites []models.Site
+ for rows.Next() {
+ var st models.Site
+ if err := rows.Scan(&st.ID, &st.Name, &st.URL, &st.Type, &st.Token, &st.Interval, &st.AlertID,
+ &st.CheckSSL, &st.ExpiryThreshold, &st.MaxRetries, &st.Hostname, &st.Port, &st.Timeout,
+ &st.Method, &st.Description, &st.ParentID, &st.AcceptedCodes, &st.DNSResolveType,
+ &st.DNSServer, &st.IgnoreTLS, &st.Paused); err != nil {
+ return sites, err
+ }
+ sites = append(sites, st)
+ }
+ return sites, rows.Err()
+}
+
+func (s *SQLStore) AddSite(site models.Site) error {
+ token := ""
+ if site.Type == "push" {
+ token = generateToken()
+ }
+ _, err := s.db.Exec(s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"),
+ site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries,
+ site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused)
+ return err
+}
+
+func (s *SQLStore) UpdateSite(site models.Site) error {
+ var existingToken string
+ s.db.QueryRow(s.q("SELECT token FROM sites WHERE id=?"), site.ID).Scan(&existingToken)
+ if site.Type == "push" && existingToken == "" {
+ existingToken = generateToken()
+ }
+ _, err := s.db.Exec(s.q("UPDATE sites SET name=?, url=?, type=?, token=?, interval=?, alert_id=?, check_ssl=?, threshold=?, max_retries=?, hostname=?, port=?, timeout=?, method=?, description=?, parent_id=?, accepted_codes=?, dns_resolve_type=?, dns_server=?, ignore_tls=?, paused=? WHERE id=?"),
+ site.Name, site.URL, site.Type, existingToken, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries,
+ site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.ID)
+ return err
+}
+
+func (s *SQLStore) UpdateSitePaused(id int, paused bool) error {
+ _, err := s.db.Exec(s.q("UPDATE sites SET paused=? WHERE id=?"), paused, id)
+ return err
+}
+
+func (s *SQLStore) DeleteSite(id int) error {
+ _, err := s.db.Exec(s.q("DELETE FROM sites WHERE id=?"), id)
+ if err != nil {
+ return err
+ }
+ s.dialect.ResetSequenceOnEmpty(s.db, "sites")
+ return nil
+}
+
+func (s *SQLStore) GetAllAlerts() ([]models.AlertConfig, error) {
+ rows, err := s.db.Query("SELECT id, name, type, settings FROM alerts")
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var alerts []models.AlertConfig
+ for rows.Next() {
+ var a models.AlertConfig
+ var settingsJSON string
+ if err := rows.Scan(&a.ID, &a.Name, &a.Type, &settingsJSON); err != nil {
+ return alerts, err
+ }
+ json.Unmarshal([]byte(settingsJSON), &a.Settings)
+ alerts = append(alerts, a)
+ }
+ return alerts, rows.Err()
+}
+
+func (s *SQLStore) GetAlert(id int) (models.AlertConfig, error) {
+ var a models.AlertConfig
+ var settingsJSON string
+ err := s.db.QueryRow(s.q("SELECT id, name, type, settings FROM alerts WHERE id = ?"), id).Scan(&a.ID, &a.Name, &a.Type, &settingsJSON)
+ if err != nil {
+ return a, err
+ }
+ json.Unmarshal([]byte(settingsJSON), &a.Settings)
+ return a, nil
+}
+
+func (s *SQLStore) AddAlert(name, aType string, settings map[string]string) error {
+ jsonBytes, err := json.Marshal(settings)
+ if err != nil {
+ return err
+ }
+ _, err = s.db.Exec(s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?)"), name, aType, string(jsonBytes))
+ return err
+}
+
+func (s *SQLStore) UpdateAlert(id int, name, aType string, settings map[string]string) error {
+ jsonBytes, err := json.Marshal(settings)
+ if err != nil {
+ return err
+ }
+ _, err = s.db.Exec(s.q("UPDATE alerts SET name=?, type=?, settings=? WHERE id=?"), name, aType, string(jsonBytes), id)
+ return err
+}
+
+func (s *SQLStore) DeleteAlert(id int) error {
+ _, err := s.db.Exec(s.q("DELETE FROM alerts WHERE id=?"), id)
+ if err != nil {
+ return err
+ }
+ s.dialect.ResetSequenceOnEmpty(s.db, "alerts")
+ return nil
+}
+
+func (s *SQLStore) GetAllUsers() ([]models.User, error) {
+ rows, err := s.db.Query("SELECT id, username, public_key, role FROM users")
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var users []models.User
+ for rows.Next() {
+ var u models.User
+ if err := rows.Scan(&u.ID, &u.Username, &u.PublicKey, &u.Role); err != nil {
+ return users, err
+ }
+ users = append(users, u)
+ }
+ return users, rows.Err()
+}
+
+func (s *SQLStore) AddUser(username, publicKey, role string) error {
+ _, err := s.db.Exec(s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), username, publicKey, role)
+ return err
+}
+
+func (s *SQLStore) UpdateUser(id int, username, publicKey, role string) error {
+ _, err := s.db.Exec(s.q("UPDATE users SET username=?, public_key=?, role=? WHERE id=?"), username, publicKey, role, id)
+ return err
+}
+
+func (s *SQLStore) DeleteUser(id int) error {
+ _, err := s.db.Exec(s.q("DELETE FROM users WHERE id=?"), id)
+ return err
+}
+
+func (s *SQLStore) SaveCheck(siteID int, latencyNs int64, isUp bool) error {
+ _, err := s.db.Exec(s.q("INSERT INTO check_history (site_id, latency_ns, is_up) VALUES (?, ?, ?)"), siteID, latencyNs, isUp)
+ if err != nil {
+ return err
+ }
+ _, err = s.db.Exec(s.q(`DELETE FROM check_history WHERE site_id = ? AND id NOT IN (
+ SELECT id FROM check_history WHERE site_id = ? ORDER BY checked_at DESC LIMIT 1000
+ )`), siteID, siteID)
+ return err
+}
+
+func (s *SQLStore) LoadAllHistory(limit int) (map[int][]models.CheckRecord, error) {
+ result := make(map[int][]models.CheckRecord)
+ rows, err := s.db.Query(s.q(`
+ SELECT site_id, latency_ns, is_up FROM (
+ SELECT site_id, latency_ns, is_up,
+ ROW_NUMBER() OVER (PARTITION BY site_id ORDER BY checked_at DESC) AS rn
+ FROM check_history
+ ) sub WHERE rn <= ?`), limit)
+ if err != nil {
+ return result, err
+ }
+ defer rows.Close()
+ for rows.Next() {
+ var r models.CheckRecord
+ if err := rows.Scan(&r.SiteID, &r.LatencyNs, &r.IsUp); err != nil {
+ return result, err
+ }
+ result[r.SiteID] = append(result[r.SiteID], r)
+ }
+ for id, records := range result {
+ for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 {
+ records[i], records[j] = records[j], records[i]
+ }
+ result[id] = records
+ }
+ return result, rows.Err()
+}
+
+func (s *SQLStore) ExportData() (models.Backup, error) {
+ sites, err := s.GetSites()
+ if err != nil {
+ return models.Backup{}, err
+ }
+ alerts, err := s.GetAllAlerts()
+ if err != nil {
+ return models.Backup{}, err
+ }
+ users, err := s.GetAllUsers()
+ if err != nil {
+ return models.Backup{}, err
+ }
+ return models.Backup{Sites: sites, Alerts: alerts, Users: users}, nil
+}
+
+func (s *SQLStore) ImportData(data models.Backup) error {
+ tx, err := s.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ s.dialect.ImportWipe(tx)
+
+ for _, u := range data.Users {
+ if _, err := tx.Exec(s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), u.Username, u.PublicKey, u.Role); err != nil {
+ return err
+ }
+ }
+ for _, a := range data.Alerts {
+ jsonBytes, err := json.Marshal(a.Settings)
+ if err != nil {
+ return err
+ }
+ if _, err := tx.Exec(s.q("INSERT INTO alerts (id, name, type, settings) VALUES (?, ?, ?, ?)"), a.ID, a.Name, a.Type, string(jsonBytes)); err != nil {
+ return err
+ }
+ }
+ for _, st := range data.Sites {
+ if _, err := tx.Exec(s.q("INSERT INTO sites (id, name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"),
+ st.ID, st.Name, st.URL, st.Type, st.Token, st.Interval, st.AlertID, st.CheckSSL, st.ExpiryThreshold, st.MaxRetries,
+ st.Hostname, st.Port, st.Timeout, st.Method, st.Description, st.ParentID, st.AcceptedCodes, st.DNSResolveType, st.DNSServer, st.IgnoreTLS, st.Paused); err != nil {
+ return err
+ }
+ }
+
+ s.dialect.ImportResetSequences(tx)
+
+ return tx.Commit()
+}
diff --git a/internal/store/sqlstore_test.go b/internal/store/sqlstore_test.go
new file mode 100644
index 0000000..7ae2bd9
--- /dev/null
+++ b/internal/store/sqlstore_test.go
@@ -0,0 +1,231 @@
+package store
+
+import (
+ "go-upkeep/internal/models"
+ "testing"
+)
+
+func newTestStore(t *testing.T) *SQLStore {
+ t.Helper()
+ s, err := NewSQLiteStore(":memory:")
+ if err != nil {
+ t.Fatalf("NewSQLiteStore: %v", err)
+ }
+ if err := s.Init(); err != nil {
+ t.Fatalf("Init: %v", err)
+ }
+ return s
+}
+
+func TestSiteCRUD(t *testing.T) {
+ s := newTestStore(t)
+
+ sites, err := s.GetSites()
+ if err != nil {
+ t.Fatalf("GetSites: %v", err)
+ }
+ if len(sites) != 0 {
+ t.Fatalf("expected 0 sites, got %d", len(sites))
+ }
+
+ if err := s.AddSite(models.Site{Name: "Test", URL: "https://example.com", Type: "http", Interval: 30}); err != nil {
+ t.Fatalf("AddSite: %v", err)
+ }
+
+ sites, err = s.GetSites()
+ if err != nil {
+ t.Fatalf("GetSites: %v", err)
+ }
+ if len(sites) != 1 {
+ t.Fatalf("expected 1 site, got %d", len(sites))
+ }
+ if sites[0].Name != "Test" {
+ t.Errorf("expected name 'Test', got '%s'", sites[0].Name)
+ }
+
+ sites[0].Name = "Updated"
+ if err := s.UpdateSite(sites[0]); err != nil {
+ t.Fatalf("UpdateSite: %v", err)
+ }
+
+ sites, _ = s.GetSites()
+ if sites[0].Name != "Updated" {
+ t.Errorf("expected name 'Updated', got '%s'", sites[0].Name)
+ }
+
+ if err := s.DeleteSite(sites[0].ID); err != nil {
+ t.Fatalf("DeleteSite: %v", err)
+ }
+
+ sites, _ = s.GetSites()
+ if len(sites) != 0 {
+ t.Fatalf("expected 0 sites after delete, got %d", len(sites))
+ }
+}
+
+func TestAlertCRUD(t *testing.T) {
+ s := newTestStore(t)
+
+ if err := s.AddAlert("Discord", "discord", map[string]string{"url": "https://example.com/hook"}); err != nil {
+ t.Fatalf("AddAlert: %v", err)
+ }
+
+ alerts, err := s.GetAllAlerts()
+ if err != nil {
+ t.Fatalf("GetAllAlerts: %v", err)
+ }
+ if len(alerts) != 1 {
+ t.Fatalf("expected 1 alert, got %d", len(alerts))
+ }
+ if alerts[0].Type != "discord" {
+ t.Errorf("expected type 'discord', got '%s'", alerts[0].Type)
+ }
+ if alerts[0].Settings["url"] != "https://example.com/hook" {
+ t.Errorf("settings url mismatch")
+ }
+
+ a, err := s.GetAlert(alerts[0].ID)
+ if err != nil {
+ t.Fatalf("GetAlert: %v", err)
+ }
+ if a.Name != "Discord" {
+ t.Errorf("expected name 'Discord', got '%s'", a.Name)
+ }
+
+ if err := s.UpdateAlert(a.ID, "Slack", "slack", map[string]string{"url": "https://slack.com/hook"}); err != nil {
+ t.Fatalf("UpdateAlert: %v", err)
+ }
+
+ a, _ = s.GetAlert(a.ID)
+ if a.Type != "slack" {
+ t.Errorf("expected type 'slack', got '%s'", a.Type)
+ }
+
+ if err := s.DeleteAlert(a.ID); err != nil {
+ t.Fatalf("DeleteAlert: %v", err)
+ }
+
+ alerts, _ = s.GetAllAlerts()
+ if len(alerts) != 0 {
+ t.Fatalf("expected 0 alerts after delete, got %d", len(alerts))
+ }
+}
+
+func TestUserCRUD(t *testing.T) {
+ s := newTestStore(t)
+
+ if err := s.AddUser("admin", "ssh-ed25519 AAAA...", "admin"); err != nil {
+ t.Fatalf("AddUser: %v", err)
+ }
+
+ users, err := s.GetAllUsers()
+ if err != nil {
+ t.Fatalf("GetAllUsers: %v", err)
+ }
+ if len(users) != 1 {
+ t.Fatalf("expected 1 user, got %d", len(users))
+ }
+ if users[0].Username != "admin" {
+ t.Errorf("expected username 'admin', got '%s'", users[0].Username)
+ }
+
+ if err := s.UpdateUser(users[0].ID, "root", "ssh-ed25519 BBBB...", "admin"); err != nil {
+ t.Fatalf("UpdateUser: %v", err)
+ }
+
+ users, _ = s.GetAllUsers()
+ if users[0].Username != "root" {
+ t.Errorf("expected username 'root', got '%s'", users[0].Username)
+ }
+
+ if err := s.DeleteUser(users[0].ID); err != nil {
+ t.Fatalf("DeleteUser: %v", err)
+ }
+
+ users, _ = s.GetAllUsers()
+ if len(users) != 0 {
+ t.Fatalf("expected 0 users after delete, got %d", len(users))
+ }
+}
+
+func TestPushTokenGeneration(t *testing.T) {
+ s := newTestStore(t)
+
+ if err := s.AddSite(models.Site{Name: "Push Monitor", Type: "push", Interval: 60}); err != nil {
+ t.Fatalf("AddSite: %v", err)
+ }
+
+ sites, _ := s.GetSites()
+ if len(sites) != 1 {
+ t.Fatalf("expected 1 site, got %d", len(sites))
+ }
+ if sites[0].Token == "" {
+ t.Error("expected non-empty token for push monitor")
+ }
+ if len(sites[0].Token) != 32 {
+ t.Errorf("expected 32-char hex token, got %d chars", len(sites[0].Token))
+ }
+}
+
+func TestImportExport(t *testing.T) {
+ s := newTestStore(t)
+
+ s.AddAlert("Test Alert", "webhook", map[string]string{"url": "https://example.com"})
+ s.AddSite(models.Site{Name: "Site1", URL: "https://example.com", Type: "http", Interval: 30})
+ s.AddUser("user1", "ssh-ed25519 KEY", "user")
+
+ backup, err := s.ExportData()
+ if err != nil {
+ t.Fatalf("ExportData: %v", err)
+ }
+ if len(backup.Sites) != 1 || len(backup.Alerts) != 1 || len(backup.Users) != 1 {
+ t.Fatalf("export mismatch: %d sites, %d alerts, %d users", len(backup.Sites), len(backup.Alerts), len(backup.Users))
+ }
+
+ s2 := newTestStore(t)
+ if err := s2.ImportData(backup); err != nil {
+ t.Fatalf("ImportData: %v", err)
+ }
+
+ sites, _ := s2.GetSites()
+ alerts, _ := s2.GetAllAlerts()
+ users, _ := s2.GetAllUsers()
+ if len(sites) != 1 || len(alerts) != 1 || len(users) != 1 {
+ t.Fatalf("import mismatch: %d sites, %d alerts, %d users", len(sites), len(alerts), len(users))
+ }
+}
+
+func TestCheckHistory(t *testing.T) {
+ s := newTestStore(t)
+
+ if err := s.SaveCheck(1, 5000000, true); err != nil {
+ t.Fatalf("SaveCheck: %v", err)
+ }
+ if err := s.SaveCheck(1, 10000000, false); err != nil {
+ t.Fatalf("SaveCheck: %v", err)
+ }
+ if err := s.SaveCheck(2, 3000000, true); err != nil {
+ t.Fatalf("SaveCheck site 2: %v", err)
+ }
+
+ history, err := s.LoadAllHistory(10)
+ if err != nil {
+ t.Fatalf("LoadAllHistory: %v", err)
+ }
+ if len(history[1]) != 2 {
+ t.Fatalf("expected 2 records for site 1, got %d", len(history[1]))
+ }
+ if len(history[2]) != 1 {
+ t.Fatalf("expected 1 record for site 2, got %d", len(history[2]))
+ }
+
+ upCount := 0
+ for _, r := range history[1] {
+ if r.IsUp {
+ upCount++
+ }
+ }
+ if upCount != 1 {
+ t.Errorf("expected 1 up record for site 1, got %d", upCount)
+ }
+}
diff --git a/internal/store/store.go b/internal/store/store.go
index af3cd71..35afa0b 100644
--- a/internal/store/store.go
+++ b/internal/store/store.go
@@ -8,40 +8,30 @@ type Store interface {
Init() error
// Sites
- GetSites() []models.Site
- AddSite(site models.Site)
- UpdateSite(site models.Site)
- UpdateSitePaused(id int, paused bool)
- DeleteSite(id int)
+ GetSites() ([]models.Site, error)
+ AddSite(site models.Site) error
+ UpdateSite(site models.Site) error
+ UpdateSitePaused(id int, paused bool) error
+ DeleteSite(id int) error
// Alerts
- GetAllAlerts() []models.AlertConfig
- GetAlert(id int) (models.AlertConfig, bool)
- AddAlert(name, aType string, settings map[string]string)
- UpdateAlert(id int, name, aType string, settings map[string]string)
- DeleteAlert(id int)
+ GetAllAlerts() ([]models.AlertConfig, error)
+ GetAlert(id int) (models.AlertConfig, error)
+ AddAlert(name, aType string, settings map[string]string) error
+ UpdateAlert(id int, name, aType string, settings map[string]string) error
+ DeleteAlert(id int) error
// Users
- GetAllUsers() []models.User
+ GetAllUsers() ([]models.User, error)
AddUser(username, publicKey, role string) error
UpdateUser(id int, username, publicKey, role string) error
DeleteUser(id int) error
// History
- SaveCheck(siteID int, latencyNs int64, isUp bool)
- LoadAllHistory(limit int) map[int][]models.CheckRecord
+ SaveCheck(siteID int, latencyNs int64, isUp bool) error
+ LoadAllHistory(limit int) (map[int][]models.CheckRecord, error)
// Backup & Restore
- ExportData() models.Backup
+ ExportData() (models.Backup, error)
ImportData(data models.Backup) error
}
-
-var Current Store
-
-func SetGlobal(s Store) {
- Current = s
-}
-
-func Get() Store {
- return Current
-}
diff --git a/internal/tui/tab_alerts.go b/internal/tui/tab_alerts.go
index 8a0447b..342e1bd 100644
--- a/internal/tui/tab_alerts.go
+++ b/internal/tui/tab_alerts.go
@@ -2,30 +2,10 @@ package tui
import (
"fmt"
- "go-upkeep/internal/store"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/huh"
"github.com/charmbracelet/lipgloss"
- "github.com/charmbracelet/lipgloss/table"
-)
-
-var (
- alertHeaderStyle = lipgloss.NewStyle().
- Foreground(lipgloss.Color("#7D56F4")).
- Bold(true).
- Padding(0, 1)
-
- alertCellStyle = lipgloss.NewStyle().Padding(0, 1)
-
- alertSelectedStyle = lipgloss.NewStyle().
- Padding(0, 1).
- Bold(true).
- Foreground(lipgloss.Color("#ffffff")).
- Background(lipgloss.Color("#3b3b5c"))
-
- alertBorderStyle = lipgloss.NewStyle().
- Foreground(lipgloss.Color("#444"))
)
type alertFormData struct {
@@ -43,6 +23,19 @@ type alertFormData struct {
NtfyUser string
NtfyPass string
NtfyPri string
+ // Telegram
+ TelegramToken string
+ TelegramChatID string
+ // PagerDuty
+ PagerDutyKey string
+ PagerDutySeverity string
+ // Pushover
+ PushoverToken string
+ PushoverUser string
+ // Gotify
+ GotifyURL string
+ GotifyToken string
+ GotifyPriority string
}
func fmtAlertType(t string) string {
@@ -57,6 +50,14 @@ func fmtAlertType(t string) string {
return lipgloss.NewStyle().Foreground(lipgloss.Color("#73F59F")).Render(t)
case "ntfy":
return lipgloss.NewStyle().Foreground(lipgloss.Color("#FF6B6B")).Render(t)
+ case "telegram":
+ return lipgloss.NewStyle().Foreground(lipgloss.Color("#26A5E4")).Render(t)
+ case "pagerduty":
+ return lipgloss.NewStyle().Foreground(lipgloss.Color("#06AC38")).Render(t)
+ case "pushover":
+ return lipgloss.NewStyle().Foreground(lipgloss.Color("#249DF1")).Render(t)
+ case "gotify":
+ return lipgloss.NewStyle().Foreground(lipgloss.Color("#3F8BBA")).Render(t)
default:
return t
}
@@ -84,6 +85,26 @@ func fmtAlertConfig(alert struct {
return limitStr(fmt.Sprintf("%s/%s", url, topic), 34)
}
return subtleStyle.Render("โ")
+ case "telegram":
+ if id := alert.Settings["chat_id"]; id != "" {
+ return limitStr(fmt.Sprintf("chat:%s", id), 34)
+ }
+ return subtleStyle.Render("โ")
+ case "pagerduty":
+ if key := alert.Settings["routing_key"]; key != "" {
+ return limitStr(key, 34)
+ }
+ return subtleStyle.Render("โ")
+ case "pushover":
+ if user := alert.Settings["user"]; user != "" {
+ return limitStr(fmt.Sprintf("user:%s", user), 34)
+ }
+ return subtleStyle.Render("โ")
+ case "gotify":
+ if url := alert.Settings["url"]; url != "" {
+ return limitStr(url, 34)
+ }
+ return subtleStyle.Render("โ")
default:
if val, ok := alert.Settings["url"]; ok {
return limitStr(val, 34)
@@ -97,55 +118,35 @@ func (m Model) viewAlertsTab() string {
return "\n No alert channels configured. Press [n] to add one."
}
- end := m.tableOffset + m.maxTableRows
- if end > len(m.alerts) {
- end = len(m.alerts)
- }
-
- selectedVisual := m.cursor - m.tableOffset
-
- var rows [][]string
- for i := m.tableOffset; i < end; i++ {
- alert := m.alerts[i]
- rows = append(rows, []string{
- fmt.Sprintf("%d", i+1),
- m.zones.Mark(fmt.Sprintf("alert-%d", i), limitStr(alert.Name, 15)),
- fmtAlertType(alert.Type),
- fmtAlertConfig(struct {
- Type string
- Settings map[string]string
- }{alert.Type, alert.Settings}),
- })
- }
-
- tableWidth := m.termWidth - 6
- if tableWidth < 40 {
- tableWidth = 40
- }
-
- t := table.New().
- Border(lipgloss.RoundedBorder()).
- BorderStyle(alertBorderStyle).
- Width(tableWidth).
- Headers("#", "NAME", "TYPE", "CONFIG").
- Rows(rows...).
- StyleFunc(func(row, col int) lipgloss.Style {
- if row == table.HeaderRow {
- return alertHeaderStyle
+ return m.renderTable(
+ []string{"#", "NAME", "TYPE", "CONFIG"},
+ len(m.alerts),
+ func(start, end int) [][]string {
+ var rows [][]string
+ for i := start; i < end; i++ {
+ a := m.alerts[i]
+ rows = append(rows, []string{
+ fmt.Sprintf("%d", i+1),
+ m.zones.Mark(fmt.Sprintf("alert-%d", i), limitStr(a.Name, 15)),
+ fmtAlertType(a.Type),
+ fmtAlertConfig(struct {
+ Type string
+ Settings map[string]string
+ }{a.Type, a.Settings}),
+ })
}
- if row == selectedVisual {
- return alertSelectedStyle
- }
- return alertCellStyle
- })
-
- return "\n" + t.Render()
+ return rows
+ },
+ nil, nil,
+ )
}
func (m *Model) initAlertHuhForm() tea.Cmd {
m.alertFormData = &alertFormData{
- AlertType: "discord",
- NtfyPri: "3",
+ AlertType: "discord",
+ NtfyPri: "3",
+ PagerDutySeverity: "critical",
+ GotifyPriority: "5",
}
if m.editID > 0 {
@@ -170,6 +171,19 @@ func (m *Model) initAlertHuhForm() tea.Cmd {
m.alertFormData.NtfyUser = alert.Settings["username"]
m.alertFormData.NtfyPass = alert.Settings["password"]
m.alertFormData.NtfyPri = alert.Settings["priority"]
+ case "telegram":
+ m.alertFormData.TelegramToken = alert.Settings["token"]
+ m.alertFormData.TelegramChatID = alert.Settings["chat_id"]
+ case "pagerduty":
+ m.alertFormData.PagerDutyKey = alert.Settings["routing_key"]
+ m.alertFormData.PagerDutySeverity = alert.Settings["severity"]
+ case "pushover":
+ m.alertFormData.PushoverToken = alert.Settings["token"]
+ m.alertFormData.PushoverUser = alert.Settings["user"]
+ case "gotify":
+ m.alertFormData.GotifyURL = alert.Settings["url"]
+ m.alertFormData.GotifyToken = alert.Settings["token"]
+ m.alertFormData.GotifyPriority = alert.Settings["priority"]
}
break
}
@@ -194,6 +208,10 @@ func (m *Model) initAlertHuhForm() tea.Cmd {
huh.NewOption("Webhook", "webhook"),
huh.NewOption("Email (SMTP)", "email"),
huh.NewOption("Ntfy", "ntfy"),
+ huh.NewOption("Telegram", "telegram"),
+ huh.NewOption("PagerDuty", "pagerduty"),
+ huh.NewOption("Pushover", "pushover"),
+ huh.NewOption("Gotify", "gotify"),
).Value(&m.alertFormData.AlertType),
).Title("Alert Config"),
huh.NewGroup(
@@ -201,7 +219,8 @@ func (m *Model) initAlertHuhForm() tea.Cmd {
Placeholder("https://discord.com/api/webhooks/...").
Value(&m.alertFormData.WebhookURL),
).Title("Webhook").WithHideFunc(func() bool {
- return m.alertFormData.AlertType == "email" || m.alertFormData.AlertType == "ntfy"
+ t := m.alertFormData.AlertType
+ return t != "discord" && t != "slack" && t != "webhook"
}),
huh.NewGroup(
huh.NewInput().Title("Ntfy Server URL").
@@ -249,6 +268,57 @@ func (m *Model) initAlertHuhForm() tea.Cmd {
).Title("Email Settings").WithHideFunc(func() bool {
return m.alertFormData.AlertType != "email"
}),
+ huh.NewGroup(
+ huh.NewInput().Title("Bot Token").
+ Placeholder("123456:ABC-DEF1234...").
+ Value(&m.alertFormData.TelegramToken),
+ huh.NewInput().Title("Chat ID").
+ Placeholder("-1001234567890").
+ Value(&m.alertFormData.TelegramChatID),
+ ).Title("Telegram Settings").WithHideFunc(func() bool {
+ return m.alertFormData.AlertType != "telegram"
+ }),
+ huh.NewGroup(
+ huh.NewInput().Title("Routing Key").
+ Placeholder("your-integration-routing-key").
+ Value(&m.alertFormData.PagerDutyKey),
+ huh.NewSelect[string]().Title("Severity").
+ Options(
+ huh.NewOption("Critical", "critical"),
+ huh.NewOption("Error", "error"),
+ huh.NewOption("Warning", "warning"),
+ huh.NewOption("Info", "info"),
+ ).Value(&m.alertFormData.PagerDutySeverity),
+ ).Title("PagerDuty Settings").WithHideFunc(func() bool {
+ return m.alertFormData.AlertType != "pagerduty"
+ }),
+ huh.NewGroup(
+ huh.NewInput().Title("App Token").
+ Placeholder("your-pushover-app-token").
+ Value(&m.alertFormData.PushoverToken),
+ huh.NewInput().Title("User Key").
+ Placeholder("your-pushover-user-key").
+ Value(&m.alertFormData.PushoverUser),
+ ).Title("Pushover Settings").WithHideFunc(func() bool {
+ return m.alertFormData.AlertType != "pushover"
+ }),
+ huh.NewGroup(
+ huh.NewInput().Title("Server URL").
+ Placeholder("https://gotify.example.com").
+ Value(&m.alertFormData.GotifyURL),
+ huh.NewInput().Title("App Token").
+ Placeholder("your-gotify-app-token").
+ Value(&m.alertFormData.GotifyToken),
+ huh.NewSelect[string]().Title("Priority").
+ Options(
+ huh.NewOption("Min (0)", "0"),
+ huh.NewOption("Low (2)", "2"),
+ huh.NewOption("Normal (5)", "5"),
+ huh.NewOption("High (8)", "8"),
+ ).Value(&m.alertFormData.GotifyPriority),
+ ).Title("Gotify Settings").WithHideFunc(func() bool {
+ return m.alertFormData.AlertType != "gotify"
+ }),
).WithTheme(huh.ThemeDracula())
return m.huhForm.Init()
@@ -272,14 +342,31 @@ func (m *Model) submitAlertForm() {
settings["priority"] = d.NtfyPri
settings["username"] = d.NtfyUser
settings["password"] = d.NtfyPass
+ case "telegram":
+ settings["token"] = d.TelegramToken
+ settings["chat_id"] = d.TelegramChatID
+ case "pagerduty":
+ settings["routing_key"] = d.PagerDutyKey
+ settings["severity"] = d.PagerDutySeverity
+ case "pushover":
+ settings["token"] = d.PushoverToken
+ settings["user"] = d.PushoverUser
+ case "gotify":
+ settings["url"] = d.GotifyURL
+ settings["token"] = d.GotifyToken
+ settings["priority"] = d.GotifyPriority
default:
settings["url"] = d.WebhookURL
}
if m.editID > 0 {
- store.Get().UpdateAlert(m.editID, d.Name, d.AlertType, settings)
+ if err := m.store.UpdateAlert(m.editID, d.Name, d.AlertType, settings); err != nil {
+ m.engine.AddLog("Update alert failed: " + err.Error())
+ }
} else {
- store.Get().AddAlert(d.Name, d.AlertType, settings)
+ if err := m.store.AddAlert(d.Name, d.AlertType, settings); err != nil {
+ m.engine.AddLog("Add alert failed: " + err.Error())
+ }
}
m.state = stateDashboard
}
diff --git a/internal/tui/tab_sites.go b/internal/tui/tab_sites.go
index 2656613..4b6ad00 100644
--- a/internal/tui/tab_sites.go
+++ b/internal/tui/tab_sites.go
@@ -3,8 +3,6 @@ package tui
import (
"fmt"
"go-upkeep/internal/models"
- "go-upkeep/internal/monitor"
- "go-upkeep/internal/store"
"net/url"
"strconv"
"strings"
@@ -13,33 +11,14 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/huh"
"github.com/charmbracelet/lipgloss"
- "github.com/charmbracelet/lipgloss/table"
)
var sparkChars = []rune{'โ', 'โ', 'โ', 'โ', 'โ
', 'โ', 'โ', 'โ'}
-var (
- siteHeaderStyle = lipgloss.NewStyle().
- Foreground(lipgloss.Color("#7D56F4")).
- Bold(true).
- Padding(0, 1)
-
- siteCellStyle = lipgloss.NewStyle().Padding(0, 1)
-
- siteSelectedStyle = lipgloss.NewStyle().
- Padding(0, 1).
- Bold(true).
- Foreground(lipgloss.Color("#ffffff")).
- Background(lipgloss.Color("#3b3b5c"))
-
- siteBorderStyle = lipgloss.NewStyle().
- Foreground(lipgloss.Color("#444"))
-
- siteGroupStyle = lipgloss.NewStyle().
- Padding(0, 1).
- Bold(true).
- Foreground(lipgloss.Color("#7D56F4"))
-)
+var siteGroupStyle = lipgloss.NewStyle().
+ Padding(0, 1).
+ Bold(true).
+ Foreground(lipgloss.Color("#7D56F4"))
type siteFormData struct {
Name string
@@ -220,111 +199,80 @@ func (m Model) viewSitesTab() string {
return "\n No sites configured. Press [n] to add one."
}
- end := m.tableOffset + m.maxTableRows
- if end > len(m.sites) {
- end = len(m.sites)
- }
-
- selectedVisual := m.cursor - m.tableOffset
-
- var rows [][]string
- var groupRows []int
- for i := m.tableOffset; i < end; i++ {
- site := m.sites[i]
-
- if site.Type == "group" {
- groupRows = append(groupRows, i-m.tableOffset)
- arrow := "โพ"
- if m.collapsed[site.ID] {
- arrow = "โธ"
- }
- rows = append(rows, []string{
- strconv.Itoa(i + 1),
- m.zones.Mark(fmt.Sprintf("site-%d", i), arrow+" "+limitStr(site.Name, 11)),
- "group",
- fmtStatus(site.Status, site.Paused),
- subtleStyle.Render("โ"),
- subtleStyle.Render("โ"),
- subtleStyle.Render(strings.Repeat("ยท", sparkWidth)),
- subtleStyle.Render("-"),
- subtleStyle.Render("โ"),
- })
- continue
- }
-
- name := site.Name
- if site.ParentID > 0 {
- prefix := "โ"
- if i+1 >= len(m.sites) || m.sites[i+1].ParentID != site.ParentID {
- prefix = "โ"
- }
- name = prefix + " " + limitStr(name, 11)
- } else {
- name = limitStr(name, 13)
- }
-
- hist, _ := monitor.GetHistory(site.ID)
- var spark string
- if site.Type == "push" {
- spark = heartbeatSparkline(hist.Statuses, sparkWidth)
- } else {
- spark = latencySparkline(hist.Latencies, sparkWidth)
- }
-
- rows = append(rows, []string{
- strconv.Itoa(i + 1),
- m.zones.Mark(fmt.Sprintf("site-%d", i), name),
- site.Type,
- fmtStatus(site.Status, site.Paused),
- fmtLatency(site.Latency),
- fmtUptime(hist.TotalChecks, hist.UpChecks),
- spark,
- fmtSSL(site),
- fmtRetries(site),
- })
- }
-
- isGroupRow := func(row int) bool {
- for _, g := range groupRows {
- if g == row {
- return true
- }
- }
- return false
- }
-
- tableWidth := m.termWidth - 6
- if tableWidth < 40 {
- tableWidth = 40
- }
-
- // column widths: #=6, name=flex, type=10, status=10, latency=8, uptime=8, history=sparkWidth+4, ssl=7, retry=9
colWidths := []int{6, 0, 10, 10, 8, 8, sparkWidth + 4, 7, 9}
- t := table.New().
- Border(lipgloss.RoundedBorder()).
- BorderStyle(siteBorderStyle).
- Width(tableWidth).
- Headers("#", "NAME", "TYPE", "STATUS", "LATENCY", "UPTIME", "HISTORY", "SSL", "RETRY").
- Rows(rows...).
- StyleFunc(func(row, col int) lipgloss.Style {
- var base lipgloss.Style
- if row == table.HeaderRow {
- base = siteHeaderStyle
- } else if row == selectedVisual {
- base = siteSelectedStyle
- } else if isGroupRow(row) {
- base = siteGroupStyle
- } else {
- base = siteCellStyle
- }
- if col < len(colWidths) && colWidths[col] > 0 {
- base = base.Width(colWidths[col])
- }
- return base
- })
+ var groupRows map[int]bool
+ return m.renderTable(
+ []string{"#", "NAME", "TYPE", "STATUS", "LATENCY", "UPTIME", "HISTORY", "SSL", "RETRY"},
+ len(m.sites),
+ func(start, end int) [][]string {
+ groupRows = make(map[int]bool)
+ var rows [][]string
+ for i := start; i < end; i++ {
+ site := m.sites[i]
- return "\n" + t.Render()
+ if site.Type == "group" {
+ groupRows[i-start] = true
+ arrow := "โพ"
+ if m.collapsed[site.ID] {
+ arrow = "โธ"
+ }
+ rows = append(rows, []string{
+ strconv.Itoa(i + 1),
+ m.zones.Mark(fmt.Sprintf("site-%d", i), arrow+" "+limitStr(site.Name, 11)),
+ "group",
+ fmtStatus(site.Status, site.Paused),
+ subtleStyle.Render("โ"),
+ subtleStyle.Render("โ"),
+ subtleStyle.Render(strings.Repeat("ยท", sparkWidth)),
+ subtleStyle.Render("-"),
+ subtleStyle.Render("โ"),
+ })
+ continue
+ }
+
+ name := site.Name
+ if site.ParentID > 0 {
+ prefix := "โ"
+ if i+1 >= len(m.sites) || m.sites[i+1].ParentID != site.ParentID {
+ prefix = "โ"
+ }
+ name = prefix + " " + limitStr(name, 11)
+ } else {
+ name = limitStr(name, 13)
+ }
+
+ hist, _ := m.engine.GetHistory(site.ID)
+ var spark string
+ if site.Type == "push" {
+ spark = heartbeatSparkline(hist.Statuses, sparkWidth)
+ } else {
+ spark = latencySparkline(hist.Latencies, sparkWidth)
+ }
+
+ rows = append(rows, []string{
+ strconv.Itoa(i + 1),
+ m.zones.Mark(fmt.Sprintf("site-%d", i), name),
+ site.Type,
+ fmtStatus(site.Status, site.Paused),
+ fmtLatency(site.Latency),
+ fmtUptime(hist.TotalChecks, hist.UpChecks),
+ spark,
+ fmtSSL(site),
+ fmtRetries(site),
+ })
+ }
+ return rows
+ },
+ colWidths,
+ func(row, col int) *lipgloss.Style {
+ if groupRows[row] {
+ s := siteGroupStyle
+ return &s
+ }
+ return nil
+ },
+ )
}
func (m *Model) initSiteHuhForm() tea.Cmd {
@@ -361,8 +309,8 @@ func (m *Model) initSiteHuhForm() tea.Cmd {
}
alertOpts := []huh.Option[string]{huh.NewOption("None", "0")}
- if store.Get() != nil {
- for _, a := range store.Get().GetAllAlerts() {
+ if alerts, err := m.store.GetAllAlerts(); err == nil {
+ for _, a := range alerts {
alertOpts = append(alertOpts, huh.NewOption(
fmt.Sprintf("%s (%s)", a.Name, a.Type),
strconv.Itoa(a.ID),
@@ -558,10 +506,14 @@ func (m *Model) submitSiteForm() {
}
if m.editID > 0 {
- store.Get().UpdateSite(site)
- monitor.UpdateSiteConfig(site)
+ if err := m.store.UpdateSite(site); err != nil {
+ m.engine.AddLog("Update site failed: " + err.Error())
+ }
+ m.engine.UpdateSiteConfig(site)
} else {
- store.Get().AddSite(site)
+ if err := m.store.AddSite(site); err != nil {
+ m.engine.AddLog("Add site failed: " + err.Error())
+ }
}
m.state = stateDashboard
}
diff --git a/internal/tui/tab_users.go b/internal/tui/tab_users.go
index 4858b7c..019bb03 100644
--- a/internal/tui/tab_users.go
+++ b/internal/tui/tab_users.go
@@ -2,30 +2,9 @@ package tui
import (
"fmt"
- "go-upkeep/internal/store"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/huh"
- "github.com/charmbracelet/lipgloss"
- "github.com/charmbracelet/lipgloss/table"
-)
-
-var (
- userHeaderStyle = lipgloss.NewStyle().
- Foreground(lipgloss.Color("#7D56F4")).
- Bold(true).
- Padding(0, 1)
-
- userCellStyle = lipgloss.NewStyle().Padding(0, 1)
-
- userSelectedStyle = lipgloss.NewStyle().
- Padding(0, 1).
- Bold(true).
- Foreground(lipgloss.Color("#ffffff")).
- Background(lipgloss.Color("#3b3b5c"))
-
- userBorderStyle = lipgloss.NewStyle().
- Foreground(lipgloss.Color("#444"))
)
type userFormData struct {
@@ -53,46 +32,24 @@ func (m Model) viewUsersTab() string {
return "\n No users configured. Press [n] to add one."
}
- end := m.tableOffset + m.maxTableRows
- if end > len(m.users) {
- end = len(m.users)
- }
-
- selectedVisual := m.cursor - m.tableOffset
-
- var rows [][]string
- for i := m.tableOffset; i < end; i++ {
- u := m.users[i]
- rows = append(rows, []string{
- fmt.Sprintf("%d", i+1),
- m.zones.Mark(fmt.Sprintf("user-%d", i), limitStr(u.Username, 15)),
- fmtRole(u.Role),
- fmtKey(u.PublicKey),
- })
- }
-
- tableWidth := m.termWidth - 6
- if tableWidth < 40 {
- tableWidth = 40
- }
-
- t := table.New().
- Border(lipgloss.RoundedBorder()).
- BorderStyle(userBorderStyle).
- Width(tableWidth).
- Headers("#", "USERNAME", "ROLE", "PUBLIC KEY").
- Rows(rows...).
- StyleFunc(func(row, col int) lipgloss.Style {
- if row == table.HeaderRow {
- return userHeaderStyle
+ return m.renderTable(
+ []string{"#", "USERNAME", "ROLE", "PUBLIC KEY"},
+ len(m.users),
+ func(start, end int) [][]string {
+ var rows [][]string
+ for i := start; i < end; i++ {
+ u := m.users[i]
+ rows = append(rows, []string{
+ fmt.Sprintf("%d", i+1),
+ m.zones.Mark(fmt.Sprintf("user-%d", i), limitStr(u.Username, 15)),
+ fmtRole(u.Role),
+ fmtKey(u.PublicKey),
+ })
}
- if row == selectedVisual {
- return userSelectedStyle
- }
- return userCellStyle
- })
-
- return "\n" + t.Render()
+ return rows
+ },
+ nil, nil,
+ )
}
func (m *Model) initUserHuhForm() tea.Cmd {
@@ -145,9 +102,13 @@ func (m *Model) initUserHuhForm() tea.Cmd {
func (m *Model) submitUserForm() {
d := m.userFormData
if m.editID > 0 {
- store.Get().UpdateUser(m.editID, d.Username, d.PublicKey, d.Role)
+ if err := m.store.UpdateUser(m.editID, d.Username, d.PublicKey, d.Role); err != nil {
+ m.engine.AddLog("Update user failed: " + err.Error())
+ }
} else {
- store.Get().AddUser(d.Username, d.PublicKey, d.Role)
+ if err := m.store.AddUser(d.Username, d.PublicKey, d.Role); err != nil {
+ m.engine.AddLog("Add user failed: " + err.Error())
+ }
}
m.state = stateUsers
}
diff --git a/internal/tui/table_helpers.go b/internal/tui/table_helpers.go
new file mode 100644
index 0000000..be8719e
--- /dev/null
+++ b/internal/tui/table_helpers.go
@@ -0,0 +1,75 @@
+package tui
+
+import (
+ "github.com/charmbracelet/lipgloss"
+ "github.com/charmbracelet/lipgloss/table"
+)
+
+var (
+ tableHeaderStyle = lipgloss.NewStyle().
+ Foreground(lipgloss.Color("#7D56F4")).
+ Bold(true).
+ Padding(0, 1)
+
+ tableCellStyle = lipgloss.NewStyle().Padding(0, 1)
+
+ tableSelectedStyle = lipgloss.NewStyle().
+ Padding(0, 1).
+ Bold(true).
+ Foreground(lipgloss.Color("#ffffff")).
+ Background(lipgloss.Color("#3b3b5c"))
+
+ tableBorderStyle = lipgloss.NewStyle().
+ Foreground(lipgloss.Color("#444"))
+)
+
+type StyleOverride func(row, col int) *lipgloss.Style
+
+func (m Model) renderTable(headers []string, items int, buildRows func(start, end int) [][]string, colWidths []int, styleOverride StyleOverride) string {
+ if items == 0 {
+ return ""
+ }
+
+ end := m.tableOffset + m.maxTableRows
+ if end > items {
+ end = items
+ }
+
+ selectedVisual := m.cursor - m.tableOffset
+ rows := buildRows(m.tableOffset, end)
+
+ tableWidth := m.termWidth - 6
+ if tableWidth < 40 {
+ tableWidth = 40
+ }
+
+ t := table.New().
+ Border(lipgloss.RoundedBorder()).
+ BorderStyle(tableBorderStyle).
+ Width(tableWidth).
+ Headers(headers...).
+ Rows(rows...).
+ StyleFunc(func(row, col int) lipgloss.Style {
+ if row == table.HeaderRow {
+ return tableHeaderStyle
+ }
+ if styleOverride != nil {
+ if s := styleOverride(row, col); s != nil {
+ if col < len(colWidths) && colWidths[col] > 0 {
+ return s.Width(colWidths[col])
+ }
+ return *s
+ }
+ }
+ base := tableCellStyle
+ if row == selectedVisual {
+ base = tableSelectedStyle
+ }
+ if col < len(colWidths) && colWidths[col] > 0 {
+ base = base.Width(colWidths[col])
+ }
+ return base
+ })
+
+ return "\n" + t.Render()
+}
diff --git a/internal/tui/tui.go b/internal/tui/tui.go
index 4972a3a..89846a5 100644
--- a/internal/tui/tui.go
+++ b/internal/tui/tui.go
@@ -68,6 +68,8 @@ type Model struct {
deleteTab int
collapsed map[int]bool
+ store store.Store
+ engine *monitor.Engine
// harmonica animation state
pulseSpring harmonica.Spring
@@ -80,7 +82,7 @@ type Model struct {
users []models.User
}
-func InitialModel(isAdmin bool) Model {
+func InitialModel(isAdmin bool, s store.Store, eng *monitor.Engine) Model {
vpLogs := viewport.New(100, 20)
vpLogs.SetContent("Waiting for logs...")
z := zone.New()
@@ -90,6 +92,8 @@ func InitialModel(isAdmin bool) Model {
logViewport: vpLogs,
maxTableRows: 5,
isAdmin: isAdmin,
+ store: s,
+ engine: eng,
zones: z,
pulseSpring: spring,
collapsed: make(map[int]bool),
@@ -107,19 +111,23 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if keyMsg, ok := msg.(tea.KeyMsg); ok {
switch keyMsg.String() {
case "y", "Y":
- if store.Get() != nil {
- switch m.deleteTab {
- case 0:
- store.Get().DeleteSite(m.deleteID)
- monitor.RemoveSite(m.deleteID)
- m.adjustCursor(len(m.sites) - 1)
- case 1:
- store.Get().DeleteAlert(m.deleteID)
- m.adjustCursor(len(m.alerts) - 1)
- case 3:
- store.Get().DeleteUser(m.deleteID)
- m.adjustCursor(len(m.users) - 1)
+ switch m.deleteTab {
+ case 0:
+ if err := m.store.DeleteSite(m.deleteID); err != nil {
+ m.engine.AddLog("Delete site failed: " + err.Error())
}
+ m.engine.RemoveSite(m.deleteID)
+ m.adjustCursor(len(m.sites) - 1)
+ case 1:
+ if err := m.store.DeleteAlert(m.deleteID); err != nil {
+ m.engine.AddLog("Delete alert failed: " + err.Error())
+ }
+ m.adjustCursor(len(m.alerts) - 1)
+ case 3:
+ if err := m.store.DeleteUser(m.deleteID); err != nil {
+ m.engine.AddLog("Delete user failed: " + err.Error())
+ }
+ m.adjustCursor(len(m.users) - 1)
}
m.refreshData()
m.state = stateDashboard
@@ -311,11 +319,9 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case "p":
if m.currentTab == 0 && len(m.sites) > 0 {
site := m.sites[m.cursor]
- monitor.ToggleSitePause(site.ID)
+ m.engine.ToggleSitePause(site.ID)
site.Paused = !site.Paused
- if store.Get() != nil {
- store.Get().UpdateSitePaused(site.ID, site.Paused)
- }
+ _ = m.store.UpdateSitePaused(site.ID, site.Paused)
m.refreshData()
}
case "d", "backspace":
@@ -342,11 +348,11 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
func (m *Model) handleClick(msg tea.MouseMsg) (tea.Model, tea.Cmd) {
- maxTabs := 3
- if !m.isAdmin {
- maxTabs = 2
+ tabCount := 3
+ if m.isAdmin {
+ tabCount = 4
}
- for i := 0; i <= maxTabs; i++ {
+ for i := 0; i < tabCount; i++ {
if m.zones.Get(fmt.Sprintf("tab-%d", i)).InBounds(msg) {
m.switchTab(i)
return m, nil
@@ -429,12 +435,7 @@ func (m *Model) adjustCursor(newLen int) {
}
func (m *Model) refreshData() {
- monitor.Mutex.RLock()
- var allSites []models.Site
- for _, s := range monitor.LiveState {
- allSites = append(allSites, s)
- }
- monitor.Mutex.RUnlock()
+ allSites := m.engine.GetAllSites()
var groups, ungrouped []models.Site
children := make(map[int][]models.Site)
@@ -464,19 +465,31 @@ func (m *Model) refreshData() {
}
ordered = append(ordered, ungrouped...)
m.sites = ordered
- if store.Get() != nil {
- m.alerts = store.Get().GetAllAlerts()
- if m.isAdmin {
- m.users = store.Get().GetAllUsers()
+ if alerts, err := m.store.GetAllAlerts(); err == nil {
+ m.alerts = alerts
+ }
+ if m.isAdmin {
+ if users, err := m.store.GetAllUsers(); err == nil {
+ m.users = users
}
}
- m.logViewport.SetContent(strings.Join(monitor.GetLogs(), "\n"))
+ m.logViewport.SetContent(strings.Join(m.engine.GetLogs(), "\n"))
+
+ listLen := len(m.sites)
+ if m.currentTab == 1 {
+ listLen = len(m.alerts)
+ } else if m.currentTab == 3 {
+ listLen = len(m.users)
+ }
+ if listLen > 0 && m.cursor >= listLen {
+ m.cursor = listLen - 1
+ }
+ if m.cursor < m.tableOffset {
+ m.tableOffset = m.cursor
+ }
}
func (m *Model) submitForm() {
- if store.Get() == nil {
- return
- }
switch m.state {
case stateFormSite:
if m.siteFormData != nil {