Files
nib-v1/internal/db/entities.go
T
lerko d715b053e7 refactor(db): thread context.Context through all Store methods
Enables request-scoped cancellation, timeouts, and graceful shutdown
for all database operations across API handlers, CLI commands, and TUI.
2026-05-20 20:51:51 -04:00

726 lines
16 KiB
Go

package db
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
nibulid "github.com/lerko/nib/internal/ulid"
)
type Glyph string
const (
GlyphNote Glyph = "note"
GlyphTodo Glyph = "todo"
GlyphEvent Glyph = "event"
GlyphReminder Glyph = "reminder"
)
type CardType string
const (
CardSnippet CardType = "snippet"
CardTemplate CardType = "template"
CardChecklist CardType = "checklist"
CardDecision CardType = "decision"
CardLink CardType = "link"
CardNote CardType = "note"
)
func ValidGlyph(s string) bool {
switch Glyph(s) {
case GlyphNote, GlyphTodo, GlyphEvent, GlyphReminder:
return true
}
return false
}
func ValidCardType(s string) bool {
switch CardType(s) {
case CardSnippet, CardTemplate, CardChecklist, CardDecision, CardLink, CardNote:
return true
}
return false
}
type Entity struct {
ID string
CreatedAt time.Time
ModifiedAt time.Time
Body string
Title *string
Description *string
Glyph Glyph
TimeAnchor *string
CompletedAt *time.Time
Pinned bool
DeletedAt *time.Time
CardType *CardType
CardData *string
UseCount int
LastUsedAt *time.Time
Tags []string
}
type ListParams struct {
Tag *string
Date *string
From *string
To *string
Since *time.Time
ModifiedBefore *time.Time
CardsOnly bool
IncludeDeleted bool
CardTypeFilter *CardType
Sort string
Order string
Limit int
Offset int
}
func DefaultListParams() ListParams {
return ListParams{
Sort: "created",
Order: "desc",
Limit: 50,
}
}
type EntityUpdate struct {
Body *string
Title *string
Description *string
Glyph *Glyph
TimeAnchor *string
ClearTime bool
CompletedAt *time.Time
ClearCompleted bool
Pinned *bool
CardType *CardType
CardData *string
Tags *[]string
}
func (s *Store) Create(ctx context.Context, e *Entity) error {
if e.CardData != nil && !json.Valid([]byte(*e.CardData)) {
return ErrInvalidCardData
}
now := time.Now().UTC()
e.ID = nibulid.New()
e.CreatedAt = now
e.ModifiedAt = now
if e.Tags == nil {
e.Tags = []string{}
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
_, err = tx.ExecContext(ctx, `
INSERT INTO entities (id, created_at, modified_at, body, title, description,
glyph, time_anchor, completed_at, pinned, deleted_at,
card_type, card_data, use_count, last_used_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
e.ID,
e.CreatedAt.Format(time.RFC3339),
e.ModifiedAt.Format(time.RFC3339),
e.Body,
e.Title,
e.Description,
string(e.Glyph),
e.TimeAnchor,
formatTimePtr(e.CompletedAt),
boolToInt(e.Pinned),
formatTimePtr(e.DeletedAt),
cardTypePtr(e.CardType),
e.CardData,
e.UseCount,
formatTimePtr(e.LastUsedAt),
)
if err != nil {
return err
}
if err := insertTags(ctx, tx, e.ID, e.Tags); err != nil {
return err
}
return tx.Commit()
}
func (s *Store) Get(ctx context.Context, id string) (*Entity, error) {
e := &Entity{}
row := newEntityRow()
err := s.db.QueryRowContext(ctx, `
SELECT id, created_at, modified_at, body, title, description,
glyph, time_anchor, completed_at, pinned, deleted_at,
card_type, card_data, use_count, last_used_at
FROM entities WHERE id = ?`, id).Scan(row.ptrs(e)...)
if err == sql.ErrNoRows {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
if err := row.apply(e); err != nil {
return nil, fmt.Errorf("scan entity %s: %w", id, err)
}
tags, err := s.loadTags(ctx, id)
if err != nil {
return nil, err
}
e.Tags = tags
return e, nil
}
func listWhere(params ListParams) (string, []any) {
var where []string
var args []any
if !params.IncludeDeleted {
where = append(where, "e.deleted_at IS NULL")
}
if params.Tag != nil {
where = append(where, "e.id IN (SELECT entity_id FROM entity_tags WHERE tag = ?)")
args = append(args, *params.Tag)
}
if params.Date != nil {
where = append(where, "date(e.created_at) = ?")
args = append(args, *params.Date)
}
if params.From != nil {
where = append(where, "date(e.created_at) >= ?")
args = append(args, *params.From)
}
if params.To != nil {
where = append(where, "date(e.created_at) <= ?")
args = append(args, *params.To)
}
if params.Since != nil {
where = append(where, "e.created_at >= ?")
args = append(args, params.Since.Format(time.RFC3339))
}
if params.CardsOnly {
where = append(where, "e.card_type IS NOT NULL")
}
if params.CardTypeFilter != nil {
where = append(where, "e.card_type = ?")
args = append(args, string(*params.CardTypeFilter))
}
if params.ModifiedBefore != nil {
where = append(where, "e.modified_at < ?")
args = append(args, params.ModifiedBefore.Format(time.RFC3339))
}
clause := ""
if len(where) > 0 {
clause = "WHERE " + strings.Join(where, " AND ")
}
return clause, args
}
func (s *Store) Count(ctx context.Context, params ListParams) (int, error) {
whereClause, args := listWhere(params)
query := fmt.Sprintf("SELECT COUNT(*) FROM entities e %s", whereClause)
var count int
err := s.db.QueryRowContext(ctx, query, args...).Scan(&count)
return count, err
}
func (s *Store) List(ctx context.Context, params ListParams) ([]*Entity, error) {
whereClause, args := listWhere(params)
orderCol := "e.created_at"
switch params.Sort {
case "use_count":
orderCol = "e.use_count"
case "modified_at":
orderCol = "e.modified_at"
case "created_at", "":
orderCol = "e.created_at"
default:
orderCol = "e.created_at"
}
orderDir := "DESC"
switch strings.ToLower(params.Order) {
case "asc":
orderDir = "ASC"
default:
orderDir = "DESC"
}
limit := params.Limit
if limit <= 0 {
limit = 50
}
query := fmt.Sprintf(`
SELECT e.id, e.created_at, e.modified_at, e.body, e.title, e.description,
e.glyph, e.time_anchor, e.completed_at, e.pinned, e.deleted_at,
e.card_type, e.card_data, e.use_count, e.last_used_at
FROM entities e
%s
ORDER BY %s %s
LIMIT ? OFFSET ?`, whereClause, orderCol, orderDir)
args = append(args, limit, params.Offset)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var entities []*Entity
for rows.Next() {
e := &Entity{}
row := newEntityRow()
if err := rows.Scan(row.ptrs(e)...); err != nil {
return nil, err
}
if err := row.apply(e); err != nil {
return nil, err
}
entities = append(entities, e)
}
if err := rows.Err(); err != nil {
return nil, err
}
if err := s.batchLoadTags(ctx, entities); err != nil {
return nil, err
}
return entities, nil
}
func (s *Store) Update(ctx context.Context, id string, u *EntityUpdate) error {
existing, err := s.Get(ctx, id)
if err != nil {
return err
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
var sets []string
var args []any
sets = append(sets, "modified_at = ?")
args = append(args, time.Now().UTC().Format(time.RFC3339))
if u.Body != nil {
sets = append(sets, "body = ?")
args = append(args, *u.Body)
}
if u.Title != nil {
sets = append(sets, "title = ?")
args = append(args, *u.Title)
}
if u.Description != nil {
sets = append(sets, "description = ?")
args = append(args, *u.Description)
}
if u.Glyph != nil {
sets = append(sets, "glyph = ?")
args = append(args, string(*u.Glyph))
}
if u.ClearTime {
sets = append(sets, "time_anchor = NULL")
} else if u.TimeAnchor != nil {
sets = append(sets, "time_anchor = ?")
args = append(args, *u.TimeAnchor)
}
if u.ClearCompleted {
sets = append(sets, "completed_at = NULL")
} else if u.CompletedAt != nil {
sets = append(sets, "completed_at = ?")
args = append(args, u.CompletedAt.Format(time.RFC3339))
}
if u.Pinned != nil {
sets = append(sets, "pinned = ?")
args = append(args, boolToInt(*u.Pinned))
}
if u.CardType != nil {
sets = append(sets, "card_type = ?")
args = append(args, string(*u.CardType))
}
if u.CardData != nil {
if !json.Valid([]byte(*u.CardData)) {
return ErrInvalidCardData
}
sets = append(sets, "card_data = ?")
args = append(args, *u.CardData)
}
args = append(args, existing.ID)
query := fmt.Sprintf("UPDATE entities SET %s WHERE id = ?", strings.Join(sets, ", "))
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
return err
}
if u.Tags != nil {
if _, err := tx.ExecContext(ctx, "DELETE FROM entity_tags WHERE entity_id = ?", existing.ID); err != nil {
return err
}
if err := insertTags(ctx, tx, existing.ID, *u.Tags); err != nil {
return err
}
}
return tx.Commit()
}
func (s *Store) Promote(ctx context.Context, id string, cardType CardType, cardData *string) error {
e, err := s.Get(ctx, id)
if err != nil {
return err
}
if e.CardType != nil {
return ErrAlreadyPromoted
}
dataVal := "{}"
if cardData != nil {
if !json.Valid([]byte(*cardData)) {
return ErrInvalidCardData
}
dataVal = *cardData
}
_, err = s.db.ExecContext(ctx, `
UPDATE entities SET card_type = ?, card_data = ?, modified_at = ?
WHERE id = ?`,
string(cardType), dataVal, time.Now().UTC().Format(time.RFC3339), id)
return err
}
func (s *Store) Demote(ctx context.Context, id string) error {
e, err := s.Get(ctx, id)
if err != nil {
return err
}
if e.CardType == nil {
return ErrAlreadyFluid
}
_, err = s.db.ExecContext(ctx, `
UPDATE entities SET card_type = NULL, card_data = NULL,
use_count = 0, last_used_at = NULL, modified_at = ?
WHERE id = ?`,
time.Now().UTC().Format(time.RFC3339), id)
return err
}
type DeleteResult int
const (
DeletedSoft DeleteResult = iota
DeletedHard
)
func (s *Store) SoftDelete(ctx context.Context, id string) (DeleteResult, error) {
var deletedAt sql.NullString
err := s.db.QueryRowContext(ctx, "SELECT deleted_at FROM entities WHERE id = ?", id).Scan(&deletedAt)
if err == sql.ErrNoRows {
return 0, ErrNotFound
}
if err != nil {
return 0, err
}
if deletedAt.Valid {
_, err = s.db.ExecContext(ctx, "DELETE FROM entities WHERE id = ?", id)
return DeletedHard, err
}
_, err = s.db.ExecContext(ctx, "UPDATE entities SET deleted_at = ? WHERE id = ?",
time.Now().UTC().Format(time.RFC3339), id)
return DeletedSoft, err
}
func (s *Store) Absorb(ctx context.Context, targetID, sourceID string) error {
target, err := s.Get(ctx, targetID)
if err != nil {
return err
}
source, err := s.Get(ctx, sourceID)
if err != nil {
return err
}
if target.CardType != nil {
return ErrTargetCrystallized
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
now := time.Now().UTC().Format(time.RFC3339)
merged := target.Body + "\n" + source.Body
if _, err := tx.ExecContext(ctx, "UPDATE entities SET body = ?, modified_at = ? WHERE id = ?",
merged, now, targetID); err != nil {
return err
}
seen := map[string]bool{}
for _, t := range target.Tags {
seen[t] = true
}
for _, t := range source.Tags {
if !seen[t] {
if _, err := tx.ExecContext(ctx, "INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)",
targetID, t); err != nil {
return err
}
}
}
if source.CardType != nil {
if _, err := tx.ExecContext(ctx, `UPDATE entities SET card_type = NULL, card_data = NULL,
use_count = 0, last_used_at = NULL, modified_at = ? WHERE id = ?`,
now, sourceID); err != nil {
return err
}
}
absorbNote := source.Body + "\n\n[absorbed into " + targetID + "]"
if _, err := tx.ExecContext(ctx, "UPDATE entities SET body = ?, deleted_at = ?, modified_at = ? WHERE id = ?",
absorbNote, now, now, sourceID); err != nil {
return err
}
return tx.Commit()
}
func (s *Store) IncrementUse(ctx context.Context, id string) error {
res, err := s.db.ExecContext(ctx, `
UPDATE entities SET use_count = use_count + 1, last_used_at = ?
WHERE id = ?`,
time.Now().UTC().Format(time.RFC3339), id)
if err != nil {
return err
}
n, _ := res.RowsAffected()
if n == 0 {
return ErrNotFound
}
return nil
}
func (s *Store) Resolve(ctx context.Context, prefix string) (string, error) {
rows, err := s.db.QueryContext(ctx, "SELECT id FROM entities WHERE id LIKE ?", prefix+"%")
if err != nil {
return "", err
}
defer rows.Close()
var ids []string
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return "", err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return "", err
}
switch len(ids) {
case 0:
return "", ErrNotFound
case 1:
return ids[0], nil
default:
return "", fmt.Errorf("ambiguous id prefix %q matches %d entities", prefix, len(ids))
}
}
type entityRow struct {
createdAt, modifiedAt string
completedAt, deletedAt, lastUsedAt sql.NullString
timeAnchor, cardType, cardData sql.NullString
title, description sql.NullString
pinned int
}
func newEntityRow() *entityRow { return &entityRow{} }
func (r *entityRow) ptrs(e *Entity) []any {
return []any{
&e.ID, &r.createdAt, &r.modifiedAt, &e.Body, &r.title, &r.description,
&e.Glyph, &r.timeAnchor, &r.completedAt, &r.pinned, &r.deletedAt,
&r.cardType, &r.cardData, &e.UseCount, &r.lastUsedAt,
}
}
func (r *entityRow) apply(e *Entity) error {
var err error
if e.CreatedAt, err = time.Parse(time.RFC3339, r.createdAt); err != nil {
return fmt.Errorf("created_at: %w", err)
}
if e.ModifiedAt, err = time.Parse(time.RFC3339, r.modifiedAt); err != nil {
return fmt.Errorf("modified_at: %w", err)
}
e.Title = nullToPtr(r.title)
e.Description = nullToPtr(r.description)
e.TimeAnchor = nullToPtr(r.timeAnchor)
e.CompletedAt = parseTimePtr(r.completedAt)
e.Pinned = r.pinned != 0
e.DeletedAt = parseTimePtr(r.deletedAt)
e.CardType = nullToCardType(r.cardType)
e.CardData = nullToPtr(r.cardData)
e.LastUsedAt = parseTimePtr(r.lastUsedAt)
return nil
}
func (s *Store) batchLoadTags(ctx context.Context, entities []*Entity) error {
if len(entities) == 0 {
return nil
}
idMap := make(map[string]*Entity, len(entities))
placeholders := make([]string, len(entities))
args := make([]any, len(entities))
for i, e := range entities {
e.Tags = []string{}
idMap[e.ID] = e
placeholders[i] = "?"
args[i] = e.ID
}
query := fmt.Sprintf(
"SELECT entity_id, tag FROM entity_tags WHERE entity_id IN (%s) ORDER BY entity_id, tag",
strings.Join(placeholders, ","),
)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var entityID, tag string
if err := rows.Scan(&entityID, &tag); err != nil {
return err
}
if e, ok := idMap[entityID]; ok {
e.Tags = append(e.Tags, tag)
}
}
return rows.Err()
}
func (s *Store) loadTags(ctx context.Context, entityID string) ([]string, error) {
rows, err := s.db.QueryContext(ctx, "SELECT tag FROM entity_tags WHERE entity_id = ? ORDER BY tag", entityID)
if err != nil {
return nil, err
}
defer rows.Close()
var tags []string
for rows.Next() {
var tag string
if err := rows.Scan(&tag); err != nil {
return nil, err
}
tags = append(tags, tag)
}
if err := rows.Err(); err != nil {
return nil, err
}
if tags == nil {
tags = []string{}
}
return tags, nil
}
func insertTags(ctx context.Context, tx *sql.Tx, entityID string, tags []string) error {
for _, tag := range tags {
if _, err := tx.ExecContext(ctx, "INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)",
entityID, tag); err != nil {
return err
}
}
return nil
}
func formatTimePtr(t *time.Time) interface{} {
if t == nil {
return nil
}
return t.Format(time.RFC3339)
}
func parseTimePtr(ns sql.NullString) *time.Time {
if !ns.Valid {
return nil
}
t, err := time.Parse(time.RFC3339, ns.String)
if err != nil {
return nil
}
return &t
}
func nullToPtr(ns sql.NullString) *string {
if !ns.Valid {
return nil
}
return &ns.String
}
func nullToCardType(ns sql.NullString) *CardType {
if !ns.Valid {
return nil
}
ct := CardType(ns.String)
return &ct
}
func cardTypePtr(ct *CardType) interface{} {
if ct == nil {
return nil
}
return string(*ct)
}
func boolToInt(b bool) int {
if b {
return 1
}
return 0
}
func (e *Entity) CardDataJSON() (map[string]interface{}, error) {
if e.CardData == nil {
return nil, nil
}
var m map[string]interface{}
if err := json.Unmarshal([]byte(*e.CardData), &m); err != nil {
return nil, fmt.Errorf("card_data: %w", err)
}
return m, nil
}