Repositories / more_dspy.git
more_dspy.git
Clone (read-only): git clone http://git.guha-anderson.com/git/more_dspy.git
@@ -0,0 +1,7 @@ +from more_dspy.chat_adapter_with_trailing_instructions import ChatAdapterWithTrailingInstructions + +__all__ = ["ChatAdapterWithTrailingInstructions"] + + +def main() -> None: + print("Hello from more-dspy!")
@@ -0,0 +1,79 @@ +from typing import Any + +from dspy.adapters.chat_adapter import ChatAdapter +from dspy.signatures.signature import Signature +from dspy.utils.callback import BaseCallback + +EXTRA_INSTRUCTIONS_FIELD = "extra_instructions" + + +def _strip_extra_instructions(signature: type[Signature]) -> type[Signature]: + if EXTRA_INSTRUCTIONS_FIELD not in signature.input_fields: + return signature + fields = {k: v for k, v in signature.fields.items() if k != EXTRA_INSTRUCTIONS_FIELD} + # Build a fresh Signature so auto-generated instructions are regenerated + # without mentioning extra_instructions. Custom (user-provided) instructions + # won't reference the field name, so they're safe to pass through. + stripped = Signature(fields, signature.instructions) + # If the instructions were auto-generated, they'll mention extra_instructions. + # Regenerate by creating without explicit instructions. + if EXTRA_INSTRUCTIONS_FIELD in stripped.instructions: + stripped = Signature(fields) + return stripped + + +class ChatAdapterWithTrailingInstructions(ChatAdapter): + """ChatAdapter variant that appends extra_instructions to the end of the user message. + + If the signature has an InputField called ``extra_instructions`` whose value + is not None or "", its text is appended at the very end of the final user + message (separated by a newline). JSON adapter fallback is disabled. + + The ``extra_instructions`` field is excluded from the system message and from + the normal field rendering in user messages — it only appears as trailing text. + """ + + def __init__( + self, + callbacks: list[BaseCallback] | None = None, + use_native_function_calling: bool = False, + native_response_types: list[type[type]] | None = None, + ): + super().__init__( + callbacks=callbacks, + use_native_function_calling=use_native_function_calling, + native_response_types=native_response_types, + use_json_adapter_fallback=False, + ) + + def format_field_description(self, signature: type[Signature]) -> str: + return super().format_field_description(_strip_extra_instructions(signature)) + + def format_field_structure(self, signature: type[Signature]) -> str: + return super().format_field_structure(_strip_extra_instructions(signature)) + + def format_task_description(self, signature: type[Signature]) -> str: + return super().format_task_description(_strip_extra_instructions(signature)) + + def format_user_message_content( + self, + signature: type[Signature], + inputs: dict[str, Any], + prefix: str = "", + suffix: str = "", + main_request: bool = False, + ) -> str: + extra = inputs.get(EXTRA_INSTRUCTIONS_FIELD) + + content = super().format_user_message_content( + _strip_extra_instructions(signature), + inputs, + prefix=prefix, + suffix=suffix, + main_request=main_request, + ) + + if main_request and extra not in (None, ""): + content = content + "\n" + extra + + return content
@@ -0,0 +1,94 @@ +import dspy + +from more_dspy import ChatAdapterWithTrailingInstructions + + +class SigWithExtra(dspy.Signature): + x: str = dspy.InputField() + y: str = dspy.InputField() + extra_instructions: str = dspy.InputField() + z: str = dspy.OutputField() + + +class SigWithoutExtra(dspy.Signature): + x: str = dspy.InputField() + y: str = dspy.InputField() + z: str = dspy.OutputField() + + +adapter = ChatAdapterWithTrailingInstructions() + + +def get_last_user_content(messages: list[dict]) -> str: + for msg in reversed(messages): + if msg["role"] == "user": + return msg["content"] + raise ValueError("No user message found") + + +def test_extra_instructions_appended(): + inputs = {"x": "hello", "y": "world", "extra_instructions": "Be concise."} + messages = adapter.format(SigWithExtra, demos=[], inputs=inputs) + content = get_last_user_content(messages) + assert content.endswith("Be concise.") + + +def test_extra_instructions_empty_string(): + inputs = {"x": "hello", "y": "world", "extra_instructions": ""} + messages = adapter.format(SigWithExtra, demos=[], inputs=inputs) + content = get_last_user_content(messages) + assert not content.endswith("\n") + + +def test_extra_instructions_none(): + inputs = {"x": "hello", "y": "world", "extra_instructions": None} + messages = adapter.format(SigWithExtra, demos=[], inputs=inputs) + content = get_last_user_content(messages) + assert "extra_instructions" not in content.split("\n")[-1] + + +def test_signature_without_extra_instructions_field(): + inputs = {"x": "hello", "y": "world"} + messages = adapter.format(SigWithoutExtra, demos=[], inputs=inputs) + content = get_last_user_content(messages) + # Should work normally, ending with the output requirements + assert "[[ ## completed ## ]]" in content + + +def test_extra_instructions_after_output_requirements(): + """extra_instructions should come after the 'Respond with...' output requirements.""" + inputs = {"x": "hello", "y": "world", "extra_instructions": "TRAILING"} + messages = adapter.format(SigWithExtra, demos=[], inputs=inputs) + content = get_last_user_content(messages) + respond_pos = content.find("Respond with the corresponding output fields") + trailing_pos = content.find("TRAILING") + assert respond_pos < trailing_pos + + +def test_no_json_fallback(): + assert adapter.use_json_adapter_fallback is False + + +def test_extra_instructions_not_in_demos(): + """extra_instructions should not appear as a field in demo messages.""" + demo = {"x": "a", "y": "b", "extra_instructions": "DEMO EXTRA", "z": "out"} + inputs = {"x": "hello", "y": "world", "extra_instructions": "MAIN EXTRA"} + messages = adapter.format(SigWithExtra, demos=[demo], inputs=inputs) + + # The demo user message should NOT contain extra_instructions as a field + demo_user = messages[1] + assert demo_user["role"] == "user" + assert "extra_instructions" not in demo_user["content"] + assert "DEMO EXTRA" not in demo_user["content"] + + # The last user message should have MAIN EXTRA at the end + content = get_last_user_content(messages) + assert content.endswith("MAIN EXTRA") + + +def test_extra_instructions_not_in_system_message(): + """extra_instructions should not appear in the system message.""" + inputs = {"x": "hello", "y": "world", "extra_instructions": "Be concise."} + messages = adapter.format(SigWithExtra, demos=[], inputs=inputs) + system_msg = messages[0]["content"] + assert "extra_instructions" not in system_msg