Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
7bf278e538
|
|||
|
023234f4c3
|
|||
|
4328d25f22
|
|||
|
f745dcb21f
|
@@ -16,6 +16,11 @@ A follower is a standby replica that takes over if the leader goes down.
|
|||||||
- When the leader recovers, the follower detects it and goes back to standby
|
- When the leader recovers, the follower detects it and goes back to standby
|
||||||
- Both nodes have their own database — they do not share state
|
- Both nodes have their own database — they do not share state
|
||||||
|
|
||||||
|
**Limitations:**
|
||||||
|
- During a network partition where both nodes are healthy, both will run checks and fire alerts independently. There is no leader fencing — the follower has no way to confirm the leader is actually down vs. unreachable from its perspective. This window lasts until the partition heals, at which point the follower detects the leader and steps down.
|
||||||
|
- Expect duplicate alerts and doubled check history entries during a split-brain event. Alerts are idempotent for most providers (a second "site is down" notification is noisy but not harmful).
|
||||||
|
- Failover takeover time is ~15 seconds (3 missed polls × 5 second interval). This is not configurable.
|
||||||
|
|
||||||
**Required env vars:**
|
**Required env vars:**
|
||||||
|
|
||||||
| Node | Variable | Value |
|
| Node | Variable | Value |
|
||||||
|
|||||||
+63
-2
@@ -3,9 +3,11 @@ package alert
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/smtp"
|
"net/smtp"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -244,7 +246,6 @@ func (e *EmailProvider) Send(ctx context.Context, title, message string) error {
|
|||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
auth := smtp.PlainAuth("", e.User, e.Pass, e.Host)
|
|
||||||
to := sanitizeHeader(e.To)
|
to := sanitizeHeader(e.To)
|
||||||
from := sanitizeHeader(e.From)
|
from := sanitizeHeader(e.From)
|
||||||
subject := sanitizeHeader(title)
|
subject := sanitizeHeader(title)
|
||||||
@@ -256,7 +257,67 @@ func (e *EmailProvider) Send(ctx context.Context, title, message string) error {
|
|||||||
"Content-Type: text/plain; charset=utf-8\r\n" +
|
"Content-Type: text/plain; charset=utf-8\r\n" +
|
||||||
"\r\n" +
|
"\r\n" +
|
||||||
body + "\r\n")
|
body + "\r\n")
|
||||||
return smtp.SendMail(e.Host+":"+e.Port, auth, from, []string{to}, msg)
|
return sendMailContext(ctx, e.Host, e.Port, e.User, e.Pass, from, []string{to}, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendMailContext is a ctx-aware replacement for smtp.SendMail.
|
||||||
|
// smtp.SendMail ignores context entirely — a blackholed SMTP server hangs for
|
||||||
|
// the OS TCP timeout (minutes). This dials with the context deadline and sets
|
||||||
|
// connection deadlines so cancellation is respected throughout.
|
||||||
|
func sendMailContext(ctx context.Context, host, port, user, pass, from string, rcpt []string, msg []byte) error {
|
||||||
|
addr := host + ":" + port
|
||||||
|
|
||||||
|
dialer := net.Dialer{}
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("smtp dial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
_ = conn.SetDeadline(deadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := smtp.NewClient(conn, host)
|
||||||
|
if err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return fmt.Errorf("smtp client: %w", err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
if ok, _ := c.Extension("STARTTLS"); ok {
|
||||||
|
if err := c.StartTLS(&tls.Config{ServerName: host}); err != nil {
|
||||||
|
return fmt.Errorf("smtp starttls: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if user != "" || pass != "" {
|
||||||
|
auth := smtp.PlainAuth("", user, pass, host)
|
||||||
|
if err := c.Auth(auth); err != nil {
|
||||||
|
return fmt.Errorf("smtp auth: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.Mail(from); err != nil {
|
||||||
|
return fmt.Errorf("smtp mail: %w", err)
|
||||||
|
}
|
||||||
|
for _, r := range rcpt {
|
||||||
|
if err := c.Rcpt(r); err != nil {
|
||||||
|
return fmt.Errorf("smtp rcpt: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w, err := c.Data()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("smtp data: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := w.Write(msg); err != nil {
|
||||||
|
return fmt.Errorf("smtp write: %w", err)
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
return fmt.Errorf("smtp data close: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Quit()
|
||||||
}
|
}
|
||||||
|
|
||||||
type NtfyProvider struct {
|
type NtfyProvider struct {
|
||||||
|
|||||||
@@ -1,14 +1,18 @@
|
|||||||
package alert
|
package alert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gitea.lerkolabs.com/lerkolabs/uptop/internal/models"
|
"gitea.lerkolabs.com/lerkolabs/uptop/internal/models"
|
||||||
)
|
)
|
||||||
@@ -330,3 +334,116 @@ func TestSanitizeError(t *testing.T) {
|
|||||||
t.Error("nil should stay nil")
|
t.Error("nil should stay nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmailProvider_ContextTimeout(t *testing.T) {
|
||||||
|
// Listener that accepts but never speaks — simulates a blackholed SMTP server.
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Hold connection open, never send banner.
|
||||||
|
go func(c net.Conn) {
|
||||||
|
time.Sleep(30 * time.Second)
|
||||||
|
c.Close()
|
||||||
|
}(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, portStr, _ := net.SplitHostPort(ln.Addr().String())
|
||||||
|
provider := &EmailProvider{
|
||||||
|
Host: "127.0.0.1", Port: portStr,
|
||||||
|
From: "test@test.com", To: "dest@test.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
err = provider.Send(ctx, "test", "body")
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error from stalled SMTP")
|
||||||
|
}
|
||||||
|
if elapsed > 2*time.Second {
|
||||||
|
t.Errorf("Send took %v — context deadline not respected", elapsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendMailContext_HappyPath(t *testing.T) {
|
||||||
|
// Minimal fake SMTP server that accepts one message.
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
received := make(chan string, 1)
|
||||||
|
go func() {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
fmt.Fprintf(conn, "220 localhost ESMTP\r\n")
|
||||||
|
scanner := bufio.NewScanner(conn)
|
||||||
|
var dataMode bool
|
||||||
|
var body strings.Builder
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if dataMode {
|
||||||
|
if line == "." {
|
||||||
|
dataMode = false
|
||||||
|
fmt.Fprintf(conn, "250 OK\r\n")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
body.WriteString(line + "\n")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(line, "EHLO"):
|
||||||
|
fmt.Fprintf(conn, "250-localhost\r\n250 OK\r\n")
|
||||||
|
case strings.HasPrefix(line, "MAIL FROM"):
|
||||||
|
fmt.Fprintf(conn, "250 OK\r\n")
|
||||||
|
case strings.HasPrefix(line, "RCPT TO"):
|
||||||
|
fmt.Fprintf(conn, "250 OK\r\n")
|
||||||
|
case line == "DATA":
|
||||||
|
fmt.Fprintf(conn, "354 Go ahead\r\n")
|
||||||
|
dataMode = true
|
||||||
|
case line == "QUIT":
|
||||||
|
fmt.Fprintf(conn, "221 Bye\r\n")
|
||||||
|
received <- body.String()
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
fmt.Fprintf(conn, "250 OK\r\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, portStr, _ := net.SplitHostPort(ln.Addr().String())
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err = sendMailContext(ctx, "127.0.0.1", portStr, "", "", "from@test.com", []string{"to@test.com"}, []byte("Subject: test\r\n\r\nhello"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sendMailContext: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case body := <-received:
|
||||||
|
if !strings.Contains(body, "hello") {
|
||||||
|
t.Errorf("expected body to contain 'hello', got: %s", body)
|
||||||
|
}
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for fake SMTP to receive message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -37,19 +37,26 @@ type CheckResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RunCheck(ctx context.Context, site models.SiteConfig, strict, insecure *http.Client, globalInsecure, allowPrivate bool) CheckResult {
|
func RunCheck(ctx context.Context, site models.SiteConfig, strict, insecure *http.Client, globalInsecure, allowPrivate bool) CheckResult {
|
||||||
|
// Resolve + validate once for non-HTTP types to prevent DNS-rebind TOCTOU:
|
||||||
|
// a second resolve in the check function could return a different (private) IP.
|
||||||
|
// HTTP is safe — SafeDialContext resolves and validates at dial time.
|
||||||
|
var pinnedIP net.IP
|
||||||
if site.Type != "http" && site.Type != "dns" && !allowPrivate {
|
if site.Type != "http" && site.Type != "dns" && !allowPrivate {
|
||||||
host := site.Hostname
|
host := site.Hostname
|
||||||
if host == "" {
|
if host == "" {
|
||||||
host = site.URL
|
host = site.URL
|
||||||
}
|
}
|
||||||
if host != "" {
|
if host != "" {
|
||||||
if ips, err := net.LookupIP(host); err == nil {
|
ips, err := net.LookupIP(host)
|
||||||
for _, ip := range ips {
|
if err != nil {
|
||||||
if isPrivateIP(ip) {
|
return CheckResult{SiteID: site.ID, Status: string(models.StatusDown), ErrorReason: "resolve failed: " + err.Error()}
|
||||||
return CheckResult{SiteID: site.ID, Status: string(models.StatusDown), ErrorReason: "target resolves to private IP"}
|
}
|
||||||
}
|
for _, ip := range ips {
|
||||||
|
if isPrivateIP(ip) {
|
||||||
|
return CheckResult{SiteID: site.ID, Status: string(models.StatusDown), ErrorReason: "target resolves to private IP"}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
pinnedIP = ips[0]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,9 +64,9 @@ func RunCheck(ctx context.Context, site models.SiteConfig, strict, insecure *htt
|
|||||||
case "http":
|
case "http":
|
||||||
return runHTTPCheck(ctx, site, strict, insecure, globalInsecure)
|
return runHTTPCheck(ctx, site, strict, insecure, globalInsecure)
|
||||||
case "ping":
|
case "ping":
|
||||||
return runPingCheck(ctx, site)
|
return runPingCheck(ctx, site, pinnedIP)
|
||||||
case "port":
|
case "port":
|
||||||
return runPortCheck(ctx, site)
|
return runPortCheck(ctx, site, pinnedIP)
|
||||||
case "dns":
|
case "dns":
|
||||||
return runDNSCheck(ctx, site, allowPrivate)
|
return runDNSCheck(ctx, site, allowPrivate)
|
||||||
default:
|
default:
|
||||||
@@ -130,7 +137,7 @@ func runHTTPCheck(ctx context.Context, site models.SiteConfig, strict, insecure
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func runPingCheck(_ context.Context, site models.SiteConfig) CheckResult {
|
func runPingCheck(_ context.Context, site models.SiteConfig, pinnedIP net.IP) CheckResult {
|
||||||
host := site.Hostname
|
host := site.Hostname
|
||||||
if host == "" {
|
if host == "" {
|
||||||
host = site.URL
|
host = site.URL
|
||||||
@@ -140,6 +147,9 @@ func runPingCheck(_ context.Context, site models.SiteConfig) CheckResult {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return CheckResult{SiteID: site.ID, Status: string(models.StatusDown), ErrorReason: "ping setup: " + err.Error()}
|
return CheckResult{SiteID: site.ID, Status: string(models.StatusDown), ErrorReason: "ping setup: " + err.Error()}
|
||||||
}
|
}
|
||||||
|
if pinnedIP != nil {
|
||||||
|
pinger.SetIPAddr(&net.IPAddr{IP: pinnedIP})
|
||||||
|
}
|
||||||
pinger.Count = 1
|
pinger.Count = 1
|
||||||
pinger.Timeout = siteTimeout(site)
|
pinger.Timeout = siteTimeout(site)
|
||||||
pinger.SetPrivileged(false)
|
pinger.SetPrivileged(false)
|
||||||
@@ -159,11 +169,14 @@ func runPingCheck(_ context.Context, site models.SiteConfig) CheckResult {
|
|||||||
return CheckResult{SiteID: site.ID, Status: string(models.StatusUp), LatencyNs: stats.AvgRtt.Nanoseconds()}
|
return CheckResult{SiteID: site.ID, Status: string(models.StatusUp), LatencyNs: stats.AvgRtt.Nanoseconds()}
|
||||||
}
|
}
|
||||||
|
|
||||||
func runPortCheck(_ context.Context, site models.SiteConfig) CheckResult {
|
func runPortCheck(_ context.Context, site models.SiteConfig, pinnedIP net.IP) CheckResult {
|
||||||
host := site.Hostname
|
host := site.Hostname
|
||||||
if host == "" {
|
if host == "" {
|
||||||
host = site.URL
|
host = site.URL
|
||||||
}
|
}
|
||||||
|
if pinnedIP != nil {
|
||||||
|
host = pinnedIP.String()
|
||||||
|
}
|
||||||
addr := net.JoinHostPort(host, strconv.Itoa(site.Port))
|
addr := net.JoinHostPort(host, strconv.Itoa(site.Port))
|
||||||
timeout := siteTimeout(site)
|
timeout := siteTimeout(site)
|
||||||
|
|
||||||
|
|||||||
@@ -161,6 +161,43 @@ func TestRunCheck_Port_Closed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRunPortCheck_UsesPinnedIP(t *testing.T) {
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
_, portStr, _ := net.SplitHostPort(ln.Addr().String())
|
||||||
|
port, _ := strconv.Atoi(portStr)
|
||||||
|
|
||||||
|
// Pass a pinned IP — runPortCheck should dial it instead of resolving Hostname.
|
||||||
|
site := models.SiteConfig{ID: 1, Type: "port", Hostname: "will-not-resolve.invalid", Port: port, Timeout: 2}
|
||||||
|
result := runPortCheck(context.Background(), site, net.ParseIP("127.0.0.1"))
|
||||||
|
|
||||||
|
if result.Status != "UP" {
|
||||||
|
t.Errorf("expected UP when pinned IP used, got %s: %s", result.Status, result.ErrorReason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunPortCheck_NilPinnedIP_UsesHostname(t *testing.T) {
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
_, portStr, _ := net.SplitHostPort(ln.Addr().String())
|
||||||
|
port, _ := strconv.Atoi(portStr)
|
||||||
|
|
||||||
|
site := models.SiteConfig{ID: 1, Type: "port", Hostname: "127.0.0.1", Port: port, Timeout: 2}
|
||||||
|
result := runPortCheck(context.Background(), site, nil)
|
||||||
|
|
||||||
|
if result.Status != "UP" {
|
||||||
|
t.Errorf("expected UP with nil pinnedIP fallback, got %s: %s", result.Status, result.ErrorReason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRunCheck_Port_BlocksPrivateByDefault(t *testing.T) {
|
func TestRunCheck_Port_BlocksPrivateByDefault(t *testing.T) {
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -205,12 +205,15 @@ func (s *Server) handleImport(w http.ResponseWriter, r *http.Request) {
|
|||||||
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// API import never modifies users — cluster-secret holder shouldn't be
|
||||||
|
// able to replace admin accounts. CLI restore still does full import.
|
||||||
|
data.Users = nil
|
||||||
if err := s.store.ImportData(r.Context(), data); err != nil {
|
if err := s.store.ImportData(r.Context(), data); err != nil {
|
||||||
slog.Error("import failed", "err", err)
|
slog.Error("import failed", "err", err)
|
||||||
http.Error(w, "Import failed", http.StatusInternalServerError)
|
http.Error(w, "Import failed", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, _ = w.Write([]byte("Import Successful"))
|
_, _ = w.Write([]byte("Import Successful (users excluded — manage via CLI or UPTOP_KEYS)"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleKumaImport(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleKumaImport(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type Dialect interface {
|
|||||||
BoolFalse() string
|
BoolFalse() string
|
||||||
ResetSequenceOnEmpty(db *sql.DB, table string)
|
ResetSequenceOnEmpty(db *sql.DB, table string)
|
||||||
ImportWipe(tx *sql.Tx)
|
ImportWipe(tx *sql.Tx)
|
||||||
|
ImportWipeUsers(tx *sql.Tx)
|
||||||
ImportResetSequences(tx *sql.Tx)
|
ImportResetSequences(tx *sql.Tx)
|
||||||
UpsertNodeSQL() string
|
UpsertNodeSQL() string
|
||||||
UpsertAlertHealthSQL() string
|
UpsertAlertHealthSQL() string
|
||||||
|
|||||||
@@ -138,9 +138,6 @@ func (d *PostgresDialect) ImportWipe(tx *sql.Tx) {
|
|||||||
if _, err := tx.Exec("TRUNCATE TABLE alerts RESTART IDENTITY CASCADE"); err != nil {
|
if _, err := tx.Exec("TRUNCATE TABLE alerts RESTART IDENTITY CASCADE"); err != nil {
|
||||||
slog.Debug("import wipe failed", "table", "alerts", "err", err)
|
slog.Debug("import wipe failed", "table", "alerts", "err", err)
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec("TRUNCATE TABLE users RESTART IDENTITY CASCADE"); err != nil {
|
|
||||||
slog.Debug("import wipe failed", "table", "users", "err", err)
|
|
||||||
}
|
|
||||||
if _, err := tx.Exec("TRUNCATE TABLE maintenance_windows RESTART IDENTITY CASCADE"); err != nil {
|
if _, err := tx.Exec("TRUNCATE TABLE maintenance_windows RESTART IDENTITY CASCADE"); err != nil {
|
||||||
slog.Debug("import wipe failed", "table", "maintenance_windows", "err", err)
|
slog.Debug("import wipe failed", "table", "maintenance_windows", "err", err)
|
||||||
}
|
}
|
||||||
@@ -155,6 +152,12 @@ func (d *PostgresDialect) ImportWipe(tx *sql.Tx) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *PostgresDialect) ImportWipeUsers(tx *sql.Tx) {
|
||||||
|
if _, err := tx.Exec("TRUNCATE TABLE users RESTART IDENTITY CASCADE"); err != nil {
|
||||||
|
slog.Debug("import wipe failed", "table", "users", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (d *PostgresDialect) ImportResetSequences(tx *sql.Tx) {
|
func (d *PostgresDialect) ImportResetSequences(tx *sql.Tx) {
|
||||||
if _, err := tx.Exec("SELECT setval('sites_id_seq', (SELECT COALESCE(MAX(id), 1) FROM sites))"); err != nil {
|
if _, err := tx.Exec("SELECT setval('sites_id_seq', (SELECT COALESCE(MAX(id), 1) FROM sites))"); err != nil {
|
||||||
slog.Debug("sequence reset failed", "table", "sites", "err", err)
|
slog.Debug("sequence reset failed", "table", "sites", "err", err)
|
||||||
|
|||||||
@@ -167,12 +167,6 @@ func (d *SQLiteDialect) ImportWipe(tx *sql.Tx) {
|
|||||||
if _, err := tx.Exec("DELETE FROM sqlite_sequence WHERE name='alerts'"); err != nil {
|
if _, err := tx.Exec("DELETE FROM sqlite_sequence WHERE name='alerts'"); err != nil {
|
||||||
slog.Debug("import wipe failed", "table", "sqlite_sequence(alerts)", "err", err)
|
slog.Debug("import wipe failed", "table", "sqlite_sequence(alerts)", "err", err)
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec("DELETE FROM users"); err != nil {
|
|
||||||
slog.Debug("import wipe failed", "table", "users", "err", err)
|
|
||||||
}
|
|
||||||
if _, err := tx.Exec("DELETE FROM sqlite_sequence WHERE name='users'"); err != nil {
|
|
||||||
slog.Debug("import wipe failed", "table", "sqlite_sequence(users)", "err", err)
|
|
||||||
}
|
|
||||||
if _, err := tx.Exec("DELETE FROM maintenance_windows"); err != nil {
|
if _, err := tx.Exec("DELETE FROM maintenance_windows"); err != nil {
|
||||||
slog.Debug("import wipe failed", "table", "maintenance_windows", "err", err)
|
slog.Debug("import wipe failed", "table", "maintenance_windows", "err", err)
|
||||||
}
|
}
|
||||||
@@ -190,4 +184,13 @@ func (d *SQLiteDialect) ImportWipe(tx *sql.Tx) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *SQLiteDialect) ImportWipeUsers(tx *sql.Tx) {
|
||||||
|
if _, err := tx.Exec("DELETE FROM users"); err != nil {
|
||||||
|
slog.Debug("import wipe failed", "table", "users", "err", err)
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec("DELETE FROM sqlite_sequence WHERE name='users'"); err != nil {
|
||||||
|
slog.Debug("import wipe failed", "table", "sqlite_sequence(users)", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (d *SQLiteDialect) ImportResetSequences(tx *sql.Tx) {}
|
func (d *SQLiteDialect) ImportResetSequences(tx *sql.Tx) {}
|
||||||
|
|||||||
@@ -742,9 +742,14 @@ func (s *SQLStore) ImportData(ctx context.Context, data models.Backup) error {
|
|||||||
|
|
||||||
s.dialect.ImportWipe(tx)
|
s.dialect.ImportWipe(tx)
|
||||||
|
|
||||||
for _, u := range data.Users {
|
// Only wipe+replace users when callers explicitly provide them (CLI
|
||||||
if _, err := tx.ExecContext(ctx, s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), u.Username, u.PublicKey, u.Role); err != nil {
|
// full restore). API/Kuma imports pass nil — existing users preserved.
|
||||||
return err
|
if data.Users != nil {
|
||||||
|
s.dialect.ImportWipeUsers(tx)
|
||||||
|
for _, u := range data.Users {
|
||||||
|
if _, err := tx.ExecContext(ctx, s.q("INSERT INTO users (username, public_key, role) VALUES (?, ?, ?)"), u.Username, u.PublicKey, u.Role); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, a := range data.Alerts {
|
for _, a := range data.Alerts {
|
||||||
|
|||||||
@@ -276,6 +276,31 @@ func TestImportData_WipesHistory(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestImportData_NilUsersPreservesExisting(t *testing.T) {
|
||||||
|
s := newTestStore(t)
|
||||||
|
|
||||||
|
if err := s.AddUser(context.Background(), "admin", "ssh-ed25519 ADMINKEY", "admin"); err != nil {
|
||||||
|
t.Fatalf("AddUser: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
backup := models.Backup{
|
||||||
|
Sites: []models.SiteConfig{{ID: 1, Name: "New", URL: "https://new.com", Type: "http", Interval: 30}},
|
||||||
|
Alerts: []models.AlertConfig{{ID: 1, Name: "a", Type: "webhook", Settings: map[string]string{"url": "https://h.com"}}},
|
||||||
|
Users: nil,
|
||||||
|
}
|
||||||
|
if err := s.ImportData(context.Background(), backup); err != nil {
|
||||||
|
t.Fatalf("ImportData: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
users, err := s.GetAllUsers(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetAllUsers: %v", err)
|
||||||
|
}
|
||||||
|
if len(users) != 1 || users[0].Username != "admin" {
|
||||||
|
t.Errorf("expected existing admin user preserved, got %d users", len(users))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCheckHistory(t *testing.T) {
|
func TestCheckHistory(t *testing.T) {
|
||||||
s := newTestStore(t)
|
s := newTestStore(t)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user