Merge remote-tracking branch 'origin/main'
This commit is contained in:
@@ -155,6 +155,7 @@ class AgentLoop:
|
|||||||
context_window_tokens=context_window_tokens,
|
context_window_tokens=context_window_tokens,
|
||||||
build_messages=self.context.build_messages,
|
build_messages=self.context.build_messages,
|
||||||
get_tool_definitions=self.tools.get_definitions,
|
get_tool_definitions=self.tools.get_definitions,
|
||||||
|
max_completion_tokens=provider.generation.max_tokens,
|
||||||
)
|
)
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
|
|
||||||
|
|||||||
@@ -228,6 +228,8 @@ class MemoryConsolidator:
|
|||||||
|
|
||||||
_MAX_CONSOLIDATION_ROUNDS = 5
|
_MAX_CONSOLIDATION_ROUNDS = 5
|
||||||
|
|
||||||
|
_SAFETY_BUFFER = 1024 # extra headroom for tokenizer estimation drift
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
workspace: Path,
|
workspace: Path,
|
||||||
@@ -237,12 +239,14 @@ class MemoryConsolidator:
|
|||||||
context_window_tokens: int,
|
context_window_tokens: int,
|
||||||
build_messages: Callable[..., list[dict[str, Any]]],
|
build_messages: Callable[..., list[dict[str, Any]]],
|
||||||
get_tool_definitions: Callable[[], list[dict[str, Any]]],
|
get_tool_definitions: Callable[[], list[dict[str, Any]]],
|
||||||
|
max_completion_tokens: int = 4096,
|
||||||
):
|
):
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.model = model
|
self.model = model
|
||||||
self.sessions = sessions
|
self.sessions = sessions
|
||||||
self.context_window_tokens = context_window_tokens
|
self.context_window_tokens = context_window_tokens
|
||||||
|
self.max_completion_tokens = max_completion_tokens
|
||||||
self._build_messages = build_messages
|
self._build_messages = build_messages
|
||||||
self._get_tool_definitions = get_tool_definitions
|
self._get_tool_definitions = get_tool_definitions
|
||||||
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
||||||
@@ -356,17 +360,22 @@ class MemoryConsolidator:
|
|||||||
return await self._archive_messages_locked(session, snapshot)
|
return await self._archive_messages_locked(session, snapshot)
|
||||||
|
|
||||||
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
||||||
"""Loop: archive old messages until prompt fits within half the context window."""
|
"""Loop: archive old messages until prompt fits within safe budget.
|
||||||
|
|
||||||
|
The budget reserves space for completion tokens and a safety buffer
|
||||||
|
so the LLM request never exceeds the context window.
|
||||||
|
"""
|
||||||
if not session.messages or self.context_window_tokens <= 0:
|
if not session.messages or self.context_window_tokens <= 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
lock = self.get_lock(session.key)
|
lock = self.get_lock(session.key)
|
||||||
async with lock:
|
async with lock:
|
||||||
target = self.context_window_tokens // 2
|
budget = self.context_window_tokens - self.max_completion_tokens - self._SAFETY_BUFFER
|
||||||
|
target = budget // 2
|
||||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||||
if estimated <= 0:
|
if estimated <= 0:
|
||||||
return
|
return
|
||||||
if estimated < self.context_window_tokens:
|
if estimated < budget:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Token consolidation idle {}: {}/{} via {}",
|
"Token consolidation idle {}: {}/{} via {}",
|
||||||
session.key,
|
session.key,
|
||||||
|
|||||||
@@ -785,6 +785,7 @@ def agent(
|
|||||||
on_stream_end=renderer.on_end,
|
on_stream_end=renderer.on_end,
|
||||||
)
|
)
|
||||||
if not renderer.streamed:
|
if not renderer.streamed:
|
||||||
|
await renderer.close()
|
||||||
_print_agent_response(
|
_print_agent_response(
|
||||||
response.content if response else "",
|
response.content if response else "",
|
||||||
render_markdown=markdown,
|
render_markdown=markdown,
|
||||||
@@ -906,9 +907,13 @@ def agent(
|
|||||||
if turn_response:
|
if turn_response:
|
||||||
content, meta = turn_response[0]
|
content, meta = turn_response[0]
|
||||||
if content and not meta.get("_streamed"):
|
if content and not meta.get("_streamed"):
|
||||||
|
if renderer:
|
||||||
|
await renderer.close()
|
||||||
_print_agent_response(
|
_print_agent_response(
|
||||||
content, render_markdown=markdown, metadata=meta,
|
content, render_markdown=markdown, metadata=meta,
|
||||||
)
|
)
|
||||||
|
elif renderer and not renderer.streamed:
|
||||||
|
await renderer.close()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
_restore_terminal()
|
_restore_terminal()
|
||||||
console.print("\nGoodbye!")
|
console.print("\nGoodbye!")
|
||||||
|
|||||||
@@ -119,3 +119,10 @@ class StreamRenderer:
|
|||||||
self._start_spinner()
|
self._start_spinner()
|
||||||
else:
|
else:
|
||||||
_make_console().print()
|
_make_console().print()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Stop spinner/live without rendering a final streamed round."""
|
||||||
|
if self._live:
|
||||||
|
self._live.stop()
|
||||||
|
self._live = None
|
||||||
|
self._stop_spinner()
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ from nanobot.providers.base import LLMResponse
|
|||||||
|
|
||||||
|
|
||||||
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
|
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
|
||||||
|
from nanobot.providers.base import GenerationSettings
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.generation = GenerationSettings(max_tokens=0)
|
||||||
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
||||||
_response = LLMResponse(content="ok", tool_calls=[])
|
_response = LLMResponse(content="ok", tool_calls=[])
|
||||||
provider.chat_with_retry = AsyncMock(return_value=_response)
|
provider.chat_with_retry = AsyncMock(return_value=_response)
|
||||||
@@ -25,6 +27,7 @@ def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -
|
|||||||
context_window_tokens=context_window_tokens,
|
context_window_tokens=context_window_tokens,
|
||||||
)
|
)
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
loop.memory_consolidator._SAFETY_BUFFER = 0
|
||||||
return loop
|
return loop
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user