Repositories / more_dspy.git

src/more_dspy/chat_adapter_with_trailing_instructions.py

Clone (read-only): git clone http://git.guha-anderson.com/git/more_dspy.git

Branch
3133 bytes · ed69971c920c
from typing import Any from dspy.adapters.chat_adapter import ChatAdapter from dspy.signatures.signature import Signature from dspy.utils.callback import BaseCallback EXTRA_INSTRUCTIONS_FIELD = "extra_instructions" def _strip_extra_instructions(signature: type[Signature]) -> type[Signature]: if EXTRA_INSTRUCTIONS_FIELD not in signature.input_fields: return signature fields = {k: v for k, v in signature.fields.items() if k != EXTRA_INSTRUCTIONS_FIELD} # Build a fresh Signature so auto-generated instructions are regenerated # without mentioning extra_instructions. Custom (user-provided) instructions # won't reference the field name, so they're safe to pass through. stripped = Signature(fields, signature.instructions) # If the instructions were auto-generated, they'll mention extra_instructions. # Regenerate by creating without explicit instructions. if EXTRA_INSTRUCTIONS_FIELD in stripped.instructions: stripped = Signature(fields) return stripped class ChatAdapterWithTrailingInstructions(ChatAdapter): """ChatAdapter variant that appends extra_instructions to the end of the user message. If the signature has an InputField called ``extra_instructions`` whose value is not None or "", its text is appended at the very end of the final user message (separated by a newline). JSON adapter fallback is disabled. The ``extra_instructions`` field is excluded from the system message and from the normal field rendering in user messages — it only appears as trailing text. """ def __init__( self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = False, native_response_types: list[type[type]] | None = None, ): super().__init__( callbacks=callbacks, use_native_function_calling=use_native_function_calling, native_response_types=native_response_types, use_json_adapter_fallback=False, ) def format_field_description(self, signature: type[Signature]) -> str: return super().format_field_description(_strip_extra_instructions(signature)) def format_field_structure(self, signature: type[Signature]) -> str: return super().format_field_structure(_strip_extra_instructions(signature)) def format_task_description(self, signature: type[Signature]) -> str: return super().format_task_description(_strip_extra_instructions(signature)) def format_user_message_content( self, signature: type[Signature], inputs: dict[str, Any], prefix: str = "", suffix: str = "", main_request: bool = False, ) -> str: extra = inputs.get(EXTRA_INSTRUCTIONS_FIELD) content = super().format_user_message_content( _strip_extra_instructions(signature), inputs, prefix=prefix, suffix=suffix, main_request=main_request, ) if main_request and extra not in (None, ""): content = content + "\n" + extra return content