323 lines
8.6 KiB
Python
323 lines
8.6 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Simple LangGraph Agent with Langfuse Integration
|
|
Demonstrates basic agent functionality with Langfuse observability
|
|
"""
|
|
|
|
import json
|
|
import boto3
|
|
from botocore.exceptions import ClientError
|
|
from typing import TypedDict, Annotated
|
|
import operator
|
|
|
|
from langgraph.graph import StateGraph, END
|
|
from langchain_aws import ChatBedrock
|
|
from langchain_core.tools import tool
|
|
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage
|
|
|
|
from langfuse import Langfuse
|
|
from langfuse.langchain import CallbackHandler
|
|
|
|
|
|
# ==============================================
|
|
# SECRETS MANAGER
|
|
# ==============================================
|
|
|
|
def get_secret():
|
|
"""Fetch secrets from AWS Secrets Manager"""
|
|
secret_name = "assistente-db-secrets-manager"
|
|
region_name = "us-east-1"
|
|
|
|
session = boto3.session.Session()
|
|
client = session.client(
|
|
service_name='secretsmanager',
|
|
region_name=region_name
|
|
)
|
|
|
|
try:
|
|
get_secret_value_response = client.get_secret_value(
|
|
SecretId=secret_name
|
|
)
|
|
except ClientError as e:
|
|
raise e
|
|
|
|
secret = get_secret_value_response['SecretString']
|
|
return json.loads(secret)
|
|
|
|
|
|
# ==============================================
|
|
# INITIALIZE LANGFUSE
|
|
# ==============================================
|
|
|
|
print("Initializing Langfuse...")
|
|
secrets = get_secret()
|
|
|
|
langfuse = Langfuse(
|
|
public_key=secrets['LANGFUSE-PUBLIC-KEY'],
|
|
secret_key=secrets['LANGFUSE-SECRET-KEY'],
|
|
host="http://98.92.98.83:3000"
|
|
)
|
|
|
|
print(f"✓ Langfuse initialized successfully")
|
|
print(f" Host: http://98.92.98.83:3000")
|
|
|
|
|
|
# ==============================================
|
|
# DEFINE TOOLS
|
|
# ==============================================
|
|
|
|
@tool
|
|
def add_numbers(a: int, b: int) -> int:
|
|
"""
|
|
Add two numbers together.
|
|
|
|
Args:
|
|
a: First number
|
|
b: Second number
|
|
|
|
Returns:
|
|
The sum of a and b
|
|
"""
|
|
print(f"🔧 [TOOL] Adding {a} + {b}")
|
|
return a + b
|
|
|
|
|
|
@tool
|
|
def multiply_numbers(a: int, b: int) -> int:
|
|
"""
|
|
Multiply two numbers together.
|
|
|
|
Args:
|
|
a: First number
|
|
b: Second number
|
|
|
|
Returns:
|
|
The product of a and b
|
|
"""
|
|
print(f"🔧 [TOOL] Multiplying {a} * {b}")
|
|
return a * b
|
|
|
|
|
|
# ==============================================
|
|
# AGENT STATE
|
|
# ==============================================
|
|
|
|
class AgentState(TypedDict):
|
|
messages: Annotated[list, operator.add]
|
|
current_step: str
|
|
|
|
|
|
# ==============================================
|
|
# AGENT NODES
|
|
# ==============================================
|
|
|
|
def create_bedrock_llm(inference_profile_arn: str, region: str = "us-east-1"):
|
|
"""Create a ChatBedrock instance with tools"""
|
|
llm = ChatBedrock(
|
|
model_id=inference_profile_arn,
|
|
region_name=region,
|
|
model_kwargs={
|
|
"max_tokens": 2048,
|
|
"temperature": 0.7,
|
|
},
|
|
provider="anthropic"
|
|
)
|
|
|
|
# Bind tools to the LLM
|
|
tools = [add_numbers, multiply_numbers]
|
|
llm_with_tools = llm.bind_tools(tools)
|
|
|
|
return llm_with_tools
|
|
|
|
|
|
def call_model(state: AgentState, llm) -> AgentState:
|
|
"""Call the LLM with Langfuse callback"""
|
|
print(f"[MODEL] Calling Bedrock with Langfuse tracing...")
|
|
|
|
messages = state["messages"]
|
|
|
|
# Create Langfuse callback handler
|
|
langfuse_handler = CallbackHandler()
|
|
config = {
|
|
"configurable": {"thread_id": "simple_agent_demo"},
|
|
"callbacks": [langfuse_handler]
|
|
}
|
|
|
|
response = llm.invoke(messages, config=config)
|
|
state["current_step"] = "model_called"
|
|
return {"messages": [response]}
|
|
|
|
|
|
def call_tools(state: AgentState) -> AgentState:
|
|
"""Execute any tool calls from the LLM response"""
|
|
print(f"[TOOLS] Checking for tool calls...")
|
|
|
|
messages = state["messages"]
|
|
last_message = messages[-1]
|
|
|
|
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
|
|
print(f"[TOOLS] Found {len(last_message.tool_calls)} tool call(s)")
|
|
|
|
tool_messages = []
|
|
tools_map = {
|
|
"add_numbers": add_numbers,
|
|
"multiply_numbers": multiply_numbers
|
|
}
|
|
|
|
for tool_call in last_message.tool_calls:
|
|
tool_name = tool_call["name"]
|
|
tool_args = tool_call["args"]
|
|
|
|
print(f"[TOOLS] Executing: {tool_name}({tool_args})")
|
|
|
|
tool_func = tools_map[tool_name]
|
|
result = tool_func.invoke(tool_args)
|
|
|
|
tool_message = ToolMessage(
|
|
content=str(result),
|
|
tool_call_id=tool_call["id"]
|
|
)
|
|
tool_messages.append(tool_message)
|
|
|
|
state["current_step"] = "tools_executed"
|
|
return {"messages": tool_messages}
|
|
else:
|
|
print(f"[TOOLS] No tool calls found")
|
|
state["current_step"] = "no_tools"
|
|
return {"messages": []}
|
|
|
|
|
|
def should_continue(state: AgentState) -> str:
|
|
"""Determine if we should continue to tools or end"""
|
|
messages = state["messages"]
|
|
last_message = messages[-1]
|
|
|
|
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
|
|
print("[ROUTER] Routing to tools...")
|
|
return "tools"
|
|
|
|
print("[ROUTER] No more tool calls, ending...")
|
|
return "end"
|
|
|
|
|
|
# ==============================================
|
|
# CREATE AGENT
|
|
# ==============================================
|
|
|
|
def create_agent(inference_profile_arn: str, region: str = "us-east-1"):
|
|
"""Create a LangGraph agent with Langfuse observability"""
|
|
|
|
llm = create_bedrock_llm(inference_profile_arn, region)
|
|
|
|
workflow = StateGraph(AgentState)
|
|
|
|
# Add nodes
|
|
workflow.add_node("model", lambda state: call_model(state, llm))
|
|
workflow.add_node("tools", call_tools)
|
|
|
|
# Define workflow
|
|
workflow.set_entry_point("model")
|
|
|
|
workflow.add_conditional_edges(
|
|
"model",
|
|
should_continue,
|
|
{
|
|
"tools": "tools",
|
|
"end": END
|
|
}
|
|
)
|
|
|
|
workflow.add_edge("tools", "model")
|
|
|
|
app = workflow.compile()
|
|
|
|
return app
|
|
|
|
|
|
# ==============================================
|
|
# MAIN EXECUTION
|
|
# ==============================================
|
|
|
|
def main():
|
|
"""Main execution function"""
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Simple LangGraph Agent with Langfuse")
|
|
print("=" * 60)
|
|
|
|
# Configuration
|
|
INFERENCE_PROFILE_ARN = "arn:aws:bedrock:us-east-1:305427701314:application-inference-profile/b3umwd5jpd0u"
|
|
REGION = "us-east-1"
|
|
|
|
# System prompt
|
|
SYSTEM_PROMPT = """You are a helpful math assistant with access to calculation tools.
|
|
|
|
Your available tools:
|
|
- add_numbers: Add two numbers
|
|
- multiply_numbers: Multiply two numbers
|
|
|
|
When a user asks you to perform calculations:
|
|
1. Break down the problem into steps
|
|
2. Use the appropriate tools
|
|
3. Show your reasoning clearly
|
|
4. Provide the final answer
|
|
|
|
Always use the tools rather than calculating in your head."""
|
|
|
|
print(f"\nUsing inference profile: {INFERENCE_PROFILE_ARN}")
|
|
print(f"Region: {REGION}")
|
|
print("=" * 60)
|
|
|
|
# Create the agent
|
|
agent = create_agent(INFERENCE_PROFILE_ARN, REGION)
|
|
|
|
# Example query
|
|
user_query = "What is (10 + 5) * 3?"
|
|
|
|
# Initialize state
|
|
initial_state = {
|
|
"messages": [
|
|
SystemMessage(content=SYSTEM_PROMPT),
|
|
HumanMessage(content=user_query)
|
|
],
|
|
"current_step": "init"
|
|
}
|
|
|
|
print(f"\nUser Query: {user_query}\n")
|
|
print("-" * 60)
|
|
|
|
# Run the agent
|
|
final_state = agent.invoke(initial_state)
|
|
|
|
# Display results
|
|
print("-" * 60)
|
|
print("\n[FINAL RESULT]\n")
|
|
|
|
for i, msg in enumerate(final_state["messages"], 1):
|
|
if isinstance(msg, SystemMessage):
|
|
print(f"{i}. System: [System prompt configured]")
|
|
elif isinstance(msg, HumanMessage):
|
|
print(f"{i}. User: {msg.content}")
|
|
elif isinstance(msg, AIMessage):
|
|
if hasattr(msg, 'tool_calls') and msg.tool_calls:
|
|
print(f"{i}. AI: [Calling tools...]")
|
|
else:
|
|
print(f"{i}. AI: {msg.content}")
|
|
elif isinstance(msg, ToolMessage):
|
|
print(f"{i}. Tool Result: {msg.content}")
|
|
|
|
print("\n" + "=" * 60)
|
|
print(f"Agent completed successfully!")
|
|
|
|
# Flush Langfuse data
|
|
print("\nFlushing data to Langfuse...")
|
|
langfuse.flush()
|
|
print("✓ Data sent to Langfuse")
|
|
|
|
print("\nView your traces at: http://98.92.98.83:3000")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|