babf1d6620
Split EDITOR env var on whitespace so multi-word values like "code --wait" work correctly. Add allow-list switch for sort column and order direction at the query boundary to prevent future callers from passing unsanitized values into SQL.
697 lines
15 KiB
Go
697 lines
15 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
nibulid "github.com/lerko/nib/internal/ulid"
|
|
)
|
|
|
|
type Glyph string
|
|
|
|
const (
|
|
GlyphNote Glyph = "note"
|
|
GlyphTodo Glyph = "todo"
|
|
GlyphEvent Glyph = "event"
|
|
GlyphReminder Glyph = "reminder"
|
|
)
|
|
|
|
type CardType string
|
|
|
|
const (
|
|
CardSnippet CardType = "snippet"
|
|
CardTemplate CardType = "template"
|
|
CardChecklist CardType = "checklist"
|
|
CardDecision CardType = "decision"
|
|
CardLink CardType = "link"
|
|
CardNote CardType = "note"
|
|
)
|
|
|
|
func ValidGlyph(s string) bool {
|
|
switch Glyph(s) {
|
|
case GlyphNote, GlyphTodo, GlyphEvent, GlyphReminder:
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func ValidCardType(s string) bool {
|
|
switch CardType(s) {
|
|
case CardSnippet, CardTemplate, CardChecklist, CardDecision, CardLink, CardNote:
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
type Entity struct {
|
|
ID string
|
|
CreatedAt time.Time
|
|
ModifiedAt time.Time
|
|
Body string
|
|
Title *string
|
|
Description *string
|
|
Glyph Glyph
|
|
TimeAnchor *string
|
|
CompletedAt *time.Time
|
|
Pinned bool
|
|
DeletedAt *time.Time
|
|
CardType *CardType
|
|
CardData *string
|
|
UseCount int
|
|
LastUsedAt *time.Time
|
|
Tags []string
|
|
}
|
|
|
|
type ListParams struct {
|
|
Tag *string
|
|
Date *string
|
|
From *string
|
|
To *string
|
|
Since *time.Time
|
|
CardsOnly bool
|
|
IncludeDeleted bool
|
|
CardTypeFilter *CardType
|
|
Sort string
|
|
Order string
|
|
Limit int
|
|
Offset int
|
|
}
|
|
|
|
func DefaultListParams() ListParams {
|
|
return ListParams{
|
|
Sort: "created",
|
|
Order: "desc",
|
|
Limit: 50,
|
|
}
|
|
}
|
|
|
|
type EntityUpdate struct {
|
|
Body *string
|
|
Title *string
|
|
Description *string
|
|
Glyph *Glyph
|
|
TimeAnchor *string
|
|
ClearTime bool
|
|
CompletedAt *time.Time
|
|
ClearCompleted bool
|
|
Pinned *bool
|
|
CardType *CardType
|
|
CardData *string
|
|
Tags *[]string
|
|
}
|
|
|
|
func (s *Store) Create(e *Entity) error {
|
|
now := time.Now().UTC()
|
|
e.ID = nibulid.New()
|
|
e.CreatedAt = now
|
|
e.ModifiedAt = now
|
|
if e.Tags == nil {
|
|
e.Tags = []string{}
|
|
}
|
|
|
|
tx, err := s.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
_, err = tx.Exec(`
|
|
INSERT INTO entities (id, created_at, modified_at, body, title, description,
|
|
glyph, time_anchor, completed_at, pinned, deleted_at,
|
|
card_type, card_data, use_count, last_used_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
e.ID,
|
|
e.CreatedAt.Format(time.RFC3339),
|
|
e.ModifiedAt.Format(time.RFC3339),
|
|
e.Body,
|
|
e.Title,
|
|
e.Description,
|
|
string(e.Glyph),
|
|
e.TimeAnchor,
|
|
formatTimePtr(e.CompletedAt),
|
|
boolToInt(e.Pinned),
|
|
formatTimePtr(e.DeletedAt),
|
|
cardTypePtr(e.CardType),
|
|
e.CardData,
|
|
e.UseCount,
|
|
formatTimePtr(e.LastUsedAt),
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := insertTags(tx, e.ID, e.Tags); err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (s *Store) Get(id string) (*Entity, error) {
|
|
e := &Entity{}
|
|
row := newEntityRow()
|
|
|
|
err := s.db.QueryRow(`
|
|
SELECT id, created_at, modified_at, body, title, description,
|
|
glyph, time_anchor, completed_at, pinned, deleted_at,
|
|
card_type, card_data, use_count, last_used_at
|
|
FROM entities WHERE id = ?`, id).Scan(row.ptrs(e)...)
|
|
if err == sql.ErrNoRows {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := row.apply(e); err != nil {
|
|
return nil, fmt.Errorf("scan entity %s: %w", id, err)
|
|
}
|
|
|
|
tags, err := s.loadTags(id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
e.Tags = tags
|
|
|
|
return e, nil
|
|
}
|
|
|
|
func (s *Store) List(params ListParams) ([]*Entity, error) {
|
|
var where []string
|
|
var args []any
|
|
|
|
if !params.IncludeDeleted {
|
|
where = append(where, "e.deleted_at IS NULL")
|
|
}
|
|
if params.Tag != nil {
|
|
where = append(where, "e.id IN (SELECT entity_id FROM entity_tags WHERE tag = ?)")
|
|
args = append(args, *params.Tag)
|
|
}
|
|
if params.Date != nil {
|
|
where = append(where, "date(e.created_at) = ?")
|
|
args = append(args, *params.Date)
|
|
}
|
|
if params.From != nil {
|
|
where = append(where, "date(e.created_at) >= ?")
|
|
args = append(args, *params.From)
|
|
}
|
|
if params.To != nil {
|
|
where = append(where, "date(e.created_at) <= ?")
|
|
args = append(args, *params.To)
|
|
}
|
|
if params.Since != nil {
|
|
where = append(where, "e.created_at >= ?")
|
|
args = append(args, params.Since.Format(time.RFC3339))
|
|
}
|
|
if params.CardsOnly {
|
|
where = append(where, "e.card_type IS NOT NULL")
|
|
}
|
|
if params.CardTypeFilter != nil {
|
|
where = append(where, "e.card_type = ?")
|
|
args = append(args, string(*params.CardTypeFilter))
|
|
}
|
|
|
|
whereClause := ""
|
|
if len(where) > 0 {
|
|
whereClause = "WHERE " + strings.Join(where, " AND ")
|
|
}
|
|
|
|
orderCol := "e.created_at"
|
|
switch params.Sort {
|
|
case "use_count":
|
|
orderCol = "e.use_count"
|
|
case "created_at", "":
|
|
orderCol = "e.created_at"
|
|
default:
|
|
orderCol = "e.created_at"
|
|
}
|
|
orderDir := "DESC"
|
|
switch strings.ToLower(params.Order) {
|
|
case "asc":
|
|
orderDir = "ASC"
|
|
default:
|
|
orderDir = "DESC"
|
|
}
|
|
|
|
limit := params.Limit
|
|
if limit <= 0 {
|
|
limit = 50
|
|
}
|
|
|
|
query := fmt.Sprintf(`
|
|
SELECT e.id, e.created_at, e.modified_at, e.body, e.title, e.description,
|
|
e.glyph, e.time_anchor, e.completed_at, e.pinned, e.deleted_at,
|
|
e.card_type, e.card_data, e.use_count, e.last_used_at
|
|
FROM entities e
|
|
%s
|
|
ORDER BY %s %s
|
|
LIMIT ? OFFSET ?`, whereClause, orderCol, orderDir)
|
|
|
|
args = append(args, limit, params.Offset)
|
|
|
|
rows, err := s.db.Query(query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var entities []*Entity
|
|
for rows.Next() {
|
|
e := &Entity{}
|
|
row := newEntityRow()
|
|
if err := rows.Scan(row.ptrs(e)...); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := row.apply(e); err != nil {
|
|
return nil, err
|
|
}
|
|
entities = append(entities, e)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := s.batchLoadTags(entities); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return entities, nil
|
|
}
|
|
|
|
func (s *Store) Update(id string, u *EntityUpdate) error {
|
|
existing, err := s.Get(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tx, err := s.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
var sets []string
|
|
var args []any
|
|
|
|
sets = append(sets, "modified_at = ?")
|
|
args = append(args, time.Now().UTC().Format(time.RFC3339))
|
|
|
|
if u.Body != nil {
|
|
sets = append(sets, "body = ?")
|
|
args = append(args, *u.Body)
|
|
}
|
|
if u.Title != nil {
|
|
sets = append(sets, "title = ?")
|
|
args = append(args, *u.Title)
|
|
}
|
|
if u.Description != nil {
|
|
sets = append(sets, "description = ?")
|
|
args = append(args, *u.Description)
|
|
}
|
|
if u.Glyph != nil {
|
|
sets = append(sets, "glyph = ?")
|
|
args = append(args, string(*u.Glyph))
|
|
}
|
|
if u.ClearTime {
|
|
sets = append(sets, "time_anchor = NULL")
|
|
} else if u.TimeAnchor != nil {
|
|
sets = append(sets, "time_anchor = ?")
|
|
args = append(args, *u.TimeAnchor)
|
|
}
|
|
if u.ClearCompleted {
|
|
sets = append(sets, "completed_at = NULL")
|
|
} else if u.CompletedAt != nil {
|
|
sets = append(sets, "completed_at = ?")
|
|
args = append(args, u.CompletedAt.Format(time.RFC3339))
|
|
}
|
|
if u.Pinned != nil {
|
|
sets = append(sets, "pinned = ?")
|
|
args = append(args, boolToInt(*u.Pinned))
|
|
}
|
|
if u.CardType != nil {
|
|
sets = append(sets, "card_type = ?")
|
|
args = append(args, string(*u.CardType))
|
|
}
|
|
if u.CardData != nil {
|
|
sets = append(sets, "card_data = ?")
|
|
args = append(args, *u.CardData)
|
|
}
|
|
|
|
args = append(args, existing.ID)
|
|
query := fmt.Sprintf("UPDATE entities SET %s WHERE id = ?", strings.Join(sets, ", "))
|
|
|
|
if _, err := tx.Exec(query, args...); err != nil {
|
|
return err
|
|
}
|
|
|
|
if u.Tags != nil {
|
|
if _, err := tx.Exec("DELETE FROM entity_tags WHERE entity_id = ?", existing.ID); err != nil {
|
|
return err
|
|
}
|
|
if err := insertTags(tx, existing.ID, *u.Tags); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (s *Store) Promote(id string, cardType CardType, cardData *string) error {
|
|
e, err := s.Get(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if e.CardType != nil {
|
|
return ErrAlreadyPromoted
|
|
}
|
|
|
|
dataVal := "{}"
|
|
if cardData != nil {
|
|
dataVal = *cardData
|
|
}
|
|
|
|
_, err = s.db.Exec(`
|
|
UPDATE entities SET card_type = ?, card_data = ?, modified_at = ?
|
|
WHERE id = ?`,
|
|
string(cardType), dataVal, time.Now().UTC().Format(time.RFC3339), id)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) Demote(id string) error {
|
|
e, err := s.Get(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if e.CardType == nil {
|
|
return ErrAlreadyFluid
|
|
}
|
|
|
|
_, err = s.db.Exec(`
|
|
UPDATE entities SET card_type = NULL, card_data = NULL,
|
|
use_count = 0, last_used_at = NULL, modified_at = ?
|
|
WHERE id = ?`,
|
|
time.Now().UTC().Format(time.RFC3339), id)
|
|
return err
|
|
}
|
|
|
|
type DeleteResult int
|
|
|
|
const (
|
|
DeletedSoft DeleteResult = iota
|
|
DeletedHard
|
|
)
|
|
|
|
func (s *Store) SoftDelete(id string) (DeleteResult, error) {
|
|
var deletedAt sql.NullString
|
|
err := s.db.QueryRow("SELECT deleted_at FROM entities WHERE id = ?", id).Scan(&deletedAt)
|
|
if err == sql.ErrNoRows {
|
|
return 0, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if deletedAt.Valid {
|
|
_, err = s.db.Exec("DELETE FROM entities WHERE id = ?", id)
|
|
return DeletedHard, err
|
|
}
|
|
|
|
_, err = s.db.Exec("UPDATE entities SET deleted_at = ? WHERE id = ?",
|
|
time.Now().UTC().Format(time.RFC3339), id)
|
|
return DeletedSoft, err
|
|
}
|
|
|
|
func (s *Store) Absorb(targetID, sourceID string) error {
|
|
target, err := s.Get(targetID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
source, err := s.Get(sourceID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if target.CardType != nil {
|
|
return ErrTargetCrystallized
|
|
}
|
|
|
|
tx, err := s.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
now := time.Now().UTC().Format(time.RFC3339)
|
|
merged := target.Body + "\n" + source.Body
|
|
|
|
if _, err := tx.Exec("UPDATE entities SET body = ?, modified_at = ? WHERE id = ?",
|
|
merged, now, targetID); err != nil {
|
|
return err
|
|
}
|
|
|
|
seen := map[string]bool{}
|
|
for _, t := range target.Tags {
|
|
seen[t] = true
|
|
}
|
|
for _, t := range source.Tags {
|
|
if !seen[t] {
|
|
if _, err := tx.Exec("INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)",
|
|
targetID, t); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
if source.CardType != nil {
|
|
if _, err := tx.Exec(`UPDATE entities SET card_type = NULL, card_data = NULL,
|
|
use_count = 0, last_used_at = NULL, modified_at = ? WHERE id = ?`,
|
|
now, sourceID); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if _, err := tx.Exec("UPDATE entities SET deleted_at = ? WHERE id = ?",
|
|
now, sourceID); err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (s *Store) IncrementUse(id string) error {
|
|
res, err := s.db.Exec(`
|
|
UPDATE entities SET use_count = use_count + 1, last_used_at = ?
|
|
WHERE id = ?`,
|
|
time.Now().UTC().Format(time.RFC3339), id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
if n == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) Resolve(prefix string) (string, error) {
|
|
rows, err := s.db.Query("SELECT id FROM entities WHERE id LIKE ?", prefix+"%")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var ids []string
|
|
for rows.Next() {
|
|
var id string
|
|
if err := rows.Scan(&id); err != nil {
|
|
return "", err
|
|
}
|
|
ids = append(ids, id)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
switch len(ids) {
|
|
case 0:
|
|
return "", ErrNotFound
|
|
case 1:
|
|
return ids[0], nil
|
|
default:
|
|
return "", fmt.Errorf("ambiguous id prefix %q matches %d entities", prefix, len(ids))
|
|
}
|
|
}
|
|
|
|
type entityRow struct {
|
|
createdAt, modifiedAt string
|
|
completedAt, deletedAt, lastUsedAt sql.NullString
|
|
timeAnchor, cardType, cardData sql.NullString
|
|
title, description sql.NullString
|
|
pinned int
|
|
}
|
|
|
|
func newEntityRow() *entityRow { return &entityRow{} }
|
|
|
|
func (r *entityRow) ptrs(e *Entity) []any {
|
|
return []any{
|
|
&e.ID, &r.createdAt, &r.modifiedAt, &e.Body, &r.title, &r.description,
|
|
&e.Glyph, &r.timeAnchor, &r.completedAt, &r.pinned, &r.deletedAt,
|
|
&r.cardType, &r.cardData, &e.UseCount, &r.lastUsedAt,
|
|
}
|
|
}
|
|
|
|
func (r *entityRow) apply(e *Entity) error {
|
|
var err error
|
|
if e.CreatedAt, err = time.Parse(time.RFC3339, r.createdAt); err != nil {
|
|
return fmt.Errorf("created_at: %w", err)
|
|
}
|
|
if e.ModifiedAt, err = time.Parse(time.RFC3339, r.modifiedAt); err != nil {
|
|
return fmt.Errorf("modified_at: %w", err)
|
|
}
|
|
e.Title = nullToPtr(r.title)
|
|
e.Description = nullToPtr(r.description)
|
|
e.TimeAnchor = nullToPtr(r.timeAnchor)
|
|
e.CompletedAt = parseTimePtr(r.completedAt)
|
|
e.Pinned = r.pinned != 0
|
|
e.DeletedAt = parseTimePtr(r.deletedAt)
|
|
e.CardType = nullToCardType(r.cardType)
|
|
e.CardData = nullToPtr(r.cardData)
|
|
e.LastUsedAt = parseTimePtr(r.lastUsedAt)
|
|
return nil
|
|
}
|
|
|
|
// helpers
|
|
|
|
func (s *Store) batchLoadTags(entities []*Entity) error {
|
|
if len(entities) == 0 {
|
|
return nil
|
|
}
|
|
|
|
idMap := make(map[string]*Entity, len(entities))
|
|
placeholders := make([]string, len(entities))
|
|
args := make([]any, len(entities))
|
|
for i, e := range entities {
|
|
e.Tags = []string{}
|
|
idMap[e.ID] = e
|
|
placeholders[i] = "?"
|
|
args[i] = e.ID
|
|
}
|
|
|
|
query := fmt.Sprintf(
|
|
"SELECT entity_id, tag FROM entity_tags WHERE entity_id IN (%s) ORDER BY entity_id, tag",
|
|
strings.Join(placeholders, ","),
|
|
)
|
|
|
|
rows, err := s.db.Query(query, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var entityID, tag string
|
|
if err := rows.Scan(&entityID, &tag); err != nil {
|
|
return err
|
|
}
|
|
if e, ok := idMap[entityID]; ok {
|
|
e.Tags = append(e.Tags, tag)
|
|
}
|
|
}
|
|
return rows.Err()
|
|
}
|
|
|
|
func (s *Store) loadTags(entityID string) ([]string, error) {
|
|
rows, err := s.db.Query("SELECT tag FROM entity_tags WHERE entity_id = ? ORDER BY tag", entityID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var tags []string
|
|
for rows.Next() {
|
|
var tag string
|
|
if err := rows.Scan(&tag); err != nil {
|
|
return nil, err
|
|
}
|
|
tags = append(tags, tag)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
if tags == nil {
|
|
tags = []string{}
|
|
}
|
|
return tags, nil
|
|
}
|
|
|
|
func insertTags(tx *sql.Tx, entityID string, tags []string) error {
|
|
for _, tag := range tags {
|
|
if _, err := tx.Exec("INSERT OR IGNORE INTO entity_tags (entity_id, tag) VALUES (?, ?)",
|
|
entityID, tag); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func formatTimePtr(t *time.Time) interface{} {
|
|
if t == nil {
|
|
return nil
|
|
}
|
|
return t.Format(time.RFC3339)
|
|
}
|
|
|
|
func parseTimePtr(ns sql.NullString) *time.Time {
|
|
if !ns.Valid {
|
|
return nil
|
|
}
|
|
t, err := time.Parse(time.RFC3339, ns.String)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return &t
|
|
}
|
|
|
|
func nullToPtr(ns sql.NullString) *string {
|
|
if !ns.Valid {
|
|
return nil
|
|
}
|
|
return &ns.String
|
|
}
|
|
|
|
func nullToCardType(ns sql.NullString) *CardType {
|
|
if !ns.Valid {
|
|
return nil
|
|
}
|
|
ct := CardType(ns.String)
|
|
return &ct
|
|
}
|
|
|
|
func cardTypePtr(ct *CardType) interface{} {
|
|
if ct == nil {
|
|
return nil
|
|
}
|
|
return string(*ct)
|
|
}
|
|
|
|
func boolToInt(b bool) int {
|
|
if b {
|
|
return 1
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (e *Entity) CardDataJSON() (map[string]interface{}, error) {
|
|
if e.CardData == nil {
|
|
return nil, nil
|
|
}
|
|
var m map[string]interface{}
|
|
if err := json.Unmarshal([]byte(*e.CardData), &m); err != nil {
|
|
return nil, fmt.Errorf("card_data: %w", err)
|
|
}
|
|
return m, nil
|
|
}
|