From e09919b6794e107933b6fcb22dedf82359e9de10 Mon Sep 17 00:00:00 2001 From: Tyler Koenig Date: Tue, 19 May 2026 18:30:17 -0400 Subject: [PATCH] fix: harden API, DB schema, and CLI safety - Add 'reminder' to glyph CHECK constraint (was accepted by parser but rejected by DB) - Default serve bind to 127.0.0.1, add --host flag for LAN access - Validate card_data as JSON in Store.Create/Update/Promote - Return pagination envelope {data,total,limit,offset} from list endpoint - Append absorb breadcrumb to source entity before soft-delete - Add Levenshtein fuzzy match to catch command typos before routing to add - Replace DDL string-matching migrations with versioned schema_version table - Update web UI and API tests for envelope response format --- cmd/root.go | 55 +++++++++++++++ cmd/serve.go | 8 ++- go.mod | 6 +- go.sum | 2 - internal/api/api_test.go | 35 +++++---- internal/api/entities.go | 32 ++++++++- internal/db/db.go | 149 +++++++++++++++++++++++++-------------- internal/db/entities.go | 33 +++++++-- web/app.js | 12 ++-- 9 files changed, 243 insertions(+), 89 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 7424fac..067fa09 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,6 +1,7 @@ package cmd import ( + "fmt" "os" "strings" @@ -26,6 +27,10 @@ func Execute() error { isFlag := strings.HasPrefix(first, "-") && !strings.Contains(first, " ") if first != "help" && first != "completion" && !isFlag && !isSubcommand(first) { + if near := nearSubcommand(first); near != "" { + fmt.Fprintf(os.Stderr, "unknown command %q — did you mean %q?\n", first, near) + os.Exit(1) + } // "--" stops cobra from parsing glyph prefixes like "-" as flags rootCmd.SetArgs(append([]string{"add", "--"}, os.Args[1:]...)) } @@ -47,6 +52,56 @@ func isSubcommand(name string) bool { return false } +func nearSubcommand(name string) string { + for _, c := range rootCmd.Commands() { + if d := editDist(name, c.Name()); d > 0 && d <= 2 { + return c.Name() + } + for _, alias := range c.Aliases { + if d := editDist(name, alias); d > 0 && d <= 2 { + return alias + } + } + } + return "" +} + +func editDist(a, b string) int { + la, lb := len(a), len(b) + if la == 0 { + return lb + } + if lb == 0 { + return la + } + prev := make([]int, lb+1) + for j := range prev { + prev[j] = j + } + for i := 1; i <= la; i++ { + curr := make([]int, lb+1) + curr[0] = i + for j := 1; j <= lb; j++ { + cost := 1 + if a[i-1] == b[j-1] { + cost = 0 + } + ins := curr[j-1] + 1 + del := prev[j] + 1 + sub := prev[j-1] + cost + curr[j] = ins + if del < curr[j] { + curr[j] = del + } + if sub < curr[j] { + curr[j] = sub + } + } + prev = curr + } + return prev[lb] +} + func init() { rootCmd.AddCommand(addCmd) rootCmd.AddCommand(lsCmd) diff --git a/cmd/serve.go b/cmd/serve.go index 157548a..2fabf53 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -19,6 +19,7 @@ var WebFS fs.FS var ( servePort int + serveHost string serveDev bool tlsCert string tlsKey string @@ -32,6 +33,7 @@ var serveCmd = &cobra.Command{ func init() { serveCmd.Flags().IntVar(&servePort, "port", 0, "port to listen on (default 4444, or 4443 with TLS)") + serveCmd.Flags().StringVar(&serveHost, "host", "127.0.0.1", "address to bind to (default localhost only)") serveCmd.Flags().BoolVar(&serveDev, "dev", false, "enable CORS for development") serveCmd.Flags().StringVar(&tlsCert, "tls-cert", "", "path to TLS certificate file") serveCmd.Flags().StringVar(&tlsKey, "tls-key", "", "path to TLS private key file") @@ -70,7 +72,7 @@ func runServe(_ *cobra.Command, _ []string) error { router = api.NewRouter(store, serveDev, WebFS) } - addr := fmt.Sprintf(":%d", port) + addr := fmt.Sprintf("%s:%d", serveHost, port) srv := &http.Server{ Addr: addr, Handler: router, @@ -81,9 +83,9 @@ func runServe(_ *cobra.Command, _ []string) error { go func() { if useTLS { - fmt.Printf("nib serving on https://localhost%s\n", addr) + fmt.Printf("nib serving on https://%s\n", addr) } else { - fmt.Printf("nib serving on http://localhost%s\n", addr) + fmt.Printf("nib serving on http://%s\n", addr) } if serveDev { fmt.Println(" CORS enabled (dev mode)") diff --git a/go.mod b/go.mod index ec47185..58dcc91 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,9 @@ go 1.24.4 require ( github.com/atotto/clipboard v0.1.4 + github.com/charmbracelet/bubbles v1.0.0 + github.com/charmbracelet/bubbletea v1.3.10 + github.com/charmbracelet/lipgloss v1.1.0 github.com/go-chi/chi/v5 v5.2.5 github.com/oklog/ulid/v2 v2.1.1 github.com/spf13/cobra v1.10.2 @@ -12,10 +15,7 @@ require ( require ( github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect - github.com/charmbracelet/bubbles v1.0.0 // indirect - github.com/charmbracelet/bubbletea v1.3.10 // indirect github.com/charmbracelet/colorprofile v0.4.1 // indirect - github.com/charmbracelet/lipgloss v1.1.0 // indirect github.com/charmbracelet/x/ansi v0.11.6 // indirect github.com/charmbracelet/x/cellbuf v0.0.15 // indirect github.com/charmbracelet/x/term v0.2.2 // indirect diff --git a/go.sum b/go.sum index 696989a..7252a8d 100644 --- a/go.sum +++ b/go.sum @@ -74,8 +74,6 @@ golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 7f10aa1..397aeff 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -25,6 +25,20 @@ func testServer(t *testing.T) (*httptest.Server, *db.Store) { return srv, store } +type listEnvelope struct { + Data []EntityResponse `json:"data"` + Total int `json:"total"` + Limit int `json:"limit"` + Offset int `json:"offset"` +} + +func decodeList(t *testing.T, resp *http.Response) []EntityResponse { + t.Helper() + var env listEnvelope + json.NewDecoder(resp.Body).Decode(&env) + return env.Data +} + func postJSON(t *testing.T, srv *httptest.Server, path string, body any) *http.Response { t.Helper() b, err := json.Marshal(body) @@ -157,8 +171,7 @@ func TestListEntities_Default(t *testing.T) { } defer resp.Body.Close() - var entities []EntityResponse - json.NewDecoder(resp.Body).Decode(&entities) + entities := decodeList(t, resp) if len(entities) != 2 { t.Fatalf("expected 2, got %d", len(entities)) } @@ -175,8 +188,7 @@ func TestListEntities_FilterTag(t *testing.T) { } defer resp.Body.Close() - var entities []EntityResponse - json.NewDecoder(resp.Body).Decode(&entities) + entities := decodeList(t, resp) if len(entities) != 1 { t.Fatalf("expected 1, got %d", len(entities)) } @@ -198,8 +210,7 @@ func TestListEntities_CardsOnly(t *testing.T) { } defer resp.Body.Close() - var entities []EntityResponse - json.NewDecoder(resp.Body).Decode(&entities) + entities := decodeList(t, resp) if len(entities) != 1 { t.Fatalf("expected 1 card, got %d", len(entities)) } @@ -215,16 +226,14 @@ func TestListEntities_Pagination(t *testing.T) { if err != nil { t.Fatal(err) } - var page1 []EntityResponse - json.NewDecoder(resp.Body).Decode(&page1) + page1 := decodeList(t, resp) resp.Body.Close() resp, err = http.Get(srv.URL + "/api/entities?limit=2&offset=2") if err != nil { t.Fatal(err) } - var page2 []EntityResponse - json.NewDecoder(resp.Body).Decode(&page2) + page2 := decodeList(t, resp) resp.Body.Close() if len(page1) != 2 || len(page2) != 2 { @@ -517,8 +526,7 @@ func TestAbsorbEntity_Success(t *testing.T) { if err != nil { t.Fatal(err) } - var entities []EntityResponse - json.NewDecoder(listResp.Body).Decode(&entities) + entities := decodeList(t, listResp) listResp.Body.Close() for _, ent := range entities { if ent.ID == source.ID { @@ -686,8 +694,7 @@ func TestListEntities_TitleInResponse(t *testing.T) { } defer resp.Body.Close() - var entities []EntityResponse - json.NewDecoder(resp.Body).Decode(&entities) + entities := decodeList(t, resp) if len(entities) != 1 { t.Fatalf("expected 1, got %d", len(entities)) } diff --git a/internal/api/entities.go b/internal/api/entities.go index 6e78e2e..b4058d4 100644 --- a/internal/api/entities.go +++ b/internal/api/entities.go @@ -102,6 +102,15 @@ func listEntities(store *db.Store) http.HandlerFunc { } p.Offset = offset } + if p.Limit <= 0 { + p.Limit = 50 + } + + total, err := store.Count(p) + if err != nil { + writeInternalError(w, err) + return + } entities, err := store.List(p) if err != nil { @@ -109,11 +118,16 @@ func listEntities(store *db.Store) http.HandlerFunc { return } - resp := make([]EntityResponse, len(entities)) + items := make([]EntityResponse, len(entities)) for i, e := range entities { - resp[i] = entityToResponse(e) + items[i] = entityToResponse(e) } - writeJSON(w, http.StatusOK, resp) + writeJSON(w, http.StatusOK, map[string]any{ + "data": items, + "total": total, + "limit": p.Limit, + "offset": p.Offset, + }) } } @@ -161,6 +175,10 @@ func createEntity(store *db.Store) http.HandlerFunc { } if err := store.Create(e); err != nil { + if err == db.ErrInvalidCardData { + writeError(w, http.StatusBadRequest, "invalid_card_data", "card_data must be valid JSON") + return + } writeInternalError(w, err) return } @@ -227,6 +245,10 @@ func updateEntity(store *db.Store) http.HandlerFunc { writeError(w, http.StatusNotFound, "not_found", "no entity with id "+id) return } + if err == db.ErrInvalidCardData { + writeError(w, http.StatusBadRequest, "invalid_card_data", "card_data must be valid JSON") + return + } writeInternalError(w, err) return } @@ -291,6 +313,10 @@ func promoteEntity(store *db.Store) http.HandlerFunc { writeError(w, http.StatusBadRequest, "invalid_promote", "entity is already crystallized") return } + if err == db.ErrInvalidCardData { + writeError(w, http.StatusBadRequest, "invalid_card_data", "card_data must be valid JSON") + return + } writeInternalError(w, err) return } diff --git a/internal/db/db.go b/internal/db/db.go index fc389ea..982e072 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "path/filepath" - "strings" _ "modernc.org/sqlite" ) @@ -16,6 +15,7 @@ var ( ErrAlreadyPromoted = errors.New("invalid_promote") ErrAlreadyFluid = errors.New("invalid_demote") ErrTargetCrystallized = errors.New("invalid_absorb") + ErrInvalidCardData = errors.New("invalid_card_data") ) type Store struct { @@ -51,64 +51,65 @@ func (s *Store) Close() error { return s.db.Close() } -func (s *Store) migrate() error { - _, err := s.db.Exec(` - CREATE TABLE IF NOT EXISTS entities ( - id TEXT PRIMARY KEY, - created_at TEXT NOT NULL, - modified_at TEXT NOT NULL, - body TEXT NOT NULL, - glyph TEXT NOT NULL - CHECK (glyph IN ('todo', 'event', 'note')), - time_anchor TEXT, - completed_at TEXT, - pinned INTEGER NOT NULL DEFAULT 0, - deleted_at TEXT, - card_type TEXT - CHECK (card_type IN ('snippet', 'template', 'checklist', 'decision', 'link', 'note') - OR card_type IS NULL), - card_data TEXT, - use_count INTEGER NOT NULL DEFAULT 0, - last_used_at TEXT - ); +const currentSchema = 3 - CREATE TABLE IF NOT EXISTS entity_tags ( - entity_id TEXT NOT NULL REFERENCES entities(id) ON DELETE CASCADE, - tag TEXT NOT NULL, - PRIMARY KEY (entity_id, tag) - ); +var migrations = []func(db *sql.DB) error{ + // v1: initial schema + func(db *sql.DB) error { + _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS entities ( + id TEXT PRIMARY KEY, + created_at TEXT NOT NULL, + modified_at TEXT NOT NULL, + body TEXT NOT NULL, + glyph TEXT NOT NULL, + time_anchor TEXT, + completed_at TEXT, + pinned INTEGER NOT NULL DEFAULT 0, + deleted_at TEXT, + card_type TEXT, + card_data TEXT, + use_count INTEGER NOT NULL DEFAULT 0, + last_used_at TEXT + ); - CREATE INDEX IF NOT EXISTS idx_entities_created - ON entities(created_at DESC) WHERE deleted_at IS NULL; - CREATE INDEX IF NOT EXISTS idx_entities_card_use - ON entities(use_count DESC) - WHERE card_type IS NOT NULL AND deleted_at IS NULL; - CREATE INDEX IF NOT EXISTS idx_entity_tags_tag - ON entity_tags(tag); - `) - if err != nil { + CREATE TABLE IF NOT EXISTS entity_tags ( + entity_id TEXT NOT NULL REFERENCES entities(id) ON DELETE CASCADE, + tag TEXT NOT NULL, + PRIMARY KEY (entity_id, tag) + ); + + CREATE INDEX IF NOT EXISTS idx_entities_created + ON entities(created_at DESC) WHERE deleted_at IS NULL; + CREATE INDEX IF NOT EXISTS idx_entities_card_use + ON entities(use_count DESC) + WHERE card_type IS NOT NULL AND deleted_at IS NULL; + CREATE INDEX IF NOT EXISTS idx_entity_tags_tag + ON entity_tags(tag); + `) return err - } + }, - s.db.Exec(`ALTER TABLE entities ADD COLUMN title TEXT`) - s.db.Exec(`ALTER TABLE entities ADD COLUMN description TEXT`) + // v2: add title and description columns + func(db *sql.DB) error { + db.Exec(`ALTER TABLE entities ADD COLUMN title TEXT`) + db.Exec(`ALTER TABLE entities ADD COLUMN description TEXT`) + return nil + }, - // Migrate CHECK constraint to include 'note' card type - var needsMigrate bool - row := s.db.QueryRow(`SELECT sql FROM sqlite_master WHERE type='table' AND name='entities'`) - var ddl string - if row.Scan(&ddl) == nil { - hasNote := strings.Contains(ddl, "'link', 'note'") - hasModified := strings.Contains(ddl, "modified_at") - needsMigrate = !hasNote || !hasModified - } - if needsMigrate { - tx, err := s.db.Begin() + // v3: rebuild table with CHECK constraints (card_type 'note', glyph 'reminder') + func(db *sql.DB) error { + tx, err := db.Begin() if err != nil { return err } defer tx.Rollback() + // Disable FK checks during rebuild to avoid dangling references + if _, err := tx.Exec(`PRAGMA foreign_keys = OFF`); err != nil { + return fmt.Errorf("migrate fk off: %w", err) + } + if _, err := tx.Exec(`ALTER TABLE entities RENAME TO _entities_migrate`); err != nil { return fmt.Errorf("migrate rename: %w", err) } @@ -118,7 +119,7 @@ func (s *Store) migrate() error { modified_at TEXT NOT NULL, body TEXT NOT NULL, glyph TEXT NOT NULL - CHECK (glyph IN ('todo', 'event', 'note')), + CHECK (glyph IN ('todo', 'event', 'note', 'reminder')), time_anchor TEXT, completed_at TEXT, pinned INTEGER NOT NULL DEFAULT 0, @@ -140,12 +141,54 @@ func (s *Store) migrate() error { if _, err := tx.Exec(`DROP TABLE _entities_migrate`); err != nil { return fmt.Errorf("migrate drop: %w", err) } - if err := tx.Commit(); err != nil { - return fmt.Errorf("migrate commit: %w", err) + + // Rebuild entity_tags to point FK at new entities table + if _, err := tx.Exec(`ALTER TABLE entity_tags RENAME TO _tags_migrate`); err != nil { + return fmt.Errorf("migrate tags rename: %w", err) + } + if _, err := tx.Exec(`CREATE TABLE entity_tags ( + entity_id TEXT NOT NULL REFERENCES entities(id) ON DELETE CASCADE, + tag TEXT NOT NULL, + PRIMARY KEY (entity_id, tag) + )`); err != nil { + return fmt.Errorf("migrate tags create: %w", err) + } + if _, err := tx.Exec(`INSERT INTO entity_tags SELECT * FROM _tags_migrate`); err != nil { + return fmt.Errorf("migrate tags copy: %w", err) + } + if _, err := tx.Exec(`DROP TABLE _tags_migrate`); err != nil { + return fmt.Errorf("migrate tags drop: %w", err) + } + + if _, err := tx.Exec(`PRAGMA foreign_keys = ON`); err != nil { + return fmt.Errorf("migrate fk on: %w", err) + } + + return tx.Commit() + }, +} + +func (s *Store) migrate() error { + s.db.Exec(`CREATE TABLE IF NOT EXISTS schema_version (version INTEGER NOT NULL)`) + + var version int + err := s.db.QueryRow(`SELECT version FROM schema_version`).Scan(&version) + if err != nil { + version = 0 + } + + for i := version; i < len(migrations); i++ { + if err := migrations[i](s.db); err != nil { + return fmt.Errorf("migration %d: %w", i+1, err) } } - return nil + if version == 0 { + _, err = s.db.Exec(`INSERT INTO schema_version (version) VALUES (?)`, len(migrations)) + } else if len(migrations) > version { + _, err = s.db.Exec(`UPDATE schema_version SET version = ?`, len(migrations)) + } + return err } func DefaultPath() (string, error) { diff --git a/internal/db/entities.go b/internal/db/entities.go index d73f654..79dc615 100644 --- a/internal/db/entities.go +++ b/internal/db/entities.go @@ -104,6 +104,9 @@ type EntityUpdate struct { } func (s *Store) Create(e *Entity) error { + if e.CardData != nil && !json.Valid([]byte(*e.CardData)) { + return ErrInvalidCardData + } now := time.Now().UTC() e.ID = nibulid.New() e.CreatedAt = now @@ -179,7 +182,7 @@ func (s *Store) Get(id string) (*Entity, error) { return e, nil } -func (s *Store) List(params ListParams) ([]*Entity, error) { +func listWhere(params ListParams) (string, []any) { var where []string var args []any @@ -214,10 +217,23 @@ func (s *Store) List(params ListParams) ([]*Entity, error) { args = append(args, string(*params.CardTypeFilter)) } - whereClause := "" + clause := "" if len(where) > 0 { - whereClause = "WHERE " + strings.Join(where, " AND ") + clause = "WHERE " + strings.Join(where, " AND ") } + return clause, args +} + +func (s *Store) Count(params ListParams) (int, error) { + whereClause, args := listWhere(params) + query := fmt.Sprintf("SELECT COUNT(*) FROM entities e %s", whereClause) + var count int + err := s.db.QueryRow(query, args...).Scan(&count) + return count, err +} + +func (s *Store) List(params ListParams) ([]*Entity, error) { + whereClause, args := listWhere(params) orderCol := "e.created_at" switch params.Sort { @@ -336,6 +352,9 @@ func (s *Store) Update(id string, u *EntityUpdate) error { args = append(args, string(*u.CardType)) } if u.CardData != nil { + if !json.Valid([]byte(*u.CardData)) { + return ErrInvalidCardData + } sets = append(sets, "card_data = ?") args = append(args, *u.CardData) } @@ -370,6 +389,9 @@ func (s *Store) Promote(id string, cardType CardType, cardData *string) error { dataVal := "{}" if cardData != nil { + if !json.Valid([]byte(*cardData)) { + return ErrInvalidCardData + } dataVal = *cardData } @@ -473,8 +495,9 @@ func (s *Store) Absorb(targetID, sourceID string) error { } } - if _, err := tx.Exec("UPDATE entities SET deleted_at = ? WHERE id = ?", - now, sourceID); err != nil { + absorbNote := source.Body + "\n\n[absorbed into " + targetID + "]" + if _, err := tx.Exec("UPDATE entities SET body = ?, deleted_at = ?, modified_at = ? WHERE id = ?", + absorbNote, now, now, sourceID); err != nil { return err } diff --git a/web/app.js b/web/app.js index 9a24b48..25b07e7 100644 --- a/web/app.js +++ b/web/app.js @@ -1247,9 +1247,9 @@ async function loadEntities() { const params = buildListParams(0); - const results = await api.listEntities(params); - state.entities = results; - state.hasMore = results.length === PAGE_SIZE; + const resp = await api.listEntities(params); + state.entities = resp.data; + state.hasMore = (resp.offset + resp.data.length) < resp.total; state.selectedIndex = -1; renderEntityList(); renderDetailPane(); @@ -1258,9 +1258,9 @@ async function loadMore() { const params = buildListParams(state.entities.length); - const results = await api.listEntities(params); - state.entities = state.entities.concat(results); - state.hasMore = results.length === PAGE_SIZE; + const resp = await api.listEntities(params); + state.entities = state.entities.concat(resp.data); + state.hasMore = (resp.offset + resp.data.length) < resp.total; renderEntityList(); }