Update README.MD and add nano-claude-code v3.0 + original-source-code/src

- README.MD: add original-source-code and nano-claude-code sections, update
  overview table (4 subprojects), add v3.0 news entry, expand comparison table
  with memory/multi-agent/skills dimensions
- nano-claude-code v3.0: multi-agent package (multi_agent/), memory package
  (memory/), skill package (skill/) with built-in /commit and /review skills,
  context compression (compaction.py), tool registry plugin system, diff view,
  17 slash commands, 18 built-in tools, 101 tests (~5000 lines total)
- original-source-code/src: add raw TypeScript source tree (1884 files)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
chauncygu
2026-04-03 10:26:29 -07:00
parent 3de4c595ea
commit 1d4ffa964d
1942 changed files with 521644 additions and 112 deletions

View File

View File

@@ -0,0 +1,187 @@
"""Tests for compaction.py — token estimation, context limits, snipping, split point."""
from __future__ import annotations
import sys
import os
# Ensure project root is on sys.path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from compaction import estimate_tokens, get_context_limit, snip_old_tool_results, find_split_point
# ── estimate_tokens ───────────────────────────────────────────────────────
class TestEstimateTokens:
def test_simple_messages(self):
msgs = [
{"role": "user", "content": "Hello world"}, # 11 chars
{"role": "assistant", "content": "Hi there!"}, # 9 chars
]
result = estimate_tokens(msgs)
# (11 + 9) / 3.5 = 5.71 -> 5
assert result == int(20 / 3.5)
def test_empty_messages(self):
assert estimate_tokens([]) == 0
def test_empty_content(self):
msgs = [{"role": "user", "content": ""}]
assert estimate_tokens(msgs) == 0
def test_tool_result_messages(self):
msgs = [
{"role": "tool", "tool_call_id": "abc", "name": "Read", "content": "x" * 350},
]
result = estimate_tokens(msgs)
assert result == int(350 / 3.5)
def test_structured_content(self):
"""Content that is a list of dicts (e.g. Anthropic tool_result blocks)."""
msgs = [
{
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "id1", "content": "A" * 70},
],
},
]
result = estimate_tokens(msgs)
# "tool_result" (11) + "id1" (3) + "A"*70 (70) = 84 -> 84/3.5 = 24
assert result == int(84 / 3.5)
def test_with_tool_calls(self):
msgs = [
{
"role": "assistant",
"content": "ok",
"tool_calls": [
{"id": "c1", "name": "Bash", "input": {"command": "ls"}},
],
},
]
result = estimate_tokens(msgs)
# content "ok" (2) + tool_calls string values: "c1" (2) + "Bash" (4) = 8
assert result == int(8 / 3.5)
# ── get_context_limit ─────────────────────────────────────────────────────
class TestGetContextLimit:
def test_anthropic(self):
assert get_context_limit("claude-opus-4-6") == 200000
def test_gemini(self):
assert get_context_limit("gemini-2.0-flash") == 1000000
def test_deepseek(self):
assert get_context_limit("deepseek-chat") == 64000
def test_openai(self):
assert get_context_limit("gpt-4o") == 128000
def test_qwen(self):
assert get_context_limit("qwen-max") == 1000000
def test_unknown_model_fallback(self):
# Unknown models fall back to openai provider which has 128000
assert get_context_limit("some-random-model-xyz") == 128000
def test_explicit_provider_prefix(self):
assert get_context_limit("ollama/llama3.3") == 128000
# ── snip_old_tool_results ─────────────────────────────────────────────────
class TestSnipOldToolResults:
def test_old_tool_results_get_truncated(self):
long_content = "A" * 5000
msgs = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "let me check", "tool_calls": []},
{"role": "tool", "tool_call_id": "t1", "name": "Read", "content": long_content},
{"role": "user", "content": "thanks"},
{"role": "assistant", "content": "you're welcome"},
{"role": "user", "content": "bye"},
{"role": "assistant", "content": "goodbye"},
{"role": "user", "content": "wait"},
{"role": "assistant", "content": "yes?"},
{"role": "user", "content": "never mind"},
]
result = snip_old_tool_results(msgs, max_chars=2000, preserve_last_n_turns=6)
assert result is msgs # mutated in place
tool_msg = msgs[2]
assert len(tool_msg["content"]) < 5000
assert "snipped" in tool_msg["content"]
def test_recent_tool_results_preserved(self):
long_content = "B" * 5000
msgs = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "ok", "tool_calls": []},
{"role": "tool", "tool_call_id": "t1", "name": "Read", "content": long_content},
]
# All 3 messages are within preserve_last_n_turns=6
result = snip_old_tool_results(msgs, max_chars=2000, preserve_last_n_turns=6)
assert msgs[2]["content"] == long_content # not truncated
def test_short_tool_results_not_touched(self):
msgs = [
{"role": "tool", "tool_call_id": "t1", "name": "Bash", "content": "short"},
{"role": "user", "content": "a"},
{"role": "user", "content": "b"},
{"role": "user", "content": "c"},
{"role": "user", "content": "d"},
{"role": "user", "content": "e"},
{"role": "user", "content": "f"},
]
snip_old_tool_results(msgs, max_chars=2000, preserve_last_n_turns=6)
assert msgs[0]["content"] == "short"
def test_non_tool_messages_untouched(self):
msgs = [
{"role": "user", "content": "X" * 5000},
{"role": "user", "content": "a"},
{"role": "user", "content": "b"},
{"role": "user", "content": "c"},
{"role": "user", "content": "d"},
{"role": "user", "content": "e"},
{"role": "user", "content": "f"},
]
snip_old_tool_results(msgs, max_chars=2000, preserve_last_n_turns=6)
assert msgs[0]["content"] == "X" * 5000
# ── find_split_point ──────────────────────────────────────────────────────
class TestFindSplitPoint:
def test_returns_reasonable_index(self):
msgs = [
{"role": "user", "content": "A" * 1000},
{"role": "assistant", "content": "B" * 1000},
{"role": "user", "content": "C" * 1000},
{"role": "assistant", "content": "D" * 1000},
{"role": "user", "content": "E" * 1000},
]
idx = find_split_point(msgs, keep_ratio=0.3)
# With equal-size messages and keep_ratio=0.3, split should be around index 3-4
assert 2 <= idx <= 4
def test_single_message(self):
msgs = [{"role": "user", "content": "hello"}]
idx = find_split_point(msgs, keep_ratio=0.3)
assert idx == 0
def test_empty_messages(self):
idx = find_split_point([], keep_ratio=0.3)
assert idx == 0
def test_split_preserves_recent(self):
# Recent portion should contain ~30% of tokens
msgs = [{"role": "user", "content": "X" * 100} for _ in range(10)]
idx = find_split_point(msgs, keep_ratio=0.3)
total = estimate_tokens(msgs)
recent = estimate_tokens(msgs[idx:])
# Recent should be roughly 30% of total (allow some tolerance)
assert recent >= total * 0.2
assert recent <= total * 0.5

