Feat: Adds base project
This commit is contained in:
120
back/app/backend/agent_bedrock.py
Normal file
120
back/app/backend/agent_bedrock.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import operator
|
||||
from typing import TypedDict, Annotated
|
||||
|
||||
from langchain_aws import ChatBedrockConverse
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.graph import StateGraph, END
|
||||
|
||||
from .config import REGION, AWS_ACCOUNT
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
messages: Annotated[list, operator.add]
|
||||
current_step: str
|
||||
|
||||
|
||||
def create_bedrock_llm(model_id: str, region: str = REGION, tools: list = None):
|
||||
"""
|
||||
Create a ChatBedrock instance using a model ID.
|
||||
|
||||
Args:
|
||||
model_id: Bedrock model ID (e.g., anthropic.claude-haiku-4-5-20251001-v1:0)
|
||||
region: AWS region (default: REGION env var)
|
||||
tools: List of LangChain tools to bind to the model
|
||||
|
||||
Returns:
|
||||
ChatBedrock instance configured with the model
|
||||
"""
|
||||
MODEL_ARNS = {
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0": f"arn:aws:bedrock:{REGION}:{AWS_ACCOUNT}:inference-profile/us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-sonnet-4-5-20250929-v1:0": f"arn:aws:bedrock:{REGION}:{AWS_ACCOUNT}:inference-profile/global.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"meta.llama4-maverick-17b-instruct-v1:0": f"arn:aws:bedrock:{REGION}:{AWS_ACCOUNT}:inference-profile/us.meta.llama4-maverick-17b-instruct-v1:0",
|
||||
"meta.llama4-scout-17b-instruct-v1:0": f"arn:aws:bedrock:{REGION}:{AWS_ACCOUNT}:inference-profile/us.meta.llama4-scout-17b-instruct-v1:0",
|
||||
"amazon.nova-lite-v1:0": f"arn:aws:bedrock:{REGION}:{AWS_ACCOUNT}:inference-profile/us.amazon.nova-lite-v1:0",
|
||||
"amazon.nova-pro-v1:0": f"arn:aws:bedrock:{REGION}:{AWS_ACCOUNT}:inference-profile/us.amazon.nova-pro-v1:0",
|
||||
"amazon.nova-2-lite-v1:0": f"arn:aws:bedrock:{REGION}:{AWS_ACCOUNT}:inference-profile/global.amazon.nova-2-lite-v1:0",
|
||||
}
|
||||
PROVIDER = {
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0": "anthropic",
|
||||
"anthropic.claude-sonnet-4-5-20250929-v1:0": "anthropic",
|
||||
"meta.llama4-maverick-17b-instruct-v1:0": "meta",
|
||||
"meta.llama4-scout-17b-instruct-v1:0": "meta",
|
||||
"amazon.nova-lite-v1:0": "amazon",
|
||||
"amazon.nova-pro-v1:0": "amazon",
|
||||
"amazon.nova-2-lite-v1:0": "amazon",
|
||||
}
|
||||
prefix = {
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0": "us",
|
||||
"anthropic.claude-sonnet-4-5-20250929-v1:0": "global",
|
||||
"meta.llama4-maverick-17b-instruct-v1:0": "us",
|
||||
"meta.llama4-scout-17b-instruct-v1:0": "us",
|
||||
"amazon.nova-lite-v1:0": "us",
|
||||
"amazon.nova-pro-v1:0": "us",
|
||||
"amazon.nova-2-lite-v1:0": "global",
|
||||
}
|
||||
llm = ChatBedrockConverse(
|
||||
model_id=prefix[model_id] + "." + model_id,
|
||||
region_name=region,
|
||||
provider=PROVIDER[model_id],
|
||||
max_tokens=2048,
|
||||
temperature=0.7,
|
||||
)
|
||||
return llm.bind_tools(tools or [])
|
||||
|
||||
|
||||
def call_model(state: AgentState, llm) -> AgentState:
|
||||
"""Call the LLM with tools."""
|
||||
response = llm.invoke(state["messages"])
|
||||
state["current_step"] = "model_called"
|
||||
return {"messages": [response]}
|
||||
|
||||
|
||||
def call_tools(state: AgentState, tools_map: dict) -> AgentState:
|
||||
"""Execute any tool calls from the LLM response."""
|
||||
last_message = state["messages"][-1]
|
||||
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
tool_messages = []
|
||||
for tool_call in last_message.tool_calls:
|
||||
result = tools_map[tool_call["name"]].invoke(tool_call["args"])
|
||||
tool_messages.append(ToolMessage(content=str(result), tool_call_id=tool_call["id"]))
|
||||
|
||||
state["current_step"] = "tools_executed"
|
||||
return {"messages": tool_messages}
|
||||
else:
|
||||
state["current_step"] = "no_tools"
|
||||
return {"messages": []}
|
||||
|
||||
|
||||
def should_continue(state: AgentState) -> str:
|
||||
"""Determine if we should continue to tools or end."""
|
||||
last_message = state["messages"][-1]
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
return "tools"
|
||||
return "end"
|
||||
|
||||
|
||||
def create_agent(inference_profile_arn: str, region: str = REGION, tools: list = None):
|
||||
"""
|
||||
Create a LangGraph agent that uses Bedrock inference profile with tools.
|
||||
|
||||
Args:
|
||||
inference_profile_arn: ARN of the cross-region inference profile
|
||||
region: AWS region
|
||||
tools: List of LangChain tools to bind to the model
|
||||
|
||||
Returns:
|
||||
Compiled LangGraph workflow
|
||||
"""
|
||||
tools = tools or []
|
||||
llm = create_bedrock_llm(inference_profile_arn, region, tools)
|
||||
tools_map = {t.name: t for t in tools}
|
||||
|
||||
workflow = StateGraph(AgentState)
|
||||
workflow.add_node("model", lambda state: call_model(state, llm))
|
||||
workflow.add_node("tools", lambda state: call_tools(state, tools_map))
|
||||
workflow.set_entry_point("model")
|
||||
workflow.add_conditional_edges("model", should_continue, {"tools": "tools", "end": END})
|
||||
workflow.add_edge("tools", "model")
|
||||
|
||||
return workflow.compile()
|
||||
Reference in New Issue
Block a user