diff --git a/internal/monitor/monitor.go b/internal/monitor/monitor.go index 4702b9e..6a04a2f 100644 --- a/internal/monitor/monitor.go +++ b/internal/monitor/monitor.go @@ -230,47 +230,53 @@ func (e *Engine) RecordHeartbeat(token string) bool { return false } - e.mu.Lock() - defer e.mu.Unlock() - + e.mu.RLock() targetID, ok := e.tokenIndex[token] + e.mu.RUnlock() if !ok { return false } - site, exists := e.liveState[targetID] + var ( + prevStatus string + name string + alertID int + downSince time.Time + ) + _, exists := e.applyState(targetID, func(s *models.Site) { + prevStatus = s.Status + name = s.Name + alertID = s.AlertID + downSince = s.StatusChangedAt // captured before mutation = when it went down + + s.LastCheck = time.Now() + s.Status = "UP" + s.FailureCount = 0 + s.Latency = 0 + s.LastError = "" + s.LastSuccessAt = time.Now() + if prevStatus != "UP" { + s.StatusChangedAt = time.Now() + } + }) if !exists { return false } - prevStatus := site.Status - site.LastCheck = time.Now() - site.Status = "UP" - site.FailureCount = 0 - site.Latency = 0 - site.LastError = "" - site.LastSuccessAt = time.Now() - - if prevStatus != "UP" { - site.StatusChangedAt = time.Now() - } - - e.liveState[targetID] = site - switch prevStatus { case "PENDING": - e.AddLog(fmt.Sprintf("Push Monitor '%s' received first heartbeat", site.Name)) + e.AddLog(fmt.Sprintf("Push Monitor '%s' received first heartbeat", name)) case "LATE": - e.AddLog(fmt.Sprintf("Push Monitor '%s' heartbeat arrived (was late)", site.Name)) + e.AddLog(fmt.Sprintf("Push Monitor '%s' heartbeat arrived (was late)", name)) case "STALE": - e.AddLog(fmt.Sprintf("Push Monitor '%s' heartbeat arrived (was stale)", site.Name)) + e.AddLog(fmt.Sprintf("Push Monitor '%s' heartbeat arrived (was stale)", name)) case "DOWN": downDur := "" - if !site.StatusChangedAt.IsZero() { - downDur = fmt.Sprintf(" (was down %s)", fmtDurationShort(time.Since(site.StatusChangedAt))) + if !downSince.IsZero() { + downDur = fmt.Sprintf(" (was down %s)", fmtDurationShort(time.Since(downSince))) } - e.AddLog(fmt.Sprintf("Push Monitor '%s' recovered%s", site.Name, downDur)) - go e.triggerAlert(site.AlertID, "✅ RECOVERY", fmt.Sprintf("Push Monitor '%s' is receiving heartbeats.%s", site.Name, downDur)) + e.AddLog(fmt.Sprintf("Push Monitor '%s' recovered%s", name, downDur)) + go e.triggerAlert(alertID, "✅ RECOVERY", fmt.Sprintf("Push Monitor '%s' is receiving heartbeats.%s", name, downDur)) } if prevStatus != "UP" && prevStatus != "PENDING" { @@ -431,20 +437,24 @@ func (e *Engine) RemoveSite(id int) { } func (e *Engine) ToggleSitePause(id int) bool { - e.mu.Lock() - defer e.mu.Unlock() - site, ok := e.liveState[id] + var ( + paused bool + name string + ) + _, ok := e.applyState(id, func(s *models.Site) { + s.Paused = !s.Paused + paused = s.Paused + name = s.Name + }) if !ok { return false } - site.Paused = !site.Paused - e.liveState[id] = site - if site.Paused { - e.AddLog(fmt.Sprintf("Monitor '%s' paused", site.Name)) + if paused { + e.AddLog(fmt.Sprintf("Monitor '%s' paused", name)) } else { - e.AddLog(fmt.Sprintf("Monitor '%s' resumed", site.Name)) + e.AddLog(fmt.Sprintf("Monitor '%s' resumed", name)) } - return site.Paused + return paused } func (e *Engine) monitorRoutine(ctx context.Context, id int) { @@ -508,6 +518,25 @@ func (e *Engine) monitorRoutine(ctx context.Context, id int) { } } +// applyState atomically reads, mutates, and writes back the live entry for id. +// The mutator runs under the engine write lock and receives a pointer to the +// CURRENT live state, so concurrent config edits, pauses, and heartbeats are +// never clobbered by a stale snapshot. The mutator must only touch runtime / +// check-result fields — config fields (Name/URL/Type/Token/Interval/AlertID/…) +// are owned by UpdateSiteConfig and must not be written here. Returns the +// post-mutation copy and whether the site still exists. +func (e *Engine) applyState(id int, mutate func(s *models.Site)) (models.Site, bool) { + e.mu.Lock() + defer e.mu.Unlock() + cur, ok := e.liveState[id] + if !ok { + return models.Site{}, false + } + mutate(&cur) + e.liveState[id] = cur + return cur, true +} + func (e *Engine) checkByID(id int) { if !e.IsActive() { return @@ -567,111 +596,161 @@ func (e *Engine) checkPush(site models.Site) { } } -func (e *Engine) handleStatusChange(site models.Site, rawStatus string, code int, latency time.Duration, errorReason string) { +// handleStatusChange folds a check result into the live state. snap is the +// stale snapshot the check ran against; the actual mutation is applied onto the +// CURRENT live entry via applyState, so a concurrent pause / config edit / +// heartbeat is never reverted by this write. Logs and alerts are emitted after +// the lock is released, off the critical section. +func (e *Engine) handleStatusChange(snap models.Site, rawStatus string, code int, latency time.Duration, errorReason string) { if !e.IsActive() { return } - newState := site - newState.StatusCode = code - newState.LastError = errorReason + inMaint := e.isInMaintenance(snap.ID) - if rawStatus == "UP" { - newState.LastSuccessAt = time.Now() - newState.LastError = "" - } else { - newState.LastSuccessAt = site.LastSuccessAt + var ( + prev, next string + name, typ string + alertID int + failCount, maxRetries int + confirmedDown bool + failedCheck bool + downSince time.Time + sslWarnFire bool + sslDays int + skipped bool + changed bool + ) + + _, exists := e.applyState(snap.ID, func(s *models.Site) { + // A non-UP result computed from a stale snapshot must not override a + // heartbeat (or newer check) that landed while we were evaluating. + if rawStatus != "UP" && s.LastCheck.After(snap.LastCheck) { + skipped = true + return + } + + prev = s.Status + name = s.Name + typ = s.Type + alertID = s.AlertID + maxRetries = s.MaxRetries + downSince = s.StatusChangedAt + + // Fresh check results (measured by the run against snap). + s.StatusCode = code + s.Latency = snap.Latency + s.LastCheck = snap.LastCheck + s.HasSSL = snap.HasSSL + s.CertExpiry = snap.CertExpiry + s.LastError = errorReason + if rawStatus == "UP" { + s.LastSuccessAt = time.Now() + s.LastError = "" + } + + // Status + failure-count transition, based on the CURRENT live status. + switch { + case prev == "UP" && rawStatus != "UP": + s.FailureCount++ + if s.FailureCount > s.MaxRetries { + s.Status = rawStatus + s.FailureCount = s.MaxRetries + 1 + confirmedDown = true + } else { + failedCheck = true + } + case rawStatus == "UP": + s.FailureCount = 0 + s.Status = "UP" + default: + s.Status = rawStatus + s.FailureCount = s.MaxRetries + 1 + } + failCount = s.FailureCount + + if s.Status != prev && prev != "PENDING" { + s.StatusChangedAt = time.Now() + } else if s.StatusChangedAt.IsZero() && s.Status != "PENDING" { + s.StatusChangedAt = time.Now() + } + + // SSL expiry warning (fresh HasSSL/CertExpiry + config threshold). + if typ == "http" && s.CheckSSL && s.HasSSL { + days := int(time.Until(s.CertExpiry).Hours() / 24) + if days <= s.ExpiryThreshold && !s.SentSSLWarning && rawStatus != "SSL EXP" { + sslWarnFire = true + sslDays = days + s.SentSSLWarning = true + } else if days > s.ExpiryThreshold { + s.SentSSLWarning = false + } + } + + next = s.Status + changed = next != prev + }) + + if !exists || skipped { + return } - if site.Status == "UP" && rawStatus != "UP" { - newState.FailureCount++ - if newState.FailureCount > site.MaxRetries { - newState.Status = rawStatus - newState.FailureCount = site.MaxRetries + 1 - if errorReason != "" { - e.AddLog(fmt.Sprintf("Monitor '%s' confirmed DOWN: %s", site.Name, errorReason)) - } else { - e.AddLog(fmt.Sprintf("Monitor '%s' confirmed DOWN", site.Name)) - } + e.recordCheck(snap.ID, latency, rawStatus == "UP") + + if confirmedDown { + if errorReason != "" { + e.AddLog(fmt.Sprintf("Monitor '%s' confirmed DOWN: %s", name, errorReason)) } else { - e.AddLog(fmt.Sprintf("Monitor '%s' failed check %d/%d", site.Name, newState.FailureCount, site.MaxRetries)) + e.AddLog(fmt.Sprintf("Monitor '%s' confirmed DOWN", name)) } - } else if rawStatus == "UP" { - newState.FailureCount = 0 - newState.Status = "UP" - } else { - newState.Status = rawStatus - newState.FailureCount = site.MaxRetries + 1 + } else if failedCheck { + e.AddLog(fmt.Sprintf("Monitor '%s' failed check %d/%d", name, failCount, maxRetries)) } - if newState.Status != site.Status && site.Status != "PENDING" { - newState.StatusChangedAt = time.Now() - } else if site.StatusChangedAt.IsZero() && newState.Status != "PENDING" { - newState.StatusChangedAt = time.Now() - } else { - newState.StatusChangedAt = site.StatusChangedAt + if changed && prev != "PENDING" { + go func() { _ = e.db.SaveStateChange(snap.ID, prev, next, errorReason) }() } - inMaint := e.isInMaintenance(site.ID) - - if site.Type == "http" && site.CheckSSL && site.HasSSL { - daysLeft := int(time.Until(site.CertExpiry).Hours() / 24) - if daysLeft <= site.ExpiryThreshold && !site.SentSSLWarning && rawStatus != "SSL EXP" { - if !inMaint { - e.triggerAlert(site.AlertID, "SSL WARNING", fmt.Sprintf("SSL for '%s' expires in %d days", site.Name, daysLeft)) - } else { - e.AddLog(fmt.Sprintf("SSL warning for '%s' suppressed (maintenance)", site.Name)) - } - newState.SentSSLWarning = true - } else if daysLeft > site.ExpiryThreshold { - newState.SentSSLWarning = false + if sslWarnFire { + if !inMaint { + e.triggerAlert(alertID, "SSL WARNING", fmt.Sprintf("SSL for '%s' expires in %d days", name, sslDays)) + } else { + e.AddLog(fmt.Sprintf("SSL warning for '%s' suppressed (maintenance)", name)) } } - e.mu.Lock() - if _, ok := e.liveState[site.ID]; ok { - e.liveState[site.ID] = newState - } - e.mu.Unlock() - - e.recordCheck(site.ID, latency, rawStatus == "UP") - - if newState.Status != site.Status && site.Status != "PENDING" { - go func() { _ = e.db.SaveStateChange(site.ID, site.Status, newState.Status, errorReason) }() - } - isBroken := func(s string) bool { return s == "DOWN" || s == "SSL EXP" } - if site.Status == "UP" && newState.Status == "LATE" { - e.AddLog(fmt.Sprintf("Monitor '%s' heartbeat overdue", site.Name)) + if prev == "UP" && next == "LATE" { + e.AddLog(fmt.Sprintf("Monitor '%s' heartbeat overdue", name)) } - if !isBroken(site.Status) && isBroken(newState.Status) && newState.Status != "PENDING" { + if !isBroken(prev) && isBroken(next) && next != "PENDING" { if inMaint { - e.AddLog(fmt.Sprintf("Monitor '%s' is DOWN (alerts suppressed — maintenance)", site.Name)) + e.AddLog(fmt.Sprintf("Monitor '%s' is DOWN (alerts suppressed — maintenance)", name)) } else { - msg := fmt.Sprintf("Monitor '%s' is DOWN (%s)", site.Name, rawStatus) + msg := fmt.Sprintf("Monitor '%s' is DOWN (%s)", name, rawStatus) if errorReason != "" { - msg = fmt.Sprintf("Monitor '%s' is DOWN: %s", site.Name, errorReason) + msg = fmt.Sprintf("Monitor '%s' is DOWN: %s", name, errorReason) } - if site.Type == "push" { - msg = fmt.Sprintf("Push Monitor '%s' missed heartbeat.", site.Name) + if typ == "push" { + msg = fmt.Sprintf("Push Monitor '%s' missed heartbeat.", name) } - e.triggerAlert(site.AlertID, "🚨 ALERT", msg) + e.triggerAlert(alertID, "🚨 ALERT", msg) } } - if isBroken(site.Status) && newState.Status == "UP" { + if isBroken(prev) && next == "UP" { downDur := "" - if !site.StatusChangedAt.IsZero() { - downDur = fmt.Sprintf(" (was down %s)", fmtDurationShort(time.Since(site.StatusChangedAt))) + if !downSince.IsZero() { + downDur = fmt.Sprintf(" (was down %s)", fmtDurationShort(time.Since(downSince))) } - e.AddLog(fmt.Sprintf("Monitor '%s' recovered%s", site.Name, downDur)) + e.AddLog(fmt.Sprintf("Monitor '%s' recovered%s", name, downDur)) if !inMaint { - e.triggerAlert(site.AlertID, "✅ RECOVERY", fmt.Sprintf("Monitor '%s' is UP%s", site.Name, downDur)) + e.triggerAlert(alertID, "✅ RECOVERY", fmt.Sprintf("Monitor '%s' is UP%s", name, downDur)) } } - if site.Status == "LATE" && newState.Status == "UP" && !isBroken(site.Status) { - e.AddLog(fmt.Sprintf("Monitor '%s' heartbeat arrived (was late)", site.Name)) + if prev == "LATE" && next == "UP" && !isBroken(prev) { + e.AddLog(fmt.Sprintf("Monitor '%s' heartbeat arrived (was late)", name)) } } @@ -801,14 +880,12 @@ func (e *Engine) checkGroup(site models.Site) { status = "PENDING" } - e.mu.Lock() - s := e.liveState[site.ID] - s.Status = status - if hasChildren && allPaused { - s.Paused = true - } - e.liveState[site.ID] = s - e.mu.Unlock() + e.applyState(site.ID, func(s *models.Site) { + s.Status = status + if hasChildren && allPaused { + s.Paused = true + } + }) } func (e *Engine) SetAggStrategy(strategy AggregationStrategy) { diff --git a/internal/monitor/monitor_test.go b/internal/monitor/monitor_test.go index 45d5a96..39d2afd 100644 --- a/internal/monitor/monitor_test.go +++ b/internal/monitor/monitor_test.go @@ -1077,6 +1077,96 @@ func TestConcurrent_RecordCheckAndGetHistory(t *testing.T) { } } +// --- Group 10: liveState merge (lost-update race) --- + +// A pause that lands while a check is in flight must survive the check's +// write-back. The old code snapshotted the site, ran the check, then wrote the +// whole stale struct back — reverting the pause. +func TestHandleStatusChange_PauseDuringCheckSurvives(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", Status: "UP", MaxRetries: 0} + injectSite(e, site) + + // `site` is the stale snapshot the check ran against (Paused=false). + // Meanwhile the user pauses the monitor. + e.ToggleSitePause(1) + + // Check completes and folds its result in using the stale snapshot. + e.handleStatusChange(site, "DOWN", 500, 0, "boom") + + s, _ := getSite(e, 1) + if !s.Paused { + t.Error("pause was reverted by a stale check write-back") + } + if s.Status != "DOWN" { + t.Errorf("expected check result still applied (DOWN), got %s", s.Status) + } +} + +// A config edit that lands while a check is in flight must survive; the check +// must not resurrect the old config from its snapshot. +func TestHandleStatusChange_ConfigEditDuringCheckSurvives(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", URL: "http://old.com", Type: "http", Status: "UP", MaxRetries: 0, Interval: 30} + injectSite(e, site) + + // Config changes mid-check. + e.UpdateSiteConfig(models.Site{ID: 1, Name: "test", URL: "http://new.com", Type: "http", Interval: 60}) + + // Stale check (ran against http://old.com) folds its result in. + e.handleStatusChange(site, "UP", 200, 5*time.Millisecond, "") + + s, _ := getSite(e, 1) + if s.URL != "http://new.com" { + t.Errorf("config edit reverted: URL=%s", s.URL) + } + if s.Interval != 60 { + t.Errorf("config edit reverted: Interval=%d", s.Interval) + } +} + +// The classic push false-DOWN: a heartbeat marks the monitor UP while a +// staleness evaluation (computed from the older LastCheck) is mid-flight. +// The stale DOWN must not overwrite the fresh heartbeat. +func TestHandleStatusChange_HeartbeatNotOverwrittenByStaleDown(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + // Snapshot the engine would have taken before evaluating staleness: + // LastCheck is old, so checkPush decided "DOWN". + snap := models.Site{ID: 1, Name: "push", Type: "push", Token: "tok", Status: "UP", Interval: 10, LastCheck: time.Now().Add(-120 * time.Second)} + injectSite(e, snap) + + // A heartbeat lands first, advancing LastCheck and confirming UP. + if !e.RecordHeartbeat("tok") { + t.Fatal("heartbeat rejected") + } + + // Now the in-flight stale evaluation tries to write DOWN. + e.handleStatusChange(snap, "DOWN", 0, 0, "heartbeat missed") + + s, _ := getSite(e, 1) + if s.Status != "UP" { + t.Errorf("stale DOWN overwrote a fresh heartbeat: status=%s", s.Status) + } +} + +// A check result for a site removed mid-check must be dropped, not recreate it. +func TestHandleStatusChange_RemovedSiteDropped(t *testing.T) { + ms := newMockStore() + e := newTestEngine(ms) + site := models.Site{ID: 1, Name: "test", Status: "UP", MaxRetries: 0} + injectSite(e, site) + + e.RemoveSite(1) + e.handleStatusChange(site, "DOWN", 500, 0, "boom") + + if _, ok := getSite(e, 1); ok { + t.Error("removed site was recreated by a late check write-back") + } +} + // --- Utilities --- func containsStr(s, substr string) bool {