Merge remote-tracking branch 'origin/main'
Some checks failed
Test Suite / test (3.11) (push) Failing after 1m14s
Test Suite / test (3.12) (push) Failing after 1m7s
Test Suite / test (3.13) (push) Failing after 1m26s

This commit is contained in:
Hua
2026-03-23 12:56:03 +08:00
5 changed files with 28 additions and 3 deletions

View File

@@ -155,6 +155,7 @@ class AgentLoop:
context_window_tokens=context_window_tokens, context_window_tokens=context_window_tokens,
build_messages=self.context.build_messages, build_messages=self.context.build_messages,
get_tool_definitions=self.tools.get_definitions, get_tool_definitions=self.tools.get_definitions,
max_completion_tokens=provider.generation.max_tokens,
) )
self._register_default_tools() self._register_default_tools()

View File

@@ -228,6 +228,8 @@ class MemoryConsolidator:
_MAX_CONSOLIDATION_ROUNDS = 5 _MAX_CONSOLIDATION_ROUNDS = 5
_SAFETY_BUFFER = 1024 # extra headroom for tokenizer estimation drift
def __init__( def __init__(
self, self,
workspace: Path, workspace: Path,
@@ -237,12 +239,14 @@ class MemoryConsolidator:
context_window_tokens: int, context_window_tokens: int,
build_messages: Callable[..., list[dict[str, Any]]], build_messages: Callable[..., list[dict[str, Any]]],
get_tool_definitions: Callable[[], list[dict[str, Any]]], get_tool_definitions: Callable[[], list[dict[str, Any]]],
max_completion_tokens: int = 4096,
): ):
self.workspace = workspace self.workspace = workspace
self.provider = provider self.provider = provider
self.model = model self.model = model
self.sessions = sessions self.sessions = sessions
self.context_window_tokens = context_window_tokens self.context_window_tokens = context_window_tokens
self.max_completion_tokens = max_completion_tokens
self._build_messages = build_messages self._build_messages = build_messages
self._get_tool_definitions = get_tool_definitions self._get_tool_definitions = get_tool_definitions
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
@@ -356,17 +360,22 @@ class MemoryConsolidator:
return await self._archive_messages_locked(session, snapshot) return await self._archive_messages_locked(session, snapshot)
async def maybe_consolidate_by_tokens(self, session: Session) -> None: async def maybe_consolidate_by_tokens(self, session: Session) -> None:
"""Loop: archive old messages until prompt fits within half the context window.""" """Loop: archive old messages until prompt fits within safe budget.
The budget reserves space for completion tokens and a safety buffer
so the LLM request never exceeds the context window.
"""
if not session.messages or self.context_window_tokens <= 0: if not session.messages or self.context_window_tokens <= 0:
return return
lock = self.get_lock(session.key) lock = self.get_lock(session.key)
async with lock: async with lock:
target = self.context_window_tokens // 2 budget = self.context_window_tokens - self.max_completion_tokens - self._SAFETY_BUFFER
target = budget // 2
estimated, source = self.estimate_session_prompt_tokens(session) estimated, source = self.estimate_session_prompt_tokens(session)
if estimated <= 0: if estimated <= 0:
return return
if estimated < self.context_window_tokens: if estimated < budget:
logger.debug( logger.debug(
"Token consolidation idle {}: {}/{} via {}", "Token consolidation idle {}: {}/{} via {}",
session.key, session.key,

View File

@@ -785,6 +785,7 @@ def agent(
on_stream_end=renderer.on_end, on_stream_end=renderer.on_end,
) )
if not renderer.streamed: if not renderer.streamed:
await renderer.close()
_print_agent_response( _print_agent_response(
response.content if response else "", response.content if response else "",
render_markdown=markdown, render_markdown=markdown,
@@ -906,9 +907,13 @@ def agent(
if turn_response: if turn_response:
content, meta = turn_response[0] content, meta = turn_response[0]
if content and not meta.get("_streamed"): if content and not meta.get("_streamed"):
if renderer:
await renderer.close()
_print_agent_response( _print_agent_response(
content, render_markdown=markdown, metadata=meta, content, render_markdown=markdown, metadata=meta,
) )
elif renderer and not renderer.streamed:
await renderer.close()
except KeyboardInterrupt: except KeyboardInterrupt:
_restore_terminal() _restore_terminal()
console.print("\nGoodbye!") console.print("\nGoodbye!")

View File

@@ -119,3 +119,10 @@ class StreamRenderer:
self._start_spinner() self._start_spinner()
else: else:
_make_console().print() _make_console().print()
async def close(self) -> None:
"""Stop spinner/live without rendering a final streamed round."""
if self._live:
self._live.stop()
self._live = None
self._stop_spinner()

View File

@@ -10,8 +10,10 @@ from nanobot.providers.base import LLMResponse
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop: def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
from nanobot.providers.base import GenerationSettings
provider = MagicMock() provider = MagicMock()
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"
provider.generation = GenerationSettings(max_tokens=0)
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter") provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
_response = LLMResponse(content="ok", tool_calls=[]) _response = LLMResponse(content="ok", tool_calls=[])
provider.chat_with_retry = AsyncMock(return_value=_response) provider.chat_with_retry = AsyncMock(return_value=_response)
@@ -25,6 +27,7 @@ def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -
context_window_tokens=context_window_tokens, context_window_tokens=context_window_tokens,
) )
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
loop.memory_consolidator._SAFETY_BUFFER = 0
return loop return loop