Repositories / more_dspy.git
more_dspy.git
Clone (read-only): git clone http://git.guha-anderson.com/git/more_dspy.git
@@ -1,6 +1,55 @@ +from typing import Optional + +import dspy +from dspy.adapters.chat_adapter import ChatAdapter +from dspy.dsp.utils.settings import settings +from dspy.predict.predict import Predict +from dspy.utils.exceptions import AdapterParseError + from more_dspy.chat_adapter_with_trailing_instructions import ChatAdapterWithTrailingInstructions -__all__ = ["ChatAdapterWithTrailingInstructions"] +__all__ = [ + "ChatAdapterWithTrailingInstructions", + "format_prompt", + "parse_response", +] + + +def _get_predict(module: dspy.Module) -> Predict: + """Extract the inner Predict module from any DSPy module.""" + if isinstance(module, Predict): + return module + if hasattr(module, "predict") and isinstance(module.predict, Predict): + return module.predict + raise TypeError(f"Cannot extract a Predict module from {type(module).__name__}") + + +def format_prompt(module: dspy.Module, **kwargs) -> list[dict]: + """Produce the chat messages that would be sent to the LM. + + Uses the adapter from ``dspy.settings`` (falling back to ``ChatAdapter``). + No LM needs to be configured. + """ + predict = _get_predict(module) + adapter = settings.adapter or ChatAdapter() + signature = predict.signature + demos = predict.demos + return adapter.format(signature=signature, demos=demos, inputs=kwargs) + + +def parse_response(module: dspy.Module, response: str) -> Optional[dspy.Prediction]: + """Parse a raw LM response string into a ``dspy.Prediction``. + + Returns ``None`` if the response cannot be parsed. + """ + predict = _get_predict(module) + adapter = settings.adapter or ChatAdapter() + signature = predict.signature + try: + fields = adapter.parse(signature, response) + except (AdapterParseError, Exception): + return None + return dspy.Prediction(**fields) def main() -> None:
@@ -0,0 +1,209 @@ +import dspy +from dspy.adapters.chat_adapter import ChatAdapter +from dspy.adapters.json_adapter import JSONAdapter + +from more_dspy import format_prompt, parse_response + + +class QA(dspy.Signature): + """Answer the question.""" + + question: str = dspy.InputField() + answer: str = dspy.OutputField() + + +class Classify(dspy.Signature): + """Classify the text.""" + + text: str = dspy.InputField() + label: str = dspy.OutputField() + confidence: float = dspy.OutputField() + + +# --------------------------------------------------------------------------- +# ChatAdapter tests +# --------------------------------------------------------------------------- + + +def test_format_prompt_predict_chat(): + with dspy.context(adapter=ChatAdapter()): + module = dspy.Predict(QA) + messages = format_prompt(module, question="What is 2+2?") + + assert isinstance(messages, list) + assert messages[0]["role"] == "system" + # The last user message should contain the question value + user_msgs = [m for m in messages if m["role"] == "user"] + assert any("What is 2+2?" in m["content"] for m in user_msgs) + + +def test_format_prompt_cot_chat(): + with dspy.context(adapter=ChatAdapter()): + module = dspy.ChainOfThought(QA) + messages = format_prompt(module, question="What is 2+2?") + + user_content = [m for m in messages if m["role"] == "user"][-1]["content"] + # ChainOfThought adds a reasoning field + assert "reasoning" in messages[0]["content"].lower() or "reasoning" in user_content.lower() + + +def test_parse_response_predict_chat(): + with dspy.context(adapter=ChatAdapter()): + module = dspy.Predict(QA) + completion = "[[ ## answer ## ]]\n4\n\n[[ ## completed ## ]]" + pred = parse_response(module, completion) + + assert pred is not None + assert pred.answer == "4" + + +def test_parse_response_cot_chat(): + with dspy.context(adapter=ChatAdapter()): + module = dspy.ChainOfThought(QA) + completion = ( + "[[ ## reasoning ## ]]\n2 plus 2 equals 4\n\n" + "[[ ## answer ## ]]\n4\n\n" + "[[ ## completed ## ]]" + ) + pred = parse_response(module, completion) + + assert pred is not None + assert pred.answer == "4" + assert pred.reasoning == "2 plus 2 equals 4" + + +def test_parse_response_bad_input_chat(): + with dspy.context(adapter=ChatAdapter()): + module = dspy.Predict(QA) + pred = parse_response(module, "just some random text") + + assert pred is None + + +def test_parse_response_multiple_outputs_chat(): + with dspy.context(adapter=ChatAdapter()): + module = dspy.Predict(Classify) + completion = ( + "[[ ## label ## ]]\npositive\n\n" + "[[ ## confidence ## ]]\n0.95\n\n" + "[[ ## completed ## ]]" + ) + pred = parse_response(module, completion) + + assert pred is not None + assert pred.label == "positive" + assert pred.confidence == 0.95 + + +# --------------------------------------------------------------------------- +# JSONAdapter tests +# --------------------------------------------------------------------------- + + +def test_format_prompt_predict_json(): + with dspy.context(adapter=JSONAdapter()): + module = dspy.Predict(QA) + messages = format_prompt(module, question="What is 2+2?") + + assert isinstance(messages, list) + assert messages[0]["role"] == "system" + user_msgs = [m for m in messages if m["role"] == "user"] + assert any("What is 2+2?" in m["content"] for m in user_msgs) + # JSONAdapter mentions JSON in its output requirements + last_user = user_msgs[-1]["content"] + assert "JSON" in last_user or "json" in last_user + + +def test_format_prompt_cot_json(): + with dspy.context(adapter=JSONAdapter()): + module = dspy.ChainOfThought(QA) + messages = format_prompt(module, question="What is 2+2?") + + system_content = messages[0]["content"] + assert "reasoning" in system_content.lower() + + +def test_parse_response_predict_json(): + with dspy.context(adapter=JSONAdapter()): + module = dspy.Predict(QA) + completion = '{"answer": "4"}' + pred = parse_response(module, completion) + + assert pred is not None + assert pred.answer == "4" + + +def test_parse_response_cot_json(): + with dspy.context(adapter=JSONAdapter()): + module = dspy.ChainOfThought(QA) + completion = '{"reasoning": "2 plus 2 equals 4", "answer": "4"}' + pred = parse_response(module, completion) + + assert pred is not None + assert pred.answer == "4" + assert pred.reasoning == "2 plus 2 equals 4" + + +def test_parse_response_bad_input_json(): + with dspy.context(adapter=JSONAdapter()): + module = dspy.Predict(QA) + pred = parse_response(module, "not json at all {{{") + + assert pred is None + + +def test_parse_response_multiple_outputs_json(): + with dspy.context(adapter=JSONAdapter()): + module = dspy.Predict(Classify) + completion = '{"label": "positive", "confidence": 0.95}' + pred = parse_response(module, completion) + + assert pred is not None + assert pred.label == "positive" + assert pred.confidence == 0.95 + + +# --------------------------------------------------------------------------- +# Adapter from dspy.settings (default) +# --------------------------------------------------------------------------- + + +def test_format_prompt_uses_settings_adapter(): + """format_prompt should respect the adapter set via dspy.configure/context.""" + with dspy.context(adapter=JSONAdapter()): + module = dspy.Predict(QA) + messages = format_prompt(module, question="hi") + last_user = [m for m in messages if m["role"] == "user"][-1]["content"] + assert "JSON" in last_user or "json" in last_user + + with dspy.context(adapter=ChatAdapter()): + messages = format_prompt(module, question="hi") + last_user = [m for m in messages if m["role"] == "user"][-1]["content"] + assert "[[ ## completed ## ]]" in last_user + + +def test_roundtrip_chat(): + """format_prompt + simulated LM response + parse_response should roundtrip.""" + with dspy.context(adapter=ChatAdapter()): + module = dspy.Predict(QA) + messages = format_prompt(module, question="Capital of France?") + assert len(messages) >= 2 + + response = "[[ ## answer ## ]]\nParis\n\n[[ ## completed ## ]]" + pred = parse_response(module, response) + + assert pred is not None + assert pred.answer == "Paris" + + +def test_roundtrip_json(): + with dspy.context(adapter=JSONAdapter()): + module = dspy.Predict(QA) + messages = format_prompt(module, question="Capital of France?") + assert len(messages) >= 2 + + response = '{"answer": "Paris"}' + pred = parse_response(module, response) + + assert pred is not None + assert pred.answer == "Paris"