pocketflow/utils/call_llm.py

55 lines
1.9 KiB
Python

"""Minimal OpenAI Chat Completions helper used by Pocket Flow demos."""
from __future__ import annotations
import json
from typing import Iterable, Optional
try:
from openai import OpenAI
except ImportError: # pragma: no cover - optional dependency
OpenAI = None # type: ignore[misc,assignment]
try:
from llm_secrets import OPENAI_API_KEY
except ImportError as exc:
raise ImportError("Create llm_secrets.py with OPENAI_API_KEY before calling call_llm") from exc
_client: Optional[OpenAI] = None
def _get_client() -> OpenAI:
global _client
if OpenAI is None: # type: ignore[truthy-function]
raise RuntimeError("Install the 'openai' package to use call_llm")
if _client is None:
if not OPENAI_API_KEY or OPENAI_API_KEY == "REPLACE_WITH_YOUR_KEY":
raise ValueError("Set OPENAI_API_KEY in llm_secrets.py before calling call_llm")
_client = OpenAI(api_key=OPENAI_API_KEY)
return _client
def call_llm(messages: Iterable[dict] | str, model: str = "gpt-4o-mini") -> str:
"""Send a prompt or list of chat messages to OpenAI and return the text reply."""
client = _get_client()
chat_messages = (
[{"role": "user", "content": messages}]
if isinstance(messages, str)
else list(messages)
)
response = client.chat.completions.create(model=model, messages=chat_messages)
message = response.choices[0].message.content or ""
return message.strip()
def call_llm_json(messages: Iterable[dict] | str, model: str = "gpt-4o-mini") -> dict:
"""Convenience wrapper that expects a JSON object in the response."""
raw = call_llm(messages, model=model)
start = raw.find("{")
end = raw.rfind("}")
if start == -1 or end == -1:
raise ValueError(f"LLM response does not contain JSON: {raw}")
return json.loads(raw[start : end + 1])
__all__ = ["call_llm", "call_llm_json"]