fix: code hardening from senior dev audit #40

Merged
lerko merged 6 commits from fix/audit-phase1-hardening into main 2026-05-21 01:04:31 +00:00
18 changed files with 267 additions and 228 deletions
Showing only changes of commit d715b053e7 - Show all commits
+4 -4
View File
@@ -19,19 +19,19 @@ func init() {
rootCmd.AddCommand(absorbCmd) rootCmd.AddCommand(absorbCmd)
} }
func runAbsorb(_ *cobra.Command, args []string) error { func runAbsorb(cmd *cobra.Command, args []string) error {
store, err := openStore() store, err := openStore()
if err != nil { if err != nil {
return err return err
} }
defer store.Close() defer store.Close()
targetID, err := store.Resolve(args[0]) targetID, err := store.Resolve(cmd.Context(), args[0])
if err != nil { if err != nil {
return fmt.Errorf("not_found — no entity with id %s", args[0]) return fmt.Errorf("not_found — no entity with id %s", args[0])
} }
sourceID, err := store.Resolve(args[1]) sourceID, err := store.Resolve(cmd.Context(), args[1])
if err != nil { if err != nil {
return fmt.Errorf("not_found — no entity with id %s", args[1]) return fmt.Errorf("not_found — no entity with id %s", args[1])
} }
@@ -40,7 +40,7 @@ func runAbsorb(_ *cobra.Command, args []string) error {
return fmt.Errorf("target and source must be different entities") return fmt.Errorf("target and source must be different entities")
} }
if err := store.Absorb(targetID, sourceID); err != nil { if err := store.Absorb(cmd.Context(), targetID, sourceID); err != nil {
if err == db.ErrTargetCrystallized { if err == db.ErrTargetCrystallized {
return fmt.Errorf("invalid_absorb — target %s is crystallized, demote first", return fmt.Errorf("invalid_absorb — target %s is crystallized, demote first",
display.FormatID(targetID)) display.FormatID(targetID))
+2 -2
View File
@@ -17,7 +17,7 @@ var addCmd = &cobra.Command{
RunE: runAdd, RunE: runAdd,
} }
func runAdd(_ *cobra.Command, args []string) error { func runAdd(cmd *cobra.Command, args []string) error {
input := strings.Join(args, " ") input := strings.Join(args, " ")
parsed, err := parse.Parse(input) parsed, err := parse.Parse(input)
@@ -47,7 +47,7 @@ func runAdd(_ *cobra.Command, args []string) error {
e.CardType = &ct e.CardType = &ct
} }
if err := store.Create(e); err != nil { if err := store.Create(cmd.Context(), e); err != nil {
return err return err
} }
+2 -2
View File
@@ -26,7 +26,7 @@ func init() {
rootCmd.AddCommand(cardsCmd) rootCmd.AddCommand(cardsCmd)
} }
func runCards(_ *cobra.Command, _ []string) error { func runCards(cmd *cobra.Command, _ []string) error {
store, err := openStore() store, err := openStore()
if err != nil { if err != nil {
return err return err
@@ -49,7 +49,7 @@ func runCards(_ *cobra.Command, _ []string) error {
p.CardTypeFilter = &ct p.CardTypeFilter = &ct
} }
entities, err := store.List(p) entities, err := store.List(cmd.Context(), p)
if err != nil { if err != nil {
return err return err
} }
+4 -4
View File
@@ -19,19 +19,19 @@ func init() {
rootCmd.AddCommand(copyCmd) rootCmd.AddCommand(copyCmd)
} }
func runCopy(_ *cobra.Command, args []string) error { func runCopy(cmd *cobra.Command, args []string) error {
store, err := openStore() store, err := openStore()
if err != nil { if err != nil {
return err return err
} }
defer store.Close() defer store.Close()
id, err := store.Resolve(args[0]) id, err := store.Resolve(cmd.Context(), args[0])
if err != nil { if err != nil {
return fmt.Errorf("not_found — no entity with id %s", args[0]) return fmt.Errorf("not_found — no entity with id %s", args[0])
} }
e, err := store.Get(id) e, err := store.Get(cmd.Context(), id)
if err != nil { if err != nil {
return err return err
} }
@@ -40,7 +40,7 @@ func runCopy(_ *cobra.Command, args []string) error {
return fmt.Errorf("clipboard: %w", err) return fmt.Errorf("clipboard: %w", err)
} }
if err := store.IncrementUse(id); err != nil { if err := store.IncrementUse(cmd.Context(), id); err != nil {
return err return err
} }
+3 -3
View File
@@ -19,19 +19,19 @@ func init() {
rootCmd.AddCommand(deleteCmd) rootCmd.AddCommand(deleteCmd)
} }
func runDelete(_ *cobra.Command, args []string) error { func runDelete(cmd *cobra.Command, args []string) error {
store, err := openStore() store, err := openStore()
if err != nil { if err != nil {
return err return err
} }
defer store.Close() defer store.Close()
id, err := store.Resolve(args[0]) id, err := store.Resolve(cmd.Context(), args[0])
if err != nil { if err != nil {
return fmt.Errorf("not_found — no entity with id %s", args[0]) return fmt.Errorf("not_found — no entity with id %s", args[0])
} }
result, err := store.SoftDelete(id) result, err := store.SoftDelete(cmd.Context(), id)
if err != nil { if err != nil {
return err return err
} }
+7 -6
View File
@@ -1,6 +1,7 @@
package cmd package cmd
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@@ -35,7 +36,7 @@ type demoEntity struct {
Tags []string `json:"tags"` Tags []string `json:"tags"`
} }
func runDemo(_ *cobra.Command, _ []string) error { func runDemo(cmd *cobra.Command, _ []string) error {
tmpDir, err := os.MkdirTemp("", "nib-demo-*") tmpDir, err := os.MkdirTemp("", "nib-demo-*")
if err != nil { if err != nil {
return err return err
@@ -48,7 +49,7 @@ func runDemo(_ *cobra.Command, _ []string) error {
return err return err
} }
if err := seedDemo(store); err != nil { if err := seedDemo(cmd.Context(), store); err != nil {
store.Close() store.Close()
return fmt.Errorf("seed demo data: %w", err) return fmt.Errorf("seed demo data: %w", err)
} }
@@ -58,7 +59,7 @@ func runDemo(_ *cobra.Command, _ []string) error {
return runServe(nil, nil) return runServe(nil, nil)
} }
func seedDemo(store *db.Store) error { func seedDemo(ctx context.Context, store *db.Store) error {
data, err := findDemoFile() data, err := findDemoFile()
if err != nil { if err != nil {
return err return err
@@ -94,19 +95,19 @@ func seedDemo(store *db.Store) error {
e.CompletedAt = &t e.CompletedAt = &t
} }
if err := store.Create(e); err != nil { if err := store.Create(ctx, e); err != nil {
return fmt.Errorf("entity %d: %w", i, err) return fmt.Errorf("entity %d: %w", i, err)
} }
if entry.CardType != nil { if entry.CardType != nil {
ct := db.CardType(*entry.CardType) ct := db.CardType(*entry.CardType)
if err := store.Promote(e.ID, ct, entry.CardData); err != nil { if err := store.Promote(ctx, e.ID, ct, entry.CardData); err != nil {
return fmt.Errorf("promote entity %d: %w", i, err) return fmt.Errorf("promote entity %d: %w", i, err)
} }
} }
if entry.Deleted { if entry.Deleted {
store.SoftDelete(e.ID) store.SoftDelete(ctx, e.ID)
} }
} }
+3 -3
View File
@@ -19,19 +19,19 @@ func init() {
rootCmd.AddCommand(demoteCmd) rootCmd.AddCommand(demoteCmd)
} }
func runDemote(_ *cobra.Command, args []string) error { func runDemote(cmd *cobra.Command, args []string) error {
store, err := openStore() store, err := openStore()
if err != nil { if err != nil {
return err return err
} }
defer store.Close() defer store.Close()
id, err := store.Resolve(args[0]) id, err := store.Resolve(cmd.Context(), args[0])
if err != nil { if err != nil {
return fmt.Errorf("not_found — no entity with id %s", args[0]) return fmt.Errorf("not_found — no entity with id %s", args[0])
} }
if err := store.Demote(id); err != nil { if err := store.Demote(cmd.Context(), id); err != nil {
if err == db.ErrAlreadyFluid { if err == db.ErrAlreadyFluid {
return fmt.Errorf("invalid_demote — entity %s is already fluid", display.FormatID(id)) return fmt.Errorf("invalid_demote — entity %s is already fluid", display.FormatID(id))
} }
+9 -9
View File
@@ -21,19 +21,19 @@ func init() {
rootCmd.AddCommand(editCmd) rootCmd.AddCommand(editCmd)
} }
func runEdit(_ *cobra.Command, args []string) error { func runEdit(cmd *cobra.Command, args []string) error {
store, err := openStore() store, err := openStore()
if err != nil { if err != nil {
return err return err
} }
defer store.Close() defer store.Close()
id, err := store.Resolve(args[0]) id, err := store.Resolve(cmd.Context(), args[0])
if err != nil { if err != nil {
return fmt.Errorf("not_found — no entity with id %s", args[0]) return fmt.Errorf("not_found — no entity with id %s", args[0])
} }
e, err := store.Get(id) e, err := store.Get(cmd.Context(), id)
if err != nil { if err != nil {
return err return err
} }
@@ -55,11 +55,11 @@ func runEdit(_ *cobra.Command, args []string) error {
editor = "vi" editor = "vi"
} }
cmd := exec.Command(editor, tmpfile.Name()) editorCmd := exec.Command(editor, tmpfile.Name())
cmd.Stdin = os.Stdin editorCmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout editorCmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr editorCmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil { if err := editorCmd.Run(); err != nil {
return fmt.Errorf("editor: %w", err) return fmt.Errorf("editor: %w", err)
} }
@@ -74,7 +74,7 @@ func runEdit(_ *cobra.Command, args []string) error {
return nil return nil
} }
if err := store.Update(id, &db.EntityUpdate{Body: &body}); err != nil { if err := store.Update(cmd.Context(), id, &db.EntityUpdate{Body: &body}); err != nil {
return err return err
} }
+2 -2
View File
@@ -36,7 +36,7 @@ func init() {
lsCmd.Flags().BoolVar(&lsAll, "all", false, "include deleted entities") lsCmd.Flags().BoolVar(&lsAll, "all", false, "include deleted entities")
} }
func runLs(_ *cobra.Command, _ []string) error { func runLs(cmd *cobra.Command, _ []string) error {
store, err := openStore() store, err := openStore()
if err != nil { if err != nil {
return err return err
@@ -88,7 +88,7 @@ func runLs(_ *cobra.Command, _ []string) error {
p.Since = &since p.Since = &since
} }
entities, err := store.List(p) entities, err := store.List(cmd.Context(), p)
if err != nil { if err != nil {
return err return err
} }
+4 -4
View File
@@ -20,14 +20,14 @@ func init() {
rootCmd.AddCommand(promoteCmd) rootCmd.AddCommand(promoteCmd)
} }
func runPromote(_ *cobra.Command, args []string) error { func runPromote(cmd *cobra.Command, args []string) error {
store, err := openStore() store, err := openStore()
if err != nil { if err != nil {
return err return err
} }
defer store.Close() defer store.Close()
id, err := store.Resolve(args[0]) id, err := store.Resolve(cmd.Context(), args[0])
if err != nil { if err != nil {
return fmt.Errorf("not_found — no entity with id %s", args[0]) return fmt.Errorf("not_found — no entity with id %s", args[0])
} }
@@ -40,14 +40,14 @@ func runPromote(_ *cobra.Command, args []string) error {
cardType = db.CardType(args[1]) cardType = db.CardType(args[1])
} }
e, err := store.Get(id) e, err := store.Get(cmd.Context(), id)
if err != nil { if err != nil {
return err return err
} }
cd := carddata.GenerateCardData(cardType, e.Body) cd := carddata.GenerateCardData(cardType, e.Body)
if err := store.Promote(id, cardType, cd); err != nil { if err := store.Promote(cmd.Context(), id, cardType, cd); err != nil {
if err == db.ErrAlreadyPromoted { if err == db.ErrAlreadyPromoted {
return fmt.Errorf("invalid_promote — entity %s is already a %s", return fmt.Errorf("invalid_promote — entity %s is already a %s",
display.FormatID(id), *e.CardType) display.FormatID(id), *e.CardType)
+15 -15
View File
@@ -109,13 +109,13 @@ func listEntities(store *db.Store) http.HandlerFunc {
p.Limit = 50 p.Limit = 50
} }
total, err := store.Count(p) total, err := store.Count(r.Context(), p)
if err != nil { if err != nil {
writeInternalError(w, err) writeInternalError(w, err)
return return
} }
entities, err := store.List(p) entities, err := store.List(r.Context(), p)
if err != nil { if err != nil {
writeInternalError(w, err) writeInternalError(w, err)
return return
@@ -177,7 +177,7 @@ func createEntity(store *db.Store) http.HandlerFunc {
e.CardData = req.CardData e.CardData = req.CardData
} }
if err := store.Create(e); err != nil { if err := store.Create(r.Context(), e); err != nil {
if err == db.ErrInvalidCardData { if err == db.ErrInvalidCardData {
writeError(w, http.StatusBadRequest, "invalid_card_data", "card_data must be valid JSON") writeError(w, http.StatusBadRequest, "invalid_card_data", "card_data must be valid JSON")
return return
@@ -193,7 +193,7 @@ func createEntity(store *db.Store) http.HandlerFunc {
func getEntity(store *db.Store) http.HandlerFunc { func getEntity(store *db.Store) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
e, err := store.Get(id) e, err := store.Get(r.Context(), id)
if err != nil { if err != nil {
if err == db.ErrNotFound { if err == db.ErrNotFound {
writeError(w, http.StatusNotFound, "not_found", "no entity with id "+id) writeError(w, http.StatusNotFound, "not_found", "no entity with id "+id)
@@ -243,7 +243,7 @@ func updateEntity(store *db.Store) http.HandlerFunc {
u.CardType = &ct u.CardType = &ct
} }
if err := store.Update(id, u); err != nil { if err := store.Update(r.Context(), id, u); err != nil {
if err == db.ErrNotFound { if err == db.ErrNotFound {
writeError(w, http.StatusNotFound, "not_found", "no entity with id "+id) writeError(w, http.StatusNotFound, "not_found", "no entity with id "+id)
return return
@@ -256,7 +256,7 @@ func updateEntity(store *db.Store) http.HandlerFunc {
return return
} }
e, err := store.Get(id) e, err := store.Get(r.Context(), id)
if err != nil { if err != nil {
writeInternalError(w, err) writeInternalError(w, err)
return return
@@ -272,7 +272,7 @@ type DeleteResponse struct {
func deleteEntity(store *db.Store) http.HandlerFunc { func deleteEntity(store *db.Store) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
result, err := store.SoftDelete(id) result, err := store.SoftDelete(r.Context(), id)
if err != nil { if err != nil {
if err == db.ErrNotFound { if err == db.ErrNotFound {
writeError(w, http.StatusNotFound, "not_found", "no entity with id "+id) writeError(w, http.StatusNotFound, "not_found", "no entity with id "+id)
@@ -307,7 +307,7 @@ func promoteEntity(store *db.Store) http.HandlerFunc {
return return
} }
if err := store.Promote(id, db.CardType(req.CardType), req.CardData); err != nil { if err := store.Promote(r.Context(), id, db.CardType(req.CardType), req.CardData); err != nil {
if err == db.ErrNotFound { if err == db.ErrNotFound {
writeError(w, http.StatusNotFound, "not_found", "no entity with id "+id) writeError(w, http.StatusNotFound, "not_found", "no entity with id "+id)
return return
@@ -324,7 +324,7 @@ func promoteEntity(store *db.Store) http.HandlerFunc {
return return
} }
e, err := store.Get(id) e, err := store.Get(r.Context(), id)
if err != nil { if err != nil {
writeInternalError(w, err) writeInternalError(w, err)
return return
@@ -337,7 +337,7 @@ func demoteEntity(store *db.Store) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
if err := store.Demote(id); err != nil { if err := store.Demote(r.Context(), id); err != nil {
if err == db.ErrNotFound { if err == db.ErrNotFound {
writeError(w, http.StatusNotFound, "not_found", "no entity with id "+id) writeError(w, http.StatusNotFound, "not_found", "no entity with id "+id)
return return
@@ -350,7 +350,7 @@ func demoteEntity(store *db.Store) http.HandlerFunc {
return return
} }
e, err := store.Get(id) e, err := store.Get(r.Context(), id)
if err != nil { if err != nil {
writeInternalError(w, err) writeInternalError(w, err)
return return
@@ -381,7 +381,7 @@ func absorbEntity(store *db.Store) http.HandlerFunc {
return return
} }
if err := store.Absorb(id, req.SourceID); err != nil { if err := store.Absorb(r.Context(), id, req.SourceID); err != nil {
if err == db.ErrNotFound { if err == db.ErrNotFound {
writeError(w, http.StatusNotFound, "not_found", "target or source entity not found") writeError(w, http.StatusNotFound, "not_found", "target or source entity not found")
return return
@@ -394,7 +394,7 @@ func absorbEntity(store *db.Store) http.HandlerFunc {
return return
} }
e, err := store.Get(id) e, err := store.Get(r.Context(), id)
if err != nil { if err != nil {
writeInternalError(w, err) writeInternalError(w, err)
return return
@@ -407,7 +407,7 @@ func useEntity(store *db.Store) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
if err := store.IncrementUse(id); err != nil { if err := store.IncrementUse(r.Context(), id); err != nil {
if err == db.ErrNotFound { if err == db.ErrNotFound {
writeError(w, http.StatusNotFound, "not_found", "no entity with id "+id) writeError(w, http.StatusNotFound, "not_found", "no entity with id "+id)
return return
@@ -416,7 +416,7 @@ func useEntity(store *db.Store) http.HandlerFunc {
return return
} }
e, err := store.Get(id) e, err := store.Get(r.Context(), id)
if err != nil { if err != nil {
writeInternalError(w, err) writeInternalError(w, err)
return return
+1 -1
View File
@@ -14,7 +14,7 @@ type TagResponse struct {
func listTags(store *db.Store) http.HandlerFunc { func listTags(store *db.Store) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
cardsOnly := r.URL.Query().Get("cards_only") == "true" cardsOnly := r.URL.Query().Get("cards_only") == "true"
tags, err := store.ListTags(cardsOnly) tags, err := store.ListTags(r.Context(), cardsOnly)
if err != nil { if err != nil {
writeInternalError(w, err) writeInternalError(w, err)
return return
+47 -48
View File
@@ -1,6 +1,7 @@
package db package db
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -104,7 +105,7 @@ type EntityUpdate struct {
Tags *[]string Tags *[]string
} }
func (s *Store) Create(e *Entity) error { func (s *Store) Create(ctx context.Context, e *Entity) error {
if e.CardData != nil && !json.Valid([]byte(*e.CardData)) { if e.CardData != nil && !json.Valid([]byte(*e.CardData)) {
return ErrInvalidCardData return ErrInvalidCardData
} }
@@ -116,13 +117,13 @@ func (s *Store) Create(e *Entity) error {
e.Tags = []string{} e.Tags = []string{}
} }
tx, err := s.db.Begin() tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
_, err = tx.Exec(` _, err = tx.ExecContext(ctx, `
INSERT INTO entities (id, created_at, modified_at, body, title, description, INSERT INTO entities (id, created_at, modified_at, body, title, description,
glyph, time_anchor, completed_at, pinned, deleted_at, glyph, time_anchor, completed_at, pinned, deleted_at,
card_type, card_data, use_count, last_used_at) card_type, card_data, use_count, last_used_at)
@@ -147,18 +148,18 @@ func (s *Store) Create(e *Entity) error {
return err return err
} }
if err := insertTags(tx, e.ID, e.Tags); err != nil { if err := insertTags(ctx, tx, e.ID, e.Tags); err != nil {
return err return err
} }
return tx.Commit() return tx.Commit()
} }
func (s *Store) Get(id string) (*Entity, error) { func (s *Store) Get(ctx context.Context, id string) (*Entity, error) {
e := &Entity{} e := &Entity{}
row := newEntityRow() row := newEntityRow()
err := s.db.QueryRow(` err := s.db.QueryRowContext(ctx, `
SELECT id, created_at, modified_at, body, title, description, SELECT id, created_at, modified_at, body, title, description,
glyph, time_anchor, completed_at, pinned, deleted_at, glyph, time_anchor, completed_at, pinned, deleted_at,
card_type, card_data, use_count, last_used_at card_type, card_data, use_count, last_used_at
@@ -174,7 +175,7 @@ func (s *Store) Get(id string) (*Entity, error) {
return nil, fmt.Errorf("scan entity %s: %w", id, err) return nil, fmt.Errorf("scan entity %s: %w", id, err)
} }
tags, err := s.loadTags(id) tags, err := s.loadTags(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -229,15 +230,15 @@ func listWhere(params ListParams) (string, []any) {
return clause, args return clause, args
} }
func (s *Store) Count(params ListParams) (int, error) { func (s *Store) Count(ctx context.Context, params ListParams) (int, error) {
whereClause, args := listWhere(params) whereClause, args := listWhere(params)
query := fmt.Sprintf("SELECT COUNT(*) FROM entities e %s", whereClause) query := fmt.Sprintf("SELECT COUNT(*) FROM entities e %s", whereClause)
var count int var count int
err := s.db.QueryRow(query, args...).Scan(&count) err := s.db.QueryRowContext(ctx, query, args...).Scan(&count)
return count, err return count, err
} }
func (s *Store) List(params ListParams) ([]*Entity, error) { func (s *Store) List(ctx context.Context, params ListParams) ([]*Entity, error) {
whereClause, args := listWhere(params) whereClause, args := listWhere(params)
orderCol := "e.created_at" orderCol := "e.created_at"
@@ -275,7 +276,7 @@ func (s *Store) List(params ListParams) ([]*Entity, error) {
args = append(args, limit, params.Offset) args = append(args, limit, params.Offset)
rows, err := s.db.Query(query, args...) rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -297,20 +298,20 @@ func (s *Store) List(params ListParams) ([]*Entity, error) {
return nil, err return nil, err
} }
if err := s.batchLoadTags(entities); err != nil { if err := s.batchLoadTags(ctx, entities); err != nil {
return nil, err return nil, err
} }
return entities, nil return entities, nil
} }
func (s *Store) Update(id string, u *EntityUpdate) error { func (s *Store) Update(ctx context.Context, id string, u *EntityUpdate) error {
existing, err := s.Get(id) existing, err := s.Get(ctx, id)
if err != nil { if err != nil {
return err return err
} }
tx, err := s.db.Begin() tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -369,15 +370,15 @@ func (s *Store) Update(id string, u *EntityUpdate) error {
args = append(args, existing.ID) args = append(args, existing.ID)
query := fmt.Sprintf("UPDATE entities SET %s WHERE id = ?", strings.Join(sets, ", ")) query := fmt.Sprintf("UPDATE entities SET %s WHERE id = ?", strings.Join(sets, ", "))
if _, err := tx.Exec(query, args...); err != nil { if _, err := tx.ExecContext(ctx, query, args...); err != nil {
return err return err
} }
if u.Tags != nil { if u.Tags != nil {
if _, err := tx.Exec("DELETE FROM entity_tags WHERE entity_id = ?", existing.ID); err != nil { if _, err := tx.ExecContext(ctx, "DELETE FROM entity_tags WHERE entity_id = ?", existing.ID); err != nil {
return err return err
} }
if err := insertTags(tx, existing.ID, *u.Tags); err != nil { if err := insertTags(ctx, tx, existing.ID, *u.Tags); err != nil {
return err return err
} }
} }
@@ -385,8 +386,8 @@ func (s *Store) Update(id string, u *EntityUpdate) error {
return tx.Commit() return tx.Commit()
} }
func (s *Store) Promote(id string, cardType CardType, cardData *string) error { func (s *Store) Promote(ctx context.Context, id string, cardType CardType, cardData *string) error {
e, err := s.Get(id) e, err := s.Get(ctx, id)
if err != nil { if err != nil {
return err return err
} }
@@ -402,15 +403,15 @@ func (s *Store) Promote(id string, cardType CardType, cardData *string) error {
dataVal = *cardData dataVal = *cardData
} }
_, err = s.db.Exec(` _, err = s.db.ExecContext(ctx, `
UPDATE entities SET card_type = ?, card_data = ?, modified_at = ? UPDATE entities SET card_type = ?, card_data = ?, modified_at = ?
WHERE id = ?`, WHERE id = ?`,
string(cardType), dataVal, time.Now().UTC().Format(time.RFC3339), id) string(cardType), dataVal, time.Now().UTC().Format(time.RFC3339), id)
return err return err
} }
func (s *Store) Demote(id string) error { func (s *Store) Demote(ctx context.Context, id string) error {
e, err := s.Get(id) e, err := s.Get(ctx, id)
if err != nil { if err != nil {
return err return err
} }
@@ -418,7 +419,7 @@ func (s *Store) Demote(id string) error {
return ErrAlreadyFluid return ErrAlreadyFluid
} }
_, err = s.db.Exec(` _, err = s.db.ExecContext(ctx, `
UPDATE entities SET card_type = NULL, card_data = NULL, UPDATE entities SET card_type = NULL, card_data = NULL,
use_count = 0, last_used_at = NULL, modified_at = ? use_count = 0, last_used_at = NULL, modified_at = ?
WHERE id = ?`, WHERE id = ?`,
@@ -433,9 +434,9 @@ const (
DeletedHard DeletedHard
) )
func (s *Store) SoftDelete(id string) (DeleteResult, error) { func (s *Store) SoftDelete(ctx context.Context, id string) (DeleteResult, error) {
var deletedAt sql.NullString var deletedAt sql.NullString
err := s.db.QueryRow("SELECT deleted_at FROM entities WHERE id = ?", id).Scan(&deletedAt) err := s.db.QueryRowContext(ctx, "SELECT deleted_at FROM entities WHERE id = ?", id).Scan(&deletedAt)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return 0, ErrNotFound return 0, ErrNotFound
} }
@@ -444,21 +445,21 @@ func (s *Store) SoftDelete(id string) (DeleteResult, error) {
} }
if deletedAt.Valid { if deletedAt.Valid {
_, err = s.db.Exec("DELETE FROM entities WHERE id = ?", id) _, err = s.db.ExecContext(ctx, "DELETE FROM entities WHERE id = ?", id)
return DeletedHard, err return DeletedHard, err
} }
_, err = s.db.Exec("UPDATE entities SET deleted_at = ? WHERE id = ?", _, err = s.db.ExecContext(ctx, "UPDATE entities SET deleted_at = ? WHERE id = ?",
time.Now().UTC().Format(time.RFC3339), id) time.Now().UTC().Format(time.RFC3339), id)
return DeletedSoft, err return DeletedSoft, err
} }
func (s *Store) Absorb(targetID, sourceID string) error { func (s *Store) Absorb(ctx context.Context, targetID, sourceID string) error {
target, err := s.Get(targetID) target, err := s.Get(ctx, targetID)
if err != nil { if err != nil {
return err return err
} }
source, err := s.Get(sourceID) source, err := s.Get(ctx, sourceID)
if err != nil { if err != nil {
return err return err
} }
@@ -467,7 +468,7 @@ func (s *Store) Absorb(targetID, sourceID string) error {
return ErrTargetCrystallized return ErrTargetCrystallized
} }
tx, err := s.db.Begin() tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -476,7 +477,7 @@ func (s *Store) Absorb(targetID, sourceID string) error {
now := time.Now().UTC().Format(time.RFC3339) now := time.Now().UTC().Format(time.RFC3339)
merged := target.Body + "\n" + source.Body merged := target.Body + "\n" + source.Body
if _, err := tx.Exec("UPDATE entities SET body = ?, modified_at = ? WHERE id = ?", if _, err := tx.ExecContext(ctx, "UPDATE entities SET body = ?, modified_at = ? WHERE id = ?",
merged, now, targetID); err != nil { merged, now, targetID); err != nil {
return err return err
} }
@@ -487,7 +488,7 @@ func (s *Store) Absorb(targetID, sourceID string) error {
} }
for _, t := range source.Tags { for _, t := range source.Tags {
if !seen[t] { if !seen[t] {
if _, err := tx.Exec("INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)", if _, err := tx.ExecContext(ctx, "INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)",
targetID, t); err != nil { targetID, t); err != nil {
return err return err
} }
@@ -495,7 +496,7 @@ func (s *Store) Absorb(targetID, sourceID string) error {
} }
if source.CardType != nil { if source.CardType != nil {
if _, err := tx.Exec(`UPDATE entities SET card_type = NULL, card_data = NULL, if _, err := tx.ExecContext(ctx, `UPDATE entities SET card_type = NULL, card_data = NULL,
use_count = 0, last_used_at = NULL, modified_at = ? WHERE id = ?`, use_count = 0, last_used_at = NULL, modified_at = ? WHERE id = ?`,
now, sourceID); err != nil { now, sourceID); err != nil {
return err return err
@@ -503,7 +504,7 @@ func (s *Store) Absorb(targetID, sourceID string) error {
} }
absorbNote := source.Body + "\n\n[absorbed into " + targetID + "]" absorbNote := source.Body + "\n\n[absorbed into " + targetID + "]"
if _, err := tx.Exec("UPDATE entities SET body = ?, deleted_at = ?, modified_at = ? WHERE id = ?", if _, err := tx.ExecContext(ctx, "UPDATE entities SET body = ?, deleted_at = ?, modified_at = ? WHERE id = ?",
absorbNote, now, now, sourceID); err != nil { absorbNote, now, now, sourceID); err != nil {
return err return err
} }
@@ -511,8 +512,8 @@ func (s *Store) Absorb(targetID, sourceID string) error {
return tx.Commit() return tx.Commit()
} }
func (s *Store) IncrementUse(id string) error { func (s *Store) IncrementUse(ctx context.Context, id string) error {
res, err := s.db.Exec(` res, err := s.db.ExecContext(ctx, `
UPDATE entities SET use_count = use_count + 1, last_used_at = ? UPDATE entities SET use_count = use_count + 1, last_used_at = ?
WHERE id = ?`, WHERE id = ?`,
time.Now().UTC().Format(time.RFC3339), id) time.Now().UTC().Format(time.RFC3339), id)
@@ -526,8 +527,8 @@ func (s *Store) IncrementUse(id string) error {
return nil return nil
} }
func (s *Store) Resolve(prefix string) (string, error) { func (s *Store) Resolve(ctx context.Context, prefix string) (string, error) {
rows, err := s.db.Query("SELECT id FROM entities WHERE id LIKE ?", prefix+"%") rows, err := s.db.QueryContext(ctx, "SELECT id FROM entities WHERE id LIKE ?", prefix+"%")
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -593,9 +594,7 @@ func (r *entityRow) apply(e *Entity) error {
return nil return nil
} }
// helpers func (s *Store) batchLoadTags(ctx context.Context, entities []*Entity) error {
func (s *Store) batchLoadTags(entities []*Entity) error {
if len(entities) == 0 { if len(entities) == 0 {
return nil return nil
} }
@@ -615,7 +614,7 @@ func (s *Store) batchLoadTags(entities []*Entity) error {
strings.Join(placeholders, ","), strings.Join(placeholders, ","),
) )
rows, err := s.db.Query(query, args...) rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return err return err
} }
@@ -633,8 +632,8 @@ func (s *Store) batchLoadTags(entities []*Entity) error {
return rows.Err() return rows.Err()
} }
func (s *Store) loadTags(entityID string) ([]string, error) { func (s *Store) loadTags(ctx context.Context, entityID string) ([]string, error) {
rows, err := s.db.Query("SELECT tag FROM entity_tags WHERE entity_id = ? ORDER BY tag", entityID) rows, err := s.db.QueryContext(ctx, "SELECT tag FROM entity_tags WHERE entity_id = ? ORDER BY tag", entityID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -657,9 +656,9 @@ func (s *Store) loadTags(entityID string) ([]string, error) {
return tags, nil return tags, nil
} }
func insertTags(tx *sql.Tx, entityID string, tags []string) error { func insertTags(ctx context.Context, tx *sql.Tx, entityID string, tags []string) error {
for _, tag := range tags { for _, tag := range tags {
if _, err := tx.Exec("INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)", if _, err := tx.ExecContext(ctx, "INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)",
entityID, tag); err != nil { entityID, tag); err != nil {
return err return err
} }
+116 -87
View File
@@ -1,6 +1,7 @@
package db package db
import ( import (
"context"
"testing" "testing"
"time" "time"
) )
@@ -11,15 +12,16 @@ func ptr[T any](v T) *T {
func TestCreate_Note(t *testing.T) { func TestCreate_Note(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "hello world", Glyph: GlyphNote} e := &Entity{Body: "hello world", Glyph: GlyphNote}
if err := s.Create(e); err != nil { if err := s.Create(ctx, e); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if e.ID == "" { if e.ID == "" {
t.Fatal("ID not set") t.Fatal("ID not set")
} }
got, err := s.Get(e.ID) got, err := s.Get(ctx, e.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -33,12 +35,13 @@ func TestCreate_Note(t *testing.T) {
func TestCreate_TodoWithTimeAnchor(t *testing.T) { func TestCreate_TodoWithTimeAnchor(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "deploy", Glyph: GlyphTodo, TimeAnchor: ptr("14:00")} e := &Entity{Body: "deploy", Glyph: GlyphTodo, TimeAnchor: ptr("14:00")}
if err := s.Create(e); err != nil { if err := s.Create(ctx, e); err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, err := s.Get(e.ID) got, err := s.Get(ctx, e.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -49,12 +52,13 @@ func TestCreate_TodoWithTimeAnchor(t *testing.T) {
func TestCreate_WithTags(t *testing.T) { func TestCreate_WithTags(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "deploy nginx", Glyph: GlyphNote, Tags: []string{"ops", "nginx"}} e := &Entity{Body: "deploy nginx", Glyph: GlyphNote, Tags: []string{"ops", "nginx"}}
if err := s.Create(e); err != nil { if err := s.Create(ctx, e); err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, err := s.Get(e.ID) got, err := s.Get(ctx, e.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -65,13 +69,14 @@ func TestCreate_WithTags(t *testing.T) {
func TestCreate_WithCardType(t *testing.T) { func TestCreate_WithCardType(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
ct := CardSnippet ct := CardSnippet
e := &Entity{Body: "trick", Glyph: GlyphNote, CardType: &ct} e := &Entity{Body: "trick", Glyph: GlyphNote, CardType: &ct}
if err := s.Create(e); err != nil { if err := s.Create(ctx, e); err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, err := s.Get(e.ID) got, err := s.Get(ctx, e.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -82,7 +87,7 @@ func TestCreate_WithCardType(t *testing.T) {
func TestGet_NotFound(t *testing.T) { func TestGet_NotFound(t *testing.T) {
s := testStore(t) s := testStore(t)
_, err := s.Get("01NONEXISTENT0000000000000") _, err := s.Get(context.Background(), "01NONEXISTENT0000000000000")
if err != ErrNotFound { if err != ErrNotFound {
t.Errorf("expected ErrNotFound, got %v", err) t.Errorf("expected ErrNotFound, got %v", err)
} }
@@ -90,11 +95,12 @@ func TestGet_NotFound(t *testing.T) {
func TestList_DefaultParams(t *testing.T) { func TestList_DefaultParams(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
s.Create(&Entity{Body: "note", Glyph: GlyphNote}) s.Create(ctx, &Entity{Body: "note", Glyph: GlyphNote})
} }
entities, err := s.List(DefaultListParams()) entities, err := s.List(ctx, DefaultListParams())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -109,15 +115,16 @@ func TestList_DefaultParams(t *testing.T) {
func TestList_FilterByTag(t *testing.T) { func TestList_FilterByTag(t *testing.T) {
s := testStore(t) s := testStore(t)
s.Create(&Entity{Body: "a", Glyph: GlyphNote, Tags: []string{"ops"}}) ctx := context.Background()
s.Create(&Entity{Body: "b", Glyph: GlyphNote, Tags: []string{"home"}}) s.Create(ctx, &Entity{Body: "a", Glyph: GlyphNote, Tags: []string{"ops"}})
s.Create(&Entity{Body: "c", Glyph: GlyphNote, Tags: []string{"ops", "home"}}) s.Create(ctx, &Entity{Body: "b", Glyph: GlyphNote, Tags: []string{"home"}})
s.Create(ctx, &Entity{Body: "c", Glyph: GlyphNote, Tags: []string{"ops", "home"}})
p := DefaultListParams() p := DefaultListParams()
tag := "ops" tag := "ops"
p.Tag = &tag p.Tag = &tag
entities, err := s.List(p) entities, err := s.List(ctx, p)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -128,13 +135,14 @@ func TestList_FilterByTag(t *testing.T) {
func TestList_FilterByDate(t *testing.T) { func TestList_FilterByDate(t *testing.T) {
s := testStore(t) s := testStore(t)
s.Create(&Entity{Body: "today", Glyph: GlyphNote}) ctx := context.Background()
s.Create(ctx, &Entity{Body: "today", Glyph: GlyphNote})
p := DefaultListParams() p := DefaultListParams()
date := time.Now().UTC().Format("2006-01-02") date := time.Now().UTC().Format("2006-01-02")
p.Date = &date p.Date = &date
entities, err := s.List(p) entities, err := s.List(ctx, p)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -144,7 +152,7 @@ func TestList_FilterByDate(t *testing.T) {
otherDate := "2020-01-01" otherDate := "2020-01-01"
p.Date = &otherDate p.Date = &otherDate
entities, err = s.List(p) entities, err = s.List(ctx, p)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -155,13 +163,14 @@ func TestList_FilterByDate(t *testing.T) {
func TestList_CardsOnly(t *testing.T) { func TestList_CardsOnly(t *testing.T) {
s := testStore(t) s := testStore(t)
s.Create(&Entity{Body: "fluid", Glyph: GlyphNote}) ctx := context.Background()
s.Create(ctx, &Entity{Body: "fluid", Glyph: GlyphNote})
ct := CardSnippet ct := CardSnippet
s.Create(&Entity{Body: "card", Glyph: GlyphNote, CardType: &ct}) s.Create(ctx, &Entity{Body: "card", Glyph: GlyphNote, CardType: &ct})
p := DefaultListParams() p := DefaultListParams()
p.CardsOnly = true p.CardsOnly = true
entities, err := s.List(p) entities, err := s.List(ctx, p)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -175,12 +184,13 @@ func TestList_CardsOnly(t *testing.T) {
func TestList_IncludeDeleted(t *testing.T) { func TestList_IncludeDeleted(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "doomed", Glyph: GlyphNote} e := &Entity{Body: "doomed", Glyph: GlyphNote}
s.Create(e) s.Create(ctx, e)
s.SoftDelete(e.ID) s.SoftDelete(ctx, e.ID)
p := DefaultListParams() p := DefaultListParams()
entities, err := s.List(p) entities, err := s.List(ctx, p)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -189,7 +199,7 @@ func TestList_IncludeDeleted(t *testing.T) {
} }
p.IncludeDeleted = true p.IncludeDeleted = true
entities, err = s.List(p) entities, err = s.List(ctx, p)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -200,17 +210,18 @@ func TestList_IncludeDeleted(t *testing.T) {
func TestList_SortByUseCount(t *testing.T) { func TestList_SortByUseCount(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
ct := CardSnippet ct := CardSnippet
e1 := &Entity{Body: "low", Glyph: GlyphNote, CardType: &ct} e1 := &Entity{Body: "low", Glyph: GlyphNote, CardType: &ct}
e2 := &Entity{Body: "high", Glyph: GlyphNote, CardType: &ct} e2 := &Entity{Body: "high", Glyph: GlyphNote, CardType: &ct}
s.Create(e1) s.Create(ctx, e1)
s.Create(e2) s.Create(ctx, e2)
s.IncrementUse(e2.ID) s.IncrementUse(ctx, e2.ID)
s.IncrementUse(e2.ID) s.IncrementUse(ctx, e2.ID)
p := DefaultListParams() p := DefaultListParams()
p.Sort = "use_count" p.Sort = "use_count"
entities, err := s.List(p) entities, err := s.List(ctx, p)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -221,14 +232,15 @@ func TestList_SortByUseCount(t *testing.T) {
func TestList_Pagination(t *testing.T) { func TestList_Pagination(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
s.Create(&Entity{Body: "note", Glyph: GlyphNote}) s.Create(ctx, &Entity{Body: "note", Glyph: GlyphNote})
} }
p := DefaultListParams() p := DefaultListParams()
p.Limit = 3 p.Limit = 3
p.Offset = 0 p.Offset = 0
page1, err := s.List(p) page1, err := s.List(ctx, p)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -237,7 +249,7 @@ func TestList_Pagination(t *testing.T) {
} }
p.Offset = 3 p.Offset = 3
page2, err := s.List(p) page2, err := s.List(ctx, p)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -251,16 +263,17 @@ func TestList_Pagination(t *testing.T) {
func TestUpdate_Body(t *testing.T) { func TestUpdate_Body(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "old", Glyph: GlyphNote} e := &Entity{Body: "old", Glyph: GlyphNote}
s.Create(e) s.Create(ctx, e)
time.Sleep(1100 * time.Millisecond) time.Sleep(1100 * time.Millisecond)
newBody := "new" newBody := "new"
if err := s.Update(e.ID, &EntityUpdate{Body: &newBody}); err != nil { if err := s.Update(ctx, e.ID, &EntityUpdate{Body: &newBody}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, _ := s.Get(e.ID) got, _ := s.Get(ctx, e.ID)
if got.Body != "new" { if got.Body != "new" {
t.Errorf("body not updated: %q", got.Body) t.Errorf("body not updated: %q", got.Body)
} }
@@ -271,15 +284,16 @@ func TestUpdate_Body(t *testing.T) {
func TestUpdate_Tags(t *testing.T) { func TestUpdate_Tags(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "test", Glyph: GlyphNote, Tags: []string{"old"}} e := &Entity{Body: "test", Glyph: GlyphNote, Tags: []string{"old"}}
s.Create(e) s.Create(ctx, e)
newTags := []string{"new1", "new2"} newTags := []string{"new1", "new2"}
if err := s.Update(e.ID, &EntityUpdate{Tags: &newTags}); err != nil { if err := s.Update(ctx, e.ID, &EntityUpdate{Tags: &newTags}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, _ := s.Get(e.ID) got, _ := s.Get(ctx, e.ID)
if len(got.Tags) != 2 { if len(got.Tags) != 2 {
t.Fatalf("expected 2 tags, got %d: %v", len(got.Tags), got.Tags) t.Fatalf("expected 2 tags, got %d: %v", len(got.Tags), got.Tags)
} }
@@ -287,14 +301,15 @@ func TestUpdate_Tags(t *testing.T) {
func TestPromote_Success(t *testing.T) { func TestPromote_Success(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "trick", Glyph: GlyphNote} e := &Entity{Body: "trick", Glyph: GlyphNote}
s.Create(e) s.Create(ctx, e)
if err := s.Promote(e.ID, CardSnippet, nil); err != nil { if err := s.Promote(ctx, e.ID, CardSnippet, nil); err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, _ := s.Get(e.ID) got, _ := s.Get(ctx, e.ID)
if got.CardType == nil || *got.CardType != CardSnippet { if got.CardType == nil || *got.CardType != CardSnippet {
t.Errorf("expected snippet, got %v", got.CardType) t.Errorf("expected snippet, got %v", got.CardType)
} }
@@ -302,26 +317,28 @@ func TestPromote_Success(t *testing.T) {
func TestPromote_AlreadyPromoted(t *testing.T) { func TestPromote_AlreadyPromoted(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
ct := CardSnippet ct := CardSnippet
e := &Entity{Body: "trick", Glyph: GlyphNote, CardType: &ct} e := &Entity{Body: "trick", Glyph: GlyphNote, CardType: &ct}
s.Create(e) s.Create(ctx, e)
if err := s.Promote(e.ID, CardTemplate, nil); err != ErrAlreadyPromoted { if err := s.Promote(ctx, e.ID, CardTemplate, nil); err != ErrAlreadyPromoted {
t.Errorf("expected ErrAlreadyPromoted, got %v", err) t.Errorf("expected ErrAlreadyPromoted, got %v", err)
} }
} }
func TestDemote_Success(t *testing.T) { func TestDemote_Success(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "trick", Glyph: GlyphNote} e := &Entity{Body: "trick", Glyph: GlyphNote}
s.Create(e) s.Create(ctx, e)
s.Promote(e.ID, CardSnippet, nil) s.Promote(ctx, e.ID, CardSnippet, nil)
if err := s.Demote(e.ID); err != nil { if err := s.Demote(ctx, e.ID); err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, _ := s.Get(e.ID) got, _ := s.Get(ctx, e.ID)
if got.CardType != nil { if got.CardType != nil {
t.Errorf("expected nil card_type, got %v", got.CardType) t.Errorf("expected nil card_type, got %v", got.CardType)
} }
@@ -332,20 +349,22 @@ func TestDemote_Success(t *testing.T) {
func TestDemote_AlreadyFluid(t *testing.T) { func TestDemote_AlreadyFluid(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "trick", Glyph: GlyphNote} e := &Entity{Body: "trick", Glyph: GlyphNote}
s.Create(e) s.Create(ctx, e)
if err := s.Demote(e.ID); err != ErrAlreadyFluid { if err := s.Demote(ctx, e.ID); err != ErrAlreadyFluid {
t.Errorf("expected ErrAlreadyFluid, got %v", err) t.Errorf("expected ErrAlreadyFluid, got %v", err)
} }
} }
func TestSoftDelete_First(t *testing.T) { func TestSoftDelete_First(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "doomed", Glyph: GlyphNote} e := &Entity{Body: "doomed", Glyph: GlyphNote}
s.Create(e) s.Create(ctx, e)
result, err := s.SoftDelete(e.ID) result, err := s.SoftDelete(ctx, e.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -353,7 +372,7 @@ func TestSoftDelete_First(t *testing.T) {
t.Errorf("expected DeletedSoft, got %d", result) t.Errorf("expected DeletedSoft, got %d", result)
} }
got, _ := s.Get(e.ID) got, _ := s.Get(ctx, e.ID)
if got.DeletedAt == nil { if got.DeletedAt == nil {
t.Error("expected deleted_at to be set") t.Error("expected deleted_at to be set")
} }
@@ -361,11 +380,12 @@ func TestSoftDelete_First(t *testing.T) {
func TestSoftDelete_Second(t *testing.T) { func TestSoftDelete_Second(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "doomed", Glyph: GlyphNote} e := &Entity{Body: "doomed", Glyph: GlyphNote}
s.Create(e) s.Create(ctx, e)
s.SoftDelete(e.ID) s.SoftDelete(ctx, e.ID)
result, err := s.SoftDelete(e.ID) result, err := s.SoftDelete(ctx, e.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -373,7 +393,7 @@ func TestSoftDelete_Second(t *testing.T) {
t.Errorf("expected DeletedHard, got %d", result) t.Errorf("expected DeletedHard, got %d", result)
} }
_, err = s.Get(e.ID) _, err = s.Get(ctx, e.ID)
if err != ErrNotFound { if err != ErrNotFound {
t.Errorf("expected ErrNotFound after hard delete, got %v", err) t.Errorf("expected ErrNotFound after hard delete, got %v", err)
} }
@@ -381,7 +401,7 @@ func TestSoftDelete_Second(t *testing.T) {
func TestSoftDelete_NotFound(t *testing.T) { func TestSoftDelete_NotFound(t *testing.T) {
s := testStore(t) s := testStore(t)
_, err := s.SoftDelete("01NONEXISTENT0000000000000") _, err := s.SoftDelete(context.Background(), "01NONEXISTENT0000000000000")
if err != ErrNotFound { if err != ErrNotFound {
t.Errorf("expected ErrNotFound, got %v", err) t.Errorf("expected ErrNotFound, got %v", err)
} }
@@ -389,15 +409,16 @@ func TestSoftDelete_NotFound(t *testing.T) {
func TestIncrementUse(t *testing.T) { func TestIncrementUse(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
ct := CardSnippet ct := CardSnippet
e := &Entity{Body: "trick", Glyph: GlyphNote, CardType: &ct} e := &Entity{Body: "trick", Glyph: GlyphNote, CardType: &ct}
s.Create(e) s.Create(ctx, e)
if err := s.IncrementUse(e.ID); err != nil { if err := s.IncrementUse(ctx, e.ID); err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, _ := s.Get(e.ID) got, _ := s.Get(ctx, e.ID)
if got.UseCount != 1 { if got.UseCount != 1 {
t.Errorf("expected use_count=1, got %d", got.UseCount) t.Errorf("expected use_count=1, got %d", got.UseCount)
} }
@@ -408,10 +429,11 @@ func TestIncrementUse(t *testing.T) {
func TestResolve_FullID(t *testing.T) { func TestResolve_FullID(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "test", Glyph: GlyphNote} e := &Entity{Body: "test", Glyph: GlyphNote}
s.Create(e) s.Create(ctx, e)
got, err := s.Resolve(e.ID) got, err := s.Resolve(ctx, e.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -422,10 +444,11 @@ func TestResolve_FullID(t *testing.T) {
func TestResolve_Prefix(t *testing.T) { func TestResolve_Prefix(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "test", Glyph: GlyphNote} e := &Entity{Body: "test", Glyph: GlyphNote}
s.Create(e) s.Create(ctx, e)
got, err := s.Resolve(e.ID[:6]) got, err := s.Resolve(ctx, e.ID[:6])
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -436,7 +459,7 @@ func TestResolve_Prefix(t *testing.T) {
func TestResolve_NotFound(t *testing.T) { func TestResolve_NotFound(t *testing.T) {
s := testStore(t) s := testStore(t)
_, err := s.Resolve("ZZZZZZZZZ") _, err := s.Resolve(context.Background(), "ZZZZZZZZZ")
if err != ErrNotFound { if err != ErrNotFound {
t.Errorf("expected ErrNotFound, got %v", err) t.Errorf("expected ErrNotFound, got %v", err)
} }
@@ -444,24 +467,25 @@ func TestResolve_NotFound(t *testing.T) {
func TestAbsorb_SourceIsCard(t *testing.T) { func TestAbsorb_SourceIsCard(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
target := &Entity{Body: "target", Glyph: GlyphNote, Tags: []string{"a"}} target := &Entity{Body: "target", Glyph: GlyphNote, Tags: []string{"a"}}
s.Create(target) s.Create(ctx, target)
source := &Entity{Body: "source", Glyph: GlyphNote} source := &Entity{Body: "source", Glyph: GlyphNote}
s.Create(source) s.Create(ctx, source)
s.Promote(source.ID, CardSnippet, nil) s.Promote(ctx, source.ID, CardSnippet, nil)
s.IncrementUse(source.ID) s.IncrementUse(ctx, source.ID)
if err := s.Absorb(target.ID, source.ID); err != nil { if err := s.Absorb(ctx, target.ID, source.ID); err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, _ := s.Get(target.ID) got, _ := s.Get(ctx, target.ID)
if got.Body != "target\nsource" { if got.Body != "target\nsource" {
t.Errorf("merged body: %q", got.Body) t.Errorf("merged body: %q", got.Body)
} }
src, _ := s.Get(source.ID) src, _ := s.Get(ctx, source.ID)
if src.CardType != nil { if src.CardType != nil {
t.Error("source card_type should be cleared after absorb") t.Error("source card_type should be cleared after absorb")
} }
@@ -475,6 +499,7 @@ func TestAbsorb_SourceIsCard(t *testing.T) {
func TestCreate_WithTitleAndDescription(t *testing.T) { func TestCreate_WithTitleAndDescription(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{ e := &Entity{
Body: "body text", Body: "body text",
Title: ptr("nginx trick"), Title: ptr("nginx trick"),
@@ -482,11 +507,11 @@ func TestCreate_WithTitleAndDescription(t *testing.T) {
Glyph: GlyphNote, Glyph: GlyphNote,
Tags: []string{"ops"}, Tags: []string{"ops"},
} }
if err := s.Create(e); err != nil { if err := s.Create(ctx, e); err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, err := s.Get(e.ID) got, err := s.Get(ctx, e.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -503,12 +528,13 @@ func TestCreate_WithTitleAndDescription(t *testing.T) {
func TestCreate_WithoutTitle(t *testing.T) { func TestCreate_WithoutTitle(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "just body", Glyph: GlyphNote} e := &Entity{Body: "just body", Glyph: GlyphNote}
if err := s.Create(e); err != nil { if err := s.Create(ctx, e); err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, _ := s.Get(e.ID) got, _ := s.Get(ctx, e.ID)
if got.Title != nil { if got.Title != nil {
t.Errorf("expected nil title, got %v", got.Title) t.Errorf("expected nil title, got %v", got.Title)
} }
@@ -519,15 +545,16 @@ func TestCreate_WithoutTitle(t *testing.T) {
func TestUpdate_Title(t *testing.T) { func TestUpdate_Title(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "body", Glyph: GlyphNote} e := &Entity{Body: "body", Glyph: GlyphNote}
s.Create(e) s.Create(ctx, e)
newTitle := "new title" newTitle := "new title"
if err := s.Update(e.ID, &EntityUpdate{Title: &newTitle}); err != nil { if err := s.Update(ctx, e.ID, &EntityUpdate{Title: &newTitle}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, _ := s.Get(e.ID) got, _ := s.Get(ctx, e.ID)
if got.Title == nil || *got.Title != "new title" { if got.Title == nil || *got.Title != "new title" {
t.Errorf("title: got %v", got.Title) t.Errorf("title: got %v", got.Title)
} }
@@ -535,15 +562,16 @@ func TestUpdate_Title(t *testing.T) {
func TestUpdate_Description(t *testing.T) { func TestUpdate_Description(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "body", Glyph: GlyphNote} e := &Entity{Body: "body", Glyph: GlyphNote}
s.Create(e) s.Create(ctx, e)
newDesc := "new desc" newDesc := "new desc"
if err := s.Update(e.ID, &EntityUpdate{Description: &newDesc}); err != nil { if err := s.Update(ctx, e.ID, &EntityUpdate{Description: &newDesc}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, _ := s.Get(e.ID) got, _ := s.Get(ctx, e.ID)
if got.Description == nil || *got.Description != "new desc" { if got.Description == nil || *got.Description != "new desc" {
t.Errorf("description: got %v", got.Description) t.Errorf("description: got %v", got.Description)
} }
@@ -551,16 +579,17 @@ func TestUpdate_Description(t *testing.T) {
func TestAbsorb_PreservesTargetTitle(t *testing.T) { func TestAbsorb_PreservesTargetTitle(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
target := &Entity{Body: "target body", Title: ptr("target title"), Glyph: GlyphNote} target := &Entity{Body: "target body", Title: ptr("target title"), Glyph: GlyphNote}
source := &Entity{Body: "source body", Title: ptr("source title"), Glyph: GlyphNote} source := &Entity{Body: "source body", Title: ptr("source title"), Glyph: GlyphNote}
s.Create(target) s.Create(ctx, target)
s.Create(source) s.Create(ctx, source)
if err := s.Absorb(target.ID, source.ID); err != nil { if err := s.Absorb(ctx, target.ID, source.ID); err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, _ := s.Get(target.ID) got, _ := s.Get(ctx, target.ID)
if got.Title == nil || *got.Title != "target title" { if got.Title == nil || *got.Title != "target title" {
t.Errorf("target title should be preserved, got %v", got.Title) t.Errorf("target title should be preserved, got %v", got.Title)
} }
+4 -2
View File
@@ -1,16 +1,18 @@
package db package db
import "context"
type TagCount struct { type TagCount struct {
Tag string Tag string
Count int Count int
} }
func (s *Store) ListTags(cardsOnly bool) ([]TagCount, error) { func (s *Store) ListTags(ctx context.Context, cardsOnly bool) ([]TagCount, error) {
where := "WHERE e.deleted_at IS NULL" where := "WHERE e.deleted_at IS NULL"
if cardsOnly { if cardsOnly {
where += " AND e.card_type IS NOT NULL" where += " AND e.card_type IS NOT NULL"
} }
rows, err := s.db.Query(` rows, err := s.db.QueryContext(ctx, `
SELECT t.tag, COUNT(*) as cnt SELECT t.tag, COUNT(*) as cnt
FROM entity_tags t FROM entity_tags t
JOIN entities e ON t.entity_id = e.id JOIN entities e ON t.entity_id = e.id
+20 -14
View File
@@ -1,10 +1,13 @@
package db package db
import "testing" import (
"context"
"testing"
)
func TestListTags_Empty(t *testing.T) { func TestListTags_Empty(t *testing.T) {
s := testStore(t) s := testStore(t)
tags, err := s.ListTags(false) tags, err := s.ListTags(context.Background(), false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -15,11 +18,12 @@ func TestListTags_Empty(t *testing.T) {
func TestListTags_Counts(t *testing.T) { func TestListTags_Counts(t *testing.T) {
s := testStore(t) s := testStore(t)
s.Create(&Entity{Body: "a", Glyph: GlyphNote, Tags: []string{"ops", "nginx"}}) ctx := context.Background()
s.Create(&Entity{Body: "b", Glyph: GlyphNote, Tags: []string{"ops"}}) s.Create(ctx, &Entity{Body: "a", Glyph: GlyphNote, Tags: []string{"ops", "nginx"}})
s.Create(&Entity{Body: "c", Glyph: GlyphNote, Tags: []string{"home"}}) s.Create(ctx, &Entity{Body: "b", Glyph: GlyphNote, Tags: []string{"ops"}})
s.Create(ctx, &Entity{Body: "c", Glyph: GlyphNote, Tags: []string{"home"}})
tags, err := s.ListTags(false) tags, err := s.ListTags(ctx, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -44,13 +48,14 @@ func TestListTags_Counts(t *testing.T) {
func TestListTags_ExcludesDeleted(t *testing.T) { func TestListTags_ExcludesDeleted(t *testing.T) {
s := testStore(t) s := testStore(t)
ctx := context.Background()
e := &Entity{Body: "doomed", Glyph: GlyphNote, Tags: []string{"gone"}} e := &Entity{Body: "doomed", Glyph: GlyphNote, Tags: []string{"gone"}}
s.Create(e) s.Create(ctx, e)
s.SoftDelete(e.ID) s.SoftDelete(ctx, e.ID)
s.Create(&Entity{Body: "alive", Glyph: GlyphNote, Tags: []string{"here"}}) s.Create(ctx, &Entity{Body: "alive", Glyph: GlyphNote, Tags: []string{"here"}})
tags, err := s.ListTags(false) tags, err := s.ListTags(ctx, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -64,12 +69,13 @@ func TestListTags_ExcludesDeleted(t *testing.T) {
func TestListTags_CardsOnly(t *testing.T) { func TestListTags_CardsOnly(t *testing.T) {
s := testStore(t) s := testStore(t)
s.Create(&Entity{Body: "fluid", Glyph: GlyphNote, Tags: []string{"ops", "shared"}}) ctx := context.Background()
s.Create(ctx, &Entity{Body: "fluid", Glyph: GlyphNote, Tags: []string{"ops", "shared"}})
ct := CardSnippet ct := CardSnippet
s.Create(&Entity{Body: "card", Glyph: GlyphNote, Tags: []string{"ops", "code"}, CardType: &ct}) s.Create(ctx, &Entity{Body: "card", Glyph: GlyphNote, Tags: []string{"ops", "code"}, CardType: &ct})
all, err := s.ListTags(false) all, err := s.ListTags(ctx, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -77,7 +83,7 @@ func TestListTags_CardsOnly(t *testing.T) {
t.Fatalf("all tags: expected 3, got %d", len(all)) t.Fatalf("all tags: expected 3, got %d", len(all))
} }
cards, err := s.ListTags(true) cards, err := s.ListTags(ctx, true)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
+21 -20
View File
@@ -1,6 +1,7 @@
package tui package tui
import ( import (
"context"
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
@@ -82,7 +83,7 @@ type errMsg struct {
func loadEntities(store *db.Store, params db.ListParams) tea.Cmd { func loadEntities(store *db.Store, params db.ListParams) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
entities, err := store.List(params) entities, err := store.List(context.Background(), params)
if err != nil { if err != nil {
return errMsg{err} return errMsg{err}
} }
@@ -92,7 +93,7 @@ func loadEntities(store *db.Store, params db.ListParams) tea.Cmd {
func createEntity(store *db.Store, e *db.Entity) tea.Cmd { func createEntity(store *db.Store, e *db.Entity) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
if err := store.Create(e); err != nil { if err := store.Create(context.Background(), e); err != nil {
return errMsg{err} return errMsg{err}
} }
return entityCreatedMsg{e} return entityCreatedMsg{e}
@@ -101,7 +102,7 @@ func createEntity(store *db.Store, e *db.Entity) tea.Cmd {
func deleteEntity(store *db.Store, id string) tea.Cmd { func deleteEntity(store *db.Store, id string) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
if _, err := store.SoftDelete(id); err != nil { if _, err := store.SoftDelete(context.Background(), id); err != nil {
return errMsg{err} return errMsg{err}
} }
return entityDeletedMsg{id} return entityDeletedMsg{id}
@@ -118,10 +119,10 @@ func toggleTodo(store *db.Store, e *db.Entity) tea.Cmd {
update = db.EntityUpdate{ClearCompleted: true} update = db.EntityUpdate{ClearCompleted: true}
} }
if err := store.Update(e.ID, &update); err != nil { if err := store.Update(context.Background(), e.ID, &update); err != nil {
return errMsg{err} return errMsg{err}
} }
updated, err := store.Get(e.ID) updated, err := store.Get(context.Background(), e.ID)
if err != nil { if err != nil {
return errMsg{err} return errMsg{err}
} }
@@ -137,10 +138,10 @@ func pinEntity(store *db.Store, e *db.Entity) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
newPinned := !e.Pinned newPinned := !e.Pinned
update := db.EntityUpdate{Pinned: &newPinned} update := db.EntityUpdate{Pinned: &newPinned}
if err := store.Update(e.ID, &update); err != nil { if err := store.Update(context.Background(), e.ID, &update); err != nil {
return errMsg{err} return errMsg{err}
} }
updated, err := store.Get(e.ID) updated, err := store.Get(context.Background(), e.ID)
if err != nil { if err != nil {
return errMsg{err} return errMsg{err}
} }
@@ -155,7 +156,7 @@ func pinEntity(store *db.Store, e *db.Entity) tea.Cmd {
func promoteEntity(store *db.Store, id string, ct db.CardType, body string) tea.Cmd { func promoteEntity(store *db.Store, id string, ct db.CardType, body string) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
cd := carddata.GenerateCardData(ct, body) cd := carddata.GenerateCardData(ct, body)
if err := store.Promote(id, ct, cd); err != nil { if err := store.Promote(context.Background(), id, ct, cd); err != nil {
return errMsg{err} return errMsg{err}
} }
return entityPromotedMsg{id, ct} return entityPromotedMsg{id, ct}
@@ -164,7 +165,7 @@ func promoteEntity(store *db.Store, id string, ct db.CardType, body string) tea.
func demoteEntity(store *db.Store, id string) tea.Cmd { func demoteEntity(store *db.Store, id string) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
if err := store.Demote(id); err != nil { if err := store.Demote(context.Background(), id); err != nil {
return errMsg{err} return errMsg{err}
} }
return entityDemotedMsg{id} return entityDemotedMsg{id}
@@ -176,7 +177,7 @@ func copyToClipboard(store *db.Store, e *db.Entity) tea.Cmd {
if err := clipboard.WriteAll(e.Body); err != nil { if err := clipboard.WriteAll(e.Body); err != nil {
return errMsg{err} return errMsg{err}
} }
if err := store.IncrementUse(e.ID); err != nil { if err := store.IncrementUse(context.Background(), e.ID); err != nil {
return errMsg{err} return errMsg{err}
} }
return entityCopiedMsg{} return entityCopiedMsg{}
@@ -185,7 +186,7 @@ func copyToClipboard(store *db.Store, e *db.Entity) tea.Cmd {
func loadTags(store *db.Store) tea.Cmd { func loadTags(store *db.Store) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
tags, err := store.ListTags(false) tags, err := store.ListTags(context.Background(), false)
if err != nil { if err != nil {
return errMsg{err} return errMsg{err}
} }
@@ -195,7 +196,7 @@ func loadTags(store *db.Store) tea.Cmd {
func loadRailTags(store *db.Store) tea.Cmd { func loadRailTags(store *db.Store) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
tags, err := store.ListTags(false) tags, err := store.ListTags(context.Background(), false)
if err != nil { if err != nil {
return errMsg{err} return errMsg{err}
} }
@@ -243,7 +244,7 @@ func editInEditor(store *db.Store, e *db.Entity) tea.Cmd {
} }
update := db.EntityUpdate{Body: &newBody} update := db.EntityUpdate{Body: &newBody}
if updateErr := store.Update(e.ID, &update); updateErr != nil { if updateErr := store.Update(context.Background(), e.ID, &update); updateErr != nil {
return editorFinishedMsg{updateErr} return editorFinishedMsg{updateErr}
} }
@@ -253,7 +254,7 @@ func editInEditor(store *db.Store, e *db.Entity) tea.Cmd {
func loadAbsorbSources(store *db.Store, targetID string) tea.Cmd { func loadAbsorbSources(store *db.Store, targetID string) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
entities, err := store.List(db.DefaultListParams()) entities, err := store.List(context.Background(), db.DefaultListParams())
if err != nil { if err != nil {
return errMsg{err} return errMsg{err}
} }
@@ -263,7 +264,7 @@ func loadAbsorbSources(store *db.Store, targetID string) tea.Cmd {
func absorbEntity(store *db.Store, targetID, sourceID string) tea.Cmd { func absorbEntity(store *db.Store, targetID, sourceID string) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
if err := store.Absorb(targetID, sourceID); err != nil { if err := store.Absorb(context.Background(), targetID, sourceID); err != nil {
return errMsg{err} return errMsg{err}
} }
return entityAbsorbedMsg{targetID} return entityAbsorbedMsg{targetID}
@@ -273,7 +274,7 @@ func absorbEntity(store *db.Store, targetID, sourceID string) tea.Cmd {
func persistSteps(store *db.Store, entityID string, stepsJSON string) tea.Cmd { func persistSteps(store *db.Store, entityID string, stepsJSON string) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
update := db.EntityUpdate{CardData: &stepsJSON} update := db.EntityUpdate{CardData: &stepsJSON}
if err := store.Update(entityID, &update); err != nil { if err := store.Update(context.Background(), entityID, &update); err != nil {
return errMsg{err} return errMsg{err}
} }
return stepsPersistedMsg{} return stepsPersistedMsg{}
@@ -285,7 +286,7 @@ func copyResolved(store *db.Store, entityID string, resolved string) tea.Cmd {
if err := clipboard.WriteAll(resolved); err != nil { if err := clipboard.WriteAll(resolved); err != nil {
return errMsg{err} return errMsg{err}
} }
if err := store.IncrementUse(entityID); err != nil { if err := store.IncrementUse(context.Background(), entityID); err != nil {
return errMsg{err} return errMsg{err}
} }
return templateCopiedMsg{} return templateCopiedMsg{}
@@ -300,7 +301,7 @@ func clearStatusAfter(d time.Duration, seq int) tea.Cmd {
func loadStaleEntities(store *db.Store) tea.Cmd { func loadStaleEntities(store *db.Store) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
entities, err := store.List(staleParams()) entities, err := store.List(context.Background(), staleParams())
if err != nil { if err != nil {
return errMsg{err} return errMsg{err}
} }
@@ -310,7 +311,7 @@ func loadStaleEntities(store *db.Store) tea.Cmd {
func stumbleDismiss(store *db.Store, id string) tea.Cmd { func stumbleDismiss(store *db.Store, id string) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
if _, err := store.SoftDelete(id); err != nil { if _, err := store.SoftDelete(context.Background(), id); err != nil {
return errMsg{err} return errMsg{err}
} }
return stumbleActionMsg{"dismissed"} return stumbleActionMsg{"dismissed"}
@@ -321,7 +322,7 @@ func stumblePin(store *db.Store, id string) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
pinned := true pinned := true
update := db.EntityUpdate{Pinned: &pinned} update := db.EntityUpdate{Pinned: &pinned}
if err := store.Update(id, &update); err != nil { if err := store.Update(context.Background(), id, &update); err != nil {
return errMsg{err} return errMsg{err}
} }
return stumbleActionMsg{"pinned"} return stumbleActionMsg{"pinned"}
+2 -1
View File
@@ -1,6 +1,7 @@
package tui package tui
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@@ -1095,7 +1096,7 @@ func (m model) reloadDetail(id string) tea.Cmd {
return tea.Batch( return tea.Batch(
loadEntities(m.store, m.listParams()), loadEntities(m.store, m.listParams()),
func() tea.Msg { func() tea.Msg {
e, err := m.store.Get(id) e, err := m.store.Get(context.Background(), id)
if err != nil { if err != nil {
return errMsg{err} return errMsg{err}
} }