View File

@@ -0,0 +1,50 @@
import sys, os, tempfile
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pytest
def test_generate_unified_diff():
from tools import generate_unified_diff
old = "line1\nline2\nline3\n"
new = "line1\nline2_modified\nline3\n"
diff = generate_unified_diff(old, new, "test.py")
assert "--- a/test.py" in diff
assert "+++ b/test.py" in diff
assert "-line2" in diff
assert "+line2_modified" in diff
def test_generate_unified_diff_empty_old():
from tools import generate_unified_diff
diff = generate_unified_diff("", "new content\n", "test.py")
assert "+new content" in diff
def test_edit_returns_diff(tmp_path):
from tools import _edit
f = tmp_path / "test.txt"
f.write_text("hello world\n")
result = _edit(str(f), "hello", "goodbye")
assert "-hello world" in result
assert "+goodbye world" in result
def test_write_existing_returns_diff(tmp_path):
from tools import _write
f = tmp_path / "test.txt"
f.write_text("old content\n")
result = _write(str(f), "new content\n")
assert "-old content" in result
assert "+new content" in result
def test_write_new_file_no_diff(tmp_path):
from tools import _write
f = tmp_path / "new.txt"
result = _write(str(f), "content\n")
assert "Created" in result
assert "---" not in result
def test_diff_truncation():
from tools import generate_unified_diff, maybe_truncate_diff
old = "\n".join(f"line{i}" for i in range(200))
new = "\n".join(f"CHANGED{i}" for i in range(200))
diff = generate_unified_diff(old, new, "big.py")
truncated = maybe_truncate_diff(diff, max_lines=50)
assert "more lines" in truncated
assert truncated.count("\n") < 60

