Repositories / more_dspy.git
src/more_dspy/chat_adapter_with_trailing_instructions.py
Clone (read-only): git clone http://git.guha-anderson.com/git/more_dspy.git
from typing import Any
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.signatures.signature import Signature
from dspy.utils.callback import BaseCallback
EXTRA_INSTRUCTIONS_FIELD = "extra_instructions"
def _strip_extra_instructions(signature: type[Signature]) -> type[Signature]:
if EXTRA_INSTRUCTIONS_FIELD not in signature.input_fields:
return signature
fields = {k: v for k, v in signature.fields.items() if k != EXTRA_INSTRUCTIONS_FIELD}
# Build a fresh Signature so auto-generated instructions are regenerated
# without mentioning extra_instructions. Custom (user-provided) instructions
# won't reference the field name, so they're safe to pass through.
stripped = Signature(fields, signature.instructions)
# If the instructions were auto-generated, they'll mention extra_instructions.
# Regenerate by creating without explicit instructions.
if EXTRA_INSTRUCTIONS_FIELD in stripped.instructions:
stripped = Signature(fields)
return stripped
class ChatAdapterWithTrailingInstructions(ChatAdapter):
"""ChatAdapter variant that appends extra_instructions to the end of the user message.
If the signature has an InputField called ``extra_instructions`` whose value
is not None or "", its text is appended at the very end of the final user
message (separated by a newline). JSON adapter fallback is disabled.
The ``extra_instructions`` field is excluded from the system message and from
the normal field rendering in user messages — it only appears as trailing text.
"""
def __init__(
self,
callbacks: list[BaseCallback] | None = None,
use_native_function_calling: bool = False,
native_response_types: list[type[type]] | None = None,
):
super().__init__(
callbacks=callbacks,
use_native_function_calling=use_native_function_calling,
native_response_types=native_response_types,
use_json_adapter_fallback=False,
)
def format_field_description(self, signature: type[Signature]) -> str:
return super().format_field_description(_strip_extra_instructions(signature))
def format_field_structure(self, signature: type[Signature]) -> str:
return super().format_field_structure(_strip_extra_instructions(signature))
def format_task_description(self, signature: type[Signature]) -> str:
return super().format_task_description(_strip_extra_instructions(signature))
def format_user_message_content(
self,
signature: type[Signature],
inputs: dict[str, Any],
prefix: str = "",
suffix: str = "",
main_request: bool = False,
) -> str:
extra = inputs.get(EXTRA_INSTRUCTIONS_FIELD)
content = super().format_user_message_content(
_strip_extra_instructions(signature),
inputs,
prefix=prefix,
suffix=suffix,
main_request=main_request,
)
if main_request and extra not in (None, ""):
content = content + "\n" + extra
return content