fix(cli): pause spinner cleanly before printing progress output
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
"""CLI commands for nanobot."""
|
"""CLI commands for nanobot."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager, nullcontext
|
||||||
import os
|
import os
|
||||||
import select
|
import select
|
||||||
import signal
|
import signal
|
||||||
@@ -170,6 +170,51 @@ async def _print_interactive_response(response: str, render_markdown: bool) -> N
|
|||||||
await run_in_terminal(_write)
|
await run_in_terminal(_write)
|
||||||
|
|
||||||
|
|
||||||
|
class _ThinkingSpinner:
|
||||||
|
"""Spinner wrapper with pause support for clean progress output."""
|
||||||
|
|
||||||
|
def __init__(self, enabled: bool):
|
||||||
|
self._spinner = console.status(
|
||||||
|
"[dim]nanobot is thinking...[/dim]", spinner="dots"
|
||||||
|
) if enabled else None
|
||||||
|
self._active = False
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if self._spinner:
|
||||||
|
self._spinner.start()
|
||||||
|
self._active = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *exc):
|
||||||
|
self._active = False
|
||||||
|
if self._spinner:
|
||||||
|
self._spinner.stop()
|
||||||
|
return False
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def pause(self):
|
||||||
|
"""Temporarily stop spinner while printing progress."""
|
||||||
|
if self._spinner and self._active:
|
||||||
|
self._spinner.stop()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if self._spinner and self._active:
|
||||||
|
self._spinner.start()
|
||||||
|
|
||||||
|
|
||||||
|
def _print_cli_progress_line(text: str, thinking: _ThinkingSpinner | None) -> None:
|
||||||
|
"""Print a CLI progress line, pausing the spinner if needed."""
|
||||||
|
with thinking.pause() if thinking else nullcontext():
|
||||||
|
console.print(f" [dim]↳ {text}[/dim]")
|
||||||
|
|
||||||
|
|
||||||
|
async def _print_interactive_progress_line(text: str, thinking: _ThinkingSpinner | None) -> None:
|
||||||
|
"""Print an interactive progress line, pausing the spinner if needed."""
|
||||||
|
with thinking.pause() if thinking else nullcontext():
|
||||||
|
await _print_interactive_line(text)
|
||||||
|
|
||||||
|
|
||||||
def _is_exit_command(command: str) -> bool:
|
def _is_exit_command(command: str) -> bool:
|
||||||
"""Return True when input should end interactive chat."""
|
"""Return True when input should end interactive chat."""
|
||||||
return command.lower() in EXIT_COMMANDS
|
return command.lower() in EXIT_COMMANDS
|
||||||
@@ -635,39 +680,6 @@ def agent(
|
|||||||
channels_config=config.channels,
|
channels_config=config.channels,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Show spinner when logs are off (no output to miss); skip when logs are on
|
|
||||||
class _ThinkingSpinner:
|
|
||||||
"""Context manager that owns spinner lifecycle with pause support."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._spinner = None if logs else console.status(
|
|
||||||
"[dim]nanobot is thinking...[/dim]", spinner="dots"
|
|
||||||
)
|
|
||||||
self._active = False
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
if self._spinner:
|
|
||||||
self._spinner.start()
|
|
||||||
self._active = True
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *exc):
|
|
||||||
self._active = False
|
|
||||||
if self._spinner:
|
|
||||||
self._spinner.stop()
|
|
||||||
return False
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def pause(self):
|
|
||||||
"""Temporarily stop spinner for clean console output."""
|
|
||||||
if self._spinner and self._active:
|
|
||||||
self._spinner.stop()
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
if self._spinner and self._active:
|
|
||||||
self._spinner.start()
|
|
||||||
|
|
||||||
# Shared reference for progress callbacks
|
# Shared reference for progress callbacks
|
||||||
_thinking: _ThinkingSpinner | None = None
|
_thinking: _ThinkingSpinner | None = None
|
||||||
|
|
||||||
@@ -677,17 +689,13 @@ def agent(
|
|||||||
return
|
return
|
||||||
if ch and not tool_hint and not ch.send_progress:
|
if ch and not tool_hint and not ch.send_progress:
|
||||||
return
|
return
|
||||||
if _thinking:
|
_print_cli_progress_line(content, _thinking)
|
||||||
with _thinking.pause():
|
|
||||||
console.print(f" [dim]↳ {content}[/dim]")
|
|
||||||
else:
|
|
||||||
console.print(f" [dim]↳ {content}[/dim]")
|
|
||||||
|
|
||||||
if message:
|
if message:
|
||||||
# Single message mode — direct call, no bus needed
|
# Single message mode — direct call, no bus needed
|
||||||
async def run_once():
|
async def run_once():
|
||||||
nonlocal _thinking
|
nonlocal _thinking
|
||||||
_thinking = _ThinkingSpinner()
|
_thinking = _ThinkingSpinner(enabled=not logs)
|
||||||
with _thinking:
|
with _thinking:
|
||||||
response = await agent_loop.process_direct(message, session_id, on_progress=_cli_progress)
|
response = await agent_loop.process_direct(message, session_id, on_progress=_cli_progress)
|
||||||
_thinking = None
|
_thinking = None
|
||||||
@@ -739,11 +747,8 @@ def agent(
|
|||||||
pass
|
pass
|
||||||
elif ch and not is_tool_hint and not ch.send_progress:
|
elif ch and not is_tool_hint and not ch.send_progress:
|
||||||
pass
|
pass
|
||||||
elif _thinking:
|
|
||||||
with _thinking.pause():
|
|
||||||
await _print_interactive_line(msg.content)
|
|
||||||
else:
|
else:
|
||||||
await _print_interactive_line(msg.content)
|
await _print_interactive_progress_line(msg.content, _thinking)
|
||||||
|
|
||||||
elif not turn_done.is_set():
|
elif not turn_done.is_set():
|
||||||
if msg.content:
|
if msg.content:
|
||||||
@@ -784,7 +789,7 @@ def agent(
|
|||||||
))
|
))
|
||||||
|
|
||||||
nonlocal _thinking
|
nonlocal _thinking
|
||||||
_thinking = _ThinkingSpinner()
|
_thinking = _ThinkingSpinner(enabled=not logs)
|
||||||
with _thinking:
|
with _thinking:
|
||||||
await turn_done.wait()
|
await turn_done.wait()
|
||||||
_thinking = None
|
_thinking = None
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from prompt_toolkit.formatted_text import HTML
|
from prompt_toolkit.formatted_text import HTML
|
||||||
@@ -57,3 +57,57 @@ def test_init_prompt_session_creates_session():
|
|||||||
_, kwargs = MockSession.call_args
|
_, kwargs = MockSession.call_args
|
||||||
assert kwargs["multiline"] is False
|
assert kwargs["multiline"] is False
|
||||||
assert kwargs["enable_open_in_editor"] is False
|
assert kwargs["enable_open_in_editor"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_thinking_spinner_pause_stops_and_restarts():
|
||||||
|
"""Pause should stop the active spinner and restart it afterward."""
|
||||||
|
spinner = MagicMock()
|
||||||
|
|
||||||
|
with patch.object(commands.console, "status", return_value=spinner):
|
||||||
|
thinking = commands._ThinkingSpinner(enabled=True)
|
||||||
|
with thinking:
|
||||||
|
with thinking.pause():
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert spinner.method_calls == [
|
||||||
|
call.start(),
|
||||||
|
call.stop(),
|
||||||
|
call.start(),
|
||||||
|
call.stop(),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_cli_progress_line_pauses_spinner_before_printing():
|
||||||
|
"""CLI progress output should pause spinner to avoid garbled lines."""
|
||||||
|
order: list[str] = []
|
||||||
|
spinner = MagicMock()
|
||||||
|
spinner.start.side_effect = lambda: order.append("start")
|
||||||
|
spinner.stop.side_effect = lambda: order.append("stop")
|
||||||
|
|
||||||
|
with patch.object(commands.console, "status", return_value=spinner), \
|
||||||
|
patch.object(commands.console, "print", side_effect=lambda *_args, **_kwargs: order.append("print")):
|
||||||
|
thinking = commands._ThinkingSpinner(enabled=True)
|
||||||
|
with thinking:
|
||||||
|
commands._print_cli_progress_line("tool running", thinking)
|
||||||
|
|
||||||
|
assert order == ["start", "stop", "print", "start", "stop"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_print_interactive_progress_line_pauses_spinner_before_printing():
|
||||||
|
"""Interactive progress output should also pause spinner cleanly."""
|
||||||
|
order: list[str] = []
|
||||||
|
spinner = MagicMock()
|
||||||
|
spinner.start.side_effect = lambda: order.append("start")
|
||||||
|
spinner.stop.side_effect = lambda: order.append("stop")
|
||||||
|
|
||||||
|
async def fake_print(_text: str) -> None:
|
||||||
|
order.append("print")
|
||||||
|
|
||||||
|
with patch.object(commands.console, "status", return_value=spinner), \
|
||||||
|
patch("nanobot.cli.commands._print_interactive_line", side_effect=fake_print):
|
||||||
|
thinking = commands._ThinkingSpinner(enabled=True)
|
||||||
|
with thinking:
|
||||||
|
await commands._print_interactive_progress_line("tool running", thinking)
|
||||||
|
|
||||||
|
assert order == ["start", "stop", "print", "start", "stop"]
|
||||||
|
|||||||
Reference in New Issue
Block a user