View File

@@ -0,0 +1,275 @@
"""Tests for the memory package (memory/)."""
import pytest
from pathlib import Path
import memory.store as _store
from memory.store import (
MemoryEntry,
save_memory,
load_index,
load_entries,
delete_memory,
search_memory,
_slugify,
parse_frontmatter,
get_index_content,
)
from memory.context import get_memory_context, truncate_index_content
from memory.scan import (
scan_memory_dir,
format_memory_manifest,
memory_age_days,
memory_age_str,
memory_freshness_text,
MemoryHeader,
)
from memory.types import MEMORY_TYPES
# ── Fixtures ─────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def redirect_memory_dirs(tmp_path, monkeypatch):
"""Redirect user and project memory dirs to tmp_path for all tests."""
user_mem = tmp_path / "user_memory"
user_mem.mkdir()
proj_mem = tmp_path / "project_memory"
proj_mem.mkdir()
monkeypatch.setattr(_store, "USER_MEMORY_DIR", user_mem)
# Patch get_project_memory_dir to return our tmp project dir
monkeypatch.setattr(_store, "get_project_memory_dir", lambda: proj_mem)
def _make_entry(name="test note", description="a test", type_="user",
content="hello world", scope="user"):
return MemoryEntry(
name=name, description=description, type=type_,
content=content, created="2026-04-02", scope=scope,
)
# ── Save and Load ─────────────────────────────────────────────────────────
class TestSaveAndLoad:
def test_roundtrip(self):
entry = _make_entry()
save_memory(entry, scope="user")
loaded = load_entries("user")
assert len(loaded) == 1
assert loaded[0].name == "test note"
assert loaded[0].description == "a test"
assert loaded[0].type == "user"
assert loaded[0].content == "hello world"
def test_creates_file_on_disk(self):
entry = _make_entry()
save_memory(entry, scope="user")
assert Path(entry.file_path).exists()
text = Path(entry.file_path).read_text()
assert "hello world" in text
def test_update_existing(self):
"""Save same name twice → only 1 entry with updated content."""
save_memory(_make_entry(content="version 1"), scope="user")
save_memory(_make_entry(content="version 2"), scope="user")
loaded = load_entries("user")
assert len(loaded) == 1
assert loaded[0].content == "version 2"
def test_project_scope_stored_separately(self):
save_memory(_make_entry(name="user note"), scope="user")
save_memory(_make_entry(name="proj note"), scope="project")
user_entries = load_entries("user")
proj_entries = load_entries("project")
assert len(user_entries) == 1
assert len(proj_entries) == 1
assert user_entries[0].name == "user note"
assert proj_entries[0].name == "proj note"
def test_load_index_all_combines_scopes(self):
save_memory(_make_entry(name="user note"), scope="user")
save_memory(_make_entry(name="proj note"), scope="project")
all_entries = load_index("all")
names = {e.name for e in all_entries}
assert "user note" in names
assert "proj note" in names
# ── Delete ────────────────────────────────────────────────────────────────
class TestDelete:
def test_delete_removes_file_and_index(self):
entry = _make_entry()
save_memory(entry, scope="user")
delete_memory("test note", scope="user")
assert load_entries("user") == []
assert not Path(entry.file_path).exists()
def test_delete_nonexistent_no_error(self):
delete_memory("nonexistent", scope="user")
def test_delete_from_project_scope(self):
save_memory(_make_entry(name="proj note"), scope="project")
delete_memory("proj note", scope="project")
assert load_entries("project") == []
# ── Search ────────────────────────────────────────────────────────────────
class TestSearch:
def test_search_by_keyword(self):
save_memory(_make_entry(name="python tips", content="use list comprehension"), scope="user")
save_memory(_make_entry(name="rust tips", content="use iterators"), scope="user")
results = search_memory("python")
assert len(results) == 1
assert results[0].name == "python tips"
def test_search_case_insensitive(self):
save_memory(_make_entry(name="Important Note", content="something"), scope="user")
results = search_memory("important")
assert len(results) == 1
def test_search_in_content(self):
save_memory(_make_entry(name="misc", content="the quick brown fox"), scope="user")
results = search_memory("brown fox")
assert len(results) == 1
def test_search_across_scopes(self):
save_memory(_make_entry(name="user note", content="alpha"), scope="user")
save_memory(_make_entry(name="proj note", content="alpha"), scope="project")
results = search_memory("alpha", scope="all")
assert len(results) == 2
# ── Memory context ────────────────────────────────────────────────────────
class TestGetMemoryContext:
def test_returns_index_text(self):
save_memory(_make_entry(name="my note", description="desc here"), scope="user")
ctx = get_memory_context()
assert "my note" in ctx
assert "desc here" in ctx
def test_empty_when_no_memories(self):
ctx = get_memory_context()
assert ctx == ""
def test_project_memories_labelled(self):
save_memory(_make_entry(name="proj note", description="project context"), scope="project")
ctx = get_memory_context()
assert "Project memories" in ctx
assert "proj note" in ctx
# ── Truncation ────────────────────────────────────────────────────────────
class TestTruncation:
def test_no_truncation_within_limits(self):
text = "- line\n" * 10
result = truncate_index_content(text)
assert "WARNING" not in result
def test_line_truncation(self):
text = "\n".join(f"- line {i}" for i in range(300))
result = truncate_index_content(text)
assert "WARNING" in result
assert "lines" in result
def test_byte_truncation(self):
# 25001 bytes of content
text = "x" * 25001
result = truncate_index_content(text)
assert "WARNING" in result
# ── Slugify ───────────────────────────────────────────────────────────────
class TestSlugify:
def test_basic(self):
assert _slugify("Hello World") == "hello_world"
def test_special_chars(self):
assert _slugify("foo@bar!baz") == "foobarbaz"
def test_max_length(self):
assert len(_slugify("a" * 100)) == 60
# ── parse_frontmatter ─────────────────────────────────────────────────────
class TestParseFrontmatter:
def test_parse(self):
text = "---\nname: foo\ntype: user\n---\nbody text"
meta, body = parse_frontmatter(text)
assert meta["name"] == "foo"
assert meta["type"] == "user"
assert body == "body text"
def test_no_frontmatter(self):
meta, body = parse_frontmatter("just plain text")
assert meta == {}
assert body == "just plain text"
# ── scan / age / freshness ────────────────────────────────────────────────
class TestScanAndAge:
def test_scan_memory_dir(self):
save_memory(_make_entry(name="note a"), scope="user")
save_memory(_make_entry(name="note b"), scope="user")
user_dir = _store.USER_MEMORY_DIR
headers = scan_memory_dir(user_dir, "user")
assert len(headers) == 2
assert all(isinstance(h, MemoryHeader) for h in headers)
def test_format_manifest(self):
import time
headers = [
MemoryHeader(
filename="foo.md",
file_path="/tmp/foo.md",
mtime_s=time.time(),
description="test desc",
type="user",
scope="user",
)
]
manifest = format_memory_manifest(headers)
assert "foo.md" in manifest
assert "test desc" in manifest
assert "today" in manifest
def test_memory_age_days_today(self):
import time
assert memory_age_days(time.time()) == 0
def test_memory_age_days_old(self):
import time
old = time.time() - 5 * 86400 # 5 days ago
assert memory_age_days(old) == 5
def test_memory_age_str(self):
import time
assert memory_age_str(time.time()) == "today"
assert memory_age_str(time.time() - 86400) == "yesterday"
assert memory_age_str(time.time() - 3 * 86400) == "3 days ago"
def test_freshness_text_fresh(self):
import time
assert memory_freshness_text(time.time()) == ""
def test_freshness_text_stale(self):
import time
old = time.time() - 10 * 86400
text = memory_freshness_text(old)
assert "10 days old" in text
assert "stale" in text.lower() or "outdated" in text.lower()
# ── Memory types ──────────────────────────────────────────────────────────
class TestMemoryTypes:
def test_types_list(self):
assert set(MEMORY_TYPES) == {"user", "feedback", "project", "reference"}

