From 946e85324155899061fce8e8fc305bba19e2be59 Mon Sep 17 00:00:00 2001 From: Tobias Huttinger Date: Wed, 6 May 2026 20:01:28 +0200 Subject: [PATCH] Added tool registry as well as hardening of message sending, checking ACKs and retransmit --- config.example.toml | 20 +++- src/lorabot/__main__.py | 21 +++- src/lorabot/bot.py | 20 +++- src/lorabot/config.py | 19 ++++ src/lorabot/handler.py | 11 +- src/lorabot/llm.py | 104 ++++++++++++++++-- src/lorabot/tools/__init__.py | 29 +++++ src/lorabot/tools/base.py | 90 +++++++++++++++ src/lorabot/tools/weather.py | 201 ++++++++++++++++++++++++++++++++++ src/lorabot/tools/web.py | 193 ++++++++++++++++++++++++++++++++ src/lorabot/transport.py | 126 ++++++++++++++------- 11 files changed, 768 insertions(+), 66 deletions(-) create mode 100644 src/lorabot/tools/__init__.py create mode 100644 src/lorabot/tools/base.py create mode 100644 src/lorabot/tools/weather.py create mode 100644 src/lorabot/tools/web.py diff --git a/config.example.toml b/config.example.toml index 91d1f47..c8051b7 100644 --- a/config.example.toml +++ b/config.example.toml @@ -4,6 +4,11 @@ # LORABOT_LLM__BASE_URL=http://llama:8080/v1 # LORABOT_MESHCORE__SERIAL_PORT=/dev/ttyACM0 +[logging] +# DEBUG | INFO | WARNING | ERROR | CRITICAL (case-insensitive). +# DEBUG adds per-iteration LLM request logs and Tavily request details. +level = "INFO" + [meshcore] serial_port = "/dev/ttyUSB0" baud_rate = 115200 @@ -11,7 +16,7 @@ baud_rate = 115200 [llm] base_url = "http://localhost:8080/v1" api_key = "not-needed" -model = "llama-3.1-8b-instruct" +model = "gemma-4-E4B" system_prompt = "You are a concise assistant on a low-bandwidth mesh radio. Replies must be brief — under 180 bytes." temperature = 0.7 request_timeout_seconds = 60 @@ -23,6 +28,10 @@ sqlite_path = "data/lorabot.db" # MeshCore MAX_PACKET_PAYLOAD is 184 bytes. Lower this if your text-frame # headers further constrain the usable payload on your device. max_bytes = 184 +# Seconds to wait for an ACK before treating a chunk as failed. +ack_timeout_seconds = 30 +# How many times to retry a chunk after failure (0 = no retries). +send_retries = 1 [web] # Built-in read-only web UI: stored conversations + live status. @@ -39,3 +48,12 @@ interval_seconds = 3600 at_startup = true # Flood = multi-hop advert across the mesh. False = zero-hop (neighbors only). flood = false + +# LLM tool calling. The weather tool (Open-Meteo, no key) is always on. Tools +# in this section are optional and only registered when configured. Requires a +# tool-capable model on the LLM server (Llama 3.1, Qwen, Hermes, …); models +# without tool support will simply ignore them. +[tools.tavily] +# Web search + page extraction via https://tavily.com (free tier available). +# Leave empty to disable both web_search and fetch_url tools. +api_key = "" diff --git a/src/lorabot/__main__.py b/src/lorabot/__main__.py index 4a74da0..7566536 100644 --- a/src/lorabot/__main__.py +++ b/src/lorabot/__main__.py @@ -6,15 +6,26 @@ import asyncio import logging from .bot import run +from .config import Settings + +_LOG_FORMAT = "%(asctime)s %(levelname)-7s %(name)s: %(message)s" def _cli() -> None: - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)-7s %(name)s: %(message)s", - ) + # Bootstrap config so any error during Settings() loading is still logged + # nicely. ``force=True`` lets us reapply with the user's level afterwards. + logging.basicConfig(level=logging.INFO, format=_LOG_FORMAT) try: - asyncio.run(run()) + cfg = Settings() + except Exception: + logging.exception("failed to load configuration") + raise SystemExit(1) from None + + level = getattr(logging, cfg.logging.level.upper(), logging.INFO) + logging.basicConfig(level=level, format=_LOG_FORMAT, force=True) + + try: + asyncio.run(run(cfg)) except KeyboardInterrupt: pass diff --git a/src/lorabot/bot.py b/src/lorabot/bot.py index fdcaf3c..d4f0d9b 100644 --- a/src/lorabot/bot.py +++ b/src/lorabot/bot.py @@ -12,12 +12,18 @@ from .commands import build_default_registry from .config import Settings from .handler import build_dm_handler from .llm import LLMClient +from .tools import build_default_registry as build_default_tool_registry +from .transport import MeshTransport log = logging.getLogger("lorabot") -async def run() -> None: - cfg = Settings() +async def run(cfg: Settings | None = None) -> None: + # ``cfg`` is normally built by the entry point so logging can be configured + # from it before we get here; falling back to ``Settings()`` keeps direct + # ``run()`` calls (tests, embedding) working. + if cfg is None: + cfg = Settings() db_conn = db.connect(cfg.storage.sqlite_path) state = web.AppState( @@ -27,6 +33,10 @@ async def run() -> None: loop=asyncio.get_running_loop(), ) + tool_registry = build_default_tool_registry( + tavily_api_key=cfg.tools.tavily.api_key, + ) + llm = LLMClient( base_url=cfg.llm.base_url, api_key=cfg.llm.api_key, @@ -34,6 +44,7 @@ async def run() -> None: system_prompt=cfg.llm.system_prompt, temperature=cfg.llm.temperature, timeout=cfg.llm.request_timeout_seconds, + tools=tool_registry, ) web_task: asyncio.Task | None = None @@ -48,6 +59,7 @@ async def run() -> None: # in the local cache before the peer's first DM lands. mc.auto_update_contacts = True await mc.ensure_contacts() + transport = MeshTransport(mc, ack_timeout=cfg.message.ack_timeout_seconds, send_retries=cfg.message.send_retries) except BaseException: state.set_connected(False) if web_task is not None: @@ -57,6 +69,7 @@ async def run() -> None: except (asyncio.CancelledError, Exception): pass await llm.aclose() + await tool_registry.aclose() db_conn.close() raise state.set_connected(True, node_name=_self_name(mc)) @@ -80,12 +93,12 @@ async def run() -> None: registry = build_default_registry() on_dm = build_dm_handler( - mc=mc, db_conn=db_conn, llm=llm, registry=registry, state=state, cfg=cfg, + transport=transport, ) sub = mc.subscribe(EventType.CONTACT_MSG_RECV, on_dm) @@ -101,6 +114,7 @@ async def run() -> None: await mc.stop_auto_message_fetching() await mc.disconnect() await llm.aclose() + await tool_registry.aclose() for task in (advert_task, web_task): if task is not None: task.cancel() diff --git a/src/lorabot/config.py b/src/lorabot/config.py index b3816d9..a96051b 100644 --- a/src/lorabot/config.py +++ b/src/lorabot/config.py @@ -37,6 +37,8 @@ class StorageCfg(BaseModel): class MessageCfg(BaseModel): max_bytes: int = Field(default=184, gt=0) + ack_timeout_seconds: float = Field(default=30.0, gt=0) + send_retries: int = Field(default=1, ge=0) class WebCfg(BaseModel): @@ -47,6 +49,21 @@ class WebCfg(BaseModel): port: int = Field(default=8080, gt=0, lt=65536) +class LoggingCfg(BaseModel): + # Standard level names: DEBUG, INFO, WARNING, ERROR, CRITICAL. Case-insensitive. + level: str = "INFO" + + +class TavilyCfg(BaseModel): + # Sign up at https://tavily.com for a key. When empty, the web_search and + # fetch_url tools simply aren't registered (the rest of the bot is unaffected). + api_key: str = "" + + +class ToolsCfg(BaseModel): + tavily: TavilyCfg = TavilyCfg() + + class AdvertiseCfg(BaseModel): enabled: bool = True # Seconds between automatic adverts. 0 = manual only (button still works). @@ -68,6 +85,8 @@ class Settings(BaseSettings): message: MessageCfg = MessageCfg() web: WebCfg = WebCfg() advertise: AdvertiseCfg = AdvertiseCfg() + tools: ToolsCfg = ToolsCfg() + logging: LoggingCfg = LoggingCfg() model_config = SettingsConfigDict( env_prefix="LORABOT_", diff --git a/src/lorabot/handler.py b/src/lorabot/handler.py index 9fa5493..bfcf7aa 100644 --- a/src/lorabot/handler.py +++ b/src/lorabot/handler.py @@ -9,13 +9,11 @@ from collections import defaultdict from collections.abc import Awaitable, Callable from datetime import datetime, timezone -from meshcore import MeshCore - from . import db from .commands import CommandContext, CommandRegistry from .config import Settings from .llm import LLMClient -from .transport import resolve_contact, send_chunked +from .transport import MeshTransport from .web import AppState log = logging.getLogger("lorabot") @@ -27,19 +25,18 @@ def _now_iso() -> str: def build_dm_handler( *, - mc: MeshCore, db_conn: sqlite3.Connection, llm: LLMClient, registry: CommandRegistry, state: AppState, cfg: Settings, + transport: MeshTransport, ) -> Callable[[object], Awaitable[None]]: """Return an ``on_dm(event)`` closure with all collaborators bound.""" # One lock per sender so a burst of messages from the same peer is processed # serially while different peers stay independent. locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) - contacts_lock = asyncio.Lock() async def on_dm(event) -> None: data = event.payload or {} @@ -48,7 +45,7 @@ def build_dm_handler( if not prefix or not text: return - contact = await resolve_contact(mc, prefix, contacts_lock) + contact = await transport.resolve_contact(prefix) if contact is None: log.info("ignoring DM from unknown sender %s (no contact after refresh)", prefix) return @@ -98,7 +95,7 @@ def build_dm_handler( log.exception("LLM call failed for %s", public_key[:12]) return - delivered = await send_chunked(mc, contact, reply, cfg.message.max_bytes) + delivered = await transport.send_chunked(contact, reply, cfg.message.max_bytes) if not delivered: # Nothing made it onto the radio; don't persist anything return diff --git a/src/lorabot/llm.py b/src/lorabot/llm.py index fa3613b..a5b85fc 100644 --- a/src/lorabot/llm.py +++ b/src/lorabot/llm.py @@ -2,8 +2,19 @@ from __future__ import annotations +import logging +from typing import Any + from openai import AsyncOpenAI +from .tools import ToolRegistry + +log = logging.getLogger("lorabot") + +# Hard cap on assistant <-> tool round-trips per ``reply`` call. Stops a misbehaving +# model from looping on tool calls forever; in practice 1-2 iterations is normal. +_MAX_TOOL_ITERATIONS = 5 + class LLMClient: def __init__( @@ -15,11 +26,13 @@ class LLMClient: system_prompt: str, temperature: float, timeout: float, + tools: ToolRegistry | None = None, ) -> None: self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, timeout=timeout) self._model = model self._system_prompt = system_prompt self._temperature = temperature + self._tools = tools async def reply(self, history: list[dict[str, str]], *, thinking: bool = False) -> str: """Send the system prompt + ``history`` and return the assistant's text. @@ -27,18 +40,93 @@ class LLMClient: ``thinking`` toggles the llama.cpp/server chat-template kwarg ``enable_thinking``, which controls whether the model prepends its hidden reasoning turn (Gemma-style). Passed via OpenAI SDK ``extra_body`` since it is not part of the standard schema. + + If a ``ToolRegistry`` is wired in, the model may emit ``tool_calls``; we + dispatch them, append the results, and re-call until the model produces + plain content (or the iteration cap kicks in). """ - messages: list[dict[str, str]] = [ + # Local working copy: we'll mutate this with assistant/tool turns across + # the tool loop without touching the persisted DB history. + messages: list[dict[str, Any]] = [ {"role": "system", "content": self._system_prompt}, *history, ] - resp = await self._client.chat.completions.create( - model=self._model, - messages=messages, - temperature=self._temperature, - extra_body={"chat_template_kwargs": {"enable_thinking": thinking}}, - ) - return (resp.choices[0].message.content or "").strip() + + # Build the tool spec list once. ``None`` (not ``[]``) means "don't send + # the tools field at all" — some servers reject an empty list. + tool_specs = self._tools.specs() if self._tools else None + + for iteration in range(_MAX_TOOL_ITERATIONS): + # Only attach ``tools`` when we actually have any registered. + kwargs: dict[str, Any] = {} + if tool_specs: + kwargs["tools"] = tool_specs + + log.debug( + "LLM request: model=%s iter=%d msgs=%d tools=%d thinking=%s", + self._model, + iteration + 1, + len(messages), + len(tool_specs) if tool_specs else 0, + thinking, + ) + resp = await self._client.chat.completions.create( + model=self._model, + messages=messages, + temperature=self._temperature, + extra_body={"chat_template_kwargs": {"enable_thinking": thinking}}, + **kwargs, + ) + msg = resp.choices[0].message + + # No tool calls → this is the final text answer; return it. + # Also bail if tools aren't even configured: nothing to dispatch to. + if not msg.tool_calls or self._tools is None: + return (msg.content or "").strip() + + # Visible at INFO so a normal log lets you see when the model + # actually reached for a tool, and which one(s). + log.info( + "LLM requested tool calls (iter %d): %s", + iteration + 1, + [tc.function.name for tc in msg.tool_calls], + ) + + # Echo the assistant's tool-calling turn back into ``messages`` so the + # next round-trip sees it. The OpenAI protocol requires this exact + # shape (role=assistant + tool_calls list) before role=tool entries. + messages.append({ + "role": "assistant", + "content": msg.content, + "tool_calls": [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in msg.tool_calls + ], + }) + + # Run each tool the model asked for and append its result. Errors are + # surfaced as plain strings inside the registry so the model can see + # them and recover (e.g. retry with a different argument). + for tc in msg.tool_calls: + result = await self._tools.dispatch( + tc.function.name, tc.function.arguments + ) + messages.append({ + "role": "tool", + "tool_call_id": tc.id, + "content": result, + }) + # Loop continues: re-call the model with the tool results in context. + + log.warning("LLM tool loop exceeded %d iterations", _MAX_TOOL_ITERATIONS) + return "(tool loop limit exceeded)" async def aclose(self) -> None: await self._client.close() diff --git a/src/lorabot/tools/__init__.py b/src/lorabot/tools/__init__.py new file mode 100644 index 0000000..9769dba --- /dev/null +++ b/src/lorabot/tools/__init__.py @@ -0,0 +1,29 @@ +"""LLM tool calling: registry + built-in tools.""" + +from __future__ import annotations + +from .base import Tool, ToolRegistry +from .weather import WeatherTool +from .web import FetchUrlTool, WebSearchTool + + +def build_default_registry(*, tavily_api_key: str = "") -> ToolRegistry: + """Build the default registry. Tavily-backed tools are only registered + when an API key is configured — without one, the bot silently runs with + just the offline tools.""" + reg = ToolRegistry() + reg.register(WeatherTool()) + if tavily_api_key: + reg.register(WebSearchTool(api_key=tavily_api_key)) + reg.register(FetchUrlTool(api_key=tavily_api_key)) + return reg + + +__all__ = [ + "Tool", + "ToolRegistry", + "WeatherTool", + "WebSearchTool", + "FetchUrlTool", + "build_default_registry", +] diff --git a/src/lorabot/tools/base.py b/src/lorabot/tools/base.py new file mode 100644 index 0000000..ea7d843 --- /dev/null +++ b/src/lorabot/tools/base.py @@ -0,0 +1,90 @@ +"""Tool abstraction and registry for OpenAI-style function calling.""" + +from __future__ import annotations + +import json +import logging +import time +from abc import ABC, abstractmethod +from typing import Any + +log = logging.getLogger("lorabot") + + +def _truncate_for_log(s: str, max_len: int = 200) -> str: + """Cap a log field so a giant tool argument doesn't blow up log lines.""" + return s if len(s) <= max_len else s[:max_len] + "…" + + +class Tool(ABC): + """Single LLM-callable function. Subclasses set ``name``/``description``/ + ``parameters`` (JSON Schema) and implement ``run``.""" + + name: str + description: str + parameters: dict[str, Any] + + @abstractmethod + async def run(self, **kwargs: Any) -> str: + ... + + async def aclose(self) -> None: + return None + + def spec(self) -> dict[str, Any]: + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": self.parameters, + }, + } + + +class ToolRegistry: + def __init__(self) -> None: + self._tools: dict[str, Tool] = {} + + def register(self, tool: Tool) -> None: + self._tools[tool.name] = tool + + def specs(self) -> list[dict[str, Any]]: + return [t.spec() for t in self._tools.values()] + + def __bool__(self) -> bool: + return bool(self._tools) + + async def dispatch(self, name: str, arguments_json: str) -> str: + """Run the named tool with JSON-encoded arguments. Errors are returned + as plain strings so the LLM can see and react to them.""" + tool = self._tools.get(name) + if tool is None: + log.warning("tool call rejected: unknown tool %r", name) + return f"error: unknown tool {name!r}" + try: + args = json.loads(arguments_json) if arguments_json else {} + except json.JSONDecodeError as exc: + log.warning("tool %s: bad arguments JSON: %s", name, exc) + return f"error: bad arguments JSON ({exc})" + + # Log entry + exit so failures are easy to attribute and we get a + # rough timing picture per tool. Args are truncated to keep logs sane. + log.info("tool %s called: %s", name, _truncate_for_log(arguments_json or "{}")) + started = time.monotonic() + try: + result = await tool.run(**args) + except Exception as exc: + elapsed_ms = (time.monotonic() - started) * 1000 + log.exception("tool %s failed after %.0f ms", name, elapsed_ms) + return f"error: {exc}" + elapsed_ms = (time.monotonic() - started) * 1000 + log.info("tool %s ok in %.0f ms (%d chars)", name, elapsed_ms, len(result)) + return result + + async def aclose(self) -> None: + for tool in self._tools.values(): + try: + await tool.aclose() + except Exception: + log.exception("aclose failed for tool %s", tool.name) diff --git a/src/lorabot/tools/weather.py b/src/lorabot/tools/weather.py new file mode 100644 index 0000000..d393d66 --- /dev/null +++ b/src/lorabot/tools/weather.py @@ -0,0 +1,201 @@ +"""Weather tool backed by Open-Meteo (no API key). + +Returns the current observation, and optionally a multi-day daily forecast +(min/max temperature + weather code per day, in the location's local timezone). +""" + +from __future__ import annotations + +from datetime import date + +import aiohttp + +from .base import Tool + +GEOCODE_URL = "https://geocoding-api.open-meteo.com/v1/search" +FORECAST_URL = "https://api.open-meteo.com/v1/forecast" + +# Hard upper bound on daily forecast length. Open-Meteo allows up to 16 days but +# bandwidth and model context favor keeping replies compact. +_MAX_FORECAST_DAYS = 7 + +# WMO weather interpretation codes — compact summaries to keep LoRa replies short. +_WMO: dict[int, str] = { + 0: "clear", + 1: "mostly clear", 2: "partly cloudy", 3: "overcast", + 45: "fog", 48: "rime fog", + 51: "light drizzle", 53: "drizzle", 55: "heavy drizzle", + 56: "freezing drizzle", 57: "freezing drizzle", + 61: "light rain", 63: "rain", 65: "heavy rain", + 66: "freezing rain", 67: "freezing rain", + 71: "light snow", 73: "snow", 75: "heavy snow", 77: "snow grains", + 80: "rain showers", 81: "rain showers", 82: "violent showers", + 85: "snow showers", 86: "heavy snow showers", + 95: "thunderstorm", 96: "thunderstorm w/ hail", 99: "thunderstorm w/ hail", +} + + +class WeatherTool(Tool): + name = "get_weather" + description = ( + "Current weather conditions for a place, optionally followed by a daily " + "forecast. Accepts a location name (e.g. 'Berlin', 'San Francisco, US') " + "or 'lat,lon' decimal coordinates. Set forecast_days to include the next " + "N days (today inclusive); 0 returns only the current observation." + ) + parameters = { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City/place name, or 'lat,lon' decimal coordinates.", + }, + "forecast_days": { + "type": "integer", + "description": ( + f"Number of daily-forecast days to include (0–{_MAX_FORECAST_DAYS}, " + "today inclusive). Defaults to 0 (current weather only)." + ), + "minimum": 0, + "maximum": _MAX_FORECAST_DAYS, + "default": 0, + }, + }, + "required": ["location"], + } + + def __init__(self) -> None: + self._session: aiohttp.ClientSession | None = None + + async def aclose(self) -> None: + if self._session is not None and not self._session.closed: + await self._session.close() + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=10) + ) + return self._session + + async def run(self, location: str = "", forecast_days: int = 0) -> str: + location = (location or "").strip() + if not location: + return "error: location is required" + + # Clamp defensively: model may ignore the schema bounds. + try: + forecast_days = max(0, min(int(forecast_days), _MAX_FORECAST_DAYS)) + except (TypeError, ValueError): + forecast_days = 0 + + coords = _parse_latlon(location) + if coords is not None: + lat, lon = coords + label = f"{lat:.3f},{lon:.3f}" + else: + geo = await self._geocode(location) + if geo is None: + return f"error: could not find location {location!r}" + lat, lon, label = geo + + return await self._forecast(lat, lon, label, forecast_days) + + async def _geocode(self, name: str) -> tuple[float, float, str] | None: + session = await self._get_session() + async with session.get( + GEOCODE_URL, + params={"name": name, "count": "1", "format": "json"}, + ) as resp: + resp.raise_for_status() + data = await resp.json() + results = data.get("results") or [] + if not results: + return None + r = results[0] + bits = [r.get("name") or name] + if cc := r.get("country_code"): + bits.append(str(cc)) + return float(r["latitude"]), float(r["longitude"]), ", ".join(bits) + + async def _forecast( + self, lat: float, lon: float, label: str, forecast_days: int + ) -> str: + # Build the request. ``timezone=auto`` makes daily aggregations align to + # the location's local calendar day instead of UTC. + params: dict[str, str] = { + "latitude": str(lat), + "longitude": str(lon), + "current": "temperature_2m,weather_code,wind_speed_10m", + "timezone": "auto", + } + if forecast_days > 0: + params["daily"] = "temperature_2m_max,temperature_2m_min,weather_code" + params["forecast_days"] = str(forecast_days) + + session = await self._get_session() + async with session.get(FORECAST_URL, params=params) as resp: + resp.raise_for_status() + data = await resp.json() + + # Current observation line (always present). + cur = data.get("current") or {} + temp = cur.get("temperature_2m") + code = cur.get("weather_code") + wind = cur.get("wind_speed_10m") + cond = _WMO.get(int(code), f"code {code}") if code is not None else "?" + cur_bits = [f"{label} now"] + if temp is not None: + cur_bits.append(f"{temp}°C") + cur_bits.append(cond) + if wind is not None: + cur_bits.append(f"wind {wind} km/h") + lines = [", ".join(cur_bits)] + + # Daily forecast block (one line per day): "Mon 8-17°C, partly cloudy". + # Open-Meteo returns parallel arrays under ``daily`` keyed by date. + if forecast_days > 0: + daily = data.get("daily") or {} + days = daily.get("time") or [] + tmax = daily.get("temperature_2m_max") or [] + tmin = daily.get("temperature_2m_min") or [] + codes = daily.get("weather_code") or [] + for i, day_str in enumerate(days): + day_label = _short_day(day_str) + hi = tmax[i] if i < len(tmax) else None + lo = tmin[i] if i < len(tmin) else None + day_code = codes[i] if i < len(codes) else None + day_cond = ( + _WMO.get(int(day_code), f"code {day_code}") + if day_code is not None + else "?" + ) + if lo is not None and hi is not None: + lines.append(f"{day_label} {lo}–{hi}°C, {day_cond}") + else: + lines.append(f"{day_label} {day_cond}") + + return "; ".join(lines) + + +def _short_day(iso_date: str) -> str: + """Render an ISO date (``YYYY-MM-DD``) as a short weekday (``Mon``). + Falls back to the raw string if parsing fails.""" + try: + return date.fromisoformat(iso_date).strftime("%a") + except (TypeError, ValueError): + return iso_date + + +def _parse_latlon(text: str) -> tuple[float, float] | None: + if "," not in text: + return None + a, b = text.split(",", 1) + try: + lat = float(a.strip()) + lon = float(b.strip()) + except ValueError: + return None + if not (-90 <= lat <= 90 and -180 <= lon <= 180): + return None + return lat, lon diff --git a/src/lorabot/tools/web.py b/src/lorabot/tools/web.py new file mode 100644 index 0000000..94f9452 --- /dev/null +++ b/src/lorabot/tools/web.py @@ -0,0 +1,193 @@ +"""Web search + page-extraction tools backed by Tavily (https://tavily.com). + +Exposes two tools to the LLM: + +- ``web_search`` — query Tavily for a list of {title, url, snippet} hits. +- ``fetch_url`` — pull the readable body text of a single page via Tavily's + ``/extract`` endpoint (so we don't have to parse HTML ourselves). + +Tavily handles all the messy bits — URL fetching, redirects, boilerplate +stripping, robots — so this module is just a thin HTTP wrapper. +""" + +from __future__ import annotations + +import logging + +import aiohttp + +from .base import Tool + +log = logging.getLogger("lorabot") + +_SEARCH_URL = "https://api.tavily.com/search" +_EXTRACT_URL = "https://api.tavily.com/extract" + +# Sane defaults for the search tool. Upper bound keeps the tool message small +# enough that the LoRa-side LLM can still reason over it. +_DEFAULT_RESULTS = 5 +_MAX_RESULTS = 10 + +# Cap extracted-page body size so a giant article doesn't blow the model's +# context. ~4 KB is plenty to summarize from for a 180-byte LoRa reply. +_MAX_EXTRACT_CHARS = 4000 + + +class _TavilyHTTP: + """Tiny per-tool HTTP helper. Each tool instance owns one ``aiohttp`` + session; lazily created on first request, closed via ``aclose``.""" + + def __init__(self, api_key: str) -> None: + self._api_key = api_key + self._session: aiohttp.ClientSession | None = None + + async def post(self, url: str, payload: dict) -> dict: + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=20) + ) + headers = {"Authorization": f"Bearer {self._api_key}"} + log.debug("tavily POST %s payload-keys=%s", url, list(payload)) + async with self._session.post(url, json=payload, headers=headers) as resp: + # On non-2xx, capture the response body before raising so the log + # tells us *why* (auth, quota, bad query) — raise_for_status alone + # would only show the status code. + if resp.status >= 400: + body = await resp.text() + log.warning( + "tavily POST %s -> %d: %s", url, resp.status, body[:200] + ) + resp.raise_for_status() + return await resp.json() + + async def aclose(self) -> None: + if self._session is not None and not self._session.closed: + await self._session.close() + + +class WebSearchTool(Tool): + name = "web_search" + description = ( + "Search the web. Returns titles, URLs and SHORT snippets — useful for " + "headlines, finding sources and orientation. Snippets are a few " + "sentences at most and are often truncated mid-thought. " + "If the user asked for any of: specific numbers, prices, dates, " + "percentages, technical specs, exact quotes, step-by-step " + "instructions, or 'current'/'latest' data — you MUST also call " + "fetch_url on the most relevant result before answering." + ) + parameters = { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query.", + }, + "max_results": { + "type": "integer", + "description": ( + f"How many results to return (1–{_MAX_RESULTS}). " + f"Default {_DEFAULT_RESULTS}." + ), + "minimum": 1, + "maximum": _MAX_RESULTS, + "default": _DEFAULT_RESULTS, + }, + }, + "required": ["query"], + } + + def __init__(self, api_key: str) -> None: + self._http = _TavilyHTTP(api_key) + + async def aclose(self) -> None: + await self._http.aclose() + + async def run(self, query: str = "", max_results: int = _DEFAULT_RESULTS) -> str: + query = (query or "").strip() + if not query: + return "error: query is required" + + # Defensive clamp — model may ignore the JSON-schema bounds. + try: + n = max(1, min(int(max_results), _MAX_RESULTS)) + except (TypeError, ValueError): + n = _DEFAULT_RESULTS + + data = await self._http.post( + _SEARCH_URL, + { + "query": query, + "max_results": n, + # "basic" is fast and cheap; "advanced" costs more credits but + # returns higher-quality snippets. Stick with basic for now. + "search_depth": "basic", + }, + ) + results = data.get("results") or [] + if not results: + return "no results" + + # Format as a numbered list — readable for the model and compact enough + # to fit comfortably in context alongside the rest of the conversation. + lines: list[str] = [] + for i, r in enumerate(results, 1): + title = (r.get("title") or "").strip() or "(no title)" + url = r.get("url") or "" + snippet = (r.get("content") or "").strip() + lines.append(f"{i}. {title} — {url}\n {snippet}") + return "\n".join(lines) + + +class FetchUrlTool(Tool): + name = "fetch_url" + description = ( + "Fetch the full readable body text of a single web page. Call this " + "after web_search whenever the user asked for specifics — exact " + "numbers, prices, dates, quotes, technical details, or current data — " + "even if a snippet looks like it might answer. Snippets are short and " + "often misleading; the full page is the source of truth. Long pages " + "are truncated to keep context manageable." + ) + parameters = { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "Absolute URL of the page to fetch.", + }, + }, + "required": ["url"], + } + + def __init__(self, api_key: str) -> None: + self._http = _TavilyHTTP(api_key) + + async def aclose(self) -> None: + await self._http.aclose() + + async def run(self, url: str = "") -> str: + url = (url or "").strip() + if not url: + return "error: url is required" + + data = await self._http.post( + _EXTRACT_URL, + {"urls": [url], "extract_depth": "basic"}, + ) + results = data.get("results") or [] + if not results: + # Tavily reports unreachable / blocked pages here instead of in + # ``results``. Surface the reason so the model can react. + failed = data.get("failed_results") or [] + if failed: + reason = (failed[0] or {}).get("error", "unknown") + return f"error: extract failed ({reason})" + return "error: no content extracted" + + body = (results[0].get("raw_content") or "").strip() + if not body: + return "error: page returned empty content" + if len(body) > _MAX_EXTRACT_CHARS: + body = body[:_MAX_EXTRACT_CHARS] + " …[truncated]" + return body diff --git a/src/lorabot/transport.py b/src/lorabot/transport.py index 41669e1..42c3a39 100644 --- a/src/lorabot/transport.py +++ b/src/lorabot/transport.py @@ -1,9 +1,10 @@ -"""MeshCore-side helpers: contact resolution and chunked sending.""" +"""MeshCore-side transport: contact resolution and reliable chunked sending.""" from __future__ import annotations import asyncio import logging +from collections import defaultdict from meshcore import EventType, MeshCore @@ -11,48 +12,89 @@ from .messages import split_to_bytes log = logging.getLogger("lorabot") +class MeshTransport: + """Owns the MeshCore connection reference; handles contact resolution and + reliable chunked message delivery. -async def resolve_contact(mc: MeshCore, prefix: str, lock: asyncio.Lock): - """Look up a contact by pubkey prefix; re-pull contacts from the device on miss.""" - contact = mc.get_contact_by_key_prefix(prefix) - if contact is not None: - return contact - async with lock: - contact = mc.get_contact_by_key_prefix(prefix) + Create one instance per MeshCore connection and pass it to build_dm_handler. + """ + + def __init__(self, mc: MeshCore, ack_timeout: float = 30.0, send_retries: int = 1) -> None: + self._mc = mc + self._ack_timeout = ack_timeout + self._send_retries = send_retries + self._contacts_lock = asyncio.Lock() + self._send_locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + + async def resolve_contact(self, prefix: str): + """Look up a contact by pubkey prefix; re-pull contacts from the device on miss.""" + contact = self._mc.get_contact_by_key_prefix(prefix) if contact is not None: return contact + async with self._contacts_lock: + contact = self._mc.get_contact_by_key_prefix(prefix) + if contact is not None: + return contact + try: + await self._mc.commands.get_contacts() + except Exception: + log.exception("get_contacts refresh failed") + return None + return self._mc.get_contact_by_key_prefix(prefix) + + async def send_chunked( + self, + contact, + text: str, + max_bytes: int, + max_chunks: int = 2, + ) -> str: + """Split ``text`` and send chunks in order, waiting for ACK before each next one. + + Returns the concatenation of chunks that were successfully delivered. + """ + chunks = split_to_bytes(text, max_bytes, max_chunks=max_chunks) + pk = contact["public_key"] + pk_short = pk[:12] + + planned_bytes = sum(len(c.encode("utf-8")) for c in chunks) + dropped = len(text.encode("utf-8")) - planned_bytes + if dropped > 0: + log.info("reply to %s split into %d chunks, dropped %d trailing bytes", + pk_short, len(chunks), dropped) + + sent: list[str] = [] + async with self._send_locks[pk]: + for i, chunk in enumerate(chunks, 1): + log.info("reply to %s (%d/%d, %d bytes): %s", + pk_short, i, len(chunks), len(chunk.encode("utf-8")), chunk) + if not await self._send_chunk(contact, chunk): + break + sent.append(chunk) + return "".join(sent) + + async def _send_chunk(self, contact, chunk: str) -> bool: + """Send one chunk, retrying once on failure.""" + for attempt in range(self._send_retries + 1): + if await self._attempt_send(contact, chunk): + return True + if attempt < self._send_retries: + log.info("retrying chunk for %s (attempt %d/%d)", + contact["public_key"][:12], attempt + 2, self._send_retries + 1) + log.error("chunk delivery failed for %s after %d attempts", + contact["public_key"][:12], self._send_retries + 1) + return False + + async def _attempt_send(self, contact, chunk: str) -> bool: + """One send attempt. Returns True on ACK, False on device error or timeout.""" try: - await mc.commands.get_contacts() - except Exception: - log.exception("get_contacts refresh failed") - return None - return mc.get_contact_by_key_prefix(prefix) - - -async def send_chunked(mc: MeshCore, contact, text: str, max_bytes: int, max_chunks: int = 2) -> str: - """Split ``text`` into byte-budgeted chunks and send them in order. - - Stops sending on the first transport error. Returns the concatenation of the - chunks that were actually accepted by the device, so the caller records the - truth — not what we hoped to send. - """ - chunks = split_to_bytes(text, max_bytes, max_chunks=max_chunks) - pk_short = contact["public_key"][:12] - planned_bytes = sum(len(c.encode("utf-8")) for c in chunks) - - dropped = len(text.encode("utf-8")) - planned_bytes - if dropped > 0: - log.info("reply to %s split into %d chunks, dropped %d trailing bytes", - pk_short, len(chunks), dropped) - - sent: list[str] = [] - for i, chunk in enumerate(chunks, 1): - log.info("reply to %s (%d/%d, %d bytes): %s", - pk_short, i, len(chunks), len(chunk.encode("utf-8")), chunk) - result = await mc.commands.send_msg(contact, chunk) - if result.type == EventType.ERROR: - log.error("send_msg failed for %s chunk %d/%d: %s", - pk_short, i, len(chunks), result.payload) - break - sent.append(chunk) - return "".join(sent) + result = await asyncio.wait_for( + self._mc.commands.send_msg(contact, chunk), timeout=self._ack_timeout + ) + if result.type != EventType.ERROR: + return True + log.warning("send_msg error for %s: %s", contact["public_key"][:12], result.payload) + except asyncio.TimeoutError: + log.warning("ACK timeout for %s after %.1fs", + contact["public_key"][:12], self._ack_timeout) + return False