From 675a18d94019f9e131c647d24b5176ddfaae4516 Mon Sep 17 00:00:00 2001 From: Tobias Huttinger Date: Mon, 4 May 2026 21:53:09 +0200 Subject: [PATCH] Refactor as well as better command handling, including feature to hide system messages from llm --- src/lorabot/bot.py | 118 ++++---------------------------------- src/lorabot/commands.py | 121 +++++++++++++++++++++++++++++++++++++++ src/lorabot/db.py | 72 +++++++++++++++++------ src/lorabot/handler.py | 114 ++++++++++++++++++++++++++++++++++++ src/lorabot/llm.py | 10 +++- src/lorabot/transport.py | 55 ++++++++++++++++++ 6 files changed, 366 insertions(+), 124 deletions(-) create mode 100644 src/lorabot/commands.py create mode 100644 src/lorabot/handler.py create mode 100644 src/lorabot/transport.py diff --git a/src/lorabot/bot.py b/src/lorabot/bot.py index 2b4bd22..fdcaf3c 100644 --- a/src/lorabot/bot.py +++ b/src/lorabot/bot.py @@ -1,26 +1,21 @@ -"""Main run loop: connect to the MeshCore device, route DMs through the LLM, reply.""" +"""Main run loop: connect to the MeshCore device, wire collaborators, listen for DMs.""" from __future__ import annotations import asyncio import logging -from collections import defaultdict -from datetime import datetime, timezone from meshcore import EventType, MeshCore from . import db, web +from .commands import build_default_registry from .config import Settings +from .handler import build_dm_handler from .llm import LLMClient -from .messages import split_to_bytes, trim_to_bytes log = logging.getLogger("lorabot") -def _now_iso() -> str: - return datetime.now(timezone.utc).isoformat() - - async def run() -> None: cfg = Settings() @@ -83,104 +78,15 @@ async def run() -> None: _advert_loop(state, cfg.advertise.interval_seconds, cfg.advertise.at_startup) ) - # One lock per sender so a burst of messages from the same peer is processed - # serially while different peers stay independent. - locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) - - contacts_lock = asyncio.Lock() - - async def _resolve_contact(prefix: str): - """Look up a contact by pubkey prefix. Re-pulls contacts from the device on miss.""" - contact = mc.get_contact_by_key_prefix(prefix) - if contact is not None: - return contact - async with contacts_lock: - contact = mc.get_contact_by_key_prefix(prefix) - if contact is not None: - return contact - try: - await mc.commands.get_contacts() - except Exception: - log.exception("get_contacts refresh failed") - return None - return mc.get_contact_by_key_prefix(prefix) - - async def on_dm(event) -> None: - data = event.payload or {} - prefix = data.get("pubkey_prefix") - text = (data.get("text") or "").strip() - if not prefix or not text: - return - - contact = await _resolve_contact(prefix) - if contact is None: - log.info("ignoring DM from unknown sender %s (no contact after refresh)", prefix) - return - - public_key = contact["public_key"] - contact_name = contact.get("adv_name", "") - log.info("DM from %s (%s): %s", contact_name, public_key[:12], text) - - async with locks[public_key]: - db.upsert_conversation(db_conn, public_key, contact_name) - db.add_message(db_conn, public_key, "user", text) - state.publish("message", { - "public_key": public_key, - "contact_name": contact_name, - "role": "user", - "content": text, - "created_at": _now_iso(), - }) - - if text.strip().lower() == "/clear": - reply = "history cleared." - db.add_message(db_conn, public_key, "assistant", reply) - db.clear_history(db_conn, public_key) - state.publish("message", { - "public_key": public_key, - "contact_name": contact_name, - "role": "assistant", - "content": reply, - "created_at": _now_iso(), - }) - outgoing = trim_to_bytes(reply, cfg.message.max_bytes) - log.info("/clear from %s — context reset", public_key[:12]) - result = await mc.commands.send_msg(contact, outgoing) - if result.type == EventType.ERROR: - log.error("send_msg failed for %s: %s", public_key[:12], result.payload) - return - - history = db.get_history(db_conn, public_key) - - try: - reply = await llm.reply(history) - except Exception: - log.exception("LLM call failed for %s", public_key[:12]) - return - - chunks = split_to_bytes(reply, cfg.message.max_bytes, max_chunks=2) - delivered = "".join(chunks) - db.add_message(db_conn, public_key, "assistant", delivered) - state.publish("message", { - "public_key": public_key, - "contact_name": contact_name, - "role": "assistant", - "content": delivered, - "created_at": _now_iso(), - }) - dropped = len(reply.encode("utf-8")) - len(delivered.encode("utf-8")) - if dropped > 0: - log.info("reply to %s split into %d chunks, dropped %d trailing bytes", - public_key[:12], len(chunks), dropped) - - for i, chunk in enumerate(chunks, 1): - log.info("reply to %s (%d/%d, %d bytes): %s", - public_key[:12], i, len(chunks), len(chunk.encode("utf-8")), chunk) - result = await mc.commands.send_msg(contact, chunk) - if result.type == EventType.ERROR: - log.error("send_msg failed for %s chunk %d/%d: %s", - public_key[:12], i, len(chunks), result.payload) - break + registry = build_default_registry() + on_dm = build_dm_handler( + mc=mc, + db_conn=db_conn, + llm=llm, + registry=registry, + state=state, + cfg=cfg, + ) sub = mc.subscribe(EventType.CONTACT_MSG_RECV, on_dm) await mc.start_auto_message_fetching() diff --git a/src/lorabot/commands.py b/src/lorabot/commands.py new file mode 100644 index 0000000..50adfe1 --- /dev/null +++ b/src/lorabot/commands.py @@ -0,0 +1,121 @@ +"""Slash-command parser and registry for incoming DMs. + +A command is any DM whose text starts with ``/``. The first whitespace-separated +token (case-insensitive) selects the handler; the rest is passed through as the +raw argument string. Handlers return the reply text; ``None`` means "no reply". +""" + +from __future__ import annotations + +import sqlite3 +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from . import db + +if TYPE_CHECKING: + from .config import Settings + from .web import AppState + + +@dataclass +class CommandContext: + """Everything a command handler might need to read state or react.""" + + db_conn: sqlite3.Connection + public_key: str + contact_name: str + cfg: Settings + state: AppState + + +@dataclass +class CommandResult: + """What a command returns. ``after_send`` runs once the assistant turn is persisted.""" + + reply: str | None = None + after_send: Callable[[CommandContext], Awaitable[None]] | None = field(default=None, repr=False) + + +CommandHandler = Callable[[CommandContext, str], Awaitable[CommandResult | str | None]] + + +@dataclass +class Command: + name: str # without leading slash, lowercase + description: str + handler: CommandHandler + + +class CommandRegistry: + def __init__(self) -> None: + self._commands: dict[str, Command] = {} + + def register(self, name: str, description: str) -> Callable[[CommandHandler], CommandHandler]: + def decorator(fn: CommandHandler) -> CommandHandler: + self._commands[name.lower()] = Command(name.lower(), description, fn) + return fn + return decorator + + def list(self) -> list[Command]: + return sorted(self._commands.values(), key=lambda c: c.name) + + @staticmethod + def parse(text: str) -> tuple[str, str] | None: + """Return ``(name, args)`` if ``text`` is a slash command, else ``None``.""" + stripped = text.strip() + if not stripped.startswith("/"): + return None + head, _, rest = stripped[1:].partition(" ") + if not head: + return None + return head.lower(), rest.strip() + + async def dispatch(self, ctx: CommandContext, text: str) -> CommandResult | None: + """If ``text`` is a known command, run it. ``None`` means "not a command".""" + parsed = self.parse(text) + if parsed is None: + return None + name, args = parsed + cmd = self._commands.get(name) + if cmd is None: + return CommandResult(reply=f"unknown command: /{name}") + out = await cmd.handler(ctx, args) + if isinstance(out, CommandResult): + return out + return CommandResult(reply=out) + + +def build_default_registry() -> CommandRegistry: + """Registry with the built-in commands wired up.""" + reg = CommandRegistry() + + async def _clear_after_send(ctx: CommandContext) -> None: + # Bump the watermark *after* the "history cleared." reply is persisted so + # neither side of this exchange leaks into the next LLM context. + db.clear_history(ctx.db_conn, ctx.public_key) + + @reg.register("clear", "reset LLM context for this conversation") + async def _clear(_ctx: CommandContext, _args: str) -> CommandResult: + return CommandResult(reply="history cleared.", after_send=_clear_after_send) + + @reg.register("thinking", "show or set thinking mode: /thinking [on|off]") + async def _thinking(ctx: CommandContext, args: str) -> str: + arg = args.strip().lower() + if not arg: + current = db.get_thinking_enabled(ctx.db_conn, ctx.public_key) + return f"thinking is {'on' if current else 'off'}" + if arg in ("on", "1", "true", "yes"): + db.set_thinking_enabled(ctx.db_conn, ctx.public_key, True) + return "thinking on" + if arg in ("off", "0", "false", "no"): + db.set_thinking_enabled(ctx.db_conn, ctx.public_key, False) + return "thinking off" + return "usage: /thinking [on|off]" + + @reg.register("help", "list available commands") + async def _help(_ctx: CommandContext, _args: str) -> str: + return "\n".join(f"/{c.name} — {c.description}" for c in reg.list()) + + return reg diff --git a/src/lorabot/db.py b/src/lorabot/db.py index d3a23b2..45c32ce 100644 --- a/src/lorabot/db.py +++ b/src/lorabot/db.py @@ -7,19 +7,21 @@ from pathlib import Path SCHEMA = """ CREATE TABLE IF NOT EXISTS conversations ( - public_key TEXT PRIMARY KEY, - contact_name TEXT, - cleared_at_id INTEGER NOT NULL DEFAULT 0, - created_at TEXT NOT NULL DEFAULT (datetime('now')), - updated_at TEXT NOT NULL DEFAULT (datetime('now')) + public_key TEXT PRIMARY KEY, + contact_name TEXT, + cleared_at_id INTEGER NOT NULL DEFAULT 0, + thinking_enabled INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')) ); CREATE TABLE IF NOT EXISTS messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - public_key TEXT NOT NULL REFERENCES conversations(public_key), - role TEXT NOT NULL CHECK (role IN ('user', 'assistant')), - content TEXT NOT NULL, - created_at TEXT NOT NULL DEFAULT (datetime('now')) + id INTEGER PRIMARY KEY AUTOINCREMENT, + public_key TEXT NOT NULL REFERENCES conversations(public_key), + role TEXT NOT NULL CHECK (role IN ('user', 'assistant')), + content TEXT NOT NULL, + hidden_from_llm INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (datetime('now')) ); CREATE INDEX IF NOT EXISTS idx_messages_pubkey_id @@ -42,9 +44,19 @@ def connect(path: str | Path) -> sqlite3.Connection: def _migrate(conn: sqlite3.Connection) -> None: """Apply additive migrations for DBs created before later columns existed.""" - cols = {row["name"] for row in conn.execute("PRAGMA table_info(conversations)")} - if "cleared_at_id" not in cols: + conv_cols = {row["name"] for row in conn.execute("PRAGMA table_info(conversations)")} + if "cleared_at_id" not in conv_cols: conn.execute("ALTER TABLE conversations ADD COLUMN cleared_at_id INTEGER NOT NULL DEFAULT 0") + if "thinking_enabled" not in conv_cols: + conn.execute( + "ALTER TABLE conversations ADD COLUMN thinking_enabled INTEGER NOT NULL DEFAULT 0" + ) + + msg_cols = {row["name"] for row in conn.execute("PRAGMA table_info(messages)")} + if "hidden_from_llm" not in msg_cols: + conn.execute( + "ALTER TABLE messages ADD COLUMN hidden_from_llm INTEGER NOT NULL DEFAULT 0" + ) def upsert_conversation(conn: sqlite3.Connection, public_key: str, contact_name: str) -> None: @@ -60,10 +72,17 @@ def upsert_conversation(conn: sqlite3.Connection, public_key: str, contact_name: ) -def add_message(conn: sqlite3.Connection, public_key: str, role: str, content: str) -> None: +def add_message( + conn: sqlite3.Connection, + public_key: str, + role: str, + content: str, + *, + hidden_from_llm: bool = False, +) -> None: conn.execute( - "INSERT INTO messages (public_key, role, content) VALUES (?, ?, ?)", - (public_key, role, content), + "INSERT INTO messages (public_key, role, content, hidden_from_llm) VALUES (?, ?, ?, ?)", + (public_key, role, content, 1 if hidden_from_llm else 0), ) @@ -78,7 +97,9 @@ def get_history(conn: sqlite3.Connection, public_key: str) -> list[dict[str, str SELECT m.role, m.content FROM messages m JOIN conversations c ON c.public_key = m.public_key - WHERE m.public_key = ? AND m.id > c.cleared_at_id + WHERE m.public_key = ? + AND m.id > c.cleared_at_id + AND m.hidden_from_llm = 0 ORDER BY m.id ASC """, (public_key,), @@ -86,6 +107,25 @@ def get_history(conn: sqlite3.Connection, public_key: str) -> list[dict[str, str return [{"role": row["role"], "content": row["content"]} for row in rows] +def get_thinking_enabled(conn: sqlite3.Connection, public_key: str) -> bool: + row = conn.execute( + "SELECT thinking_enabled FROM conversations WHERE public_key = ?", + (public_key,), + ).fetchone() + return bool(row["thinking_enabled"]) if row is not None else False + + +def set_thinking_enabled(conn: sqlite3.Connection, public_key: str, enabled: bool) -> None: + conn.execute( + """ + UPDATE conversations + SET thinking_enabled = ?, updated_at = datetime('now') + WHERE public_key = ? + """, + (1 if enabled else 0, public_key), + ) + + def clear_history(conn: sqlite3.Connection, public_key: str) -> None: """Bump the per-conversation watermark to the current max message id. diff --git a/src/lorabot/handler.py b/src/lorabot/handler.py new file mode 100644 index 0000000..5d06d1a --- /dev/null +++ b/src/lorabot/handler.py @@ -0,0 +1,114 @@ +"""DM event handler: route incoming DMs through the command registry or the LLM.""" + +from __future__ import annotations + +import asyncio +import logging +import sqlite3 +from collections import defaultdict +from collections.abc import Awaitable, Callable +from datetime import datetime, timezone + +from meshcore import MeshCore + +from . import db +from .commands import CommandContext, CommandRegistry +from .config import Settings +from .llm import LLMClient +from .transport import resolve_contact, send_chunked +from .web import AppState + +log = logging.getLogger("lorabot") + + +def _now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + +def build_dm_handler( + *, + mc: MeshCore, + db_conn: sqlite3.Connection, + llm: LLMClient, + registry: CommandRegistry, + state: AppState, + cfg: Settings, +) -> Callable[[object], Awaitable[None]]: + """Return an ``on_dm(event)`` closure with all collaborators bound.""" + + # One lock per sender so a burst of messages from the same peer is processed + # serially while different peers stay independent. + locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + contacts_lock = asyncio.Lock() + + async def on_dm(event) -> None: + data = event.payload or {} + prefix = data.get("pubkey_prefix") + text = (data.get("text") or "").strip() + if not prefix or not text: + return + + contact = await resolve_contact(mc, prefix, contacts_lock) + if contact is None: + log.info("ignoring DM from unknown sender %s (no contact after refresh)", prefix) + return + + public_key = contact["public_key"] + contact_name = contact.get("adv_name", "") + log.info("DM from %s (%s): %s", contact_name, public_key[:12], text) + + # Decided up front so both turns of a command exchange are stored consistently + # with hidden_from_llm. Commands and their replies stay out of LLM context; + # the web UI still shows them. + is_command = registry.parse(text) is not None + + async with locks[public_key]: + db.upsert_conversation(db_conn, public_key, contact_name) + db.add_message(db_conn, public_key, "user", text, hidden_from_llm=is_command) + state.publish("message", { + "public_key": public_key, + "contact_name": contact_name, + "role": "user", + "content": text, + "created_at": _now_iso(), + }) + + ctx = CommandContext( + db_conn=db_conn, + public_key=public_key, + contact_name=contact_name, + cfg=cfg, + state=state, + ) + + if is_command: + cmd_result = await registry.dispatch(ctx, text) + if cmd_result is None or cmd_result.reply is None: + return + reply = cmd_result.reply + else: + cmd_result = None + thinking = db.get_thinking_enabled(db_conn, public_key) + try: + reply = await llm.reply( + db.get_history(db_conn, public_key), + thinking=thinking, + ) + except Exception: + log.exception("LLM call failed for %s", public_key[:12]) + return + + delivered = await send_chunked(mc, contact, reply, cfg.message.max_bytes) + db.add_message(db_conn, public_key, "assistant", delivered, hidden_from_llm=is_command) + state.publish("message", { + "public_key": public_key, + "contact_name": contact_name, + "role": "assistant", + "content": delivered, + "created_at": _now_iso(), + }) + + if cmd_result is not None and cmd_result.after_send is not None: + await cmd_result.after_send(ctx) + + return on_dm diff --git a/src/lorabot/llm.py b/src/lorabot/llm.py index 1871d40..fa3613b 100644 --- a/src/lorabot/llm.py +++ b/src/lorabot/llm.py @@ -21,8 +21,13 @@ class LLMClient: self._system_prompt = system_prompt self._temperature = temperature - async def reply(self, history: list[dict[str, str]]) -> str: - """Send the system prompt + ``history`` and return the assistant's text.""" + async def reply(self, history: list[dict[str, str]], *, thinking: bool = False) -> str: + """Send the system prompt + ``history`` and return the assistant's text. + + ``thinking`` toggles the llama.cpp/server chat-template kwarg ``enable_thinking``, + which controls whether the model prepends its hidden reasoning turn (Gemma-style). + Passed via OpenAI SDK ``extra_body`` since it is not part of the standard schema. + """ messages: list[dict[str, str]] = [ {"role": "system", "content": self._system_prompt}, *history, @@ -31,6 +36,7 @@ class LLMClient: model=self._model, messages=messages, temperature=self._temperature, + extra_body={"chat_template_kwargs": {"enable_thinking": thinking}}, ) return (resp.choices[0].message.content or "").strip() diff --git a/src/lorabot/transport.py b/src/lorabot/transport.py new file mode 100644 index 0000000..f4868c6 --- /dev/null +++ b/src/lorabot/transport.py @@ -0,0 +1,55 @@ +"""MeshCore-side helpers: contact resolution and chunked sending.""" + +from __future__ import annotations + +import asyncio +import logging + +from meshcore import EventType, MeshCore + +from .messages import split_to_bytes + +log = logging.getLogger("lorabot") + + +async def resolve_contact(mc: MeshCore, prefix: str, lock: asyncio.Lock): + """Look up a contact by pubkey prefix; re-pull contacts from the device on miss.""" + contact = mc.get_contact_by_key_prefix(prefix) + if contact is not None: + return contact + async with lock: + contact = mc.get_contact_by_key_prefix(prefix) + if contact is not None: + return contact + try: + await mc.commands.get_contacts() + except Exception: + log.exception("get_contacts refresh failed") + return None + return mc.get_contact_by_key_prefix(prefix) + + +async def send_chunked(mc: MeshCore, contact, text: str, max_bytes: int, max_chunks: int = 2) -> str: + """Split ``text`` into byte-budgeted chunks and send them in order. + + Stops sending on the first transport error. Returns the actually-delivered text + (concatenated chunks that the caller should record as the assistant turn). + """ + chunks = split_to_bytes(text, max_bytes, max_chunks=max_chunks) + delivered = "".join(chunks) + pk_short = contact["public_key"][:12] + + dropped = len(text.encode("utf-8")) - len(delivered.encode("utf-8")) + if dropped > 0: + log.info("reply to %s split into %d chunks, dropped %d trailing bytes", + pk_short, len(chunks), dropped) + + for i, chunk in enumerate(chunks, 1): + log.info("reply to %s (%d/%d, %d bytes): %s", + pk_short, i, len(chunks), len(chunk.encode("utf-8")), chunk) + result = await mc.commands.send_msg(contact, chunk) + if result.type == EventType.ERROR: + log.error("send_msg failed for %s chunk %d/%d: %s", + pk_short, i, len(chunks), result.payload) + break + return delivered