View File

@@ -0,0 +1,234 @@
from __future__ import annotations
import pytest
from pathlib import Path
import skill.loader as _loader
from skill.loader import _parse_skill_file, _parse_list_field, find_skill, SkillDef
from skill import load_skills, substitute_arguments
COMMIT_MD = """\
---
name: commit
description: Create a git commit
triggers: [/commit, commit changes]
tools: [Bash, Read]
---
Review staged changes and create a commit with a descriptive message.
"""
REVIEW_MD = """\
---
name: review
description: Review a pull request
triggers: [/review, /review-pr]
tools: [Bash, Read, Grep]
---
Analyze the PR diff and provide constructive feedback.
"""
ARGS_MD = """\
---
name: deploy
description: Deploy to an environment
triggers: [/deploy]
tools: [Bash]
argument-hint: [env] [version]
arguments: [env, version]
---
Deploy $VERSION to $ENV environment. Full args: $ARGUMENTS
"""
@pytest.fixture()
def skill_dir(tmp_path, monkeypatch):
"""Create a temp skill directory with sample skills and patch _get_skill_paths."""
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
(skills_dir / "commit.md").write_text(COMMIT_MD, encoding="utf-8")
(skills_dir / "review.md").write_text(REVIEW_MD, encoding="utf-8")
monkeypatch.setattr(_loader, "_get_skill_paths", lambda: [skills_dir])
# Also patch the builtin list to be empty so tests are predictable
monkeypatch.setattr(_loader, "_BUILTIN_SKILLS", [])
return skills_dir
# ------------------------------------------------------------------
# _parse_list_field
# ------------------------------------------------------------------
def test_parse_list_field_bracket():
assert _parse_list_field("[a, b, c]") == ["a", "b", "c"]
def test_parse_list_field_plain():
assert _parse_list_field("a, b, c") == ["a", "b", "c"]
def test_parse_list_field_single():
assert _parse_list_field("solo") == ["solo"]
# ------------------------------------------------------------------
# _parse_skill_file
# ------------------------------------------------------------------
def test_parse_skill_file(skill_dir):
path = skill_dir / "commit.md"
skill = _parse_skill_file(path)
assert skill is not None
assert skill.name == "commit"
assert skill.description == "Create a git commit"
assert "/commit" in skill.triggers
assert "commit changes" in skill.triggers
assert "Bash" in skill.tools
assert "Read" in skill.tools
assert "commit" in skill.prompt.lower()
assert skill.file_path == str(path)
def test_parse_skill_file_review(skill_dir):
path = skill_dir / "review.md"
skill = _parse_skill_file(path)
assert skill is not None
assert skill.name == "review"
assert "/review" in skill.triggers
assert "/review-pr" in skill.triggers
def test_parse_skill_file_invalid(tmp_path):
bad = tmp_path / "bad.md"
bad.write_text("no frontmatter here", encoding="utf-8")
assert _parse_skill_file(bad) is None
def test_parse_skill_file_no_name(tmp_path):
no_name = tmp_path / "noname.md"
no_name.write_text("---\ndescription: test\n---\nbody\n", encoding="utf-8")
assert _parse_skill_file(no_name) is None
def test_parse_skill_file_context_fork(tmp_path):
fork_md = tmp_path / "fork.md"
fork_md.write_text("---\nname: fork-task\ndescription: test\ncontext: fork\n---\nbody\n")
skill = _parse_skill_file(fork_md)
assert skill is not None
assert skill.context == "fork"
def test_parse_skill_file_allowed_tools(tmp_path):
md = tmp_path / "t.md"
md.write_text("---\nname: myskill\ndescription: d\nallowed-tools: [Bash, Read]\n---\nbody\n")
skill = _parse_skill_file(md)
assert skill is not None
assert "Bash" in skill.tools
assert "Read" in skill.tools
# ------------------------------------------------------------------
# load_skills
# ------------------------------------------------------------------
def test_load_skills(skill_dir):
skills = load_skills()
assert len(skills) == 2
names = {s.name for s in skills}
assert names == {"commit", "review"}
def test_load_skills_empty_dir(tmp_path, monkeypatch):
empty = tmp_path / "empty_skills"
empty.mkdir()
monkeypatch.setattr(_loader, "_get_skill_paths", lambda: [empty])
monkeypatch.setattr(_loader, "_BUILTIN_SKILLS", [])
assert load_skills() == []
def test_load_skills_nonexistent_dir(tmp_path, monkeypatch):
monkeypatch.setattr(_loader, "_get_skill_paths", lambda: [tmp_path / "does_not_exist"])
monkeypatch.setattr(_loader, "_BUILTIN_SKILLS", [])
assert load_skills() == []
def test_load_skills_builtins_present(monkeypatch):
"""Without patching, builtins (commit, review) should be present."""
monkeypatch.setattr(_loader, "_get_skill_paths", lambda: [])
skills = load_skills()
names = {s.name for s in skills}
assert "commit" in names
assert "review" in names
def test_load_skills_project_overrides_builtin(tmp_path, monkeypatch):
"""A project skill with the same name overrides the builtin."""
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
# project-level "commit" with different description
(skills_dir / "commit.md").write_text(
"---\nname: commit\ndescription: OVERRIDDEN\n---\ncustom commit prompt\n"
)
monkeypatch.setattr(_loader, "_get_skill_paths", lambda: [skills_dir])
skills = load_skills()
commit = next(s for s in skills if s.name == "commit")
assert commit.description == "OVERRIDDEN"
# ------------------------------------------------------------------
# find_skill
# ------------------------------------------------------------------
def test_find_skill_commit(skill_dir):
skill = find_skill("/commit")
assert skill is not None
assert skill.name == "commit"
def test_find_skill_review(skill_dir):
skill = find_skill("/review")
assert skill is not None
assert skill.name == "review"
def test_find_skill_review_pr(skill_dir):
skill = find_skill("/review-pr some-pr-url")
assert skill is not None
assert skill.name == "review"
def test_find_skill_nonexistent(skill_dir):
result = find_skill("/nonexistent")
assert result is None
# ------------------------------------------------------------------
# substitute_arguments
# ------------------------------------------------------------------
def test_substitute_arguments_placeholder():
result = substitute_arguments("Deploy $ARGUMENTS please", "v1.2 prod", [])
assert result == "Deploy v1.2 prod please"
def test_substitute_named_args(tmp_path):
result = substitute_arguments(
"Deploy $VERSION to $ENV. Full args: $ARGUMENTS",
"1.0 staging",
["env", "version"],
)
# arg_names are positional: env=1.0, version=staging
assert "$VERSION" not in result
assert "$ENV" not in result
assert "$ARGUMENTS" not in result
def test_substitute_missing_arg():
# If user provides fewer args than named slots, missing ones become ""
result = substitute_arguments("Hello $NAME!", "", ["name"])
assert result == "Hello !"
def test_substitute_no_placeholders():
result = substitute_arguments("just a plain prompt", "some args", [])
assert result == "just a plain prompt"

