Repositories / more_dspy.git

more_dspy.git

Clone (read-only): git clone http://git.guha-anderson.com/git/more_dspy.git

Branch

Add format_prompt and parse_response functions

These allow calling the DSPy adapter directly to inspect raw prompts
and parse LM responses without configuring or invoking an LM.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Author
Arjun Guha <a.guha@northeastern.edu>
Date
2026-04-02 09:00:17 -0400
Commit
3fc95002f3da090569545fe5cbd8dfa43709e989
src/more_dspy/__init__.py
index 6c2b807..9984f1f 100644
--- a/src/more_dspy/__init__.py
+++ b/src/more_dspy/__init__.py
@@ -1,6 +1,55 @@
+from typing import Optional
+
+import dspy
+from dspy.adapters.chat_adapter import ChatAdapter
+from dspy.dsp.utils.settings import settings
+from dspy.predict.predict import Predict
+from dspy.utils.exceptions import AdapterParseError
+
 from more_dspy.chat_adapter_with_trailing_instructions import ChatAdapterWithTrailingInstructions
 
-__all__ = ["ChatAdapterWithTrailingInstructions"]
+__all__ = [
+    "ChatAdapterWithTrailingInstructions",
+    "format_prompt",
+    "parse_response",
+]
+
+
+def _get_predict(module: dspy.Module) -> Predict:
+    """Extract the inner Predict module from any DSPy module."""
+    if isinstance(module, Predict):
+        return module
+    if hasattr(module, "predict") and isinstance(module.predict, Predict):
+        return module.predict
+    raise TypeError(f"Cannot extract a Predict module from {type(module).__name__}")
+
+
+def format_prompt(module: dspy.Module, **kwargs) -> list[dict]:
+    """Produce the chat messages that would be sent to the LM.
+
+    Uses the adapter from ``dspy.settings`` (falling back to ``ChatAdapter``).
+    No LM needs to be configured.
+    """
+    predict = _get_predict(module)
+    adapter = settings.adapter or ChatAdapter()
+    signature = predict.signature
+    demos = predict.demos
+    return adapter.format(signature=signature, demos=demos, inputs=kwargs)
+
+
+def parse_response(module: dspy.Module, response: str) -> Optional[dspy.Prediction]:
+    """Parse a raw LM response string into a ``dspy.Prediction``.
+
+    Returns ``None`` if the response cannot be parsed.
+    """
+    predict = _get_predict(module)
+    adapter = settings.adapter or ChatAdapter()
+    signature = predict.signature
+    try:
+        fields = adapter.parse(signature, response)
+    except (AdapterParseError, Exception):
+        return None
+    return dspy.Prediction(**fields)
 
 
 def main() -> None:
