Merge remote-tracking branch 'origin/main'
This commit is contained in:
@@ -155,6 +155,7 @@ class AgentLoop:
|
||||
context_window_tokens=context_window_tokens,
|
||||
build_messages=self.context.build_messages,
|
||||
get_tool_definitions=self.tools.get_definitions,
|
||||
max_completion_tokens=provider.generation.max_tokens,
|
||||
)
|
||||
self._register_default_tools()
|
||||
|
||||
|
||||
@@ -228,6 +228,8 @@ class MemoryConsolidator:
|
||||
|
||||
_MAX_CONSOLIDATION_ROUNDS = 5
|
||||
|
||||
_SAFETY_BUFFER = 1024 # extra headroom for tokenizer estimation drift
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
@@ -237,12 +239,14 @@ class MemoryConsolidator:
|
||||
context_window_tokens: int,
|
||||
build_messages: Callable[..., list[dict[str, Any]]],
|
||||
get_tool_definitions: Callable[[], list[dict[str, Any]]],
|
||||
max_completion_tokens: int = 4096,
|
||||
):
|
||||
self.workspace = workspace
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.sessions = sessions
|
||||
self.context_window_tokens = context_window_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self._build_messages = build_messages
|
||||
self._get_tool_definitions = get_tool_definitions
|
||||
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
||||
@@ -356,17 +360,22 @@ class MemoryConsolidator:
|
||||
return await self._archive_messages_locked(session, snapshot)
|
||||
|
||||
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
||||
"""Loop: archive old messages until prompt fits within half the context window."""
|
||||
"""Loop: archive old messages until prompt fits within safe budget.
|
||||
|
||||
The budget reserves space for completion tokens and a safety buffer
|
||||
so the LLM request never exceeds the context window.
|
||||
"""
|
||||
if not session.messages or self.context_window_tokens <= 0:
|
||||
return
|
||||
|
||||
lock = self.get_lock(session.key)
|
||||
async with lock:
|
||||
target = self.context_window_tokens // 2
|
||||
budget = self.context_window_tokens - self.max_completion_tokens - self._SAFETY_BUFFER
|
||||
target = budget // 2
|
||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||
if estimated <= 0:
|
||||
return
|
||||
if estimated < self.context_window_tokens:
|
||||
if estimated < budget:
|
||||
logger.debug(
|
||||
"Token consolidation idle {}: {}/{} via {}",
|
||||
session.key,
|
||||
|
||||
@@ -785,6 +785,7 @@ def agent(
|
||||
on_stream_end=renderer.on_end,
|
||||
)
|
||||
if not renderer.streamed:
|
||||
await renderer.close()
|
||||
_print_agent_response(
|
||||
response.content if response else "",
|
||||
render_markdown=markdown,
|
||||
@@ -906,9 +907,13 @@ def agent(
|
||||
if turn_response:
|
||||
content, meta = turn_response[0]
|
||||
if content and not meta.get("_streamed"):
|
||||
if renderer:
|
||||
await renderer.close()
|
||||
_print_agent_response(
|
||||
content, render_markdown=markdown, metadata=meta,
|
||||
)
|
||||
elif renderer and not renderer.streamed:
|
||||
await renderer.close()
|
||||
except KeyboardInterrupt:
|
||||
_restore_terminal()
|
||||
console.print("\nGoodbye!")
|
||||
|
||||
@@ -119,3 +119,10 @@ class StreamRenderer:
|
||||
self._start_spinner()
|
||||
else:
|
||||
_make_console().print()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Stop spinner/live without rendering a final streamed round."""
|
||||
if self._live:
|
||||
self._live.stop()
|
||||
self._live = None
|
||||
self._stop_spinner()
|
||||
|
||||
@@ -10,8 +10,10 @@ from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation = GenerationSettings(max_tokens=0)
|
||||
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
||||
_response = LLMResponse(content="ok", tool_calls=[])
|
||||
provider.chat_with_retry = AsyncMock(return_value=_response)
|
||||
@@ -25,6 +27,7 @@ def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -
|
||||
context_window_tokens=context_window_tokens,
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.memory_consolidator._SAFETY_BUFFER = 0
|
||||
return loop
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user