View File

@@ -0,0 +1,136 @@
"""Tests for the sub-agent system (subagent.py)."""
import time
import threading
import pytest
from multi_agent.subagent import SubAgentManager, SubAgentTask, _extract_final_text
# ── Mock for _agent_run ──────────────────────────────────────────────────
def _make_mock_agent_run(sleep_per_iter=0.05, iters=3):
"""Return a mock _agent_run that simulates work and checks cancellation."""
def mock_agent_run(prompt, state, config, system_prompt, depth=0, cancel_check=None):
for i in range(iters):
if cancel_check and cancel_check():
return
time.sleep(sleep_per_iter)
# Append an assistant message to state
state.messages.append({
"role": "assistant",
"content": f"Result for: {prompt}",
"tool_calls": [],
})
# Yield a TurnDone-like event (generator protocol)
yield None
return mock_agent_run
def _make_slow_mock(sleep_per_iter=0.2, iters=10):
"""Return a slow mock for cancellation testing."""
return _make_mock_agent_run(sleep_per_iter=sleep_per_iter, iters=iters)
@pytest.fixture
def manager(monkeypatch):
"""Create a SubAgentManager with mocked _agent_run."""
mock = _make_mock_agent_run()
monkeypatch.setattr("multi_agent.subagent._agent_run", mock)
mgr = SubAgentManager(max_concurrent=3, max_depth=3)
yield mgr
mgr.shutdown()
@pytest.fixture
def slow_manager(monkeypatch):
"""Create a SubAgentManager with a slow mock for cancel testing."""
mock = _make_slow_mock()
monkeypatch.setattr("multi_agent.subagent._agent_run", mock)
mgr = SubAgentManager(max_concurrent=3, max_depth=3)
yield mgr
mgr.shutdown()
# ── Tests ────────────────────────────────────────────────────────────────
class TestSpawnAndWait:
def test_spawn_and_wait_completes(self, manager):
task = manager.spawn("hello", {}, "system")
result_task = manager.wait(task.id, timeout=5)
assert result_task is not None
assert result_task.status == "completed"
assert result_task.result == "Result for: hello"
def test_spawn_returns_immediately(self, manager):
task = manager.spawn("hello", {}, "system")
# Task should be pending or running, not yet completed
assert task.status in ("pending", "running")
class TestListTasks:
def test_list_tasks(self, manager):
t1 = manager.spawn("task1", {}, "system")
t2 = manager.spawn("task2", {}, "system")
tasks = manager.list_tasks()
task_ids = [t.id for t in tasks]
assert t1.id in task_ids
assert t2.id in task_ids
assert len(tasks) == 2
class TestCancel:
def test_cancel_running_task(self, slow_manager):
task = slow_manager.spawn("slow task", {}, "system")
# Wait briefly to ensure the task starts running
time.sleep(0.1)
assert task.status == "running"
success = slow_manager.cancel(task.id)
assert success is True
# Wait for the task to actually finish
slow_manager.wait(task.id, timeout=5)
assert task.status == "cancelled"
class TestDepthLimit:
def test_spawn_at_max_depth_fails(self, manager):
task = manager.spawn("deep", {}, "system", depth=3)
assert task.status == "failed"
assert "Max depth" in task.result
class TestGetResult:
def test_get_result_completed(self, manager):
task = manager.spawn("hello", {}, "system")
manager.wait(task.id, timeout=5)
result = manager.get_result(task.id)
assert result == "Result for: hello"
def test_get_result_unknown_id(self, manager):
result = manager.get_result("nonexistent_id")
assert result is None
class TestExtractFinalText:
def test_extracts_last_assistant(self):
messages = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "first"},
{"role": "user", "content": "more"},
{"role": "assistant", "content": "second"},
]
assert _extract_final_text(messages) == "second"
def test_returns_none_for_empty(self):
assert _extract_final_text([]) is None
def test_returns_none_no_assistant(self):
messages = [{"role": "user", "content": "hi"}]
assert _extract_final_text(messages) is None
class TestWaitUnknown:
def test_wait_unknown_returns_none(self, manager):
assert manager.wait("nonexistent") is None