tests/test_format_and_parse.py
new file mode 100644
index 0000000..4da5186
--- /dev/null
+++ b/tests/test_format_and_parse.py
@@ -0,0 +1,209 @@
+import dspy
+from dspy.adapters.chat_adapter import ChatAdapter
+from dspy.adapters.json_adapter import JSONAdapter
+
+from more_dspy import format_prompt, parse_response
+
+
+class QA(dspy.Signature):
+    """Answer the question."""
+
+    question: str = dspy.InputField()
+    answer: str = dspy.OutputField()
+
+
+class Classify(dspy.Signature):
+    """Classify the text."""
+
+    text: str = dspy.InputField()
+    label: str = dspy.OutputField()
+    confidence: float = dspy.OutputField()
+
+
+# ---------------------------------------------------------------------------
+# ChatAdapter tests
+# ---------------------------------------------------------------------------
+
+
+def test_format_prompt_predict_chat():
+    with dspy.context(adapter=ChatAdapter()):
+        module = dspy.Predict(QA)
+        messages = format_prompt(module, question="What is 2+2?")
+
+    assert isinstance(messages, list)
+    assert messages[0]["role"] == "system"
+    # The last user message should contain the question value
+    user_msgs = [m for m in messages if m["role"] == "user"]
+    assert any("What is 2+2?" in m["content"] for m in user_msgs)
+
+
+def test_format_prompt_cot_chat():
+    with dspy.context(adapter=ChatAdapter()):
+        module = dspy.ChainOfThought(QA)
+        messages = format_prompt(module, question="What is 2+2?")
+
+    user_content = [m for m in messages if m["role"] == "user"][-1]["content"]
+    # ChainOfThought adds a reasoning field
+    assert "reasoning" in messages[0]["content"].lower() or "reasoning" in user_content.lower()
+
+
+def test_parse_response_predict_chat():
+    with dspy.context(adapter=ChatAdapter()):
+        module = dspy.Predict(QA)
+        completion = "[[ ## answer ## ]]\n4\n\n[[ ## completed ## ]]"
+        pred = parse_response(module, completion)
+
+    assert pred is not None
+    assert pred.answer == "4"
+
+
+def test_parse_response_cot_chat():
+    with dspy.context(adapter=ChatAdapter()):
+        module = dspy.ChainOfThought(QA)
+        completion = (
+            "[[ ## reasoning ## ]]\n2 plus 2 equals 4\n\n"
+            "[[ ## answer ## ]]\n4\n\n"
+            "[[ ## completed ## ]]"
+        )
+        pred = parse_response(module, completion)
+
+    assert pred is not None
+    assert pred.answer == "4"
+    assert pred.reasoning == "2 plus 2 equals 4"
+
+
+def test_parse_response_bad_input_chat():
+    with dspy.context(adapter=ChatAdapter()):
+        module = dspy.Predict(QA)
+        pred = parse_response(module, "just some random text")
+
+    assert pred is None
+
+
+def test_parse_response_multiple_outputs_chat():
+    with dspy.context(adapter=ChatAdapter()):
+        module = dspy.Predict(Classify)
+        completion = (
+            "[[ ## label ## ]]\npositive\n\n"
+            "[[ ## confidence ## ]]\n0.95\n\n"
+            "[[ ## completed ## ]]"
+        )
+        pred = parse_response(module, completion)
+
+    assert pred is not None
+    assert pred.label == "positive"
+    assert pred.confidence == 0.95
+
+
+# ---------------------------------------------------------------------------
+# JSONAdapter tests
+# ---------------------------------------------------------------------------
+
+
+def test_format_prompt_predict_json():
+    with dspy.context(adapter=JSONAdapter()):
+        module = dspy.Predict(QA)
+        messages = format_prompt(module, question="What is 2+2?")
+
+    assert isinstance(messages, list)
+    assert messages[0]["role"] == "system"
+    user_msgs = [m for m in messages if m["role"] == "user"]
+    assert any("What is 2+2?" in m["content"] for m in user_msgs)
+    # JSONAdapter mentions JSON in its output requirements
+    last_user = user_msgs[-1]["content"]
+    assert "JSON" in last_user or "json" in last_user
+
+
+def test_format_prompt_cot_json():
+    with dspy.context(adapter=JSONAdapter()):
+        module = dspy.ChainOfThought(QA)
+        messages = format_prompt(module, question="What is 2+2?")
+
+    system_content = messages[0]["content"]
+    assert "reasoning" in system_content.lower()
+
+
+def test_parse_response_predict_json():
+    with dspy.context(adapter=JSONAdapter()):
+        module = dspy.Predict(QA)
+        completion = '{"answer": "4"}'
+        pred = parse_response(module, completion)
+
+    assert pred is not None
+    assert pred.answer == "4"
+
+
+def test_parse_response_cot_json():
+    with dspy.context(adapter=JSONAdapter()):
+        module = dspy.ChainOfThought(QA)
+        completion = '{"reasoning": "2 plus 2 equals 4", "answer": "4"}'
+        pred = parse_response(module, completion)
+
+    assert pred is not None
+    assert pred.answer == "4"
+    assert pred.reasoning == "2 plus 2 equals 4"
+
+
+def test_parse_response_bad_input_json():
+    with dspy.context(adapter=JSONAdapter()):
+        module = dspy.Predict(QA)
+        pred = parse_response(module, "not json at all {{{")
+
+    assert pred is None
+
+
+def test_parse_response_multiple_outputs_json():
+    with dspy.context(adapter=JSONAdapter()):
+        module = dspy.Predict(Classify)
+        completion = '{"label": "positive", "confidence": 0.95}'
+        pred = parse_response(module, completion)
+
+    assert pred is not None
+    assert pred.label == "positive"
+    assert pred.confidence == 0.95
+
+
+# ---------------------------------------------------------------------------
+# Adapter from dspy.settings (default)
+# ---------------------------------------------------------------------------
+
+
+def test_format_prompt_uses_settings_adapter():
+    """format_prompt should respect the adapter set via dspy.configure/context."""
+    with dspy.context(adapter=JSONAdapter()):
+        module = dspy.Predict(QA)
+        messages = format_prompt(module, question="hi")
+        last_user = [m for m in messages if m["role"] == "user"][-1]["content"]
+        assert "JSON" in last_user or "json" in last_user
+
+    with dspy.context(adapter=ChatAdapter()):
+        messages = format_prompt(module, question="hi")
+        last_user = [m for m in messages if m["role"] == "user"][-1]["content"]
+        assert "[[ ## completed ## ]]" in last_user
+
+
+def test_roundtrip_chat():
+    """format_prompt + simulated LM response + parse_response should roundtrip."""
+    with dspy.context(adapter=ChatAdapter()):
+        module = dspy.Predict(QA)
+        messages = format_prompt(module, question="Capital of France?")
+        assert len(messages) >= 2
+
+        response = "[[ ## answer ## ]]\nParis\n\n[[ ## completed ## ]]"
+        pred = parse_response(module, response)
+
+    assert pred is not None
+    assert pred.answer == "Paris"
+
+
+def test_roundtrip_json():
+    with dspy.context(adapter=JSONAdapter()):
+        module = dspy.Predict(QA)
+        messages = format_prompt(module, question="Capital of France?")
+        assert len(messages) >= 2
+
+        response = '{"answer": "Paris"}'
+        pred = parse_response(module, response)
+
+    assert pred is not None
+    assert pred.answer == "Paris"