diff --git a/internal/monitor/checker.go b/internal/monitor/checker.go index 11eb9d5..37bd653 100644 --- a/internal/monitor/checker.go +++ b/internal/monitor/checker.go @@ -37,19 +37,26 @@ type CheckResult struct { } func RunCheck(ctx context.Context, site models.SiteConfig, strict, insecure *http.Client, globalInsecure, allowPrivate bool) CheckResult { + // Resolve + validate once for non-HTTP types to prevent DNS-rebind TOCTOU: + // a second resolve in the check function could return a different (private) IP. + // HTTP is safe — SafeDialContext resolves and validates at dial time. + var pinnedIP net.IP if site.Type != "http" && site.Type != "dns" && !allowPrivate { host := site.Hostname if host == "" { host = site.URL } if host != "" { - if ips, err := net.LookupIP(host); err == nil { - for _, ip := range ips { - if isPrivateIP(ip) { - return CheckResult{SiteID: site.ID, Status: string(models.StatusDown), ErrorReason: "target resolves to private IP"} - } + ips, err := net.LookupIP(host) + if err != nil { + return CheckResult{SiteID: site.ID, Status: string(models.StatusDown), ErrorReason: "resolve failed: " + err.Error()} + } + for _, ip := range ips { + if isPrivateIP(ip) { + return CheckResult{SiteID: site.ID, Status: string(models.StatusDown), ErrorReason: "target resolves to private IP"} } } + pinnedIP = ips[0] } } @@ -57,9 +64,9 @@ func RunCheck(ctx context.Context, site models.SiteConfig, strict, insecure *htt case "http": return runHTTPCheck(ctx, site, strict, insecure, globalInsecure) case "ping": - return runPingCheck(ctx, site) + return runPingCheck(ctx, site, pinnedIP) case "port": - return runPortCheck(ctx, site) + return runPortCheck(ctx, site, pinnedIP) case "dns": return runDNSCheck(ctx, site, allowPrivate) default: @@ -130,7 +137,7 @@ func runHTTPCheck(ctx context.Context, site models.SiteConfig, strict, insecure return result } -func runPingCheck(_ context.Context, site models.SiteConfig) CheckResult { +func runPingCheck(_ context.Context, site models.SiteConfig, pinnedIP net.IP) CheckResult { host := site.Hostname if host == "" { host = site.URL @@ -140,6 +147,9 @@ func runPingCheck(_ context.Context, site models.SiteConfig) CheckResult { if err != nil { return CheckResult{SiteID: site.ID, Status: string(models.StatusDown), ErrorReason: "ping setup: " + err.Error()} } + if pinnedIP != nil { + pinger.SetIPAddr(&net.IPAddr{IP: pinnedIP}) + } pinger.Count = 1 pinger.Timeout = siteTimeout(site) pinger.SetPrivileged(false) @@ -159,11 +169,14 @@ func runPingCheck(_ context.Context, site models.SiteConfig) CheckResult { return CheckResult{SiteID: site.ID, Status: string(models.StatusUp), LatencyNs: stats.AvgRtt.Nanoseconds()} } -func runPortCheck(_ context.Context, site models.SiteConfig) CheckResult { +func runPortCheck(_ context.Context, site models.SiteConfig, pinnedIP net.IP) CheckResult { host := site.Hostname if host == "" { host = site.URL } + if pinnedIP != nil { + host = pinnedIP.String() + } addr := net.JoinHostPort(host, strconv.Itoa(site.Port)) timeout := siteTimeout(site) diff --git a/internal/monitor/checker_test.go b/internal/monitor/checker_test.go index 9233519..558b404 100644 --- a/internal/monitor/checker_test.go +++ b/internal/monitor/checker_test.go @@ -161,6 +161,43 @@ func TestRunCheck_Port_Closed(t *testing.T) { } } +func TestRunPortCheck_UsesPinnedIP(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + _, portStr, _ := net.SplitHostPort(ln.Addr().String()) + port, _ := strconv.Atoi(portStr) + + // Pass a pinned IP — runPortCheck should dial it instead of resolving Hostname. + site := models.SiteConfig{ID: 1, Type: "port", Hostname: "will-not-resolve.invalid", Port: port, Timeout: 2} + result := runPortCheck(context.Background(), site, net.ParseIP("127.0.0.1")) + + if result.Status != "UP" { + t.Errorf("expected UP when pinned IP used, got %s: %s", result.Status, result.ErrorReason) + } +} + +func TestRunPortCheck_NilPinnedIP_UsesHostname(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + _, portStr, _ := net.SplitHostPort(ln.Addr().String()) + port, _ := strconv.Atoi(portStr) + + site := models.SiteConfig{ID: 1, Type: "port", Hostname: "127.0.0.1", Port: port, Timeout: 2} + result := runPortCheck(context.Background(), site, nil) + + if result.Status != "UP" { + t.Errorf("expected UP with nil pinnedIP fallback, got %s: %s", result.Status, result.ErrorReason) + } +} + func TestRunCheck_Port_BlocksPrivateByDefault(t *testing.T) { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil {