Repositories / more_dspy.git

tests/test_chat_adapter_with_trailing_instructions.py

Clone (read-only): git clone http://git.guha-anderson.com/git/more_dspy.git

Branch
3512 bytes · d78a05df6e2c
import dspy from more_dspy import ChatAdapterWithTrailingInstructions class SigWithExtra(dspy.Signature): x: str = dspy.InputField() y: str = dspy.InputField() extra_instructions: str = dspy.InputField() z: str = dspy.OutputField() class SigWithoutExtra(dspy.Signature): x: str = dspy.InputField() y: str = dspy.InputField() z: str = dspy.OutputField() adapter = ChatAdapterWithTrailingInstructions() def get_last_user_content(messages: list[dict]) -> str: for msg in reversed(messages): if msg["role"] == "user": return msg["content"] raise ValueError("No user message found") def test_extra_instructions_appended(): inputs = {"x": "hello", "y": "world", "extra_instructions": "Be concise."} messages = adapter.format(SigWithExtra, demos=[], inputs=inputs) content = get_last_user_content(messages) assert content.endswith("Be concise.") def test_extra_instructions_empty_string(): inputs = {"x": "hello", "y": "world", "extra_instructions": ""} messages = adapter.format(SigWithExtra, demos=[], inputs=inputs) content = get_last_user_content(messages) assert not content.endswith("\n") def test_extra_instructions_none(): inputs = {"x": "hello", "y": "world", "extra_instructions": None} messages = adapter.format(SigWithExtra, demos=[], inputs=inputs) content = get_last_user_content(messages) assert "extra_instructions" not in content.split("\n")[-1] def test_signature_without_extra_instructions_field(): inputs = {"x": "hello", "y": "world"} messages = adapter.format(SigWithoutExtra, demos=[], inputs=inputs) content = get_last_user_content(messages) # Should work normally, ending with the output requirements assert "[[ ## completed ## ]]" in content def test_extra_instructions_after_output_requirements(): """extra_instructions should come after the 'Respond with...' output requirements.""" inputs = {"x": "hello", "y": "world", "extra_instructions": "TRAILING"} messages = adapter.format(SigWithExtra, demos=[], inputs=inputs) content = get_last_user_content(messages) respond_pos = content.find("Respond with the corresponding output fields") trailing_pos = content.find("TRAILING") assert respond_pos < trailing_pos def test_no_json_fallback(): assert adapter.use_json_adapter_fallback is False def test_extra_instructions_not_in_demos(): """extra_instructions should not appear as a field in demo messages.""" demo = {"x": "a", "y": "b", "extra_instructions": "DEMO EXTRA", "z": "out"} inputs = {"x": "hello", "y": "world", "extra_instructions": "MAIN EXTRA"} messages = adapter.format(SigWithExtra, demos=[demo], inputs=inputs) # The demo user message should NOT contain extra_instructions as a field demo_user = messages[1] assert demo_user["role"] == "user" assert "extra_instructions" not in demo_user["content"] assert "DEMO EXTRA" not in demo_user["content"] # The last user message should have MAIN EXTRA at the end content = get_last_user_content(messages) assert content.endswith("MAIN EXTRA") def test_extra_instructions_not_in_system_message(): """extra_instructions should not appear in the system message.""" inputs = {"x": "hello", "y": "world", "extra_instructions": "Be concise."} messages = adapter.format(SigWithExtra, demos=[], inputs=inputs) system_msg = messages[0]["content"] assert "extra_instructions" not in system_msg