Repositories / more_dspy.git
tests/test_format_and_parse.py
Clone (read-only): git clone http://git.guha-anderson.com/git/more_dspy.git
import dspy
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters.json_adapter import JSONAdapter
from more_dspy import format_prompt, parse_response
class QA(dspy.Signature):
"""Answer the question."""
question: str = dspy.InputField()
answer: str = dspy.OutputField()
class Classify(dspy.Signature):
"""Classify the text."""
text: str = dspy.InputField()
label: str = dspy.OutputField()
confidence: float = dspy.OutputField()
# ---------------------------------------------------------------------------
# ChatAdapter tests
# ---------------------------------------------------------------------------
def test_format_prompt_predict_chat():
with dspy.context(adapter=ChatAdapter()):
module = dspy.Predict(QA)
messages = format_prompt(module, question="What is 2+2?")
assert isinstance(messages, list)
assert messages[0]["role"] == "system"
# The last user message should contain the question value
user_msgs = [m for m in messages if m["role"] == "user"]
assert any("What is 2+2?" in m["content"] for m in user_msgs)
def test_format_prompt_cot_chat():
with dspy.context(adapter=ChatAdapter()):
module = dspy.ChainOfThought(QA)
messages = format_prompt(module, question="What is 2+2?")
user_content = [m for m in messages if m["role"] == "user"][-1]["content"]
# ChainOfThought adds a reasoning field
assert "reasoning" in messages[0]["content"].lower() or "reasoning" in user_content.lower()
def test_parse_response_predict_chat():
with dspy.context(adapter=ChatAdapter()):
module = dspy.Predict(QA)
completion = "[[ ## answer ## ]]\n4\n\n[[ ## completed ## ]]"
pred = parse_response(module, completion)
assert pred is not None
assert pred.answer == "4"
def test_parse_response_cot_chat():
with dspy.context(adapter=ChatAdapter()):
module = dspy.ChainOfThought(QA)
completion = (
"[[ ## reasoning ## ]]\n2 plus 2 equals 4\n\n"
"[[ ## answer ## ]]\n4\n\n"
"[[ ## completed ## ]]"
)
pred = parse_response(module, completion)
assert pred is not None
assert pred.answer == "4"
assert pred.reasoning == "2 plus 2 equals 4"
def test_parse_response_bad_input_chat():
with dspy.context(adapter=ChatAdapter()):
module = dspy.Predict(QA)
pred = parse_response(module, "just some random text")
assert pred is None
def test_parse_response_multiple_outputs_chat():
with dspy.context(adapter=ChatAdapter()):
module = dspy.Predict(Classify)
completion = (
"[[ ## label ## ]]\npositive\n\n"
"[[ ## confidence ## ]]\n0.95\n\n"
"[[ ## completed ## ]]"
)
pred = parse_response(module, completion)
assert pred is not None
assert pred.label == "positive"
assert pred.confidence == 0.95
# ---------------------------------------------------------------------------
# JSONAdapter tests
# ---------------------------------------------------------------------------
def test_format_prompt_predict_json():
with dspy.context(adapter=JSONAdapter()):
module = dspy.Predict(QA)
messages = format_prompt(module, question="What is 2+2?")
assert isinstance(messages, list)
assert messages[0]["role"] == "system"
user_msgs = [m for m in messages if m["role"] == "user"]
assert any("What is 2+2?" in m["content"] for m in user_msgs)
# JSONAdapter mentions JSON in its output requirements
last_user = user_msgs[-1]["content"]
assert "JSON" in last_user or "json" in last_user
def test_format_prompt_cot_json():
with dspy.context(adapter=JSONAdapter()):
module = dspy.ChainOfThought(QA)
messages = format_prompt(module, question="What is 2+2?")
system_content = messages[0]["content"]
assert "reasoning" in system_content.lower()
def test_parse_response_predict_json():
with dspy.context(adapter=JSONAdapter()):
module = dspy.Predict(QA)
completion = '{"answer": "4"}'
pred = parse_response(module, completion)
assert pred is not None
assert pred.answer == "4"
def test_parse_response_cot_json():
with dspy.context(adapter=JSONAdapter()):
module = dspy.ChainOfThought(QA)
completion = '{"reasoning": "2 plus 2 equals 4", "answer": "4"}'
pred = parse_response(module, completion)
assert pred is not None
assert pred.answer == "4"
assert pred.reasoning == "2 plus 2 equals 4"
def test_parse_response_bad_input_json():
with dspy.context(adapter=JSONAdapter()):
module = dspy.Predict(QA)
pred = parse_response(module, "not json at all {{{")
assert pred is None
def test_parse_response_multiple_outputs_json():
with dspy.context(adapter=JSONAdapter()):
module = dspy.Predict(Classify)
completion = '{"label": "positive", "confidence": 0.95}'
pred = parse_response(module, completion)
assert pred is not None
assert pred.label == "positive"
assert pred.confidence == 0.95
# ---------------------------------------------------------------------------
# Adapter from dspy.settings (default)
# ---------------------------------------------------------------------------
def test_format_prompt_uses_settings_adapter():
"""format_prompt should respect the adapter set via dspy.configure/context."""
with dspy.context(adapter=JSONAdapter()):
module = dspy.Predict(QA)
messages = format_prompt(module, question="hi")
last_user = [m for m in messages if m["role"] == "user"][-1]["content"]
assert "JSON" in last_user or "json" in last_user
with dspy.context(adapter=ChatAdapter()):
messages = format_prompt(module, question="hi")
last_user = [m for m in messages if m["role"] == "user"][-1]["content"]
assert "[[ ## completed ## ]]" in last_user
def test_roundtrip_chat():
"""format_prompt + simulated LM response + parse_response should roundtrip."""
with dspy.context(adapter=ChatAdapter()):
module = dspy.Predict(QA)
messages = format_prompt(module, question="Capital of France?")
assert len(messages) >= 2
response = "[[ ## answer ## ]]\nParis\n\n[[ ## completed ## ]]"
pred = parse_response(module, response)
assert pred is not None
assert pred.answer == "Paris"
def test_roundtrip_json():
with dspy.context(adapter=JSONAdapter()):
module = dspy.Predict(QA)
messages = format_prompt(module, question="Capital of France?")
assert len(messages) >= 2
response = '{"answer": "Paris"}'
pred = parse_response(module, response)
assert pred is not None
assert pred.answer == "Paris"