Repositories / more_dspy.git

tests/test_format_and_parse.py

Clone (read-only): git clone http://git.guha-anderson.com/git/more_dspy.git

Branch
6827 bytes · 4da51863c8dc
import dspy from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.json_adapter import JSONAdapter from more_dspy import format_prompt, parse_response class QA(dspy.Signature): """Answer the question.""" question: str = dspy.InputField() answer: str = dspy.OutputField() class Classify(dspy.Signature): """Classify the text.""" text: str = dspy.InputField() label: str = dspy.OutputField() confidence: float = dspy.OutputField() # --------------------------------------------------------------------------- # ChatAdapter tests # --------------------------------------------------------------------------- def test_format_prompt_predict_chat(): with dspy.context(adapter=ChatAdapter()): module = dspy.Predict(QA) messages = format_prompt(module, question="What is 2+2?") assert isinstance(messages, list) assert messages[0]["role"] == "system" # The last user message should contain the question value user_msgs = [m for m in messages if m["role"] == "user"] assert any("What is 2+2?" in m["content"] for m in user_msgs) def test_format_prompt_cot_chat(): with dspy.context(adapter=ChatAdapter()): module = dspy.ChainOfThought(QA) messages = format_prompt(module, question="What is 2+2?") user_content = [m for m in messages if m["role"] == "user"][-1]["content"] # ChainOfThought adds a reasoning field assert "reasoning" in messages[0]["content"].lower() or "reasoning" in user_content.lower() def test_parse_response_predict_chat(): with dspy.context(adapter=ChatAdapter()): module = dspy.Predict(QA) completion = "[[ ## answer ## ]]\n4\n\n[[ ## completed ## ]]" pred = parse_response(module, completion) assert pred is not None assert pred.answer == "4" def test_parse_response_cot_chat(): with dspy.context(adapter=ChatAdapter()): module = dspy.ChainOfThought(QA) completion = ( "[[ ## reasoning ## ]]\n2 plus 2 equals 4\n\n" "[[ ## answer ## ]]\n4\n\n" "[[ ## completed ## ]]" ) pred = parse_response(module, completion) assert pred is not None assert pred.answer == "4" assert pred.reasoning == "2 plus 2 equals 4" def test_parse_response_bad_input_chat(): with dspy.context(adapter=ChatAdapter()): module = dspy.Predict(QA) pred = parse_response(module, "just some random text") assert pred is None def test_parse_response_multiple_outputs_chat(): with dspy.context(adapter=ChatAdapter()): module = dspy.Predict(Classify) completion = ( "[[ ## label ## ]]\npositive\n\n" "[[ ## confidence ## ]]\n0.95\n\n" "[[ ## completed ## ]]" ) pred = parse_response(module, completion) assert pred is not None assert pred.label == "positive" assert pred.confidence == 0.95 # --------------------------------------------------------------------------- # JSONAdapter tests # --------------------------------------------------------------------------- def test_format_prompt_predict_json(): with dspy.context(adapter=JSONAdapter()): module = dspy.Predict(QA) messages = format_prompt(module, question="What is 2+2?") assert isinstance(messages, list) assert messages[0]["role"] == "system" user_msgs = [m for m in messages if m["role"] == "user"] assert any("What is 2+2?" in m["content"] for m in user_msgs) # JSONAdapter mentions JSON in its output requirements last_user = user_msgs[-1]["content"] assert "JSON" in last_user or "json" in last_user def test_format_prompt_cot_json(): with dspy.context(adapter=JSONAdapter()): module = dspy.ChainOfThought(QA) messages = format_prompt(module, question="What is 2+2?") system_content = messages[0]["content"] assert "reasoning" in system_content.lower() def test_parse_response_predict_json(): with dspy.context(adapter=JSONAdapter()): module = dspy.Predict(QA) completion = '{"answer": "4"}' pred = parse_response(module, completion) assert pred is not None assert pred.answer == "4" def test_parse_response_cot_json(): with dspy.context(adapter=JSONAdapter()): module = dspy.ChainOfThought(QA) completion = '{"reasoning": "2 plus 2 equals 4", "answer": "4"}' pred = parse_response(module, completion) assert pred is not None assert pred.answer == "4" assert pred.reasoning == "2 plus 2 equals 4" def test_parse_response_bad_input_json(): with dspy.context(adapter=JSONAdapter()): module = dspy.Predict(QA) pred = parse_response(module, "not json at all {{{") assert pred is None def test_parse_response_multiple_outputs_json(): with dspy.context(adapter=JSONAdapter()): module = dspy.Predict(Classify) completion = '{"label": "positive", "confidence": 0.95}' pred = parse_response(module, completion) assert pred is not None assert pred.label == "positive" assert pred.confidence == 0.95 # --------------------------------------------------------------------------- # Adapter from dspy.settings (default) # --------------------------------------------------------------------------- def test_format_prompt_uses_settings_adapter(): """format_prompt should respect the adapter set via dspy.configure/context.""" with dspy.context(adapter=JSONAdapter()): module = dspy.Predict(QA) messages = format_prompt(module, question="hi") last_user = [m for m in messages if m["role"] == "user"][-1]["content"] assert "JSON" in last_user or "json" in last_user with dspy.context(adapter=ChatAdapter()): messages = format_prompt(module, question="hi") last_user = [m for m in messages if m["role"] == "user"][-1]["content"] assert "[[ ## completed ## ]]" in last_user def test_roundtrip_chat(): """format_prompt + simulated LM response + parse_response should roundtrip.""" with dspy.context(adapter=ChatAdapter()): module = dspy.Predict(QA) messages = format_prompt(module, question="Capital of France?") assert len(messages) >= 2 response = "[[ ## answer ## ]]\nParis\n\n[[ ## completed ## ]]" pred = parse_response(module, response) assert pred is not None assert pred.answer == "Paris" def test_roundtrip_json(): with dspy.context(adapter=JSONAdapter()): module = dspy.Predict(QA) messages = format_prompt(module, question="Capital of France?") assert len(messages) >= 2 response = '{"answer": "Paris"}' pred = parse_response(module, response) assert pred is not None assert pred.answer == "Paris"