Files
AI-upflux-docprocessor/code/utils/langgraph_agent.py

207 lines
7.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Simple LangGraph agent using AWS Bedrock.
This agent demonstrates a basic ReAct-style agent with tool calling capabilities.
"""
import boto3
import json
import asyncio
import yaml
from pathlib import Path
from typing import Annotated, TypedDict, Literal
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langchain_aws import ChatBedrock
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage
from langchain_core.tools import tool
from langfuse import Langfuse
from langfuse.langchain import CallbackHandler
from utils.secrets_manager import SECRETS
from utils.config import AWS_REGION, BEDROCK_MODEL_ARN, LANGFUSE_HOST
langfuse = Langfuse(
secret_key=SECRETS["LANGFUSE-SECRET-KEY"],
public_key=SECRETS["LANGFUSE-PUBLIC-KEY"],
host=LANGFUSE_HOST,
)
_RULES_PATH = Path(__file__).parent / "rules.yaml"
with open(_RULES_PATH, encoding="utf-8") as _f:
_data = yaml.safe_load(_f)
RULES: dict[str, str] = _data["rules"]
MIN_DOC: dict[str, str] = _data["min_doc"]
# Define the agent state
class AgentState(TypedDict):
messages: Annotated[list, add_messages]
def get_bedrock_client():
"""Initialize and return AWS Bedrock runtime client."""
return boto3.client("bedrock-runtime", region_name=AWS_REGION)
def create_llm():
"""Create and return the Bedrock LLM."""
return ChatBedrock(
model_id=BEDROCK_MODEL_ARN,
region_name=AWS_REGION,
provider="anthropic"
)
def create_agent(file_content: str = ""):
"""Create and return the LangGraph agent."""
# Define check tool as closure to capture file_content
@tool
def check(expression: str) -> str:
"""Retrieves the values of the files associated with the input for additional information, if the json is not enough"""
return file_content
# Initialize the LLM with tools
llm = create_llm()
tools = [check]
llm_with_tools = llm.bind_tools(tools)
# Create tool lookup
tool_map = {tool.name: tool for tool in tools}
# Define the agent node
def call_model(state: AgentState) -> dict:
"""Call the LLM with the current state."""
messages = state["messages"]
response = llm_with_tools.invoke(messages)
return {"messages": [response]}
# Define the tool execution node
def call_tools(state: AgentState) -> dict:
"""Execute tools based on the last message."""
last_message = state["messages"][-1]
tool_messages = []
for tool_call in last_message.tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
if tool_name in tool_map:
result = tool_map[tool_name].invoke(tool_args)
tool_messages.append(
ToolMessage(content=str(result), tool_call_id=tool_call["id"])
)
else:
tool_messages.append(
ToolMessage(
content=f"Tool {tool_name} not found",
tool_call_id=tool_call["id"],
)
)
return {"messages": tool_messages}
# Define the routing function
def should_continue(state: AgentState) -> Literal["tools", "end"]:
"""Determine whether to continue with tools or end."""
last_message = state["messages"][-1]
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
return "tools"
return "end"
# Build the graph
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("agent", call_model)
workflow.add_node("tools", call_tools)
# Set entry point
workflow.set_entry_point("agent")
# Add conditional edges
workflow.add_conditional_edges(
"agent",
should_continue,
{
"tools": "tools",
"end": END,
},
)
# Add edge from tools back to agent
workflow.add_edge("tools", "agent")
# Compile the graph
return workflow.compile()
async def run_agent(query: str, code: str, file_content: str = "") -> str:
"""
Run the agent with a given query.
Args:
query: The user's question or request
code: The service code to look up rules
file_content: OCR text content for the check tool
Returns:
The agent's final response
"""
agent = create_agent(file_content)
SYSTEM_PROMPT = """You are a AI assistant responsible to check if a person is Allowed or Denied acces to medical procedure based on the inpout data. There are a few always accepted criteira in which, any of them been met, even a single one, will be accepted, these criteria been:
<auto-accept-criteria>
"""+RULES[code]+""""
<\auto-accept-criteria>
If those criteria aren´t met, you can check the documents to see if the following information are present, if so, aprove the procedure:
<additional-information>"""+MIN_DOC[code]+"""
<\additional-information>
If the additional information is not present, but any of the auto-accept-criteira are met, allow the procedure.
If there aren´t any auto-accept criteria present, check the documents for the additional information, and if they are all present, even if not in the exact type of document especified in them, allow the procedure.
Your capabilities:
- You can check the OCR of all the documents anexed, at the same time, if the json input is not enough to determinate if it should be aproved, using the check tool.
For every document, check if the name of the person in json is present, and at output list every document and if it belongs to the person in the request.
You MUST respond ONLY with a valid JSON object — no markdown, no text outside the JSON.
Response format:
{
"status": "Aprovado" or "Reprovado",
"criterio": "brief description of the met criterion or reason for denial",
"documentos": [
{"nome": "document name", "pertence_ao_paciente": true or false}
]
}"""
user_message = query
if file_content:
user_message += "\n\n<documentos_anexados>\n" + file_content + "\n</documentos_anexados>"
initial_state = {
"messages": [
SystemMessage(content=SYSTEM_PROMPT),
HumanMessage(content=user_message)
]
}
# Run the agent
langfuse_handler = CallbackHandler()
config = {"callbacks": [langfuse_handler]}
final_state = await agent.ainvoke(initial_state, config=config)
# Get the final response
final_message = final_state["messages"][-1]
response = final_message.content if hasattr(final_message, "content") else str(final_message)
# Count tokens from all AI messages
input_tokens = 0
output_tokens = 0
for msg in final_state["messages"]:
usage = getattr(msg, "usage_metadata", None)
if usage:
input_tokens += usage.get("input_tokens", 0)
output_tokens += usage.get("output_tokens", 0)
await asyncio.to_thread(langfuse.flush)
return {"response": response, "input_tokens": input_tokens, "output_tokens": output_tokens}