55 lines
1.9 KiB
Python
55 lines
1.9 KiB
Python
"""Minimal OpenAI Chat Completions helper used by Pocket Flow demos."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Iterable, Optional
|
|
|
|
try:
|
|
from openai import OpenAI
|
|
except ImportError: # pragma: no cover - optional dependency
|
|
OpenAI = None # type: ignore[misc,assignment]
|
|
|
|
try:
|
|
from llm_secrets import OPENAI_API_KEY
|
|
except ImportError as exc:
|
|
raise ImportError("Create llm_secrets.py with OPENAI_API_KEY before calling call_llm") from exc
|
|
|
|
_client: Optional[OpenAI] = None
|
|
|
|
|
|
def _get_client() -> OpenAI:
|
|
global _client
|
|
if OpenAI is None: # type: ignore[truthy-function]
|
|
raise RuntimeError("Install the 'openai' package to use call_llm")
|
|
if _client is None:
|
|
if not OPENAI_API_KEY or OPENAI_API_KEY == "REPLACE_WITH_YOUR_KEY":
|
|
raise ValueError("Set OPENAI_API_KEY in llm_secrets.py before calling call_llm")
|
|
_client = OpenAI(api_key=OPENAI_API_KEY)
|
|
return _client
|
|
|
|
|
|
def call_llm(messages: Iterable[dict] | str, model: str = "gpt-4o-mini") -> str:
|
|
"""Send a prompt or list of chat messages to OpenAI and return the text reply."""
|
|
client = _get_client()
|
|
chat_messages = (
|
|
[{"role": "user", "content": messages}]
|
|
if isinstance(messages, str)
|
|
else list(messages)
|
|
)
|
|
response = client.chat.completions.create(model=model, messages=chat_messages)
|
|
message = response.choices[0].message.content or ""
|
|
return message.strip()
|
|
|
|
|
|
def call_llm_json(messages: Iterable[dict] | str, model: str = "gpt-4o-mini") -> dict:
|
|
"""Convenience wrapper that expects a JSON object in the response."""
|
|
raw = call_llm(messages, model=model)
|
|
start = raw.find("{")
|
|
end = raw.rfind("}")
|
|
if start == -1 or end == -1:
|
|
raise ValueError(f"LLM response does not contain JSON: {raw}")
|
|
return json.loads(raw[start : end + 1])
|
|
|
|
|
|
__all__ = ["call_llm", "call_llm_json"]
|