refactor(db): thread context.Context through all Store methods
Enables request-scoped cancellation, timeouts, and graceful shutdown for all database operations across API handlers, CLI commands, and TUI.
This commit is contained in:
+47
-48
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -104,7 +105,7 @@ type EntityUpdate struct {
|
||||
Tags *[]string
|
||||
}
|
||||
|
||||
func (s *Store) Create(e *Entity) error {
|
||||
func (s *Store) Create(ctx context.Context, e *Entity) error {
|
||||
if e.CardData != nil && !json.Valid([]byte(*e.CardData)) {
|
||||
return ErrInvalidCardData
|
||||
}
|
||||
@@ -116,13 +117,13 @@ func (s *Store) Create(e *Entity) error {
|
||||
e.Tags = []string{}
|
||||
}
|
||||
|
||||
tx, err := s.db.Begin()
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec(`
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
INSERT INTO entities (id, created_at, modified_at, body, title, description,
|
||||
glyph, time_anchor, completed_at, pinned, deleted_at,
|
||||
card_type, card_data, use_count, last_used_at)
|
||||
@@ -147,18 +148,18 @@ func (s *Store) Create(e *Entity) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := insertTags(tx, e.ID, e.Tags); err != nil {
|
||||
if err := insertTags(ctx, tx, e.ID, e.Tags); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *Store) Get(id string) (*Entity, error) {
|
||||
func (s *Store) Get(ctx context.Context, id string) (*Entity, error) {
|
||||
e := &Entity{}
|
||||
row := newEntityRow()
|
||||
|
||||
err := s.db.QueryRow(`
|
||||
err := s.db.QueryRowContext(ctx, `
|
||||
SELECT id, created_at, modified_at, body, title, description,
|
||||
glyph, time_anchor, completed_at, pinned, deleted_at,
|
||||
card_type, card_data, use_count, last_used_at
|
||||
@@ -174,7 +175,7 @@ func (s *Store) Get(id string) (*Entity, error) {
|
||||
return nil, fmt.Errorf("scan entity %s: %w", id, err)
|
||||
}
|
||||
|
||||
tags, err := s.loadTags(id)
|
||||
tags, err := s.loadTags(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -229,15 +230,15 @@ func listWhere(params ListParams) (string, []any) {
|
||||
return clause, args
|
||||
}
|
||||
|
||||
func (s *Store) Count(params ListParams) (int, error) {
|
||||
func (s *Store) Count(ctx context.Context, params ListParams) (int, error) {
|
||||
whereClause, args := listWhere(params)
|
||||
query := fmt.Sprintf("SELECT COUNT(*) FROM entities e %s", whereClause)
|
||||
var count int
|
||||
err := s.db.QueryRow(query, args...).Scan(&count)
|
||||
err := s.db.QueryRowContext(ctx, query, args...).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *Store) List(params ListParams) ([]*Entity, error) {
|
||||
func (s *Store) List(ctx context.Context, params ListParams) ([]*Entity, error) {
|
||||
whereClause, args := listWhere(params)
|
||||
|
||||
orderCol := "e.created_at"
|
||||
@@ -275,7 +276,7 @@ func (s *Store) List(params ListParams) ([]*Entity, error) {
|
||||
|
||||
args = append(args, limit, params.Offset)
|
||||
|
||||
rows, err := s.db.Query(query, args...)
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -297,20 +298,20 @@ func (s *Store) List(params ListParams) ([]*Entity, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.batchLoadTags(entities); err != nil {
|
||||
if err := s.batchLoadTags(ctx, entities); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return entities, nil
|
||||
}
|
||||
|
||||
func (s *Store) Update(id string, u *EntityUpdate) error {
|
||||
existing, err := s.Get(id)
|
||||
func (s *Store) Update(ctx context.Context, id string, u *EntityUpdate) error {
|
||||
existing, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := s.db.Begin()
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -369,15 +370,15 @@ func (s *Store) Update(id string, u *EntityUpdate) error {
|
||||
args = append(args, existing.ID)
|
||||
query := fmt.Sprintf("UPDATE entities SET %s WHERE id = ?", strings.Join(sets, ", "))
|
||||
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if u.Tags != nil {
|
||||
if _, err := tx.Exec("DELETE FROM entity_tags WHERE entity_id = ?", existing.ID); err != nil {
|
||||
if _, err := tx.ExecContext(ctx, "DELETE FROM entity_tags WHERE entity_id = ?", existing.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := insertTags(tx, existing.ID, *u.Tags); err != nil {
|
||||
if err := insertTags(ctx, tx, existing.ID, *u.Tags); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -385,8 +386,8 @@ func (s *Store) Update(id string, u *EntityUpdate) error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *Store) Promote(id string, cardType CardType, cardData *string) error {
|
||||
e, err := s.Get(id)
|
||||
func (s *Store) Promote(ctx context.Context, id string, cardType CardType, cardData *string) error {
|
||||
e, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -402,15 +403,15 @@ func (s *Store) Promote(id string, cardType CardType, cardData *string) error {
|
||||
dataVal = *cardData
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(`
|
||||
_, err = s.db.ExecContext(ctx, `
|
||||
UPDATE entities SET card_type = ?, card_data = ?, modified_at = ?
|
||||
WHERE id = ?`,
|
||||
string(cardType), dataVal, time.Now().UTC().Format(time.RFC3339), id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) Demote(id string) error {
|
||||
e, err := s.Get(id)
|
||||
func (s *Store) Demote(ctx context.Context, id string) error {
|
||||
e, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -418,7 +419,7 @@ func (s *Store) Demote(id string) error {
|
||||
return ErrAlreadyFluid
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(`
|
||||
_, err = s.db.ExecContext(ctx, `
|
||||
UPDATE entities SET card_type = NULL, card_data = NULL,
|
||||
use_count = 0, last_used_at = NULL, modified_at = ?
|
||||
WHERE id = ?`,
|
||||
@@ -433,9 +434,9 @@ const (
|
||||
DeletedHard
|
||||
)
|
||||
|
||||
func (s *Store) SoftDelete(id string) (DeleteResult, error) {
|
||||
func (s *Store) SoftDelete(ctx context.Context, id string) (DeleteResult, error) {
|
||||
var deletedAt sql.NullString
|
||||
err := s.db.QueryRow("SELECT deleted_at FROM entities WHERE id = ?", id).Scan(&deletedAt)
|
||||
err := s.db.QueryRowContext(ctx, "SELECT deleted_at FROM entities WHERE id = ?", id).Scan(&deletedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, ErrNotFound
|
||||
}
|
||||
@@ -444,21 +445,21 @@ func (s *Store) SoftDelete(id string) (DeleteResult, error) {
|
||||
}
|
||||
|
||||
if deletedAt.Valid {
|
||||
_, err = s.db.Exec("DELETE FROM entities WHERE id = ?", id)
|
||||
_, err = s.db.ExecContext(ctx, "DELETE FROM entities WHERE id = ?", id)
|
||||
return DeletedHard, err
|
||||
}
|
||||
|
||||
_, err = s.db.Exec("UPDATE entities SET deleted_at = ? WHERE id = ?",
|
||||
_, err = s.db.ExecContext(ctx, "UPDATE entities SET deleted_at = ? WHERE id = ?",
|
||||
time.Now().UTC().Format(time.RFC3339), id)
|
||||
return DeletedSoft, err
|
||||
}
|
||||
|
||||
func (s *Store) Absorb(targetID, sourceID string) error {
|
||||
target, err := s.Get(targetID)
|
||||
func (s *Store) Absorb(ctx context.Context, targetID, sourceID string) error {
|
||||
target, err := s.Get(ctx, targetID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
source, err := s.Get(sourceID)
|
||||
source, err := s.Get(ctx, sourceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -467,7 +468,7 @@ func (s *Store) Absorb(targetID, sourceID string) error {
|
||||
return ErrTargetCrystallized
|
||||
}
|
||||
|
||||
tx, err := s.db.Begin()
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -476,7 +477,7 @@ func (s *Store) Absorb(targetID, sourceID string) error {
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
merged := target.Body + "\n" + source.Body
|
||||
|
||||
if _, err := tx.Exec("UPDATE entities SET body = ?, modified_at = ? WHERE id = ?",
|
||||
if _, err := tx.ExecContext(ctx, "UPDATE entities SET body = ?, modified_at = ? WHERE id = ?",
|
||||
merged, now, targetID); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -487,7 +488,7 @@ func (s *Store) Absorb(targetID, sourceID string) error {
|
||||
}
|
||||
for _, t := range source.Tags {
|
||||
if !seen[t] {
|
||||
if _, err := tx.Exec("INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)",
|
||||
if _, err := tx.ExecContext(ctx, "INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)",
|
||||
targetID, t); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -495,7 +496,7 @@ func (s *Store) Absorb(targetID, sourceID string) error {
|
||||
}
|
||||
|
||||
if source.CardType != nil {
|
||||
if _, err := tx.Exec(`UPDATE entities SET card_type = NULL, card_data = NULL,
|
||||
if _, err := tx.ExecContext(ctx, `UPDATE entities SET card_type = NULL, card_data = NULL,
|
||||
use_count = 0, last_used_at = NULL, modified_at = ? WHERE id = ?`,
|
||||
now, sourceID); err != nil {
|
||||
return err
|
||||
@@ -503,7 +504,7 @@ func (s *Store) Absorb(targetID, sourceID string) error {
|
||||
}
|
||||
|
||||
absorbNote := source.Body + "\n\n[absorbed into " + targetID + "]"
|
||||
if _, err := tx.Exec("UPDATE entities SET body = ?, deleted_at = ?, modified_at = ? WHERE id = ?",
|
||||
if _, err := tx.ExecContext(ctx, "UPDATE entities SET body = ?, deleted_at = ?, modified_at = ? WHERE id = ?",
|
||||
absorbNote, now, now, sourceID); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -511,8 +512,8 @@ func (s *Store) Absorb(targetID, sourceID string) error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *Store) IncrementUse(id string) error {
|
||||
res, err := s.db.Exec(`
|
||||
func (s *Store) IncrementUse(ctx context.Context, id string) error {
|
||||
res, err := s.db.ExecContext(ctx, `
|
||||
UPDATE entities SET use_count = use_count + 1, last_used_at = ?
|
||||
WHERE id = ?`,
|
||||
time.Now().UTC().Format(time.RFC3339), id)
|
||||
@@ -526,8 +527,8 @@ func (s *Store) IncrementUse(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) Resolve(prefix string) (string, error) {
|
||||
rows, err := s.db.Query("SELECT id FROM entities WHERE id LIKE ?", prefix+"%")
|
||||
func (s *Store) Resolve(ctx context.Context, prefix string) (string, error) {
|
||||
rows, err := s.db.QueryContext(ctx, "SELECT id FROM entities WHERE id LIKE ?", prefix+"%")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -593,9 +594,7 @@ func (r *entityRow) apply(e *Entity) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// helpers
|
||||
|
||||
func (s *Store) batchLoadTags(entities []*Entity) error {
|
||||
func (s *Store) batchLoadTags(ctx context.Context, entities []*Entity) error {
|
||||
if len(entities) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -615,7 +614,7 @@ func (s *Store) batchLoadTags(entities []*Entity) error {
|
||||
strings.Join(placeholders, ","),
|
||||
)
|
||||
|
||||
rows, err := s.db.Query(query, args...)
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -633,8 +632,8 @@ func (s *Store) batchLoadTags(entities []*Entity) error {
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) loadTags(entityID string) ([]string, error) {
|
||||
rows, err := s.db.Query("SELECT tag FROM entity_tags WHERE entity_id = ? ORDER BY tag", entityID)
|
||||
func (s *Store) loadTags(ctx context.Context, entityID string) ([]string, error) {
|
||||
rows, err := s.db.QueryContext(ctx, "SELECT tag FROM entity_tags WHERE entity_id = ? ORDER BY tag", entityID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -657,9 +656,9 @@ func (s *Store) loadTags(entityID string) ([]string, error) {
|
||||
return tags, nil
|
||||
}
|
||||
|
||||
func insertTags(tx *sql.Tx, entityID string, tags []string) error {
|
||||
func insertTags(ctx context.Context, tx *sql.Tx, entityID string, tags []string) error {
|
||||
for _, tag := range tags {
|
||||
if _, err := tx.Exec("INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)",
|
||||
if _, err := tx.ExecContext(ctx, "INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)",
|
||||
entityID, tag); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user