Merge pull request 'fix: critical bugs and security hardening' (#19) from fix/critical-bugs-security-hardening into main

This commit was merged in pull request #19.
This commit is contained in:
2026-05-24 01:45:11 +00:00
13 changed files with 2316 additions and 45 deletions
+23 -6
View File
@@ -2,6 +2,7 @@ package main
import ( import (
"context" "context"
"errors"
"flag" "flag"
"fmt" "fmt"
"go-upkeep/internal/cluster" "go-upkeep/internal/cluster"
@@ -17,6 +18,7 @@ import (
"os/signal" "os/signal"
"strconv" "strconv"
"syscall" "syscall"
"time"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/ssh" "github.com/charmbracelet/ssh"
@@ -225,6 +227,7 @@ func runServe(args []string) {
fmt.Printf("Database connection error: %v\n", dbErr) fmt.Printf("Database connection error: %v\n", dbErr)
os.Exit(1) os.Exit(1)
} }
defer s.Close()
if err := s.Init(); err != nil { if err := s.Init(); err != nil {
fmt.Printf("Database init error: %v\n", err) fmt.Printf("Database init error: %v\n", err)
@@ -263,7 +266,7 @@ func runServe(args []string) {
eng.InitLogs() eng.InitLogs()
eng.Start(ctx) eng.Start(ctx)
server.Start(server.ServerConfig{ httpSrv := server.Start(server.ServerConfig{
Port: httpPort, Port: httpPort,
EnableStatus: enableStatus, EnableStatus: enableStatus,
Title: statusTitle, Title: statusTitle,
@@ -276,7 +279,7 @@ func runServe(args []string) {
SharedKey: clusterKey, SharedKey: clusterKey,
}, eng) }, eng)
startSSHServer(*port, s, eng) sshSrv := startSSHServer(*port, s, eng)
if isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd()) { if isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd()) {
p := tea.NewProgram(tui.InitialModel(true, s, eng), tea.WithAltScreen(), tea.WithMouseCellMotion()) p := tea.NewProgram(tui.InitialModel(true, s, eng), tea.WithAltScreen(), tea.WithMouseCellMotion())
@@ -291,9 +294,22 @@ func runServe(args []string) {
fmt.Println("Shutting down...") fmt.Println("Shutting down...")
} }
cancel() cancel()
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
if httpSrv != nil {
if err := httpSrv.Shutdown(shutdownCtx); err != nil {
log.Printf("HTTP shutdown error: %v", err)
}
}
if sshSrv != nil {
if err := sshSrv.Shutdown(shutdownCtx); err != nil {
log.Printf("SSH shutdown error: %v", err)
}
}
} }
func startSSHServer(port int, db store.Store, eng *monitor.Engine) { func startSSHServer(port int, db store.Store, eng *monitor.Engine) *ssh.Server {
s, err := wish.NewServer( s, err := wish.NewServer(
wish.WithAddress(fmt.Sprintf(":%d", port)), wish.WithAddress(fmt.Sprintf(":%d", port)),
wish.WithHostKeyPath(".ssh/id_ed25519"), wish.WithHostKeyPath(".ssh/id_ed25519"),
@@ -308,13 +324,14 @@ func startSSHServer(port int, db store.Store, eng *monitor.Engine) {
) )
if err != nil { if err != nil {
fmt.Printf("SSH server error: %v\n", err) fmt.Printf("SSH server error: %v\n", err)
return return nil
} }
go func() { go func() {
if err := s.ListenAndServe(); err != nil { if err := s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
log.Fatalf("SSH server failed: %v", err) log.Printf("SSH server error: %v", err)
} }
}() }()
return s
} }
func seedDemoData(s store.Store) { func seedDemoData(s store.Store) {
+17 -6
View File
@@ -2,6 +2,7 @@ package alert
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"go-upkeep/internal/models" "go-upkeep/internal/models"
@@ -15,7 +16,7 @@ import (
var alertClient = &http.Client{Timeout: 10 * time.Second} var alertClient = &http.Client{Timeout: 10 * time.Second}
type Provider interface { type Provider interface {
Send(title, message string) error Send(ctx context.Context, title, message string) error
} }
type PayloadFunc func(title, message string) ([]byte, error) type PayloadFunc func(title, message string) ([]byte, error)
@@ -25,12 +26,17 @@ type HTTPProvider struct {
Payload PayloadFunc Payload PayloadFunc
} }
func (h *HTTPProvider) Send(title, message string) error { func (h *HTTPProvider) Send(ctx context.Context, title, message string) error {
body, err := h.Payload(title, message) body, err := h.Payload(title, message)
if err != nil { if err != nil {
return err return err
} }
resp, err := alertClient.Post(h.URL, "application/json", bytes.NewBuffer(body)) req, err := http.NewRequestWithContext(ctx, "POST", h.URL, bytes.NewBuffer(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := alertClient.Do(req)
if err != nil { if err != nil {
return err return err
} }
@@ -170,7 +176,12 @@ type EmailProvider struct {
Host, Port, User, Pass, To, From string Host, Port, User, Pass, To, From string
} }
func (e *EmailProvider) Send(title, message string) error { func (e *EmailProvider) Send(ctx context.Context, title, message string) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
auth := smtp.PlainAuth("", e.User, e.Pass, e.Host) auth := smtp.PlainAuth("", e.User, e.Pass, e.Host)
msg := []byte("To: " + e.To + "\r\n" + msg := []byte("To: " + e.To + "\r\n" +
"Subject: Go-Upkeep: " + title + "\r\n" + "Subject: Go-Upkeep: " + title + "\r\n" +
@@ -187,9 +198,9 @@ type NtfyProvider struct {
Password string Password string
} }
func (n *NtfyProvider) Send(title, message string) error { func (n *NtfyProvider) Send(ctx context.Context, title, message string) error {
url := strings.TrimRight(n.ServerURL, "/") + "/" + n.Topic url := strings.TrimRight(n.ServerURL, "/") + "/" + n.Topic
req, err := http.NewRequest("POST", url, strings.NewReader(message)) req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(message))
if err != nil { if err != nil {
return err return err
} }
+10 -9
View File
@@ -1,6 +1,7 @@
package alert package alert
import ( import (
"context"
"encoding/json" "encoding/json"
"go-upkeep/internal/models" "go-upkeep/internal/models"
"net/http" "net/http"
@@ -17,7 +18,7 @@ func TestHTTPProviderDiscord(t *testing.T) {
defer srv.Close() defer srv.Close()
p := GetProvider(models.AlertConfig{Type: "discord", Settings: map[string]string{"url": srv.URL}}) p := GetProvider(models.AlertConfig{Type: "discord", Settings: map[string]string{"url": srv.URL}})
if err := p.Send("Test Title", "Test Body"); err != nil { if err := p.Send(context.Background(), "Test Title", "Test Body"); err != nil {
t.Fatalf("Send: %v", err) t.Fatalf("Send: %v", err)
} }
@@ -35,7 +36,7 @@ func TestHTTPProviderSlack(t *testing.T) {
defer srv.Close() defer srv.Close()
p := GetProvider(models.AlertConfig{Type: "slack", Settings: map[string]string{"url": srv.URL}}) p := GetProvider(models.AlertConfig{Type: "slack", Settings: map[string]string{"url": srv.URL}})
if err := p.Send("Alert", "Message"); err != nil { if err := p.Send(context.Background(), "Alert", "Message"); err != nil {
t.Fatalf("Send: %v", err) t.Fatalf("Send: %v", err)
} }
@@ -53,7 +54,7 @@ func TestHTTPProviderWebhook(t *testing.T) {
defer srv.Close() defer srv.Close()
p := GetProvider(models.AlertConfig{Type: "webhook", Settings: map[string]string{"url": srv.URL}}) p := GetProvider(models.AlertConfig{Type: "webhook", Settings: map[string]string{"url": srv.URL}})
if err := p.Send("Title", "Body"); err != nil { if err := p.Send(context.Background(), "Title", "Body"); err != nil {
t.Fatalf("Send: %v", err) t.Fatalf("Send: %v", err)
} }
@@ -69,7 +70,7 @@ func TestHTTPProviderErrorOnHTTP4xx(t *testing.T) {
defer srv.Close() defer srv.Close()
p := GetProvider(models.AlertConfig{Type: "discord", Settings: map[string]string{"url": srv.URL}}) p := GetProvider(models.AlertConfig{Type: "discord", Settings: map[string]string{"url": srv.URL}})
if err := p.Send("Test", "Test"); err == nil { if err := p.Send(context.Background(), "Test", "Test"); err == nil {
t.Fatal("expected error on 403 response") t.Fatal("expected error on 403 response")
} }
} }
@@ -89,7 +90,7 @@ func TestNtfyProvider(t *testing.T) {
"url": srv.URL, "url": srv.URL,
"topic": "test", "topic": "test",
}}) }})
if err := p.Send("Alert Title", "Alert Body"); err != nil { if err := p.Send(context.Background(), "Alert Title", "Alert Body"); err != nil {
t.Fatalf("Send: %v", err) t.Fatalf("Send: %v", err)
} }
@@ -110,7 +111,7 @@ func TestHTTPProviderTelegram(t *testing.T) {
defer srv.Close() defer srv.Close()
p := &HTTPProvider{URL: srv.URL, Payload: telegramPayload("12345")} p := &HTTPProvider{URL: srv.URL, Payload: telegramPayload("12345")}
if err := p.Send("Alert", "Down"); err != nil { if err := p.Send(context.Background(), "Alert", "Down"); err != nil {
t.Fatalf("Send: %v", err) t.Fatalf("Send: %v", err)
} }
if received["chat_id"] != "12345" { if received["chat_id"] != "12345" {
@@ -133,7 +134,7 @@ func TestHTTPProviderPagerDuty(t *testing.T) {
defer srv.Close() defer srv.Close()
p := &HTTPProvider{URL: srv.URL, Payload: pagerdutyPayload("test-key", "critical")} p := &HTTPProvider{URL: srv.URL, Payload: pagerdutyPayload("test-key", "critical")}
if err := p.Send("Alert", "Down"); err != nil { if err := p.Send(context.Background(), "Alert", "Down"); err != nil {
t.Fatalf("Send: %v", err) t.Fatalf("Send: %v", err)
} }
if received["routing_key"] != "test-key" { if received["routing_key"] != "test-key" {
@@ -160,7 +161,7 @@ func TestHTTPProviderPushover(t *testing.T) {
defer srv.Close() defer srv.Close()
p := &HTTPProvider{URL: srv.URL, Payload: pushoverPayload("app-tok", "user-key")} p := &HTTPProvider{URL: srv.URL, Payload: pushoverPayload("app-tok", "user-key")}
if err := p.Send("Alert", "Down"); err != nil { if err := p.Send(context.Background(), "Alert", "Down"); err != nil {
t.Fatalf("Send: %v", err) t.Fatalf("Send: %v", err)
} }
if received["token"] != "app-tok" { if received["token"] != "app-tok" {
@@ -183,7 +184,7 @@ func TestHTTPProviderGotify(t *testing.T) {
defer srv.Close() defer srv.Close()
p := &HTTPProvider{URL: srv.URL, Payload: gotifyPayload("8")} p := &HTTPProvider{URL: srv.URL, Payload: gotifyPayload("8")}
if err := p.Send("Alert", "Down"); err != nil { if err := p.Send(context.Background(), "Alert", "Down"); err != nil {
t.Fatalf("Send: %v", err) t.Fatalf("Send: %v", err)
} }
if received["title"] != "Alert" || received["message"] != "Down" { if received["title"] != "Alert" || received["message"] != "Down" {
+395
View File
@@ -0,0 +1,395 @@
package cluster
import (
"context"
"encoding/json"
"go-upkeep/internal/models"
"go-upkeep/internal/monitor"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
)
// --- Mock Store (minimal, for monitor.NewEngine) ---
type mockStore struct {
mu sync.Mutex
sites []models.Site
}
func (m *mockStore) Init() error { return nil }
func (m *mockStore) GetSites() ([]models.Site, error) { return m.sites, nil }
func (m *mockStore) AddSite(models.Site) error { return nil }
func (m *mockStore) UpdateSite(models.Site) error { return nil }
func (m *mockStore) UpdateSitePaused(int, bool) error { return nil }
func (m *mockStore) DeleteSite(int) error { return nil }
func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { return nil, nil }
func (m *mockStore) GetAlert(int) (models.AlertConfig, error) { return models.AlertConfig{}, nil }
func (m *mockStore) AddAlert(string, string, map[string]string) error { return nil }
func (m *mockStore) UpdateAlert(int, string, string, map[string]string) error { return nil }
func (m *mockStore) DeleteAlert(int) error { return nil }
func (m *mockStore) GetAllUsers() ([]models.User, error) { return nil, nil }
func (m *mockStore) AddUser(string, string, string) error { return nil }
func (m *mockStore) UpdateUser(int, string, string, string) error { return nil }
func (m *mockStore) DeleteUser(int) error { return nil }
func (m *mockStore) SaveCheck(int, int64, bool) error { return nil }
func (m *mockStore) SaveCheckFromNode(int, string, int64, bool) error { return nil }
func (m *mockStore) LoadAllHistory(int) (map[int][]models.CheckRecord, error) { return nil, nil }
func (m *mockStore) ExportData() (models.Backup, error) { return models.Backup{}, nil }
func (m *mockStore) ImportData(models.Backup) error { return nil }
func (m *mockStore) GetSiteByName(string) (models.Site, error) { return models.Site{}, nil }
func (m *mockStore) GetAlertByName(string) (models.AlertConfig, error) {
return models.AlertConfig{}, nil
}
func (m *mockStore) AddSiteReturningID(models.Site) (int, error) { return 0, nil }
func (m *mockStore) AddAlertReturningID(string, string, map[string]string) (int, error) {
return 0, nil
}
func (m *mockStore) RegisterNode(models.ProbeNode) error { return nil }
func (m *mockStore) GetNode(string) (models.ProbeNode, error) { return models.ProbeNode{}, nil }
func (m *mockStore) GetAllNodes() ([]models.ProbeNode, error) { return nil, nil }
func (m *mockStore) UpdateNodeLastSeen(string) error { return nil }
func (m *mockStore) DeleteNode(string) error { return nil }
func (m *mockStore) SaveLog(string) error { return nil }
func (m *mockStore) LoadLogs(int) ([]string, error) { return nil, nil }
func (m *mockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) {
return nil, nil
}
func (m *mockStore) GetAllMaintenanceWindows(int) ([]models.MaintenanceWindow, error) {
return nil, nil
}
func (m *mockStore) AddMaintenanceWindow(models.MaintenanceWindow) error { return nil }
func (m *mockStore) EndMaintenanceWindow(int) error { return nil }
func (m *mockStore) DeleteMaintenanceWindow(int) error { return nil }
func (m *mockStore) IsMonitorInMaintenance(int) (bool, error) { return false, nil }
func (m *mockStore) GetPreference(string) (string, error) { return "", nil }
func (m *mockStore) SetPreference(string, string) error { return nil }
func (m *mockStore) Close() error { return nil }
// --- Cluster Start Tests ---
func TestStart_LeaderMode(t *testing.T) {
eng := monitor.NewEngine(&mockStore{})
eng.SetActive(false)
ctx := context.Background()
Start(ctx, Config{Mode: "leader"}, eng)
if !eng.IsActive() {
t.Error("leader mode should set engine active")
}
}
func TestStart_FollowerMode(t *testing.T) {
eng := monitor.NewEngine(&mockStore{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Start(ctx, Config{Mode: "follower", PeerURL: "http://localhost:9999"}, eng)
time.Sleep(50 * time.Millisecond)
if eng.IsActive() {
t.Error("follower mode should set engine inactive")
}
}
// --- Follower Loop Tests ---
func TestFollowerLoop_FailoverOnLeaderDown(t *testing.T) {
eng := monitor.NewEngine(&mockStore{})
eng.SetActive(false)
// Server always returns 503
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(503)
}))
defer srv.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go runFollowerLoop(ctx, Config{PeerURL: srv.URL, SharedKey: "key"}, eng)
// Follower checks every 5s, needs 3 failures → ~15s minimum
// But we can't wait that long in a test. The loop sleeps 5s between checks.
// We'll wait up to 20s for failover.
deadline := time.After(20 * time.Second)
for {
if eng.IsActive() {
return // success
}
select {
case <-deadline:
t.Fatal("expected failover to ACTIVE after 3 failures")
case <-time.After(500 * time.Millisecond):
}
}
}
func TestFollowerLoop_RecoveryOnLeaderReturn(t *testing.T) {
eng := monitor.NewEngine(&mockStore{})
eng.SetActive(true) // simulate already failed over
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("OK"))
}))
defer srv.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go runFollowerLoop(ctx, Config{PeerURL: srv.URL}, eng)
deadline := time.After(10 * time.Second)
for {
if !eng.IsActive() {
return // success — switched back to passive
}
select {
case <-deadline:
t.Fatal("expected switch back to PASSIVE when leader returns")
case <-time.After(500 * time.Millisecond):
}
}
}
func TestFollowerLoop_SendsSecret(t *testing.T) {
var mu sync.Mutex
var receivedSecret string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
receivedSecret = r.Header.Get("X-Upkeep-Secret")
mu.Unlock()
w.WriteHeader(200)
w.Write([]byte("OK"))
}))
defer srv.Close()
eng := monitor.NewEngine(&mockStore{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go runFollowerLoop(ctx, Config{PeerURL: srv.URL, SharedKey: "test-secret"}, eng)
deadline := time.After(10 * time.Second)
for {
mu.Lock()
got := receivedSecret
mu.Unlock()
if got == "test-secret" {
return
}
select {
case <-deadline:
t.Fatalf("expected secret 'test-secret', got %q", got)
case <-time.After(500 * time.Millisecond):
}
}
}
func TestFollowerLoop_CancelContext(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer srv.Close()
eng := monitor.NewEngine(&mockStore{})
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
runFollowerLoop(ctx, Config{PeerURL: srv.URL}, eng)
close(done)
}()
cancel()
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatal("expected follower loop to exit on context cancel")
}
}
// --- Probe Tests ---
func TestProbeRegister_Success(t *testing.T) {
var received map[string]string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&received)
w.WriteHeader(200)
}))
defer srv.Close()
err := probeRegister(context.Background(), srv.Client(), ProbeConfig{
NodeID: "n1", NodeName: "US East", Region: "us-east", LeaderURL: srv.URL, SharedKey: "key",
})
if err != nil {
t.Fatalf("register: %v", err)
}
if received["id"] != "n1" {
t.Errorf("expected id n1, got %s", received["id"])
}
}
func TestProbeRegister_Failure(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
}))
defer srv.Close()
err := probeRegister(context.Background(), srv.Client(), ProbeConfig{
LeaderURL: srv.URL,
})
if err == nil {
t.Error("expected error on 401")
}
}
func TestProbeFetchAssignments_Success(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string][]models.Site{
"sites": {{ID: 1, Name: "s1", Type: "http", URL: "http://example.com"}},
})
}))
defer srv.Close()
sites, err := probeFetchAssignments(context.Background(), srv.Client(), ProbeConfig{
NodeID: "n1", LeaderURL: srv.URL, SharedKey: "key",
})
if err != nil {
t.Fatalf("fetch: %v", err)
}
if len(sites) != 1 {
t.Errorf("expected 1 site, got %d", len(sites))
}
}
func TestProbeFetchAssignments_Unauthorized(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
}))
defer srv.Close()
_, err := probeFetchAssignments(context.Background(), srv.Client(), ProbeConfig{
LeaderURL: srv.URL,
})
if err == nil {
t.Error("expected error on 401")
}
}
func TestProbeExecuteChecks(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer srv.Close()
sites := []models.Site{
{ID: 1, Type: "http", URL: srv.URL},
{ID: 2, Type: "http", URL: srv.URL},
}
strict := &http.Client{}
insecure := &http.Client{}
results := probeExecuteChecks(context.Background(), sites, strict, insecure)
if len(results) != 2 {
t.Fatalf("expected 2 results, got %d", len(results))
}
for _, r := range results {
if !r.IsUp {
t.Errorf("site %d expected UP", r.SiteID)
}
}
}
func TestProbeExecuteChecks_Concurrency(t *testing.T) {
var concurrent int64
var maxConcurrent int64
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cur := atomic.AddInt64(&concurrent, 1)
for {
old := atomic.LoadInt64(&maxConcurrent)
if cur <= old || atomic.CompareAndSwapInt64(&maxConcurrent, old, cur) {
break
}
}
time.Sleep(50 * time.Millisecond)
atomic.AddInt64(&concurrent, -1)
w.WriteHeader(200)
}))
defer srv.Close()
var sites []models.Site
for i := 0; i < 20; i++ {
sites = append(sites, models.Site{ID: i + 1, Type: "http", URL: srv.URL})
}
results := probeExecuteChecks(context.Background(), sites, &http.Client{}, &http.Client{})
if len(results) != 20 {
t.Errorf("expected 20 results, got %d", len(results))
}
mc := atomic.LoadInt64(&maxConcurrent)
if mc > 10 {
t.Errorf("expected max 10 concurrent, got %d", mc)
}
}
func TestProbeReportResults_Success(t *testing.T) {
var received struct {
NodeID string `json:"node_id"`
Results []probeResultItem `json:"results"`
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&received)
w.WriteHeader(200)
}))
defer srv.Close()
err := probeReportResults(context.Background(), srv.Client(), ProbeConfig{
NodeID: "n1", LeaderURL: srv.URL, SharedKey: "key",
}, []probeResultItem{{SiteID: 1, LatencyNs: 5000000, IsUp: true}})
if err != nil {
t.Fatalf("report: %v", err)
}
if received.NodeID != "n1" {
t.Errorf("expected n1, got %s", received.NodeID)
}
if len(received.Results) != 1 {
t.Errorf("expected 1 result, got %d", len(received.Results))
}
}
func TestProbeReportResults_Failure(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
}))
defer srv.Close()
err := probeReportResults(context.Background(), srv.Client(), ProbeConfig{
LeaderURL: srv.URL,
}, []probeResultItem{{SiteID: 1}})
if err == nil {
t.Error("expected error on 500")
}
}
// --- sleepCtx ---
func TestSleepCtx_Cancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
start := time.Now()
sleepCtx(ctx, 10*time.Second)
if time.Since(start) > time.Second {
t.Error("expected immediate return on canceled context")
}
}
+3 -1
View File
@@ -96,7 +96,9 @@ func convertKumaNotifications(entries []KumaNotifEntry) map[int]models.AlertConf
result := make(map[int]models.AlertConfig) result := make(map[int]models.AlertConfig)
for _, entry := range entries { for _, entry := range entries {
var cfg KumaNotifConfig var cfg KumaNotifConfig
json.Unmarshal([]byte(entry.Config), &cfg) if err := json.Unmarshal([]byte(entry.Config), &cfg); err != nil {
continue
}
alert := models.AlertConfig{ alert := models.AlertConfig{
ID: entry.ID, ID: entry.ID,
+1
View File
@@ -64,6 +64,7 @@ func (m *mockStore) DeleteMaintenanceWindow(int) error { retur
func (m *mockStore) IsMonitorInMaintenance(int) (bool, error) { return false, nil } func (m *mockStore) IsMonitorInMaintenance(int) (bool, error) { return false, nil }
func (m *mockStore) GetPreference(string) (string, error) { return "", nil } func (m *mockStore) GetPreference(string) (string, error) { return "", nil }
func (m *mockStore) SetPreference(string, string) error { return nil } func (m *mockStore) SetPreference(string, string) error { return nil }
func (m *mockStore) Close() error { return nil }
func TestMetricsHandler(t *testing.T) { func TestMetricsHandler(t *testing.T) {
ms := &mockStore{ ms := &mockStore{
+203
View File
@@ -0,0 +1,203 @@
package monitor
import (
"crypto/tls"
"go-upkeep/internal/models"
"net"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
)
func TestRunCheck_HTTP_Success(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer srv.Close()
site := models.Site{ID: 1, Type: "http", URL: srv.URL}
result := RunCheck(site, http.DefaultClient, http.DefaultClient, false)
if result.Status != "UP" {
t.Errorf("expected UP, got %s", result.Status)
}
if result.StatusCode != 200 {
t.Errorf("expected 200, got %d", result.StatusCode)
}
if result.LatencyNs <= 0 {
t.Error("expected positive latency")
}
}
func TestRunCheck_HTTP_ServerError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
}))
defer srv.Close()
site := models.Site{ID: 1, Type: "http", URL: srv.URL}
result := RunCheck(site, http.DefaultClient, http.DefaultClient, false)
if result.Status != "DOWN" {
t.Errorf("expected DOWN, got %s", result.Status)
}
if result.StatusCode != 500 {
t.Errorf("expected 500, got %d", result.StatusCode)
}
}
func TestRunCheck_HTTP_CustomAcceptedCodes(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(302)
}))
defer srv.Close()
client := &http.Client{CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}}
site := models.Site{ID: 1, Type: "http", URL: srv.URL, AcceptedCodes: "200-399"}
result := RunCheck(site, client, client, false)
if result.Status != "UP" {
t.Errorf("expected UP with accepted 200-399, got %s", result.Status)
}
}
func TestRunCheck_HTTP_MethodRespected(t *testing.T) {
var receivedMethod string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedMethod = r.Method
w.WriteHeader(200)
}))
defer srv.Close()
site := models.Site{ID: 1, Type: "http", URL: srv.URL, Method: "HEAD"}
RunCheck(site, http.DefaultClient, http.DefaultClient, false)
if receivedMethod != "HEAD" {
t.Errorf("expected HEAD, got %s", receivedMethod)
}
}
func TestRunCheck_HTTP_Timeout(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(2 * time.Second)
w.WriteHeader(200)
}))
defer srv.Close()
site := models.Site{ID: 1, Type: "http", URL: srv.URL, Timeout: 1}
result := RunCheck(site, http.DefaultClient, http.DefaultClient, false)
if result.Status != "DOWN" {
t.Errorf("expected DOWN on timeout, got %s", result.Status)
}
}
func TestRunCheck_HTTP_SSLFields(t *testing.T) {
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer srv.Close()
insecureClient := &http.Client{
Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
}
site := models.Site{ID: 1, Type: "http", URL: srv.URL, CheckSSL: true, IgnoreTLS: true}
result := RunCheck(site, http.DefaultClient, insecureClient, false)
if result.Status != "UP" {
t.Errorf("expected UP, got %s", result.Status)
}
if !result.HasSSL {
t.Error("expected HasSSL=true")
}
if result.CertExpiry.IsZero() {
t.Error("expected CertExpiry populated")
}
}
func TestRunCheck_Port_Open(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
_, portStr, _ := net.SplitHostPort(ln.Addr().String())
port, _ := strconv.Atoi(portStr)
site := models.Site{ID: 1, Type: "port", Hostname: "127.0.0.1", Port: port, Timeout: 2}
result := RunCheck(site, nil, nil, false)
if result.Status != "UP" {
t.Errorf("expected UP, got %s", result.Status)
}
if result.LatencyNs <= 0 {
t.Error("expected positive latency")
}
}
func TestRunCheck_Port_Closed(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
_, portStr, _ := net.SplitHostPort(ln.Addr().String())
port, _ := strconv.Atoi(portStr)
ln.Close()
site := models.Site{ID: 1, Type: "port", Hostname: "127.0.0.1", Port: port, Timeout: 1}
result := RunCheck(site, nil, nil, false)
if result.Status != "DOWN" {
t.Errorf("expected DOWN, got %s", result.Status)
}
}
func TestRunCheck_UnknownType(t *testing.T) {
site := models.Site{ID: 1, Type: "invalid"}
result := RunCheck(site, nil, nil, false)
if result.Status != "DOWN" {
t.Errorf("expected DOWN for unknown type, got %s", result.Status)
}
}
func TestIsCodeAccepted(t *testing.T) {
tests := []struct {
code int
accepted string
want bool
}{
{200, "", true},
{299, "", true},
{300, "", false},
{302, "200-399", true},
{400, "200-399", false},
{301, "200,301,404", true},
{500, "200,301,404", false},
{404, "200-299,400-499", true},
{500, "200-299,400-499", false},
}
for _, tt := range tests {
got := isCodeAccepted(tt.code, tt.accepted)
if got != tt.want {
t.Errorf("isCodeAccepted(%d, %q) = %v, want %v", tt.code, tt.accepted, got, tt.want)
}
}
}
func TestSiteTimeout(t *testing.T) {
if got := siteTimeout(models.Site{Timeout: 0}); got != 5*time.Second {
t.Errorf("expected 5s default, got %v", got)
}
if got := siteTimeout(models.Site{Timeout: 10}); got != 10*time.Second {
t.Errorf("expected 10s, got %v", got)
}
}
+16 -4
View File
@@ -7,6 +7,7 @@ import (
"go-upkeep/internal/alert" "go-upkeep/internal/alert"
"go-upkeep/internal/models" "go-upkeep/internal/models"
"go-upkeep/internal/store" "go-upkeep/internal/store"
"math/rand/v2"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@@ -25,7 +26,7 @@ type Engine struct {
histMu sync.RWMutex histMu sync.RWMutex
histories map[int]*SiteHistory histories map[int]*SiteHistory
tokenIndex map[string]int tokenIndex map[string]int // protected by mu
probeResultsMu sync.RWMutex probeResultsMu sync.RWMutex
probeResults map[int]map[string]NodeResult probeResults map[int]map[string]NodeResult
@@ -277,6 +278,14 @@ func (e *Engine) ToggleSitePause(id int) bool {
} }
func (e *Engine) monitorRoutine(ctx context.Context, id int) { func (e *Engine) monitorRoutine(ctx context.Context, id int) {
// Stagger initial check to avoid thundering herd on startup
stagger := time.Duration(rand.IntN(3000)) * time.Millisecond
select {
case <-time.After(stagger):
case <-ctx.Done():
return
}
e.checkByID(id) e.checkByID(id)
for { for {
select { select {
@@ -314,8 +323,9 @@ func (e *Engine) monitorRoutine(ctx context.Context, id int) {
if interval < 5 { if interval < 5 {
interval = 5 interval = 5
} }
jitter := time.Duration(rand.IntN(interval*100)) * time.Millisecond
select { select {
case <-time.After(time.Duration(interval) * time.Second): case <-time.After(time.Duration(interval)*time.Second + jitter):
case <-ctx.Done(): case <-ctx.Done():
return return
} }
@@ -433,6 +443,7 @@ func (e *Engine) handleStatusChange(site models.Site, rawStatus string, code int
func (e *Engine) triggerAlert(alertID int, title, message string) { func (e *Engine) triggerAlert(alertID int, title, message string) {
cfg, err := e.db.GetAlert(alertID) cfg, err := e.db.GetAlert(alertID)
if err != nil { if err != nil {
e.AddLog(fmt.Sprintf("Failed to load alert config %d: %v", alertID, err))
return return
} }
provider := alert.GetProvider(cfg) provider := alert.GetProvider(cfg)
@@ -440,8 +451,9 @@ func (e *Engine) triggerAlert(alertID int, title, message string) {
go func() { go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
_ = ctx if err := provider.Send(ctx, title, message); err != nil {
_ = provider.Send(title, message) e.AddLog(fmt.Sprintf("Alert send failed (%s): %v", cfg.Name, err))
}
}() }()
} }
} }
File diff suppressed because it is too large Load Diff
+22 -11
View File
@@ -1,6 +1,7 @@
package server package server
import ( import (
"crypto/subtle"
"encoding/json" "encoding/json"
"fmt" "fmt"
"go-upkeep/internal/importer" "go-upkeep/internal/importer"
@@ -15,6 +16,10 @@ import (
"strings" "strings"
) )
func checkSecret(got, want string) bool {
return subtle.ConstantTimeCompare([]byte(got), []byte(want)) == 1
}
var statusTpl = template.Must(template.New("status").Parse(` var statusTpl = template.Must(template.New("status").Parse(`
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
@@ -153,7 +158,7 @@ type ServerConfig struct {
ClusterKey string // Shared Secret for Security ClusterKey string // Shared Secret for Security
} }
func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) { func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server {
if cfg.ClusterKey == "" { if cfg.ClusterKey == "" {
fmt.Println("WARNING: No UPKEEP_CLUSTER_SECRET set. Cluster API endpoints are unauthenticated.") fmt.Println("WARNING: No UPKEEP_CLUSTER_SECRET set. Cluster API endpoints are unauthenticated.")
} }
@@ -176,7 +181,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) {
// 2. Health Check (For Cluster Follower) // 2. Health Check (For Cluster Follower)
mux.HandleFunc("/api/health", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/health", func(w http.ResponseWriter, r *http.Request) {
if cfg.ClusterKey != "" && r.Header.Get("X-Upkeep-Secret") != cfg.ClusterKey { if cfg.ClusterKey != "" && !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) {
http.Error(w, "Unauthorized", 401) http.Error(w, "Unauthorized", 401)
return return
} }
@@ -186,7 +191,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) {
// 3. Config Export // 3. Config Export
mux.HandleFunc("/api/backup/export", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/backup/export", func(w http.ResponseWriter, r *http.Request) {
if cfg.ClusterKey == "" || r.Header.Get("X-Upkeep-Secret") != cfg.ClusterKey { if cfg.ClusterKey == "" || !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) {
http.Error(w, "Unauthorized: UPKEEP_CLUSTER_SECRET required", 401) http.Error(w, "Unauthorized: UPKEEP_CLUSTER_SECRET required", 401)
return return
} }
@@ -205,10 +210,11 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) {
http.Error(w, "POST required", 405) http.Error(w, "POST required", 405)
return return
} }
if cfg.ClusterKey == "" || r.Header.Get("X-Upkeep-Secret") != cfg.ClusterKey { if cfg.ClusterKey == "" || !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) {
http.Error(w, "Unauthorized", 401) http.Error(w, "Unauthorized", 401)
return return
} }
r.Body = http.MaxBytesReader(w, r.Body, 1<<20)
var data models.Backup var data models.Backup
if err := json.NewDecoder(r.Body).Decode(&data); err != nil { if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
http.Error(w, "Invalid JSON", 400) http.Error(w, "Invalid JSON", 400)
@@ -228,10 +234,11 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) {
http.Error(w, "POST required", 405) http.Error(w, "POST required", 405)
return return
} }
if cfg.ClusterKey == "" || r.Header.Get("X-Upkeep-Secret") != cfg.ClusterKey { if cfg.ClusterKey == "" || !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) {
http.Error(w, "Unauthorized", 401) http.Error(w, "Unauthorized", 401)
return return
} }
r.Body = http.MaxBytesReader(w, r.Body, 1<<20)
var kb importer.KumaBackup var kb importer.KumaBackup
if err := json.NewDecoder(r.Body).Decode(&kb); err != nil { if err := json.NewDecoder(r.Body).Decode(&kb); err != nil {
log.Printf("Invalid Kuma JSON: %v", err) log.Printf("Invalid Kuma JSON: %v", err)
@@ -253,10 +260,11 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) {
http.Error(w, "POST required", 405) http.Error(w, "POST required", 405)
return return
} }
if cfg.ClusterKey == "" || r.Header.Get("X-Upkeep-Secret") != cfg.ClusterKey { if cfg.ClusterKey == "" || !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) {
http.Error(w, "Unauthorized", 401) http.Error(w, "Unauthorized", 401)
return return
} }
r.Body = http.MaxBytesReader(w, r.Body, 1<<20)
var req struct { var req struct {
ID string `json:"id"` ID string `json:"id"`
Name string `json:"name"` Name string `json:"name"`
@@ -283,7 +291,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) {
// 7. Probe Assignment Fetch // 7. Probe Assignment Fetch
mux.HandleFunc("/api/probe/assignments", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/probe/assignments", func(w http.ResponseWriter, r *http.Request) {
if cfg.ClusterKey == "" || r.Header.Get("X-Upkeep-Secret") != cfg.ClusterKey { if cfg.ClusterKey == "" || !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) {
http.Error(w, "Unauthorized", 401) http.Error(w, "Unauthorized", 401)
return return
} }
@@ -324,10 +332,11 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) {
http.Error(w, "POST required", 405) http.Error(w, "POST required", 405)
return return
} }
if cfg.ClusterKey == "" || r.Header.Get("X-Upkeep-Secret") != cfg.ClusterKey { if cfg.ClusterKey == "" || !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) {
http.Error(w, "Unauthorized", 401) http.Error(w, "Unauthorized", 401)
return return
} }
r.Body = http.MaxBytesReader(w, r.Body, 1<<20)
var req struct { var req struct {
NodeID string `json:"node_id"` NodeID string `json:"node_id"`
Results []struct { Results []struct {
@@ -387,13 +396,15 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) {
}) })
} }
addr := fmt.Sprintf(":%d", cfg.Port)
srv := &http.Server{Addr: addr, Handler: mux}
go func() { go func() {
addr := fmt.Sprintf(":%d", cfg.Port)
fmt.Printf("HTTP Server listening on %s\n", addr) fmt.Printf("HTTP Server listening on %s\n", addr)
if err := http.ListenAndServe(addr, mux); err != nil { if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("HTTP server failed: %v", err) log.Printf("HTTP server error: %v", err)
} }
}() }()
return srv
} }
func renderStatusPage(w http.ResponseWriter, title string, eng *monitor.Engine) { func renderStatusPage(w http.ResponseWriter, title string, eng *monitor.Engine) {
+553
View File
@@ -0,0 +1,553 @@
package server
import (
"bytes"
"encoding/json"
"fmt"
"go-upkeep/internal/models"
"go-upkeep/internal/monitor"
"net"
"net/http"
"sync"
"testing"
"time"
)
// --- Mock Store ---
type mockStore struct {
mu sync.Mutex
sites []models.Site
alerts []models.AlertConfig
nodes map[string]models.ProbeNode
importedData *models.Backup
registeredNodes []models.ProbeNode
maintWindows []models.MaintenanceWindow
}
func newMockStore() *mockStore {
return &mockStore{
nodes: make(map[string]models.ProbeNode),
}
}
func (m *mockStore) Init() error { return nil }
func (m *mockStore) GetSites() ([]models.Site, error) { return m.sites, nil }
func (m *mockStore) AddSite(models.Site) error { return nil }
func (m *mockStore) UpdateSite(models.Site) error { return nil }
func (m *mockStore) UpdateSitePaused(int, bool) error { return nil }
func (m *mockStore) DeleteSite(int) error { return nil }
func (m *mockStore) GetAllAlerts() ([]models.AlertConfig, error) { return m.alerts, nil }
func (m *mockStore) GetAlert(int) (models.AlertConfig, error) { return models.AlertConfig{}, nil }
func (m *mockStore) AddAlert(string, string, map[string]string) error { return nil }
func (m *mockStore) UpdateAlert(int, string, string, map[string]string) error { return nil }
func (m *mockStore) DeleteAlert(int) error { return nil }
func (m *mockStore) GetAllUsers() ([]models.User, error) { return nil, nil }
func (m *mockStore) AddUser(string, string, string) error { return nil }
func (m *mockStore) UpdateUser(int, string, string, string) error { return nil }
func (m *mockStore) DeleteUser(int) error { return nil }
func (m *mockStore) SaveCheck(int, int64, bool) error { return nil }
func (m *mockStore) SaveCheckFromNode(siteID int, nodeID string, latencyNs int64, isUp bool) error {
return nil
}
func (m *mockStore) LoadAllHistory(int) (map[int][]models.CheckRecord, error) {
return nil, nil
}
func (m *mockStore) GetSiteByName(string) (models.Site, error) { return models.Site{}, nil }
func (m *mockStore) GetAlertByName(string) (models.AlertConfig, error) {
return models.AlertConfig{}, nil
}
func (m *mockStore) AddSiteReturningID(models.Site) (int, error) { return 0, nil }
func (m *mockStore) AddAlertReturningID(string, string, map[string]string) (int, error) {
return 0, nil
}
func (m *mockStore) GetAllNodes() ([]models.ProbeNode, error) { return nil, nil }
func (m *mockStore) UpdateNodeLastSeen(string) error { return nil }
func (m *mockStore) DeleteNode(string) error { return nil }
func (m *mockStore) SaveLog(string) error { return nil }
func (m *mockStore) LoadLogs(int) ([]string, error) { return nil, nil }
func (m *mockStore) GetAllMaintenanceWindows(int) ([]models.MaintenanceWindow, error) {
return nil, nil
}
func (m *mockStore) AddMaintenanceWindow(models.MaintenanceWindow) error { return nil }
func (m *mockStore) EndMaintenanceWindow(int) error { return nil }
func (m *mockStore) DeleteMaintenanceWindow(int) error { return nil }
func (m *mockStore) IsMonitorInMaintenance(int) (bool, error) { return false, nil }
func (m *mockStore) GetPreference(string) (string, error) { return "", nil }
func (m *mockStore) SetPreference(string, string) error { return nil }
func (m *mockStore) Close() error { return nil }
func (m *mockStore) ExportData() (models.Backup, error) {
return models.Backup{
Sites: m.sites,
Alerts: m.alerts,
}, nil
}
func (m *mockStore) ImportData(data models.Backup) error {
m.mu.Lock()
defer m.mu.Unlock()
m.importedData = &data
return nil
}
func (m *mockStore) RegisterNode(node models.ProbeNode) error {
m.mu.Lock()
defer m.mu.Unlock()
m.registeredNodes = append(m.registeredNodes, node)
m.nodes[node.ID] = node
return nil
}
func (m *mockStore) GetNode(id string) (models.ProbeNode, error) {
m.mu.Lock()
defer m.mu.Unlock()
if n, ok := m.nodes[id]; ok {
return n, nil
}
return models.ProbeNode{}, fmt.Errorf("not found")
}
func (m *mockStore) GetActiveMaintenanceWindows() ([]models.MaintenanceWindow, error) {
return m.maintWindows, nil
}
// --- Helpers ---
func freePort() int {
ln, _ := net.Listen("tcp", "127.0.0.1:0")
port := ln.Addr().(*net.TCPAddr).Port
ln.Close()
return port
}
type testServer struct {
baseURL string
srv *http.Server
store *mockStore
engine *monitor.Engine
}
func newTestServer(t *testing.T, clusterKey string, enableStatus bool) *testServer {
t.Helper()
ms := newMockStore()
eng := monitor.NewEngine(ms)
port := freePort()
srv := Start(ServerConfig{
Port: port,
EnableStatus: enableStatus,
Title: "Test Status",
ClusterKey: clusterKey,
}, ms, eng)
ts := &testServer{
baseURL: fmt.Sprintf("http://127.0.0.1:%d", port),
srv: srv,
store: ms,
engine: eng,
}
// Wait for server to be ready
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
resp, err := http.Get(ts.baseURL + "/api/health")
if err == nil {
resp.Body.Close()
break
}
time.Sleep(10 * time.Millisecond)
}
t.Cleanup(func() {
srv.Close()
})
return ts
}
func authReq(method, url, secret string, body []byte) (*http.Response, error) {
var req *http.Request
var err error
if body != nil {
req, err = http.NewRequest(method, url, bytes.NewReader(body))
} else {
req, err = http.NewRequest(method, url, nil)
}
if err != nil {
return nil, err
}
if secret != "" {
req.Header.Set("X-Upkeep-Secret", secret)
}
return http.DefaultClient.Do(req)
}
// --- Tests ---
func TestCheckSecret(t *testing.T) {
if !checkSecret("mykey", "mykey") {
t.Error("expected match")
}
if checkSecret("mykey", "wrong") {
t.Error("expected no match")
}
if checkSecret("", "key") {
t.Error("expected no match for empty got")
}
}
// --- Push Heartbeat ---
func TestPush_MissingToken(t *testing.T) {
ts := newTestServer(t, "secret", false)
resp, err := http.Get(ts.baseURL + "/api/push")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 400 {
t.Errorf("expected 400, got %d", resp.StatusCode)
}
}
func TestPush_InvalidToken(t *testing.T) {
ts := newTestServer(t, "secret", false)
resp, err := http.Get(ts.baseURL + "/api/push?token=bad")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 404 {
t.Errorf("expected 404, got %d", resp.StatusCode)
}
}
// --- Health ---
func TestHealth_NoSecret(t *testing.T) {
ts := newTestServer(t, "", false)
resp, err := http.Get(ts.baseURL + "/api/health")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("expected 200 with no cluster key, got %d", resp.StatusCode)
}
}
func TestHealth_ValidSecret(t *testing.T) {
ts := newTestServer(t, "secret", false)
resp, err := authReq("GET", ts.baseURL+"/api/health", "secret", nil)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("expected 200, got %d", resp.StatusCode)
}
}
func TestHealth_WrongSecret(t *testing.T) {
ts := newTestServer(t, "secret", false)
resp, err := authReq("GET", ts.baseURL+"/api/health", "wrong", nil)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 401 {
t.Errorf("expected 401, got %d", resp.StatusCode)
}
}
// --- Backup Export ---
func TestExport_Unauthorized_NoKey(t *testing.T) {
ts := newTestServer(t, "", false)
resp, err := http.Get(ts.baseURL + "/api/backup/export")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 401 {
t.Errorf("expected 401 when no cluster key configured, got %d", resp.StatusCode)
}
}
func TestExport_Unauthorized_WrongKey(t *testing.T) {
ts := newTestServer(t, "secret", false)
resp, err := authReq("GET", ts.baseURL+"/api/backup/export", "wrong", nil)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 401 {
t.Errorf("expected 401, got %d", resp.StatusCode)
}
}
func TestExport_Success(t *testing.T) {
ts := newTestServer(t, "secret", false)
ts.store.sites = []models.Site{{ID: 1, Name: "example", URL: "http://example.com"}}
resp, err := authReq("GET", ts.baseURL+"/api/backup/export", "secret", nil)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("expected 200, got %d", resp.StatusCode)
}
var backup models.Backup
json.NewDecoder(resp.Body).Decode(&backup)
if len(backup.Sites) != 1 {
t.Errorf("expected 1 site, got %d", len(backup.Sites))
}
}
// --- Backup Import ---
func TestImport_MethodNotAllowed(t *testing.T) {
ts := newTestServer(t, "secret", false)
resp, err := authReq("GET", ts.baseURL+"/api/backup/import", "secret", nil)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 405 {
t.Errorf("expected 405, got %d", resp.StatusCode)
}
}
func TestImport_Unauthorized(t *testing.T) {
ts := newTestServer(t, "secret", false)
body, _ := json.Marshal(models.Backup{})
resp, err := authReq("POST", ts.baseURL+"/api/backup/import", "wrong", body)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 401 {
t.Errorf("expected 401, got %d", resp.StatusCode)
}
}
func TestImport_Success(t *testing.T) {
ts := newTestServer(t, "secret", false)
backup := models.Backup{
Sites: []models.Site{{Name: "imported", URL: "http://example.com"}},
}
body, _ := json.Marshal(backup)
resp, err := authReq("POST", ts.baseURL+"/api/backup/import", "secret", body)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("expected 200, got %d", resp.StatusCode)
}
ts.store.mu.Lock()
defer ts.store.mu.Unlock()
if ts.store.importedData == nil {
t.Error("expected import data to be stored")
}
}
func TestImport_InvalidJSON(t *testing.T) {
ts := newTestServer(t, "secret", false)
resp, err := authReq("POST", ts.baseURL+"/api/backup/import", "secret", []byte("not json"))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 400 {
t.Errorf("expected 400, got %d", resp.StatusCode)
}
}
// --- Probe Registration ---
func TestProbeRegister_Success(t *testing.T) {
ts := newTestServer(t, "secret", false)
body, _ := json.Marshal(map[string]string{
"id": "node-1", "name": "US East", "region": "us-east",
})
resp, err := authReq("POST", ts.baseURL+"/api/probe/register", "secret", body)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("expected 200, got %d", resp.StatusCode)
}
ts.store.mu.Lock()
defer ts.store.mu.Unlock()
if len(ts.store.registeredNodes) != 1 {
t.Errorf("expected 1 registered node, got %d", len(ts.store.registeredNodes))
}
if ts.store.registeredNodes[0].ID != "node-1" {
t.Errorf("expected node-1, got %s", ts.store.registeredNodes[0].ID)
}
}
func TestProbeRegister_MissingID(t *testing.T) {
ts := newTestServer(t, "secret", false)
body, _ := json.Marshal(map[string]string{"name": "test"})
resp, err := authReq("POST", ts.baseURL+"/api/probe/register", "secret", body)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 400 {
t.Errorf("expected 400, got %d", resp.StatusCode)
}
}
func TestProbeRegister_Unauthorized(t *testing.T) {
ts := newTestServer(t, "secret", false)
body, _ := json.Marshal(map[string]string{"id": "node-1"})
resp, err := authReq("POST", ts.baseURL+"/api/probe/register", "wrong", body)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 401 {
t.Errorf("expected 401, got %d", resp.StatusCode)
}
}
// --- Probe Results ---
func TestProbeResults_Success(t *testing.T) {
ts := newTestServer(t, "secret", false)
body, _ := json.Marshal(map[string]any{
"node_id": "node-1",
"results": []map[string]any{
{"site_id": 1, "latency_ns": 5000000, "is_up": true},
},
})
resp, err := authReq("POST", ts.baseURL+"/api/probe/results", "secret", body)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("expected 200, got %d", resp.StatusCode)
}
}
func TestProbeResults_MissingNodeID(t *testing.T) {
ts := newTestServer(t, "secret", false)
body, _ := json.Marshal(map[string]any{
"results": []map[string]any{},
})
resp, err := authReq("POST", ts.baseURL+"/api/probe/results", "secret", body)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 400 {
t.Errorf("expected 400, got %d", resp.StatusCode)
}
}
// --- Status Page ---
func TestStatusPage_Enabled(t *testing.T) {
ts := newTestServer(t, "secret", true)
resp, err := http.Get(ts.baseURL + "/status")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("expected 200, got %d", resp.StatusCode)
}
}
func TestStatusJSON_TokensStripped(t *testing.T) {
ts := newTestServer(t, "secret", true)
// Inject a site with a token into engine state
ts.engine.UpdateSiteConfig(models.Site{ID: 1, Name: "test", Type: "push", Token: "secret-token", Status: "UP"})
// Need to inject directly since UpdateSiteConfig only updates existing
func() {
ts.engine.RecordHeartbeat("unused") // just to exercise, won't match
}()
resp, err := http.Get(ts.baseURL + "/status/json")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("expected 200, got %d", resp.StatusCode)
}
var state map[string]models.Site
json.NewDecoder(resp.Body).Decode(&state)
for _, site := range state {
if site.Token != "" {
t.Error("expected token stripped from status JSON response")
}
}
}
func TestStatusJSON_MaintenanceOverride(t *testing.T) {
ts := newTestServer(t, "secret", true)
ts.store.maintWindows = []models.MaintenanceWindow{
{ID: 1, MonitorID: 0, Type: "maintenance", StartTime: time.Now().Add(-1 * time.Hour)},
}
resp, err := http.Get(ts.baseURL + "/status/json")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("expected 200, got %d", resp.StatusCode)
}
}
func TestStatusPage_Disabled(t *testing.T) {
ts := newTestServer(t, "secret", false)
resp, err := http.Get(ts.baseURL + "/status")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 404 {
t.Errorf("expected 404 when status disabled, got %d", resp.StatusCode)
}
}
// --- Probe Assignments ---
func TestProbeAssignments_Success(t *testing.T) {
ts := newTestServer(t, "secret", false)
resp, err := authReq("GET", ts.baseURL+"/api/probe/assignments", "secret", nil)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("expected 200, got %d", resp.StatusCode)
}
var result map[string][]models.Site
json.NewDecoder(resp.Body).Decode(&result)
if _, ok := result["sites"]; !ok {
t.Error("expected 'sites' key in response")
}
}
func TestProbeAssignments_Unauthorized(t *testing.T) {
ts := newTestServer(t, "secret", false)
resp, err := authReq("GET", ts.baseURL+"/api/probe/assignments", "wrong", nil)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 401 {
t.Errorf("expected 401, got %d", resp.StatusCode)
}
}
+26 -8
View File
@@ -29,12 +29,16 @@ func (s *SQLStore) q(query string) string {
return rewritePlaceholders(query, s.dollar) return rewritePlaceholders(query, s.dollar)
} }
func generateToken() string { func generateToken() (string, error) {
b := make([]byte, 16) b := make([]byte, 16)
if _, err := rand.Read(b); err != nil { if _, err := rand.Read(b); err != nil {
panic("crypto/rand failed: " + err.Error()) return "", fmt.Errorf("crypto/rand failed: %w", err)
} }
return hex.EncodeToString(b) return hex.EncodeToString(b), nil
}
func (s *SQLStore) Close() error {
return s.db.Close()
} }
func (s *SQLStore) Init() error { func (s *SQLStore) Init() error {
@@ -77,7 +81,11 @@ func (s *SQLStore) GetSites() ([]models.Site, error) {
func (s *SQLStore) AddSite(site models.Site) error { func (s *SQLStore) AddSite(site models.Site) error {
token := "" token := ""
if site.Type == "push" { if site.Type == "push" {
token = generateToken() var err error
token, err = generateToken()
if err != nil {
return fmt.Errorf("generate push token: %w", err)
}
} }
_, err := s.db.Exec(s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), _, err := s.db.Exec(s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"),
site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries,
@@ -89,7 +97,11 @@ func (s *SQLStore) UpdateSite(site models.Site) error {
var existingToken string var existingToken string
s.db.QueryRow(s.q("SELECT token FROM sites WHERE id=?"), site.ID).Scan(&existingToken) s.db.QueryRow(s.q("SELECT token FROM sites WHERE id=?"), site.ID).Scan(&existingToken)
if site.Type == "push" && existingToken == "" { if site.Type == "push" && existingToken == "" {
existingToken = generateToken() var err error
existingToken, err = generateToken()
if err != nil {
return fmt.Errorf("generate push token: %w", err)
}
} }
_, err := s.db.Exec(s.q("UPDATE sites SET name=?, url=?, type=?, token=?, interval=?, alert_id=?, check_ssl=?, threshold=?, max_retries=?, hostname=?, port=?, timeout=?, method=?, description=?, parent_id=?, accepted_codes=?, dns_resolve_type=?, dns_server=?, ignore_tls=?, paused=?, regions=? WHERE id=?"), _, err := s.db.Exec(s.q("UPDATE sites SET name=?, url=?, type=?, token=?, interval=?, alert_id=?, check_ssl=?, threshold=?, max_retries=?, hostname=?, port=?, timeout=?, method=?, description=?, parent_id=?, accepted_codes=?, dns_resolve_type=?, dns_server=?, ignore_tls=?, paused=?, regions=? WHERE id=?"),
site.Name, site.URL, site.Type, existingToken, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, site.Name, site.URL, site.Type, existingToken, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries,
@@ -132,7 +144,9 @@ func (s *SQLStore) GetAlertByName(name string) (models.AlertConfig, error) {
if err != nil { if err != nil {
return a, err return a, err
} }
json.Unmarshal([]byte(settingsJSON), &a.Settings) if err := json.Unmarshal([]byte(settingsJSON), &a.Settings); err != nil {
return a, fmt.Errorf("unmarshal alert settings: %w", err)
}
return a, nil return a, nil
} }
@@ -171,7 +185,9 @@ func (s *SQLStore) GetAllAlerts() ([]models.AlertConfig, error) {
if err := rows.Scan(&a.ID, &a.Name, &a.Type, &settingsJSON); err != nil { if err := rows.Scan(&a.ID, &a.Name, &a.Type, &settingsJSON); err != nil {
return alerts, err return alerts, err
} }
json.Unmarshal([]byte(settingsJSON), &a.Settings) if err := json.Unmarshal([]byte(settingsJSON), &a.Settings); err != nil {
return alerts, fmt.Errorf("unmarshal alert settings for %q: %w", a.Name, err)
}
alerts = append(alerts, a) alerts = append(alerts, a)
} }
return alerts, rows.Err() return alerts, rows.Err()
@@ -184,7 +200,9 @@ func (s *SQLStore) GetAlert(id int) (models.AlertConfig, error) {
if err != nil { if err != nil {
return a, err return a, err
} }
json.Unmarshal([]byte(settingsJSON), &a.Settings) if err := json.Unmarshal([]byte(settingsJSON), &a.Settings); err != nil {
return a, fmt.Errorf("unmarshal alert settings: %w", err)
}
return a, nil return a, nil
} }
+3
View File
@@ -64,4 +64,7 @@ type Store interface {
// Backup & Restore // Backup & Restore
ExportData() (models.Backup, error) ExportData() (models.Backup, error)
ImportData(data models.Backup) error ImportData(data models.Backup) error
// Lifecycle
Close() error
} }