aboutsummaryrefslogtreecommitdiff
path: root/build_index.py
diff options
context:
space:
mode:
authortwells46 <173561638+twells46@users.noreply.github.com>2026-04-01 15:20:50 -0500
committertwells46 <173561638+twells46@users.noreply.github.com>2026-04-01 15:20:50 -0500
commit2f37974a4c84f7ffdd07e2c223eba2d8bd981b61 (patch)
tree1741f17884077e9d4e0dbfe5908305fc21661ced /build_index.py
Initial commit
Diffstat (limited to 'build_index.py')
-rw-r--r--build_index.py192
1 files changed, 192 insertions, 0 deletions
diff --git a/build_index.py b/build_index.py
new file mode 100644
index 0000000..8597c68
--- /dev/null
+++ b/build_index.py
@@ -0,0 +1,192 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import json
+import sqlite3
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+from sentence_transformers import SentenceTransformer
+
+import os
+from pathlib import Path
+
+DB_PATH = Path("cheat.db")
+CARDS_PATH = Path("./cards.jsonl")
+MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
+LOCAL_CACHE_DIR = Path("models/hf")
+
+LOCAL_CACHE_DIR = Path("models/hf")
+os.environ.setdefault("HF_HOME", str(LOCAL_CACHE_DIR.resolve()))
+os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(LOCAL_CACHE_DIR.resolve()))
+
+REQUIRED_FIELDS = {
+ "id",
+ "intent",
+ "command",
+ "alternatives",
+ "explanation",
+ "requires",
+ "packages",
+ "tags",
+ "platform",
+ "shell",
+ "safety",
+}
+
+
+def load_cards(path: Path) -> list[dict[str, Any]]:
+ cards: list[dict[str, Any]] = []
+ with path.open("r", encoding="utf-8") as f:
+ for line_no, line in enumerate(f, start=1):
+ line = line.strip()
+ if not line:
+ continue
+ try:
+ card = json.loads(line)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Invalid JSON on line {line_no}: {e}") from e
+
+ missing = REQUIRED_FIELDS - set(card.keys())
+ if missing:
+ raise ValueError(f"Missing required fields on line {line_no}: {sorted(missing)}")
+
+ cards.append(card)
+ return cards
+
+
+def build_search_text(card: dict[str, Any]) -> str:
+ """
+ Build a compact semantic representation for embedding.
+ This is what the retriever will search over.
+ """
+ parts: list[str] = []
+
+ intents = card.get("intent", [])
+ tags = card.get("tags", [])
+ command = card.get("command", "")
+ explanation = card.get("explanation", "")
+ alternatives = card.get("alternatives", [])
+ requires = card.get("requires", [])
+ platform = card.get("platform", [])
+
+ if intents:
+ parts.append("Intents: " + " | ".join(intents))
+ if tags:
+ parts.append("Tags: " + ", ".join(tags))
+ if command:
+ parts.append("Command: " + command)
+ if alternatives:
+ parts.append("Alternatives: " + " | ".join(alternatives))
+ if explanation:
+ parts.append("Explanation: " + explanation)
+ if requires:
+ parts.append("Requires: " + ", ".join(requires))
+ if platform:
+ parts.append("Platform: " + ", ".join(platform))
+
+ return "\n".join(parts)
+
+
+def serialize_embedding(vec: np.ndarray) -> bytes:
+ return vec.astype(np.float32).tobytes()
+
+
+def upsert_card(conn: sqlite3.Connection, card: dict[str, Any], search_text: str) -> None:
+ conn.execute("""
+ INSERT INTO cards (
+ id, command, explanation, intent_json, alternatives_json, requires_json,
+ packages_json, tags_json, platform_json, shell_json, safety, search_text, updated_at
+ )
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
+ ON CONFLICT(id) DO UPDATE SET
+ command=excluded.command,
+ explanation=excluded.explanation,
+ intent_json=excluded.intent_json,
+ alternatives_json=excluded.alternatives_json,
+ requires_json=excluded.requires_json,
+ packages_json=excluded.packages_json,
+ tags_json=excluded.tags_json,
+ platform_json=excluded.platform_json,
+ shell_json=excluded.shell_json,
+ safety=excluded.safety,
+ search_text=excluded.search_text,
+ updated_at=CURRENT_TIMESTAMP
+ """, (
+ card["id"],
+ card["command"],
+ card["explanation"],
+ json.dumps(card["intent"], ensure_ascii=False),
+ json.dumps(card["alternatives"], ensure_ascii=False),
+ json.dumps(card["requires"], ensure_ascii=False),
+ json.dumps(card["packages"], ensure_ascii=False),
+ json.dumps(card["tags"], ensure_ascii=False),
+ json.dumps(card["platform"], ensure_ascii=False),
+ json.dumps(card["shell"], ensure_ascii=False),
+ card["safety"],
+ search_text,
+ ))
+
+
+def upsert_embedding(
+ conn: sqlite3.Connection,
+ card_id: str,
+ model_name: str,
+ vec: np.ndarray,
+) -> None:
+ conn.execute("""
+ INSERT INTO card_embeddings (
+ card_id, model_name, embedding_blob, embedding_dim
+ )
+ VALUES (?, ?, ?, ?)
+ ON CONFLICT(card_id) DO UPDATE SET
+ model_name=excluded.model_name,
+ embedding_blob=excluded.embedding_blob,
+ embedding_dim=excluded.embedding_dim
+ """, (
+ card_id,
+ model_name,
+ serialize_embedding(vec),
+ int(vec.shape[0]),
+ ))
+
+
+def main() -> None:
+ if not DB_PATH.exists():
+ raise FileNotFoundError(
+ f"Database not found at {DB_PATH}. Run scripts/init_db.py first."
+ )
+ if not CARDS_PATH.exists():
+ raise FileNotFoundError(f"Cards file not found at {CARDS_PATH}")
+
+ cards = load_cards(CARDS_PATH)
+
+ model = SentenceTransformer(
+ MODEL_NAME,
+ cache_folder=str(LOCAL_CACHE_DIR.resolve()),
+ local_files_only=True,
+ )
+
+ search_texts = [build_search_text(card) for card in cards]
+ embeddings = model.encode(
+ search_texts,
+ normalize_embeddings=True,
+ convert_to_numpy=True,
+ show_progress_bar=True,
+ )
+
+ conn = sqlite3.connect(DB_PATH)
+ try:
+ conn.execute("PRAGMA foreign_keys=ON;")
+ for card, vec, search_text in zip(cards, embeddings, search_texts):
+ upsert_card(conn, card, search_text)
+ upsert_embedding(conn, card["id"], MODEL_NAME, vec)
+ conn.commit()
+ print(f"Indexed {len(cards)} cards into {DB_PATH}")
+ finally:
+ conn.close()
+
+
+if __name__ == "__main__":
+ main()