refactor(db): thread context.Context through all Store methods

Enables request-scoped cancellation, timeouts, and graceful shutdown
for all database operations across API handlers, CLI commands, and TUI.
This commit is contained in:
2026-05-20 20:51:51 -04:00
parent 50b80f4407
commit d715b053e7
18 changed files with 267 additions and 228 deletions
+47 -48
View File
@@ -1,6 +1,7 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"fmt"
@@ -104,7 +105,7 @@ type EntityUpdate struct {
Tags *[]string
}
func (s *Store) Create(e *Entity) error {
func (s *Store) Create(ctx context.Context, e *Entity) error {
if e.CardData != nil && !json.Valid([]byte(*e.CardData)) {
return ErrInvalidCardData
}
@@ -116,13 +117,13 @@ func (s *Store) Create(e *Entity) error {
e.Tags = []string{}
}
tx, err := s.db.Begin()
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
_, err = tx.Exec(`
_, err = tx.ExecContext(ctx, `
INSERT INTO entities (id, created_at, modified_at, body, title, description,
glyph, time_anchor, completed_at, pinned, deleted_at,
card_type, card_data, use_count, last_used_at)
@@ -147,18 +148,18 @@ func (s *Store) Create(e *Entity) error {
return err
}
if err := insertTags(tx, e.ID, e.Tags); err != nil {
if err := insertTags(ctx, tx, e.ID, e.Tags); err != nil {
return err
}
return tx.Commit()
}
func (s *Store) Get(id string) (*Entity, error) {
func (s *Store) Get(ctx context.Context, id string) (*Entity, error) {
e := &Entity{}
row := newEntityRow()
err := s.db.QueryRow(`
err := s.db.QueryRowContext(ctx, `
SELECT id, created_at, modified_at, body, title, description,
glyph, time_anchor, completed_at, pinned, deleted_at,
card_type, card_data, use_count, last_used_at
@@ -174,7 +175,7 @@ func (s *Store) Get(id string) (*Entity, error) {
return nil, fmt.Errorf("scan entity %s: %w", id, err)
}
tags, err := s.loadTags(id)
tags, err := s.loadTags(ctx, id)
if err != nil {
return nil, err
}
@@ -229,15 +230,15 @@ func listWhere(params ListParams) (string, []any) {
return clause, args
}
func (s *Store) Count(params ListParams) (int, error) {
func (s *Store) Count(ctx context.Context, params ListParams) (int, error) {
whereClause, args := listWhere(params)
query := fmt.Sprintf("SELECT COUNT(*) FROM entities e %s", whereClause)
var count int
err := s.db.QueryRow(query, args...).Scan(&count)
err := s.db.QueryRowContext(ctx, query, args...).Scan(&count)
return count, err
}
func (s *Store) List(params ListParams) ([]*Entity, error) {
func (s *Store) List(ctx context.Context, params ListParams) ([]*Entity, error) {
whereClause, args := listWhere(params)
orderCol := "e.created_at"
@@ -275,7 +276,7 @@ func (s *Store) List(params ListParams) ([]*Entity, error) {
args = append(args, limit, params.Offset)
rows, err := s.db.Query(query, args...)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
@@ -297,20 +298,20 @@ func (s *Store) List(params ListParams) ([]*Entity, error) {
return nil, err
}
if err := s.batchLoadTags(entities); err != nil {
if err := s.batchLoadTags(ctx, entities); err != nil {
return nil, err
}
return entities, nil
}
func (s *Store) Update(id string, u *EntityUpdate) error {
existing, err := s.Get(id)
func (s *Store) Update(ctx context.Context, id string, u *EntityUpdate) error {
existing, err := s.Get(ctx, id)
if err != nil {
return err
}
tx, err := s.db.Begin()
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
@@ -369,15 +370,15 @@ func (s *Store) Update(id string, u *EntityUpdate) error {
args = append(args, existing.ID)
query := fmt.Sprintf("UPDATE entities SET %s WHERE id = ?", strings.Join(sets, ", "))
if _, err := tx.Exec(query, args...); err != nil {
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
return err
}
if u.Tags != nil {
if _, err := tx.Exec("DELETE FROM entity_tags WHERE entity_id = ?", existing.ID); err != nil {
if _, err := tx.ExecContext(ctx, "DELETE FROM entity_tags WHERE entity_id = ?", existing.ID); err != nil {
return err
}
if err := insertTags(tx, existing.ID, *u.Tags); err != nil {
if err := insertTags(ctx, tx, existing.ID, *u.Tags); err != nil {
return err
}
}
@@ -385,8 +386,8 @@ func (s *Store) Update(id string, u *EntityUpdate) error {
return tx.Commit()
}
func (s *Store) Promote(id string, cardType CardType, cardData *string) error {
e, err := s.Get(id)
func (s *Store) Promote(ctx context.Context, id string, cardType CardType, cardData *string) error {
e, err := s.Get(ctx, id)
if err != nil {
return err
}
@@ -402,15 +403,15 @@ func (s *Store) Promote(id string, cardType CardType, cardData *string) error {
dataVal = *cardData
}
_, err = s.db.Exec(`
_, err = s.db.ExecContext(ctx, `
UPDATE entities SET card_type = ?, card_data = ?, modified_at = ?
WHERE id = ?`,
string(cardType), dataVal, time.Now().UTC().Format(time.RFC3339), id)
return err
}
func (s *Store) Demote(id string) error {
e, err := s.Get(id)
func (s *Store) Demote(ctx context.Context, id string) error {
e, err := s.Get(ctx, id)
if err != nil {
return err
}
@@ -418,7 +419,7 @@ func (s *Store) Demote(id string) error {
return ErrAlreadyFluid
}
_, err = s.db.Exec(`
_, err = s.db.ExecContext(ctx, `
UPDATE entities SET card_type = NULL, card_data = NULL,
use_count = 0, last_used_at = NULL, modified_at = ?
WHERE id = ?`,
@@ -433,9 +434,9 @@ const (
DeletedHard
)
func (s *Store) SoftDelete(id string) (DeleteResult, error) {
func (s *Store) SoftDelete(ctx context.Context, id string) (DeleteResult, error) {
var deletedAt sql.NullString
err := s.db.QueryRow("SELECT deleted_at FROM entities WHERE id = ?", id).Scan(&deletedAt)
err := s.db.QueryRowContext(ctx, "SELECT deleted_at FROM entities WHERE id = ?", id).Scan(&deletedAt)
if err == sql.ErrNoRows {
return 0, ErrNotFound
}
@@ -444,21 +445,21 @@ func (s *Store) SoftDelete(id string) (DeleteResult, error) {
}
if deletedAt.Valid {
_, err = s.db.Exec("DELETE FROM entities WHERE id = ?", id)
_, err = s.db.ExecContext(ctx, "DELETE FROM entities WHERE id = ?", id)
return DeletedHard, err
}
_, err = s.db.Exec("UPDATE entities SET deleted_at = ? WHERE id = ?",
_, err = s.db.ExecContext(ctx, "UPDATE entities SET deleted_at = ? WHERE id = ?",
time.Now().UTC().Format(time.RFC3339), id)
return DeletedSoft, err
}
func (s *Store) Absorb(targetID, sourceID string) error {
target, err := s.Get(targetID)
func (s *Store) Absorb(ctx context.Context, targetID, sourceID string) error {
target, err := s.Get(ctx, targetID)
if err != nil {
return err
}
source, err := s.Get(sourceID)
source, err := s.Get(ctx, sourceID)
if err != nil {
return err
}
@@ -467,7 +468,7 @@ func (s *Store) Absorb(targetID, sourceID string) error {
return ErrTargetCrystallized
}
tx, err := s.db.Begin()
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
@@ -476,7 +477,7 @@ func (s *Store) Absorb(targetID, sourceID string) error {
now := time.Now().UTC().Format(time.RFC3339)
merged := target.Body + "\n" + source.Body
if _, err := tx.Exec("UPDATE entities SET body = ?, modified_at = ? WHERE id = ?",
if _, err := tx.ExecContext(ctx, "UPDATE entities SET body = ?, modified_at = ? WHERE id = ?",
merged, now, targetID); err != nil {
return err
}
@@ -487,7 +488,7 @@ func (s *Store) Absorb(targetID, sourceID string) error {
}
for _, t := range source.Tags {
if !seen[t] {
if _, err := tx.Exec("INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)",
if _, err := tx.ExecContext(ctx, "INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)",
targetID, t); err != nil {
return err
}
@@ -495,7 +496,7 @@ func (s *Store) Absorb(targetID, sourceID string) error {
}
if source.CardType != nil {
if _, err := tx.Exec(`UPDATE entities SET card_type = NULL, card_data = NULL,
if _, err := tx.ExecContext(ctx, `UPDATE entities SET card_type = NULL, card_data = NULL,
use_count = 0, last_used_at = NULL, modified_at = ? WHERE id = ?`,
now, sourceID); err != nil {
return err
@@ -503,7 +504,7 @@ func (s *Store) Absorb(targetID, sourceID string) error {
}
absorbNote := source.Body + "\n\n[absorbed into " + targetID + "]"
if _, err := tx.Exec("UPDATE entities SET body = ?, deleted_at = ?, modified_at = ? WHERE id = ?",
if _, err := tx.ExecContext(ctx, "UPDATE entities SET body = ?, deleted_at = ?, modified_at = ? WHERE id = ?",
absorbNote, now, now, sourceID); err != nil {
return err
}
@@ -511,8 +512,8 @@ func (s *Store) Absorb(targetID, sourceID string) error {
return tx.Commit()
}
func (s *Store) IncrementUse(id string) error {
res, err := s.db.Exec(`
func (s *Store) IncrementUse(ctx context.Context, id string) error {
res, err := s.db.ExecContext(ctx, `
UPDATE entities SET use_count = use_count + 1, last_used_at = ?
WHERE id = ?`,
time.Now().UTC().Format(time.RFC3339), id)
@@ -526,8 +527,8 @@ func (s *Store) IncrementUse(id string) error {
return nil
}
func (s *Store) Resolve(prefix string) (string, error) {
rows, err := s.db.Query("SELECT id FROM entities WHERE id LIKE ?", prefix+"%")
func (s *Store) Resolve(ctx context.Context, prefix string) (string, error) {
rows, err := s.db.QueryContext(ctx, "SELECT id FROM entities WHERE id LIKE ?", prefix+"%")
if err != nil {
return "", err
}
@@ -593,9 +594,7 @@ func (r *entityRow) apply(e *Entity) error {
return nil
}
// helpers
func (s *Store) batchLoadTags(entities []*Entity) error {
func (s *Store) batchLoadTags(ctx context.Context, entities []*Entity) error {
if len(entities) == 0 {
return nil
}
@@ -615,7 +614,7 @@ func (s *Store) batchLoadTags(entities []*Entity) error {
strings.Join(placeholders, ","),
)
rows, err := s.db.Query(query, args...)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return err
}
@@ -633,8 +632,8 @@ func (s *Store) batchLoadTags(entities []*Entity) error {
return rows.Err()
}
func (s *Store) loadTags(entityID string) ([]string, error) {
rows, err := s.db.Query("SELECT tag FROM entity_tags WHERE entity_id = ? ORDER BY tag", entityID)
func (s *Store) loadTags(ctx context.Context, entityID string) ([]string, error) {
rows, err := s.db.QueryContext(ctx, "SELECT tag FROM entity_tags WHERE entity_id = ? ORDER BY tag", entityID)
if err != nil {
return nil, err
}
@@ -657,9 +656,9 @@ func (s *Store) loadTags(entityID string) ([]string, error) {
return tags, nil
}
func insertTags(tx *sql.Tx, entityID string, tags []string) error {
func insertTags(ctx context.Context, tx *sql.Tx, entityID string, tags []string) error {
for _, tag := range tags {
if _, err := tx.Exec("INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)",
if _, err := tx.ExecContext(ctx, "INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)",
entityID, tag); err != nil {
return err
}