diff --git a/cmd/uptop/main.go b/cmd/uptop/main.go index 5afc488..9a2977a 100644 --- a/cmd/uptop/main.go +++ b/cmd/uptop/main.go @@ -385,13 +385,14 @@ func runServe(args []string) { tlsKey := os.Getenv("UPTOP_TLS_KEY") httpSrv := server.Start(server.ServerConfig{ - Port: httpPort, - EnableStatus: enableStatus, - Title: statusTitle, - ClusterKey: clusterKey, - TLSCert: tlsCert, - TLSKey: tlsKey, - ClusterMode: clusterMode, + Port: httpPort, + EnableStatus: enableStatus, + Title: statusTitle, + ClusterKey: clusterKey, + TLSCert: tlsCert, + TLSKey: tlsKey, + ClusterMode: clusterMode, + MetricsPublic: os.Getenv("UPTOP_METRICS_PUBLIC") == "true", }, s, eng) cluster.Start(ctx, cluster.Config{ diff --git a/internal/alert/alert.go b/internal/alert/alert.go index 6a9a235..2ca4f47 100644 --- a/internal/alert/alert.go +++ b/internal/alert/alert.go @@ -25,6 +25,7 @@ type PayloadFunc func(title, message string) ([]byte, error) type HTTPProvider struct { URL string Payload PayloadFunc + Headers map[string]string } func (h *HTTPProvider) Send(ctx context.Context, title, message string) error { @@ -37,6 +38,9 @@ func (h *HTTPProvider) Send(ctx context.Context, title, message string) error { return err } req.Header.Set("Content-Type", "application/json") + for k, v := range h.Headers { + req.Header.Set(k, v) + } resp, err := alertClient.Do(req) if err != nil { return err @@ -165,8 +169,9 @@ func GetProvider(cfg models.AlertConfig) Provider { } serverURL := strings.TrimRight(cfg.Settings["url"], "/") return &HTTPProvider{ - URL: fmt.Sprintf("%s/message?token=%s", serverURL, cfg.Settings["token"]), + URL: serverURL + "/message", Payload: gotifyPayload(priority), + Headers: map[string]string{"X-Gotify-Key": cfg.Settings["token"]}, } default: return nil diff --git a/internal/server/ratelimit.go b/internal/server/ratelimit.go new file mode 100644 index 0000000..b8f82f7 --- /dev/null +++ b/internal/server/ratelimit.go @@ -0,0 +1,91 @@ +package server + +import ( + "net" + "net/http" + "sync" + "time" +) + +type visitor struct { + tokens float64 + lastSeen time.Time +} + +type RateLimiter struct { + mu sync.Mutex + visitors map[string]*visitor + rate float64 + burst float64 +} + +func NewRateLimiter(requestsPerMinute int) *RateLimiter { + rl := &RateLimiter{ + visitors: make(map[string]*visitor), + rate: float64(requestsPerMinute) / 60.0, + burst: float64(requestsPerMinute), + } + go rl.cleanup() + return rl +} + +func (rl *RateLimiter) Allow(ip string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + v, exists := rl.visitors[ip] + now := time.Now() + + if !exists { + rl.visitors[ip] = &visitor{tokens: rl.burst - 1, lastSeen: now} + return true + } + + elapsed := now.Sub(v.lastSeen).Seconds() + v.tokens += elapsed * rl.rate + if v.tokens > rl.burst { + v.tokens = rl.burst + } + v.lastSeen = now + + if v.tokens < 1 { + return false + } + v.tokens-- + return true +} + +func (rl *RateLimiter) cleanup() { + for { + time.Sleep(5 * time.Minute) + rl.mu.Lock() + cutoff := time.Now().Add(-10 * time.Minute) + for ip, v := range rl.visitors { + if v.lastSeen.Before(cutoff) { + delete(rl.visitors, ip) + } + } + rl.mu.Unlock() + } +} + +func clientIP(r *http.Request) string { + if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" { + return fwd + } + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} + +func RateLimit(limiter *RateLimiter, next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if !limiter.Allow(clientIP(r)) { + http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) + return + } + next(w, r) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index b75f5cd..63e53df 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -22,6 +22,31 @@ func checkSecret(got, want string) bool { return subtle.ConstantTimeCompare([]byte(got), []byte(want)) == 1 } +func extractBearerToken(r *http.Request) string { + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + return strings.TrimPrefix(auth, "Bearer ") + } + return "" +} + +var sensitiveKeys = map[string]bool{ + "pass": true, "password": true, "token": true, + "routing_key": true, "user": true, "username": true, +} + +func redactSettings(settings map[string]string) map[string]string { + redacted := make(map[string]string, len(settings)) + for k, v := range settings { + if sensitiveKeys[k] && v != "" { + redacted[k] = "***REDACTED***" + } else { + redacted[k] = v + } + } + return redacted +} + var statusTpl = template.Must(template.New("status").Parse(` @@ -154,24 +179,37 @@ var statusTpl = template.Must(template.New("status").Parse(` `)) type ServerConfig struct { - Port int - EnableStatus bool - Title string - ClusterKey string - TLSCert string - TLSKey string - ClusterMode string + Port int + EnableStatus bool + Title string + ClusterKey string + TLSCert string + TLSKey string + ClusterMode string + MetricsPublic bool } func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { if cfg.ClusterKey == "" { fmt.Println("WARNING: No UPTOP_CLUSTER_SECRET set. Cluster API endpoints are unauthenticated.") } + + pushRL := NewRateLimiter(60) + probeRL := NewRateLimiter(30) + backupRL := NewRateLimiter(10) + statusRL := NewRateLimiter(120) + mux := http.NewServeMux() // 1. Push Heartbeat - mux.HandleFunc("/api/push", func(w http.ResponseWriter, r *http.Request) { - token := r.URL.Query().Get("token") + mux.HandleFunc("/api/push", RateLimit(pushRL, func(w http.ResponseWriter, r *http.Request) { + token := extractBearerToken(r) + if token == "" { + if qt := r.URL.Query().Get("token"); qt != "" { + token = qt + log.Printf("DEPRECATED: push token in query string — use Authorization: Bearer header instead") + } + } if token == "" { http.Error(w, "Missing token", http.StatusBadRequest) return @@ -182,7 +220,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { } else { http.Error(w, "Invalid Token", http.StatusNotFound) } - }) + })) // 2. Health Check (For Cluster Follower) mux.HandleFunc("/api/health", func(w http.ResponseWriter, r *http.Request) { @@ -195,7 +233,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { }) // 3. Config Export - mux.HandleFunc("/api/backup/export", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/backup/export", RateLimit(backupRL, func(w http.ResponseWriter, r *http.Request) { if cfg.ClusterKey == "" || !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) { http.Error(w, "Unauthorized: UPTOP_CLUSTER_SECRET required", http.StatusUnauthorized) return @@ -206,11 +244,16 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { http.Error(w, "Export failed", http.StatusInternalServerError) return } + if r.URL.Query().Get("redact_secrets") != "false" { + for i := range data.Alerts { + data.Alerts[i].Settings = redactSettings(data.Alerts[i].Settings) + } + } _ = json.NewEncoder(w).Encode(data) //nolint:errcheck - }) + })) // 4. Config Import - mux.HandleFunc("/api/backup/import", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/backup/import", RateLimit(backupRL, func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "POST required", http.StatusMethodNotAllowed) return @@ -231,10 +274,10 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { return } _, _ = w.Write([]byte("Import Successful")) - }) + })) // 5. Kuma Import - mux.HandleFunc("/api/import/kuma", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/import/kuma", RateLimit(backupRL, func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "POST required", http.StatusMethodNotAllowed) return @@ -257,10 +300,10 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { return } fmt.Fprintf(w, "Imported %d monitors, %d alerts from Kuma v%s", len(backup.Sites), len(backup.Alerts), kb.Version) - }) + })) // 6. Probe Registration - mux.HandleFunc("/api/probe/register", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/probe/register", RateLimit(probeRL, func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "POST required", http.StatusMethodNotAllowed) return @@ -292,10 +335,10 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { return } _ = json.NewEncoder(w).Encode(map[string]bool{"ok": true}) //nolint:errcheck - }) + })) // 7. Probe Assignment Fetch - mux.HandleFunc("/api/probe/assignments", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/probe/assignments", RateLimit(probeRL, func(w http.ResponseWriter, r *http.Request) { if cfg.ClusterKey == "" || !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) { http.Error(w, "Unauthorized", http.StatusUnauthorized) return @@ -329,10 +372,10 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string][]models.Site{"sites": assigned}) //nolint:errcheck - }) + })) // 8. Probe Result Submission - mux.HandleFunc("/api/probe/results", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/probe/results", RateLimit(probeRL, func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "POST required", http.StatusMethodNotAllowed) return @@ -368,15 +411,23 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { log.Printf("Failed to update node last seen: %v", err) } _ = json.NewEncoder(w).Encode(map[string]bool{"ok": true}) //nolint:errcheck - }) + })) // 9. Prometheus Metrics - mux.HandleFunc("/metrics", metrics.Handler(eng)) + mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) { + if !cfg.MetricsPublic && cfg.ClusterKey != "" { + if !checkSecret(r.Header.Get("X-Upkeep-Secret"), cfg.ClusterKey) { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + } + metrics.Handler(eng)(w, r) + }) // 10. Status Page if cfg.EnableStatus { - mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) { renderStatusPage(w, cfg.Title, eng) }) - mux.HandleFunc("/status/json", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/status", RateLimit(statusRL, func(w http.ResponseWriter, r *http.Request) { renderStatusPage(w, cfg.Title, eng) })) + mux.HandleFunc("/status/json", RateLimit(statusRL, func(w http.ResponseWriter, r *http.Request) { state := eng.GetLiveState() activeWindows, _ := s.GetActiveMaintenanceWindows() maintSet := make(map[int]bool) @@ -400,7 +451,7 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(state) //nolint:errcheck - }) + })) } if cfg.ClusterMode != "" && cfg.ClusterMode != "leader" && cfg.TLSCert == "" { @@ -413,7 +464,14 @@ func Start(cfg ServerConfig, s store.Store, eng *monitor.Engine) *http.Server { } addr := fmt.Sprintf(":%d", cfg.Port) - srv := &http.Server{Addr: addr, Handler: handler, ReadHeaderTimeout: 10 * time.Second} + srv := &http.Server{ + Addr: addr, + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + WriteTimeout: 60 * time.Second, + IdleTimeout: 120 * time.Second, + } go func() { if cfg.TLSCert != "" && cfg.TLSKey != "" { fmt.Printf("HTTPS Server listening on %s\n", addr) diff --git a/internal/store/dialect.go b/internal/store/dialect.go index f6e35b2..4a4f8e8 100644 --- a/internal/store/dialect.go +++ b/internal/store/dialect.go @@ -1,6 +1,9 @@ package store -import "database/sql" +import ( + "database/sql" + "strconv" +) type Dialect interface { DriverName() string @@ -13,8 +16,6 @@ type Dialect interface { UpsertNodeSQL() string } -// rewritePlaceholders converts ? markers to $1, $2, etc. for Postgres. -// For SQLite (or any dialect not needing rewrite), returns the input unchanged. func rewritePlaceholders(query string, dollarStyle bool) string { if !dollarStyle { return query @@ -25,10 +26,7 @@ func rewritePlaceholders(query string, dollarStyle bool) string { if query[i] == '?' { n++ buf = append(buf, '$') - if n >= 10 { - buf = append(buf, byte('0'+n/10)) - } - buf = append(buf, byte('0'+n%10)) + buf = append(buf, []byte(strconv.Itoa(n))...) } else { buf = append(buf, query[i]) } diff --git a/internal/store/sqlstore.go b/internal/store/sqlstore.go index f3368e3..78f3e23 100644 --- a/internal/store/sqlstore.go +++ b/internal/store/sqlstore.go @@ -195,25 +195,47 @@ func (s *SQLStore) GetAlertByName(name string) (models.AlertConfig, error) { } func (s *SQLStore) AddSiteReturningID(site models.Site) (int, error) { - if err := s.AddSite(site); err != nil { - return 0, err + token := "" + if site.Type == "push" { + var err error + token, err = generateToken() + if err != nil { + return 0, fmt.Errorf("generate push token: %w", err) + } } - created, err := s.GetSiteByName(site.Name) + if s.dollar { + var id int + err := s.db.QueryRow(s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) RETURNING id"), + site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, + site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.Regions).Scan(&id) + return id, err + } + result, err := s.db.Exec(s.q("INSERT INTO sites (name, url, type, token, interval, alert_id, check_ssl, threshold, max_retries, hostname, port, timeout, method, description, parent_id, accepted_codes, dns_resolve_type, dns_server, ignore_tls, paused, regions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"), + site.Name, site.URL, site.Type, token, site.Interval, site.AlertID, site.CheckSSL, site.ExpiryThreshold, site.MaxRetries, + site.Hostname, site.Port, site.Timeout, site.Method, site.Description, site.ParentID, site.AcceptedCodes, site.DNSResolveType, site.DNSServer, site.IgnoreTLS, site.Paused, site.Regions) if err != nil { return 0, err } - return created.ID, nil + id, err := result.LastInsertId() + return int(id), err } func (s *SQLStore) AddAlertReturningID(name, aType string, settings map[string]string) (int, error) { - if err := s.AddAlert(name, aType, settings); err != nil { - return 0, err - } - created, err := s.GetAlertByName(name) + stored, err := s.marshalSettings(settings) if err != nil { return 0, err } - return created.ID, nil + if s.dollar { + var id int + err := s.db.QueryRow(s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?) RETURNING id"), name, aType, stored).Scan(&id) + return id, err + } + result, err := s.db.Exec(s.q("INSERT INTO alerts (name, type, settings) VALUES (?, ?, ?)"), name, aType, stored) + if err != nil { + return 0, err + } + id, err := result.LastInsertId() + return int(id), err } func (s *SQLStore) GetAllAlerts() ([]models.AlertConfig, error) {