View File

@@ -0,0 +1,160 @@
from __future__ import annotations
import pytest
from tool_registry import (
ToolDef,
clear_registry,
execute_tool,
get_all_tools,
get_tool,
get_tool_schemas,
register_tool,
)
@pytest.fixture(autouse=True)
def _clean_registry():
"""Reset registry before each test."""
clear_registry()
yield
clear_registry()
def _make_echo_tool(name: str = "echo", read_only: bool = False) -> ToolDef:
"""Helper to build a simple echo tool."""
schema = {
"name": name,
"description": f"Echo tool ({name})",
"input_schema": {
"type": "object",
"properties": {
"text": {"type": "string", "description": "text to echo"},
},
"required": ["text"],
},
}
def func(params: dict, config: dict) -> str:
return params["text"]
return ToolDef(
name=name,
schema=schema,
func=func,
read_only=read_only,
concurrent_safe=True,
)
# ------------------------------------------------------------------
# register and get
# ------------------------------------------------------------------
def test_register_and_get():
tool = _make_echo_tool()
register_tool(tool)
result = get_tool("echo")
assert result is not None
assert result.name == "echo"
def test_get_unknown_returns_none():
assert get_tool("no_such_tool") is None
# ------------------------------------------------------------------
# get_all_tools
# ------------------------------------------------------------------
def test_get_all_tools_empty():
assert get_all_tools() == []
def test_get_all_tools():
register_tool(_make_echo_tool("a"))
register_tool(_make_echo_tool("b"))
names = [t.name for t in get_all_tools()]
assert sorted(names) == ["a", "b"]
# ------------------------------------------------------------------
# get_tool_schemas
# ------------------------------------------------------------------
def test_get_tool_schemas():
register_tool(_make_echo_tool("echo"))
schemas = get_tool_schemas()
assert len(schemas) == 1
assert schemas[0]["name"] == "echo"
# ------------------------------------------------------------------
# execute_tool
# ------------------------------------------------------------------
def test_execute_tool():
register_tool(_make_echo_tool())
result = execute_tool("echo", {"text": "hello"}, config={})
assert result == "hello"
def test_execute_unknown_tool():
result = execute_tool("missing", {}, config={})
assert "unknown" in result.lower() or "not found" in result.lower()
# ------------------------------------------------------------------
# output truncation
# ------------------------------------------------------------------
def test_output_truncation():
def big_func(params: dict, config: dict) -> str:
return "x" * 100
tool = ToolDef(
name="big",
schema={"name": "big", "description": "big", "input_schema": {"type": "object", "properties": {}}},
func=big_func,
read_only=True,
concurrent_safe=True,
)
register_tool(tool)
result = execute_tool("big", {}, config={}, max_output=40)
# first half = 20 chars, last quarter = 10 chars, marker in between
assert len(result) < 100
assert "truncated" in result
# The kept portion: first 20 + last 10 should be present
assert result.startswith("x" * 20)
assert result.endswith("x" * 10)
def test_no_truncation_when_within_limit():
register_tool(_make_echo_tool())
result = execute_tool("echo", {"text": "short"}, config={})
assert result == "short"
# ------------------------------------------------------------------
# duplicate register overwrites
# ------------------------------------------------------------------
def test_duplicate_register_overwrites():
register_tool(_make_echo_tool("dup"))
def new_func(params: dict, config: dict) -> str:
return "new"
replacement = ToolDef(
name="dup",
schema={"name": "dup", "description": "new", "input_schema": {"type": "object", "properties": {}}},
func=new_func,
read_only=False,
concurrent_safe=False,
)
register_tool(replacement)
assert len(get_all_tools()) == 1
result = execute_tool("dup", {}, config={})
assert result == "new"