Added tool registry as well as hardening of message sending, checking ACKs and retransmit
This commit is contained in:
+19
-1
@@ -4,6 +4,11 @@
|
|||||||
# LORABOT_LLM__BASE_URL=http://llama:8080/v1
|
# LORABOT_LLM__BASE_URL=http://llama:8080/v1
|
||||||
# LORABOT_MESHCORE__SERIAL_PORT=/dev/ttyACM0
|
# LORABOT_MESHCORE__SERIAL_PORT=/dev/ttyACM0
|
||||||
|
|
||||||
|
[logging]
|
||||||
|
# DEBUG | INFO | WARNING | ERROR | CRITICAL (case-insensitive).
|
||||||
|
# DEBUG adds per-iteration LLM request logs and Tavily request details.
|
||||||
|
level = "INFO"
|
||||||
|
|
||||||
[meshcore]
|
[meshcore]
|
||||||
serial_port = "/dev/ttyUSB0"
|
serial_port = "/dev/ttyUSB0"
|
||||||
baud_rate = 115200
|
baud_rate = 115200
|
||||||
@@ -11,7 +16,7 @@ baud_rate = 115200
|
|||||||
[llm]
|
[llm]
|
||||||
base_url = "http://localhost:8080/v1"
|
base_url = "http://localhost:8080/v1"
|
||||||
api_key = "not-needed"
|
api_key = "not-needed"
|
||||||
model = "llama-3.1-8b-instruct"
|
model = "gemma-4-E4B"
|
||||||
system_prompt = "You are a concise assistant on a low-bandwidth mesh radio. Replies must be brief — under 180 bytes."
|
system_prompt = "You are a concise assistant on a low-bandwidth mesh radio. Replies must be brief — under 180 bytes."
|
||||||
temperature = 0.7
|
temperature = 0.7
|
||||||
request_timeout_seconds = 60
|
request_timeout_seconds = 60
|
||||||
@@ -23,6 +28,10 @@ sqlite_path = "data/lorabot.db"
|
|||||||
# MeshCore MAX_PACKET_PAYLOAD is 184 bytes. Lower this if your text-frame
|
# MeshCore MAX_PACKET_PAYLOAD is 184 bytes. Lower this if your text-frame
|
||||||
# headers further constrain the usable payload on your device.
|
# headers further constrain the usable payload on your device.
|
||||||
max_bytes = 184
|
max_bytes = 184
|
||||||
|
# Seconds to wait for an ACK before treating a chunk as failed.
|
||||||
|
ack_timeout_seconds = 30
|
||||||
|
# How many times to retry a chunk after failure (0 = no retries).
|
||||||
|
send_retries = 1
|
||||||
|
|
||||||
[web]
|
[web]
|
||||||
# Built-in read-only web UI: stored conversations + live status.
|
# Built-in read-only web UI: stored conversations + live status.
|
||||||
@@ -39,3 +48,12 @@ interval_seconds = 3600
|
|||||||
at_startup = true
|
at_startup = true
|
||||||
# Flood = multi-hop advert across the mesh. False = zero-hop (neighbors only).
|
# Flood = multi-hop advert across the mesh. False = zero-hop (neighbors only).
|
||||||
flood = false
|
flood = false
|
||||||
|
|
||||||
|
# LLM tool calling. The weather tool (Open-Meteo, no key) is always on. Tools
|
||||||
|
# in this section are optional and only registered when configured. Requires a
|
||||||
|
# tool-capable model on the LLM server (Llama 3.1, Qwen, Hermes, …); models
|
||||||
|
# without tool support will simply ignore them.
|
||||||
|
[tools.tavily]
|
||||||
|
# Web search + page extraction via https://tavily.com (free tier available).
|
||||||
|
# Leave empty to disable both web_search and fetch_url tools.
|
||||||
|
api_key = ""
|
||||||
|
|||||||
+16
-5
@@ -6,15 +6,26 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .bot import run
|
from .bot import run
|
||||||
|
from .config import Settings
|
||||||
|
|
||||||
|
_LOG_FORMAT = "%(asctime)s %(levelname)-7s %(name)s: %(message)s"
|
||||||
|
|
||||||
|
|
||||||
def _cli() -> None:
|
def _cli() -> None:
|
||||||
logging.basicConfig(
|
# Bootstrap config so any error during Settings() loading is still logged
|
||||||
level=logging.INFO,
|
# nicely. ``force=True`` lets us reapply with the user's level afterwards.
|
||||||
format="%(asctime)s %(levelname)-7s %(name)s: %(message)s",
|
logging.basicConfig(level=logging.INFO, format=_LOG_FORMAT)
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
asyncio.run(run())
|
cfg = Settings()
|
||||||
|
except Exception:
|
||||||
|
logging.exception("failed to load configuration")
|
||||||
|
raise SystemExit(1) from None
|
||||||
|
|
||||||
|
level = getattr(logging, cfg.logging.level.upper(), logging.INFO)
|
||||||
|
logging.basicConfig(level=level, format=_LOG_FORMAT, force=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(run(cfg))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
+17
-3
@@ -12,12 +12,18 @@ from .commands import build_default_registry
|
|||||||
from .config import Settings
|
from .config import Settings
|
||||||
from .handler import build_dm_handler
|
from .handler import build_dm_handler
|
||||||
from .llm import LLMClient
|
from .llm import LLMClient
|
||||||
|
from .tools import build_default_registry as build_default_tool_registry
|
||||||
|
from .transport import MeshTransport
|
||||||
|
|
||||||
log = logging.getLogger("lorabot")
|
log = logging.getLogger("lorabot")
|
||||||
|
|
||||||
|
|
||||||
async def run() -> None:
|
async def run(cfg: Settings | None = None) -> None:
|
||||||
cfg = Settings()
|
# ``cfg`` is normally built by the entry point so logging can be configured
|
||||||
|
# from it before we get here; falling back to ``Settings()`` keeps direct
|
||||||
|
# ``run()`` calls (tests, embedding) working.
|
||||||
|
if cfg is None:
|
||||||
|
cfg = Settings()
|
||||||
|
|
||||||
db_conn = db.connect(cfg.storage.sqlite_path)
|
db_conn = db.connect(cfg.storage.sqlite_path)
|
||||||
state = web.AppState(
|
state = web.AppState(
|
||||||
@@ -27,6 +33,10 @@ async def run() -> None:
|
|||||||
loop=asyncio.get_running_loop(),
|
loop=asyncio.get_running_loop(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tool_registry = build_default_tool_registry(
|
||||||
|
tavily_api_key=cfg.tools.tavily.api_key,
|
||||||
|
)
|
||||||
|
|
||||||
llm = LLMClient(
|
llm = LLMClient(
|
||||||
base_url=cfg.llm.base_url,
|
base_url=cfg.llm.base_url,
|
||||||
api_key=cfg.llm.api_key,
|
api_key=cfg.llm.api_key,
|
||||||
@@ -34,6 +44,7 @@ async def run() -> None:
|
|||||||
system_prompt=cfg.llm.system_prompt,
|
system_prompt=cfg.llm.system_prompt,
|
||||||
temperature=cfg.llm.temperature,
|
temperature=cfg.llm.temperature,
|
||||||
timeout=cfg.llm.request_timeout_seconds,
|
timeout=cfg.llm.request_timeout_seconds,
|
||||||
|
tools=tool_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
web_task: asyncio.Task | None = None
|
web_task: asyncio.Task | None = None
|
||||||
@@ -48,6 +59,7 @@ async def run() -> None:
|
|||||||
# in the local cache before the peer's first DM lands.
|
# in the local cache before the peer's first DM lands.
|
||||||
mc.auto_update_contacts = True
|
mc.auto_update_contacts = True
|
||||||
await mc.ensure_contacts()
|
await mc.ensure_contacts()
|
||||||
|
transport = MeshTransport(mc, ack_timeout=cfg.message.ack_timeout_seconds, send_retries=cfg.message.send_retries)
|
||||||
except BaseException:
|
except BaseException:
|
||||||
state.set_connected(False)
|
state.set_connected(False)
|
||||||
if web_task is not None:
|
if web_task is not None:
|
||||||
@@ -57,6 +69,7 @@ async def run() -> None:
|
|||||||
except (asyncio.CancelledError, Exception):
|
except (asyncio.CancelledError, Exception):
|
||||||
pass
|
pass
|
||||||
await llm.aclose()
|
await llm.aclose()
|
||||||
|
await tool_registry.aclose()
|
||||||
db_conn.close()
|
db_conn.close()
|
||||||
raise
|
raise
|
||||||
state.set_connected(True, node_name=_self_name(mc))
|
state.set_connected(True, node_name=_self_name(mc))
|
||||||
@@ -80,12 +93,12 @@ async def run() -> None:
|
|||||||
|
|
||||||
registry = build_default_registry()
|
registry = build_default_registry()
|
||||||
on_dm = build_dm_handler(
|
on_dm = build_dm_handler(
|
||||||
mc=mc,
|
|
||||||
db_conn=db_conn,
|
db_conn=db_conn,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
registry=registry,
|
registry=registry,
|
||||||
state=state,
|
state=state,
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
|
transport=transport,
|
||||||
)
|
)
|
||||||
|
|
||||||
sub = mc.subscribe(EventType.CONTACT_MSG_RECV, on_dm)
|
sub = mc.subscribe(EventType.CONTACT_MSG_RECV, on_dm)
|
||||||
@@ -101,6 +114,7 @@ async def run() -> None:
|
|||||||
await mc.stop_auto_message_fetching()
|
await mc.stop_auto_message_fetching()
|
||||||
await mc.disconnect()
|
await mc.disconnect()
|
||||||
await llm.aclose()
|
await llm.aclose()
|
||||||
|
await tool_registry.aclose()
|
||||||
for task in (advert_task, web_task):
|
for task in (advert_task, web_task):
|
||||||
if task is not None:
|
if task is not None:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|||||||
@@ -37,6 +37,8 @@ class StorageCfg(BaseModel):
|
|||||||
|
|
||||||
class MessageCfg(BaseModel):
|
class MessageCfg(BaseModel):
|
||||||
max_bytes: int = Field(default=184, gt=0)
|
max_bytes: int = Field(default=184, gt=0)
|
||||||
|
ack_timeout_seconds: float = Field(default=30.0, gt=0)
|
||||||
|
send_retries: int = Field(default=1, ge=0)
|
||||||
|
|
||||||
|
|
||||||
class WebCfg(BaseModel):
|
class WebCfg(BaseModel):
|
||||||
@@ -47,6 +49,21 @@ class WebCfg(BaseModel):
|
|||||||
port: int = Field(default=8080, gt=0, lt=65536)
|
port: int = Field(default=8080, gt=0, lt=65536)
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingCfg(BaseModel):
|
||||||
|
# Standard level names: DEBUG, INFO, WARNING, ERROR, CRITICAL. Case-insensitive.
|
||||||
|
level: str = "INFO"
|
||||||
|
|
||||||
|
|
||||||
|
class TavilyCfg(BaseModel):
|
||||||
|
# Sign up at https://tavily.com for a key. When empty, the web_search and
|
||||||
|
# fetch_url tools simply aren't registered (the rest of the bot is unaffected).
|
||||||
|
api_key: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class ToolsCfg(BaseModel):
|
||||||
|
tavily: TavilyCfg = TavilyCfg()
|
||||||
|
|
||||||
|
|
||||||
class AdvertiseCfg(BaseModel):
|
class AdvertiseCfg(BaseModel):
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
# Seconds between automatic adverts. 0 = manual only (button still works).
|
# Seconds between automatic adverts. 0 = manual only (button still works).
|
||||||
@@ -68,6 +85,8 @@ class Settings(BaseSettings):
|
|||||||
message: MessageCfg = MessageCfg()
|
message: MessageCfg = MessageCfg()
|
||||||
web: WebCfg = WebCfg()
|
web: WebCfg = WebCfg()
|
||||||
advertise: AdvertiseCfg = AdvertiseCfg()
|
advertise: AdvertiseCfg = AdvertiseCfg()
|
||||||
|
tools: ToolsCfg = ToolsCfg()
|
||||||
|
logging: LoggingCfg = LoggingCfg()
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
env_prefix="LORABOT_",
|
env_prefix="LORABOT_",
|
||||||
|
|||||||
@@ -9,13 +9,11 @@ from collections import defaultdict
|
|||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from meshcore import MeshCore
|
|
||||||
|
|
||||||
from . import db
|
from . import db
|
||||||
from .commands import CommandContext, CommandRegistry
|
from .commands import CommandContext, CommandRegistry
|
||||||
from .config import Settings
|
from .config import Settings
|
||||||
from .llm import LLMClient
|
from .llm import LLMClient
|
||||||
from .transport import resolve_contact, send_chunked
|
from .transport import MeshTransport
|
||||||
from .web import AppState
|
from .web import AppState
|
||||||
|
|
||||||
log = logging.getLogger("lorabot")
|
log = logging.getLogger("lorabot")
|
||||||
@@ -27,19 +25,18 @@ def _now_iso() -> str:
|
|||||||
|
|
||||||
def build_dm_handler(
|
def build_dm_handler(
|
||||||
*,
|
*,
|
||||||
mc: MeshCore,
|
|
||||||
db_conn: sqlite3.Connection,
|
db_conn: sqlite3.Connection,
|
||||||
llm: LLMClient,
|
llm: LLMClient,
|
||||||
registry: CommandRegistry,
|
registry: CommandRegistry,
|
||||||
state: AppState,
|
state: AppState,
|
||||||
cfg: Settings,
|
cfg: Settings,
|
||||||
|
transport: MeshTransport,
|
||||||
) -> Callable[[object], Awaitable[None]]:
|
) -> Callable[[object], Awaitable[None]]:
|
||||||
"""Return an ``on_dm(event)`` closure with all collaborators bound."""
|
"""Return an ``on_dm(event)`` closure with all collaborators bound."""
|
||||||
|
|
||||||
# One lock per sender so a burst of messages from the same peer is processed
|
# One lock per sender so a burst of messages from the same peer is processed
|
||||||
# serially while different peers stay independent.
|
# serially while different peers stay independent.
|
||||||
locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||||
contacts_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
async def on_dm(event) -> None:
|
async def on_dm(event) -> None:
|
||||||
data = event.payload or {}
|
data = event.payload or {}
|
||||||
@@ -48,7 +45,7 @@ def build_dm_handler(
|
|||||||
if not prefix or not text:
|
if not prefix or not text:
|
||||||
return
|
return
|
||||||
|
|
||||||
contact = await resolve_contact(mc, prefix, contacts_lock)
|
contact = await transport.resolve_contact(prefix)
|
||||||
if contact is None:
|
if contact is None:
|
||||||
log.info("ignoring DM from unknown sender %s (no contact after refresh)", prefix)
|
log.info("ignoring DM from unknown sender %s (no contact after refresh)", prefix)
|
||||||
return
|
return
|
||||||
@@ -98,7 +95,7 @@ def build_dm_handler(
|
|||||||
log.exception("LLM call failed for %s", public_key[:12])
|
log.exception("LLM call failed for %s", public_key[:12])
|
||||||
return
|
return
|
||||||
|
|
||||||
delivered = await send_chunked(mc, contact, reply, cfg.message.max_bytes)
|
delivered = await transport.send_chunked(contact, reply, cfg.message.max_bytes)
|
||||||
if not delivered:
|
if not delivered:
|
||||||
# Nothing made it onto the radio; don't persist anything
|
# Nothing made it onto the radio; don't persist anything
|
||||||
return
|
return
|
||||||
|
|||||||
+96
-8
@@ -2,8 +2,19 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from .tools import ToolRegistry
|
||||||
|
|
||||||
|
log = logging.getLogger("lorabot")
|
||||||
|
|
||||||
|
# Hard cap on assistant <-> tool round-trips per ``reply`` call. Stops a misbehaving
|
||||||
|
# model from looping on tool calls forever; in practice 1-2 iterations is normal.
|
||||||
|
_MAX_TOOL_ITERATIONS = 5
|
||||||
|
|
||||||
|
|
||||||
class LLMClient:
|
class LLMClient:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -15,11 +26,13 @@ class LLMClient:
|
|||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
tools: ToolRegistry | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, timeout=timeout)
|
self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, timeout=timeout)
|
||||||
self._model = model
|
self._model = model
|
||||||
self._system_prompt = system_prompt
|
self._system_prompt = system_prompt
|
||||||
self._temperature = temperature
|
self._temperature = temperature
|
||||||
|
self._tools = tools
|
||||||
|
|
||||||
async def reply(self, history: list[dict[str, str]], *, thinking: bool = False) -> str:
|
async def reply(self, history: list[dict[str, str]], *, thinking: bool = False) -> str:
|
||||||
"""Send the system prompt + ``history`` and return the assistant's text.
|
"""Send the system prompt + ``history`` and return the assistant's text.
|
||||||
@@ -27,18 +40,93 @@ class LLMClient:
|
|||||||
``thinking`` toggles the llama.cpp/server chat-template kwarg ``enable_thinking``,
|
``thinking`` toggles the llama.cpp/server chat-template kwarg ``enable_thinking``,
|
||||||
which controls whether the model prepends its hidden reasoning turn (Gemma-style).
|
which controls whether the model prepends its hidden reasoning turn (Gemma-style).
|
||||||
Passed via OpenAI SDK ``extra_body`` since it is not part of the standard schema.
|
Passed via OpenAI SDK ``extra_body`` since it is not part of the standard schema.
|
||||||
|
|
||||||
|
If a ``ToolRegistry`` is wired in, the model may emit ``tool_calls``; we
|
||||||
|
dispatch them, append the results, and re-call until the model produces
|
||||||
|
plain content (or the iteration cap kicks in).
|
||||||
"""
|
"""
|
||||||
messages: list[dict[str, str]] = [
|
# Local working copy: we'll mutate this with assistant/tool turns across
|
||||||
|
# the tool loop without touching the persisted DB history.
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
{"role": "system", "content": self._system_prompt},
|
{"role": "system", "content": self._system_prompt},
|
||||||
*history,
|
*history,
|
||||||
]
|
]
|
||||||
resp = await self._client.chat.completions.create(
|
|
||||||
model=self._model,
|
# Build the tool spec list once. ``None`` (not ``[]``) means "don't send
|
||||||
messages=messages,
|
# the tools field at all" — some servers reject an empty list.
|
||||||
temperature=self._temperature,
|
tool_specs = self._tools.specs() if self._tools else None
|
||||||
extra_body={"chat_template_kwargs": {"enable_thinking": thinking}},
|
|
||||||
)
|
for iteration in range(_MAX_TOOL_ITERATIONS):
|
||||||
return (resp.choices[0].message.content or "").strip()
|
# Only attach ``tools`` when we actually have any registered.
|
||||||
|
kwargs: dict[str, Any] = {}
|
||||||
|
if tool_specs:
|
||||||
|
kwargs["tools"] = tool_specs
|
||||||
|
|
||||||
|
log.debug(
|
||||||
|
"LLM request: model=%s iter=%d msgs=%d tools=%d thinking=%s",
|
||||||
|
self._model,
|
||||||
|
iteration + 1,
|
||||||
|
len(messages),
|
||||||
|
len(tool_specs) if tool_specs else 0,
|
||||||
|
thinking,
|
||||||
|
)
|
||||||
|
resp = await self._client.chat.completions.create(
|
||||||
|
model=self._model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=self._temperature,
|
||||||
|
extra_body={"chat_template_kwargs": {"enable_thinking": thinking}},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
msg = resp.choices[0].message
|
||||||
|
|
||||||
|
# No tool calls → this is the final text answer; return it.
|
||||||
|
# Also bail if tools aren't even configured: nothing to dispatch to.
|
||||||
|
if not msg.tool_calls or self._tools is None:
|
||||||
|
return (msg.content or "").strip()
|
||||||
|
|
||||||
|
# Visible at INFO so a normal log lets you see when the model
|
||||||
|
# actually reached for a tool, and which one(s).
|
||||||
|
log.info(
|
||||||
|
"LLM requested tool calls (iter %d): %s",
|
||||||
|
iteration + 1,
|
||||||
|
[tc.function.name for tc in msg.tool_calls],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Echo the assistant's tool-calling turn back into ``messages`` so the
|
||||||
|
# next round-trip sees it. The OpenAI protocol requires this exact
|
||||||
|
# shape (role=assistant + tool_calls list) before role=tool entries.
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": msg.content,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": tc.function.arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tc in msg.tool_calls
|
||||||
|
],
|
||||||
|
})
|
||||||
|
|
||||||
|
# Run each tool the model asked for and append its result. Errors are
|
||||||
|
# surfaced as plain strings inside the registry so the model can see
|
||||||
|
# them and recover (e.g. retry with a different argument).
|
||||||
|
for tc in msg.tool_calls:
|
||||||
|
result = await self._tools.dispatch(
|
||||||
|
tc.function.name, tc.function.arguments
|
||||||
|
)
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tc.id,
|
||||||
|
"content": result,
|
||||||
|
})
|
||||||
|
# Loop continues: re-call the model with the tool results in context.
|
||||||
|
|
||||||
|
log.warning("LLM tool loop exceeded %d iterations", _MAX_TOOL_ITERATIONS)
|
||||||
|
return "(tool loop limit exceeded)"
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
async def aclose(self) -> None:
|
||||||
await self._client.close()
|
await self._client.close()
|
||||||
|
|||||||
@@ -0,0 +1,29 @@
|
|||||||
|
"""LLM tool calling: registry + built-in tools."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .base import Tool, ToolRegistry
|
||||||
|
from .weather import WeatherTool
|
||||||
|
from .web import FetchUrlTool, WebSearchTool
|
||||||
|
|
||||||
|
|
||||||
|
def build_default_registry(*, tavily_api_key: str = "") -> ToolRegistry:
|
||||||
|
"""Build the default registry. Tavily-backed tools are only registered
|
||||||
|
when an API key is configured — without one, the bot silently runs with
|
||||||
|
just the offline tools."""
|
||||||
|
reg = ToolRegistry()
|
||||||
|
reg.register(WeatherTool())
|
||||||
|
if tavily_api_key:
|
||||||
|
reg.register(WebSearchTool(api_key=tavily_api_key))
|
||||||
|
reg.register(FetchUrlTool(api_key=tavily_api_key))
|
||||||
|
return reg
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Tool",
|
||||||
|
"ToolRegistry",
|
||||||
|
"WeatherTool",
|
||||||
|
"WebSearchTool",
|
||||||
|
"FetchUrlTool",
|
||||||
|
"build_default_registry",
|
||||||
|
]
|
||||||
@@ -0,0 +1,90 @@
|
|||||||
|
"""Tool abstraction and registry for OpenAI-style function calling."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
log = logging.getLogger("lorabot")
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_for_log(s: str, max_len: int = 200) -> str:
|
||||||
|
"""Cap a log field so a giant tool argument doesn't blow up log lines."""
|
||||||
|
return s if len(s) <= max_len else s[:max_len] + "…"
|
||||||
|
|
||||||
|
|
||||||
|
class Tool(ABC):
|
||||||
|
"""Single LLM-callable function. Subclasses set ``name``/``description``/
|
||||||
|
``parameters`` (JSON Schema) and implement ``run``."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
parameters: dict[str, Any]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def run(self, **kwargs: Any) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def spec(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"parameters": self.parameters,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRegistry:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._tools: dict[str, Tool] = {}
|
||||||
|
|
||||||
|
def register(self, tool: Tool) -> None:
|
||||||
|
self._tools[tool.name] = tool
|
||||||
|
|
||||||
|
def specs(self) -> list[dict[str, Any]]:
|
||||||
|
return [t.spec() for t in self._tools.values()]
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return bool(self._tools)
|
||||||
|
|
||||||
|
async def dispatch(self, name: str, arguments_json: str) -> str:
|
||||||
|
"""Run the named tool with JSON-encoded arguments. Errors are returned
|
||||||
|
as plain strings so the LLM can see and react to them."""
|
||||||
|
tool = self._tools.get(name)
|
||||||
|
if tool is None:
|
||||||
|
log.warning("tool call rejected: unknown tool %r", name)
|
||||||
|
return f"error: unknown tool {name!r}"
|
||||||
|
try:
|
||||||
|
args = json.loads(arguments_json) if arguments_json else {}
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
log.warning("tool %s: bad arguments JSON: %s", name, exc)
|
||||||
|
return f"error: bad arguments JSON ({exc})"
|
||||||
|
|
||||||
|
# Log entry + exit so failures are easy to attribute and we get a
|
||||||
|
# rough timing picture per tool. Args are truncated to keep logs sane.
|
||||||
|
log.info("tool %s called: %s", name, _truncate_for_log(arguments_json or "{}"))
|
||||||
|
started = time.monotonic()
|
||||||
|
try:
|
||||||
|
result = await tool.run(**args)
|
||||||
|
except Exception as exc:
|
||||||
|
elapsed_ms = (time.monotonic() - started) * 1000
|
||||||
|
log.exception("tool %s failed after %.0f ms", name, elapsed_ms)
|
||||||
|
return f"error: {exc}"
|
||||||
|
elapsed_ms = (time.monotonic() - started) * 1000
|
||||||
|
log.info("tool %s ok in %.0f ms (%d chars)", name, elapsed_ms, len(result))
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
for tool in self._tools.values():
|
||||||
|
try:
|
||||||
|
await tool.aclose()
|
||||||
|
except Exception:
|
||||||
|
log.exception("aclose failed for tool %s", tool.name)
|
||||||
@@ -0,0 +1,201 @@
|
|||||||
|
"""Weather tool backed by Open-Meteo (no API key).
|
||||||
|
|
||||||
|
Returns the current observation, and optionally a multi-day daily forecast
|
||||||
|
(min/max temperature + weather code per day, in the location's local timezone).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from .base import Tool
|
||||||
|
|
||||||
|
GEOCODE_URL = "https://geocoding-api.open-meteo.com/v1/search"
|
||||||
|
FORECAST_URL = "https://api.open-meteo.com/v1/forecast"
|
||||||
|
|
||||||
|
# Hard upper bound on daily forecast length. Open-Meteo allows up to 16 days but
|
||||||
|
# bandwidth and model context favor keeping replies compact.
|
||||||
|
_MAX_FORECAST_DAYS = 7
|
||||||
|
|
||||||
|
# WMO weather interpretation codes — compact summaries to keep LoRa replies short.
|
||||||
|
_WMO: dict[int, str] = {
|
||||||
|
0: "clear",
|
||||||
|
1: "mostly clear", 2: "partly cloudy", 3: "overcast",
|
||||||
|
45: "fog", 48: "rime fog",
|
||||||
|
51: "light drizzle", 53: "drizzle", 55: "heavy drizzle",
|
||||||
|
56: "freezing drizzle", 57: "freezing drizzle",
|
||||||
|
61: "light rain", 63: "rain", 65: "heavy rain",
|
||||||
|
66: "freezing rain", 67: "freezing rain",
|
||||||
|
71: "light snow", 73: "snow", 75: "heavy snow", 77: "snow grains",
|
||||||
|
80: "rain showers", 81: "rain showers", 82: "violent showers",
|
||||||
|
85: "snow showers", 86: "heavy snow showers",
|
||||||
|
95: "thunderstorm", 96: "thunderstorm w/ hail", 99: "thunderstorm w/ hail",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class WeatherTool(Tool):
|
||||||
|
name = "get_weather"
|
||||||
|
description = (
|
||||||
|
"Current weather conditions for a place, optionally followed by a daily "
|
||||||
|
"forecast. Accepts a location name (e.g. 'Berlin', 'San Francisco, US') "
|
||||||
|
"or 'lat,lon' decimal coordinates. Set forecast_days to include the next "
|
||||||
|
"N days (today inclusive); 0 returns only the current observation."
|
||||||
|
)
|
||||||
|
parameters = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City/place name, or 'lat,lon' decimal coordinates.",
|
||||||
|
},
|
||||||
|
"forecast_days": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": (
|
||||||
|
f"Number of daily-forecast days to include (0–{_MAX_FORECAST_DAYS}, "
|
||||||
|
"today inclusive). Defaults to 0 (current weather only)."
|
||||||
|
),
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": _MAX_FORECAST_DAYS,
|
||||||
|
"default": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._session: aiohttp.ClientSession | None = None
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
if self._session is not None and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
async def _get_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._session is None or self._session.closed:
|
||||||
|
self._session = aiohttp.ClientSession(
|
||||||
|
timeout=aiohttp.ClientTimeout(total=10)
|
||||||
|
)
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def run(self, location: str = "", forecast_days: int = 0) -> str:
|
||||||
|
location = (location or "").strip()
|
||||||
|
if not location:
|
||||||
|
return "error: location is required"
|
||||||
|
|
||||||
|
# Clamp defensively: model may ignore the schema bounds.
|
||||||
|
try:
|
||||||
|
forecast_days = max(0, min(int(forecast_days), _MAX_FORECAST_DAYS))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
forecast_days = 0
|
||||||
|
|
||||||
|
coords = _parse_latlon(location)
|
||||||
|
if coords is not None:
|
||||||
|
lat, lon = coords
|
||||||
|
label = f"{lat:.3f},{lon:.3f}"
|
||||||
|
else:
|
||||||
|
geo = await self._geocode(location)
|
||||||
|
if geo is None:
|
||||||
|
return f"error: could not find location {location!r}"
|
||||||
|
lat, lon, label = geo
|
||||||
|
|
||||||
|
return await self._forecast(lat, lon, label, forecast_days)
|
||||||
|
|
||||||
|
async def _geocode(self, name: str) -> tuple[float, float, str] | None:
|
||||||
|
session = await self._get_session()
|
||||||
|
async with session.get(
|
||||||
|
GEOCODE_URL,
|
||||||
|
params={"name": name, "count": "1", "format": "json"},
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = await resp.json()
|
||||||
|
results = data.get("results") or []
|
||||||
|
if not results:
|
||||||
|
return None
|
||||||
|
r = results[0]
|
||||||
|
bits = [r.get("name") or name]
|
||||||
|
if cc := r.get("country_code"):
|
||||||
|
bits.append(str(cc))
|
||||||
|
return float(r["latitude"]), float(r["longitude"]), ", ".join(bits)
|
||||||
|
|
||||||
|
async def _forecast(
|
||||||
|
self, lat: float, lon: float, label: str, forecast_days: int
|
||||||
|
) -> str:
|
||||||
|
# Build the request. ``timezone=auto`` makes daily aggregations align to
|
||||||
|
# the location's local calendar day instead of UTC.
|
||||||
|
params: dict[str, str] = {
|
||||||
|
"latitude": str(lat),
|
||||||
|
"longitude": str(lon),
|
||||||
|
"current": "temperature_2m,weather_code,wind_speed_10m",
|
||||||
|
"timezone": "auto",
|
||||||
|
}
|
||||||
|
if forecast_days > 0:
|
||||||
|
params["daily"] = "temperature_2m_max,temperature_2m_min,weather_code"
|
||||||
|
params["forecast_days"] = str(forecast_days)
|
||||||
|
|
||||||
|
session = await self._get_session()
|
||||||
|
async with session.get(FORECAST_URL, params=params) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = await resp.json()
|
||||||
|
|
||||||
|
# Current observation line (always present).
|
||||||
|
cur = data.get("current") or {}
|
||||||
|
temp = cur.get("temperature_2m")
|
||||||
|
code = cur.get("weather_code")
|
||||||
|
wind = cur.get("wind_speed_10m")
|
||||||
|
cond = _WMO.get(int(code), f"code {code}") if code is not None else "?"
|
||||||
|
cur_bits = [f"{label} now"]
|
||||||
|
if temp is not None:
|
||||||
|
cur_bits.append(f"{temp}°C")
|
||||||
|
cur_bits.append(cond)
|
||||||
|
if wind is not None:
|
||||||
|
cur_bits.append(f"wind {wind} km/h")
|
||||||
|
lines = [", ".join(cur_bits)]
|
||||||
|
|
||||||
|
# Daily forecast block (one line per day): "Mon 8-17°C, partly cloudy".
|
||||||
|
# Open-Meteo returns parallel arrays under ``daily`` keyed by date.
|
||||||
|
if forecast_days > 0:
|
||||||
|
daily = data.get("daily") or {}
|
||||||
|
days = daily.get("time") or []
|
||||||
|
tmax = daily.get("temperature_2m_max") or []
|
||||||
|
tmin = daily.get("temperature_2m_min") or []
|
||||||
|
codes = daily.get("weather_code") or []
|
||||||
|
for i, day_str in enumerate(days):
|
||||||
|
day_label = _short_day(day_str)
|
||||||
|
hi = tmax[i] if i < len(tmax) else None
|
||||||
|
lo = tmin[i] if i < len(tmin) else None
|
||||||
|
day_code = codes[i] if i < len(codes) else None
|
||||||
|
day_cond = (
|
||||||
|
_WMO.get(int(day_code), f"code {day_code}")
|
||||||
|
if day_code is not None
|
||||||
|
else "?"
|
||||||
|
)
|
||||||
|
if lo is not None and hi is not None:
|
||||||
|
lines.append(f"{day_label} {lo}–{hi}°C, {day_cond}")
|
||||||
|
else:
|
||||||
|
lines.append(f"{day_label} {day_cond}")
|
||||||
|
|
||||||
|
return "; ".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _short_day(iso_date: str) -> str:
|
||||||
|
"""Render an ISO date (``YYYY-MM-DD``) as a short weekday (``Mon``).
|
||||||
|
Falls back to the raw string if parsing fails."""
|
||||||
|
try:
|
||||||
|
return date.fromisoformat(iso_date).strftime("%a")
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return iso_date
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_latlon(text: str) -> tuple[float, float] | None:
|
||||||
|
if "," not in text:
|
||||||
|
return None
|
||||||
|
a, b = text.split(",", 1)
|
||||||
|
try:
|
||||||
|
lat = float(a.strip())
|
||||||
|
lon = float(b.strip())
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
if not (-90 <= lat <= 90 and -180 <= lon <= 180):
|
||||||
|
return None
|
||||||
|
return lat, lon
|
||||||
@@ -0,0 +1,193 @@
|
|||||||
|
"""Web search + page-extraction tools backed by Tavily (https://tavily.com).
|
||||||
|
|
||||||
|
Exposes two tools to the LLM:
|
||||||
|
|
||||||
|
- ``web_search`` — query Tavily for a list of {title, url, snippet} hits.
|
||||||
|
- ``fetch_url`` — pull the readable body text of a single page via Tavily's
|
||||||
|
``/extract`` endpoint (so we don't have to parse HTML ourselves).
|
||||||
|
|
||||||
|
Tavily handles all the messy bits — URL fetching, redirects, boilerplate
|
||||||
|
stripping, robots — so this module is just a thin HTTP wrapper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from .base import Tool
|
||||||
|
|
||||||
|
log = logging.getLogger("lorabot")
|
||||||
|
|
||||||
|
_SEARCH_URL = "https://api.tavily.com/search"
|
||||||
|
_EXTRACT_URL = "https://api.tavily.com/extract"
|
||||||
|
|
||||||
|
# Sane defaults for the search tool. Upper bound keeps the tool message small
|
||||||
|
# enough that the LoRa-side LLM can still reason over it.
|
||||||
|
_DEFAULT_RESULTS = 5
|
||||||
|
_MAX_RESULTS = 10
|
||||||
|
|
||||||
|
# Cap extracted-page body size so a giant article doesn't blow the model's
|
||||||
|
# context. ~4 KB is plenty to summarize from for a 180-byte LoRa reply.
|
||||||
|
_MAX_EXTRACT_CHARS = 4000
|
||||||
|
|
||||||
|
|
||||||
|
class _TavilyHTTP:
|
||||||
|
"""Tiny per-tool HTTP helper. Each tool instance owns one ``aiohttp``
|
||||||
|
session; lazily created on first request, closed via ``aclose``."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str) -> None:
|
||||||
|
self._api_key = api_key
|
||||||
|
self._session: aiohttp.ClientSession | None = None
|
||||||
|
|
||||||
|
async def post(self, url: str, payload: dict) -> dict:
|
||||||
|
if self._session is None or self._session.closed:
|
||||||
|
self._session = aiohttp.ClientSession(
|
||||||
|
timeout=aiohttp.ClientTimeout(total=20)
|
||||||
|
)
|
||||||
|
headers = {"Authorization": f"Bearer {self._api_key}"}
|
||||||
|
log.debug("tavily POST %s payload-keys=%s", url, list(payload))
|
||||||
|
async with self._session.post(url, json=payload, headers=headers) as resp:
|
||||||
|
# On non-2xx, capture the response body before raising so the log
|
||||||
|
# tells us *why* (auth, quota, bad query) — raise_for_status alone
|
||||||
|
# would only show the status code.
|
||||||
|
if resp.status >= 400:
|
||||||
|
body = await resp.text()
|
||||||
|
log.warning(
|
||||||
|
"tavily POST %s -> %d: %s", url, resp.status, body[:200]
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.json()
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
if self._session is not None and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearchTool(Tool):
|
||||||
|
name = "web_search"
|
||||||
|
description = (
|
||||||
|
"Search the web. Returns titles, URLs and SHORT snippets — useful for "
|
||||||
|
"headlines, finding sources and orientation. Snippets are a few "
|
||||||
|
"sentences at most and are often truncated mid-thought. "
|
||||||
|
"If the user asked for any of: specific numbers, prices, dates, "
|
||||||
|
"percentages, technical specs, exact quotes, step-by-step "
|
||||||
|
"instructions, or 'current'/'latest' data — you MUST also call "
|
||||||
|
"fetch_url on the most relevant result before answering."
|
||||||
|
)
|
||||||
|
parameters = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search query.",
|
||||||
|
},
|
||||||
|
"max_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": (
|
||||||
|
f"How many results to return (1–{_MAX_RESULTS}). "
|
||||||
|
f"Default {_DEFAULT_RESULTS}."
|
||||||
|
),
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": _MAX_RESULTS,
|
||||||
|
"default": _DEFAULT_RESULTS,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, api_key: str) -> None:
|
||||||
|
self._http = _TavilyHTTP(api_key)
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
await self._http.aclose()
|
||||||
|
|
||||||
|
async def run(self, query: str = "", max_results: int = _DEFAULT_RESULTS) -> str:
|
||||||
|
query = (query or "").strip()
|
||||||
|
if not query:
|
||||||
|
return "error: query is required"
|
||||||
|
|
||||||
|
# Defensive clamp — model may ignore the JSON-schema bounds.
|
||||||
|
try:
|
||||||
|
n = max(1, min(int(max_results), _MAX_RESULTS))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
n = _DEFAULT_RESULTS
|
||||||
|
|
||||||
|
data = await self._http.post(
|
||||||
|
_SEARCH_URL,
|
||||||
|
{
|
||||||
|
"query": query,
|
||||||
|
"max_results": n,
|
||||||
|
# "basic" is fast and cheap; "advanced" costs more credits but
|
||||||
|
# returns higher-quality snippets. Stick with basic for now.
|
||||||
|
"search_depth": "basic",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
results = data.get("results") or []
|
||||||
|
if not results:
|
||||||
|
return "no results"
|
||||||
|
|
||||||
|
# Format as a numbered list — readable for the model and compact enough
|
||||||
|
# to fit comfortably in context alongside the rest of the conversation.
|
||||||
|
lines: list[str] = []
|
||||||
|
for i, r in enumerate(results, 1):
|
||||||
|
title = (r.get("title") or "").strip() or "(no title)"
|
||||||
|
url = r.get("url") or ""
|
||||||
|
snippet = (r.get("content") or "").strip()
|
||||||
|
lines.append(f"{i}. {title} — {url}\n {snippet}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
class FetchUrlTool(Tool):
|
||||||
|
name = "fetch_url"
|
||||||
|
description = (
|
||||||
|
"Fetch the full readable body text of a single web page. Call this "
|
||||||
|
"after web_search whenever the user asked for specifics — exact "
|
||||||
|
"numbers, prices, dates, quotes, technical details, or current data — "
|
||||||
|
"even if a snippet looks like it might answer. Snippets are short and "
|
||||||
|
"often misleading; the full page is the source of truth. Long pages "
|
||||||
|
"are truncated to keep context manageable."
|
||||||
|
)
|
||||||
|
parameters = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Absolute URL of the page to fetch.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["url"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, api_key: str) -> None:
|
||||||
|
self._http = _TavilyHTTP(api_key)
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
await self._http.aclose()
|
||||||
|
|
||||||
|
async def run(self, url: str = "") -> str:
|
||||||
|
url = (url or "").strip()
|
||||||
|
if not url:
|
||||||
|
return "error: url is required"
|
||||||
|
|
||||||
|
data = await self._http.post(
|
||||||
|
_EXTRACT_URL,
|
||||||
|
{"urls": [url], "extract_depth": "basic"},
|
||||||
|
)
|
||||||
|
results = data.get("results") or []
|
||||||
|
if not results:
|
||||||
|
# Tavily reports unreachable / blocked pages here instead of in
|
||||||
|
# ``results``. Surface the reason so the model can react.
|
||||||
|
failed = data.get("failed_results") or []
|
||||||
|
if failed:
|
||||||
|
reason = (failed[0] or {}).get("error", "unknown")
|
||||||
|
return f"error: extract failed ({reason})"
|
||||||
|
return "error: no content extracted"
|
||||||
|
|
||||||
|
body = (results[0].get("raw_content") or "").strip()
|
||||||
|
if not body:
|
||||||
|
return "error: page returned empty content"
|
||||||
|
if len(body) > _MAX_EXTRACT_CHARS:
|
||||||
|
body = body[:_MAX_EXTRACT_CHARS] + " …[truncated]"
|
||||||
|
return body
|
||||||
+84
-42
@@ -1,9 +1,10 @@
|
|||||||
"""MeshCore-side helpers: contact resolution and chunked sending."""
|
"""MeshCore-side transport: contact resolution and reliable chunked sending."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from meshcore import EventType, MeshCore
|
from meshcore import EventType, MeshCore
|
||||||
|
|
||||||
@@ -11,48 +12,89 @@ from .messages import split_to_bytes
|
|||||||
|
|
||||||
log = logging.getLogger("lorabot")
|
log = logging.getLogger("lorabot")
|
||||||
|
|
||||||
|
class MeshTransport:
|
||||||
|
"""Owns the MeshCore connection reference; handles contact resolution and
|
||||||
|
reliable chunked message delivery.
|
||||||
|
|
||||||
async def resolve_contact(mc: MeshCore, prefix: str, lock: asyncio.Lock):
|
Create one instance per MeshCore connection and pass it to build_dm_handler.
|
||||||
"""Look up a contact by pubkey prefix; re-pull contacts from the device on miss."""
|
"""
|
||||||
contact = mc.get_contact_by_key_prefix(prefix)
|
|
||||||
if contact is not None:
|
def __init__(self, mc: MeshCore, ack_timeout: float = 30.0, send_retries: int = 1) -> None:
|
||||||
return contact
|
self._mc = mc
|
||||||
async with lock:
|
self._ack_timeout = ack_timeout
|
||||||
contact = mc.get_contact_by_key_prefix(prefix)
|
self._send_retries = send_retries
|
||||||
|
self._contacts_lock = asyncio.Lock()
|
||||||
|
self._send_locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||||
|
|
||||||
|
async def resolve_contact(self, prefix: str):
|
||||||
|
"""Look up a contact by pubkey prefix; re-pull contacts from the device on miss."""
|
||||||
|
contact = self._mc.get_contact_by_key_prefix(prefix)
|
||||||
if contact is not None:
|
if contact is not None:
|
||||||
return contact
|
return contact
|
||||||
|
async with self._contacts_lock:
|
||||||
|
contact = self._mc.get_contact_by_key_prefix(prefix)
|
||||||
|
if contact is not None:
|
||||||
|
return contact
|
||||||
|
try:
|
||||||
|
await self._mc.commands.get_contacts()
|
||||||
|
except Exception:
|
||||||
|
log.exception("get_contacts refresh failed")
|
||||||
|
return None
|
||||||
|
return self._mc.get_contact_by_key_prefix(prefix)
|
||||||
|
|
||||||
|
async def send_chunked(
|
||||||
|
self,
|
||||||
|
contact,
|
||||||
|
text: str,
|
||||||
|
max_bytes: int,
|
||||||
|
max_chunks: int = 2,
|
||||||
|
) -> str:
|
||||||
|
"""Split ``text`` and send chunks in order, waiting for ACK before each next one.
|
||||||
|
|
||||||
|
Returns the concatenation of chunks that were successfully delivered.
|
||||||
|
"""
|
||||||
|
chunks = split_to_bytes(text, max_bytes, max_chunks=max_chunks)
|
||||||
|
pk = contact["public_key"]
|
||||||
|
pk_short = pk[:12]
|
||||||
|
|
||||||
|
planned_bytes = sum(len(c.encode("utf-8")) for c in chunks)
|
||||||
|
dropped = len(text.encode("utf-8")) - planned_bytes
|
||||||
|
if dropped > 0:
|
||||||
|
log.info("reply to %s split into %d chunks, dropped %d trailing bytes",
|
||||||
|
pk_short, len(chunks), dropped)
|
||||||
|
|
||||||
|
sent: list[str] = []
|
||||||
|
async with self._send_locks[pk]:
|
||||||
|
for i, chunk in enumerate(chunks, 1):
|
||||||
|
log.info("reply to %s (%d/%d, %d bytes): %s",
|
||||||
|
pk_short, i, len(chunks), len(chunk.encode("utf-8")), chunk)
|
||||||
|
if not await self._send_chunk(contact, chunk):
|
||||||
|
break
|
||||||
|
sent.append(chunk)
|
||||||
|
return "".join(sent)
|
||||||
|
|
||||||
|
async def _send_chunk(self, contact, chunk: str) -> bool:
|
||||||
|
"""Send one chunk, retrying once on failure."""
|
||||||
|
for attempt in range(self._send_retries + 1):
|
||||||
|
if await self._attempt_send(contact, chunk):
|
||||||
|
return True
|
||||||
|
if attempt < self._send_retries:
|
||||||
|
log.info("retrying chunk for %s (attempt %d/%d)",
|
||||||
|
contact["public_key"][:12], attempt + 2, self._send_retries + 1)
|
||||||
|
log.error("chunk delivery failed for %s after %d attempts",
|
||||||
|
contact["public_key"][:12], self._send_retries + 1)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _attempt_send(self, contact, chunk: str) -> bool:
|
||||||
|
"""One send attempt. Returns True on ACK, False on device error or timeout."""
|
||||||
try:
|
try:
|
||||||
await mc.commands.get_contacts()
|
result = await asyncio.wait_for(
|
||||||
except Exception:
|
self._mc.commands.send_msg(contact, chunk), timeout=self._ack_timeout
|
||||||
log.exception("get_contacts refresh failed")
|
)
|
||||||
return None
|
if result.type != EventType.ERROR:
|
||||||
return mc.get_contact_by_key_prefix(prefix)
|
return True
|
||||||
|
log.warning("send_msg error for %s: %s", contact["public_key"][:12], result.payload)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
async def send_chunked(mc: MeshCore, contact, text: str, max_bytes: int, max_chunks: int = 2) -> str:
|
log.warning("ACK timeout for %s after %.1fs",
|
||||||
"""Split ``text`` into byte-budgeted chunks and send them in order.
|
contact["public_key"][:12], self._ack_timeout)
|
||||||
|
return False
|
||||||
Stops sending on the first transport error. Returns the concatenation of the
|
|
||||||
chunks that were actually accepted by the device, so the caller records the
|
|
||||||
truth — not what we hoped to send.
|
|
||||||
"""
|
|
||||||
chunks = split_to_bytes(text, max_bytes, max_chunks=max_chunks)
|
|
||||||
pk_short = contact["public_key"][:12]
|
|
||||||
planned_bytes = sum(len(c.encode("utf-8")) for c in chunks)
|
|
||||||
|
|
||||||
dropped = len(text.encode("utf-8")) - planned_bytes
|
|
||||||
if dropped > 0:
|
|
||||||
log.info("reply to %s split into %d chunks, dropped %d trailing bytes",
|
|
||||||
pk_short, len(chunks), dropped)
|
|
||||||
|
|
||||||
sent: list[str] = []
|
|
||||||
for i, chunk in enumerate(chunks, 1):
|
|
||||||
log.info("reply to %s (%d/%d, %d bytes): %s",
|
|
||||||
pk_short, i, len(chunks), len(chunk.encode("utf-8")), chunk)
|
|
||||||
result = await mc.commands.send_msg(contact, chunk)
|
|
||||||
if result.type == EventType.ERROR:
|
|
||||||
log.error("send_msg failed for %s chunk %d/%d: %s",
|
|
||||||
pk_short, i, len(chunks), result.payload)
|
|
||||||
break
|
|
||||||
sent.append(chunk)
|
|
||||||
return "".join(sent)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user