550 lines
17 KiB
Python
550 lines
17 KiB
Python
"""
|
|
Additional Tool Suggestions for AWS Athena LangGraph Agent
|
|
|
|
This file contains suggested tools that can be added to your LangGraph agent
|
|
to enhance its capabilities with AWS Athena queries and data analysis.
|
|
|
|
Based on: BDAgent.py
|
|
"""
|
|
|
|
import boto3
|
|
import time
|
|
from langchain_core.tools import tool
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
|
|
# ==============================================
|
|
# CONFIGURATION
|
|
# ==============================================
|
|
WORKGROUP = "iceberg-workgroup"
|
|
DATABASE = "dnx_warehouse"
|
|
|
|
# Initialize Athena client (reuse from your main file)
|
|
session = boto3.Session()
|
|
athena = session.client("athena", region_name="us-east-1")
|
|
|
|
|
|
# ==============================================
|
|
# HELPER FUNCTION
|
|
# ==============================================
|
|
def execute_athena_query(query: str, database: str = DATABASE, workgroup: str = WORKGROUP) -> Dict[str, Any]:
|
|
"""
|
|
Helper function to execute Athena queries and wait for results.
|
|
|
|
Args:
|
|
query: SQL query string
|
|
database: Database name
|
|
workgroup: Athena workgroup
|
|
|
|
Returns:
|
|
Dictionary with query results or error information
|
|
"""
|
|
try:
|
|
print(f"Executing Athena query...")
|
|
response = athena.start_query_execution(
|
|
QueryString=query,
|
|
QueryExecutionContext={"Database": database},
|
|
WorkGroup=workgroup
|
|
)
|
|
|
|
query_execution_id = response["QueryExecutionId"]
|
|
print(f"QueryExecutionId: {query_execution_id}")
|
|
|
|
# Wait for query completion
|
|
while True:
|
|
result = athena.get_query_execution(QueryExecutionId=query_execution_id)
|
|
state = result["QueryExecution"]["Status"]["State"]
|
|
|
|
if state in ["SUCCEEDED", "FAILED", "CANCELLED"]:
|
|
print(f"Query state: {state}")
|
|
break
|
|
|
|
print("Waiting for query execution...")
|
|
time.sleep(1)
|
|
|
|
if state == "SUCCEEDED":
|
|
output = athena.get_query_results(QueryExecutionId=query_execution_id)
|
|
return {
|
|
"status": "success",
|
|
"results": output["ResultSet"]["Rows"],
|
|
"query_execution_id": query_execution_id
|
|
}
|
|
else:
|
|
error_message = result["QueryExecution"]["Status"].get("StateChangeReason", "Unknown error")
|
|
return {
|
|
"status": "failed",
|
|
"error": error_message,
|
|
"query_execution_id": query_execution_id
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"status": "error",
|
|
"error": str(e)
|
|
}
|
|
|
|
|
|
# ==============================================
|
|
# SUGGESTED TOOLS
|
|
# ==============================================
|
|
|
|
@tool
|
|
def get_table_schema(table_name: str) -> str:
|
|
"""
|
|
Gets the schema (column names and types) of a specific table.
|
|
|
|
Args:
|
|
table_name: Name of the table to describe
|
|
|
|
Returns:
|
|
Schema information including column names and data types
|
|
"""
|
|
query = f"DESCRIBE {DATABASE}.{table_name};"
|
|
result = execute_athena_query(query)
|
|
|
|
if result["status"] == "success":
|
|
print(f"\n🔧 [TOOL CALLED] get_table_schema for {table_name}")
|
|
return result["results"]
|
|
else:
|
|
return f"Error getting schema: {result.get('error', 'Unknown error')}"
|
|
|
|
|
|
@tool
|
|
def list_available_tables() -> str:
|
|
"""
|
|
Lists all available tables in the database.
|
|
Useful when user wants to know what data is available.
|
|
|
|
Returns:
|
|
List of all tables in the database
|
|
"""
|
|
query = f"SHOW TABLES IN {DATABASE};"
|
|
result = execute_athena_query(query)
|
|
|
|
if result["status"] == "success":
|
|
print(f"\n🔧 [TOOL CALLED] list_available_tables")
|
|
return result["results"]
|
|
else:
|
|
return f"Error listing tables: {result.get('error', 'Unknown error')}"
|
|
|
|
|
|
@tool
|
|
def count_table_rows(table_name: str) -> str:
|
|
"""
|
|
Counts the total number of rows in a specific table.
|
|
|
|
Args:
|
|
table_name: Name of the table to count rows from
|
|
|
|
Returns:
|
|
Total number of rows in the table
|
|
"""
|
|
query = f"SELECT COUNT(*) as total_rows FROM {DATABASE}.{table_name};"
|
|
result = execute_athena_query(query)
|
|
|
|
if result["status"] == "success":
|
|
print(f"\n🔧 [TOOL CALLED] count_table_rows for {table_name}")
|
|
return result["results"]
|
|
else:
|
|
return f"Error counting rows: {result.get('error', 'Unknown error')}"
|
|
|
|
|
|
@tool
|
|
def get_column_statistics(table_name: str, column_name: str) -> str:
|
|
"""
|
|
Gets statistical information about a numeric column (min, max, avg, count).
|
|
|
|
Args:
|
|
table_name: Name of the table
|
|
column_name: Name of the numeric column to analyze
|
|
|
|
Returns:
|
|
Statistical summary of the column
|
|
"""
|
|
query = f"""
|
|
SELECT
|
|
MIN({column_name}) as min_value,
|
|
MAX({column_name}) as max_value,
|
|
AVG({column_name}) as avg_value,
|
|
COUNT({column_name}) as count_values,
|
|
COUNT(DISTINCT {column_name}) as distinct_values
|
|
FROM {DATABASE}.{table_name};
|
|
"""
|
|
result = execute_athena_query(query)
|
|
|
|
if result["status"] == "success":
|
|
print(f"\n🔧 [TOOL CALLED] get_column_statistics for {table_name}.{column_name}")
|
|
return result["results"]
|
|
else:
|
|
return f"Error getting statistics: {result.get('error', 'Unknown error')}"
|
|
|
|
|
|
@tool
|
|
def get_distinct_values(table_name: str, column_name: str, limit: int = 100) -> str:
|
|
"""
|
|
Gets distinct values from a specific column, useful for categorical data analysis.
|
|
|
|
Args:
|
|
table_name: Name of the table
|
|
column_name: Name of the column to get distinct values from
|
|
limit: Maximum number of distinct values to return (default: 100)
|
|
|
|
Returns:
|
|
List of distinct values in the column
|
|
"""
|
|
query = f"""
|
|
SELECT DISTINCT {column_name}, COUNT(*) as frequency
|
|
FROM {DATABASE}.{table_name}
|
|
GROUP BY {column_name}
|
|
ORDER BY frequency DESC
|
|
LIMIT {limit};
|
|
"""
|
|
result = execute_athena_query(query)
|
|
|
|
if result["status"] == "success":
|
|
print(f"\n🔧 [TOOL CALLED] get_distinct_values for {table_name}.{column_name}")
|
|
return result["results"]
|
|
else:
|
|
return f"Error getting distinct values: {result.get('error', 'Unknown error')}"
|
|
|
|
|
|
@tool
|
|
def filter_answers_by_condition(table_name: str, column_name: str, condition: str) -> str:
|
|
"""
|
|
Filters answers based on a specific condition.
|
|
|
|
Args:
|
|
table_name: Name of the table (e.g., 'survey_name_respostas')
|
|
column_name: Column to apply the filter on
|
|
condition: SQL condition (e.g., "> 5", "= 'Yes'", "IS NOT NULL")
|
|
|
|
Returns:
|
|
Filtered results matching the condition
|
|
"""
|
|
query = f"""
|
|
SELECT *
|
|
FROM {DATABASE}.{table_name}
|
|
WHERE {column_name} {condition}
|
|
LIMIT 1000;
|
|
"""
|
|
result = execute_athena_query(query)
|
|
|
|
if result["status"] == "success":
|
|
print(f"\n🔧 [TOOL CALLED] filter_answers_by_condition")
|
|
return result["results"]
|
|
else:
|
|
return f"Error filtering data: {result.get('error', 'Unknown error')}"
|
|
|
|
|
|
@tool
|
|
def aggregate_survey_responses(survey_name: str, column_name: str, aggregation: str = "COUNT") -> str:
|
|
"""
|
|
Performs aggregation on survey responses (COUNT, SUM, AVG, etc.).
|
|
|
|
Args:
|
|
survey_name: Name of the survey (will query survey_name_respostas table)
|
|
column_name: Column to aggregate
|
|
aggregation: Type of aggregation (COUNT, SUM, AVG, MIN, MAX)
|
|
|
|
Returns:
|
|
Aggregated result
|
|
"""
|
|
table_name = f"{survey_name}_respostas"
|
|
query = f"""
|
|
SELECT {aggregation}({column_name}) as result
|
|
FROM {DATABASE}.{table_name};
|
|
"""
|
|
result = execute_athena_query(query)
|
|
|
|
if result["status"] == "success":
|
|
print(f"\n🔧 [TOOL CALLED] aggregate_survey_responses - {aggregation} on {column_name}")
|
|
return result["results"]
|
|
else:
|
|
return f"Error performing aggregation: {result.get('error', 'Unknown error')}"
|
|
|
|
|
|
@tool
|
|
def group_by_analysis(table_name: str, group_column: str, agg_column: str, agg_function: str = "COUNT") -> str:
|
|
"""
|
|
Performs GROUP BY analysis to understand distribution of responses.
|
|
|
|
Args:
|
|
table_name: Name of the table
|
|
group_column: Column to group by
|
|
agg_column: Column to aggregate (use '*' for COUNT(*))
|
|
agg_function: Aggregation function (COUNT, SUM, AVG, MIN, MAX)
|
|
|
|
Returns:
|
|
Grouped analysis results
|
|
"""
|
|
if agg_column == '*':
|
|
agg_expr = f"{agg_function}(*)"
|
|
else:
|
|
agg_expr = f"{agg_function}({agg_column})"
|
|
|
|
query = f"""
|
|
SELECT
|
|
{group_column},
|
|
{agg_expr} as aggregated_value
|
|
FROM {DATABASE}.{table_name}
|
|
GROUP BY {group_column}
|
|
ORDER BY aggregated_value DESC
|
|
LIMIT 100;
|
|
"""
|
|
result = execute_athena_query(query)
|
|
|
|
if result["status"] == "success":
|
|
print(f"\n🔧 [TOOL CALLED] group_by_analysis")
|
|
return result["results"]
|
|
else:
|
|
return f"Error performing group by analysis: {result.get('error', 'Unknown error')}"
|
|
|
|
|
|
@tool
|
|
def calculate_percentage_distribution(table_name: str, column_name: str) -> str:
|
|
"""
|
|
Calculates percentage distribution of values in a column.
|
|
Useful for understanding response distributions.
|
|
|
|
Args:
|
|
table_name: Name of the table
|
|
column_name: Column to analyze
|
|
|
|
Returns:
|
|
Percentage distribution of values
|
|
"""
|
|
query = f"""
|
|
SELECT
|
|
{column_name},
|
|
COUNT(*) as count,
|
|
ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2) as percentage
|
|
FROM {DATABASE}.{table_name}
|
|
GROUP BY {column_name}
|
|
ORDER BY count DESC;
|
|
"""
|
|
result = execute_athena_query(query)
|
|
|
|
if result["status"] == "success":
|
|
print(f"\n🔧 [TOOL CALLED] calculate_percentage_distribution for {column_name}")
|
|
return result["results"]
|
|
else:
|
|
return f"Error calculating distribution: {result.get('error', 'Unknown error')}"
|
|
|
|
|
|
@tool
|
|
def compare_columns(table_name: str, column1: str, column2: str) -> str:
|
|
"""
|
|
Compares two columns to find correlations or patterns in survey responses.
|
|
|
|
Args:
|
|
table_name: Name of the table
|
|
column1: First column to compare
|
|
column2: Second column to compare
|
|
|
|
Returns:
|
|
Cross-tabulation of the two columns
|
|
"""
|
|
query = f"""
|
|
SELECT
|
|
{column1},
|
|
{column2},
|
|
COUNT(*) as frequency
|
|
FROM {DATABASE}.{table_name}
|
|
GROUP BY {column1}, {column2}
|
|
ORDER BY frequency DESC
|
|
LIMIT 100;
|
|
"""
|
|
result = execute_athena_query(query)
|
|
|
|
if result["status"] == "success":
|
|
print(f"\n🔧 [TOOL CALLED] compare_columns: {column1} vs {column2}")
|
|
return result["results"]
|
|
else:
|
|
return f"Error comparing columns: {result.get('error', 'Unknown error')}"
|
|
|
|
|
|
@tool
|
|
def search_text_in_responses(table_name: str, column_name: str, search_term: str) -> str:
|
|
"""
|
|
Searches for specific text in text-based response columns.
|
|
Useful for open-ended survey questions.
|
|
|
|
Args:
|
|
table_name: Name of the table
|
|
column_name: Column to search in
|
|
search_term: Text to search for
|
|
|
|
Returns:
|
|
Matching responses
|
|
"""
|
|
query = f"""
|
|
SELECT {column_name}, COUNT(*) as occurrences
|
|
FROM {DATABASE}.{table_name}
|
|
WHERE LOWER({column_name}) LIKE LOWER('%{search_term}%')
|
|
GROUP BY {column_name}
|
|
LIMIT 100;
|
|
"""
|
|
result = execute_athena_query(query)
|
|
|
|
if result["status"] == "success":
|
|
print(f"\n🔧 [TOOL CALLED] search_text_in_responses for '{search_term}'")
|
|
return result["results"]
|
|
else:
|
|
return f"Error searching text: {result.get('error', 'Unknown error')}"
|
|
|
|
|
|
@tool
|
|
def get_null_count(table_name: str, column_name: str) -> str:
|
|
"""
|
|
Counts NULL or missing values in a specific column.
|
|
Useful for data quality analysis.
|
|
|
|
Args:
|
|
table_name: Name of the table
|
|
column_name: Column to check for NULL values
|
|
|
|
Returns:
|
|
Count of NULL values and percentage of total
|
|
"""
|
|
query = f"""
|
|
SELECT
|
|
COUNT(*) as total_rows,
|
|
COUNT({column_name}) as non_null_count,
|
|
COUNT(*) - COUNT({column_name}) as null_count,
|
|
ROUND((COUNT(*) - COUNT({column_name})) * 100.0 / COUNT(*), 2) as null_percentage
|
|
FROM {DATABASE}.{table_name};
|
|
"""
|
|
result = execute_athena_query(query)
|
|
|
|
if result["status"] == "success":
|
|
print(f"\n🔧 [TOOL CALLED] get_null_count for {table_name}.{column_name}")
|
|
return result["results"]
|
|
else:
|
|
return f"Error checking null values: {result.get('error', 'Unknown error')}"
|
|
|
|
|
|
@tool
|
|
def get_date_range_data(table_name: str, date_column: str, start_date: str, end_date: str) -> str:
|
|
"""
|
|
Retrieves data within a specific date range.
|
|
|
|
Args:
|
|
table_name: Name of the table
|
|
date_column: Name of the date column
|
|
start_date: Start date in format 'YYYY-MM-DD'
|
|
end_date: End date in format 'YYYY-MM-DD'
|
|
|
|
Returns:
|
|
Data within the specified date range
|
|
"""
|
|
query = f"""
|
|
SELECT *
|
|
FROM {DATABASE}.{table_name}
|
|
WHERE {date_column} BETWEEN DATE '{start_date}' AND DATE '{end_date}'
|
|
LIMIT 1000;
|
|
"""
|
|
result = execute_athena_query(query)
|
|
|
|
if result["status"] == "success":
|
|
print(f"\n🔧 [TOOL CALLED] get_date_range_data from {start_date} to {end_date}")
|
|
return result["results"]
|
|
else:
|
|
return f"Error getting date range data: {result.get('error', 'Unknown error')}"
|
|
|
|
|
|
@tool
|
|
def execute_custom_query(sql_query: str) -> str:
|
|
"""
|
|
Executes a custom SQL query. USE WITH CAUTION.
|
|
Only use this when other specific tools don't meet the requirement.
|
|
|
|
Args:
|
|
sql_query: Complete SQL query to execute
|
|
|
|
Returns:
|
|
Query results
|
|
"""
|
|
# Add basic validation
|
|
if any(keyword in sql_query.upper() for keyword in ["DROP", "DELETE", "INSERT", "UPDATE", "ALTER", "CREATE"]):
|
|
return "Error: Write operations are not allowed. Only SELECT queries are permitted."
|
|
|
|
result = execute_athena_query(sql_query)
|
|
|
|
if result["status"] == "success":
|
|
print(f"\n🔧 [TOOL CALLED] execute_custom_query")
|
|
return result["results"]
|
|
else:
|
|
return f"Error executing query: {result.get('error', 'Unknown error')}"
|
|
|
|
|
|
# ==============================================
|
|
# TOOL MAPPING FOR EASY INTEGRATION
|
|
# ==============================================
|
|
|
|
SUGGESTED_TOOLS = [
|
|
get_table_schema,
|
|
list_available_tables,
|
|
count_table_rows,
|
|
get_column_statistics,
|
|
get_distinct_values,
|
|
filter_answers_by_condition,
|
|
aggregate_survey_responses,
|
|
group_by_analysis,
|
|
calculate_percentage_distribution,
|
|
compare_columns,
|
|
search_text_in_responses,
|
|
get_null_count,
|
|
get_date_range_data,
|
|
execute_custom_query
|
|
]
|
|
|
|
TOOLS_MAP = {
|
|
"get_table_schema": get_table_schema,
|
|
"list_available_tables": list_available_tables,
|
|
"count_table_rows": count_table_rows,
|
|
"get_column_statistics": get_column_statistics,
|
|
"get_distinct_values": get_distinct_values,
|
|
"filter_answers_by_condition": filter_answers_by_condition,
|
|
"aggregate_survey_responses": aggregate_survey_responses,
|
|
"group_by_analysis": group_by_analysis,
|
|
"calculate_percentage_distribution": calculate_percentage_distribution,
|
|
"compare_columns": compare_columns,
|
|
"search_text_in_responses": search_text_in_responses,
|
|
"get_null_count": get_null_count,
|
|
"get_date_range_data": get_date_range_data,
|
|
"execute_custom_query": execute_custom_query
|
|
}
|
|
|
|
|
|
# ==============================================
|
|
# INTEGRATION EXAMPLE
|
|
# ==============================================
|
|
"""
|
|
To integrate these tools into your BDAgent.py:
|
|
|
|
1. Import the tools:
|
|
from app.utils.AthenaToolsSuggestions import SUGGESTED_TOOLS, TOOLS_MAP
|
|
|
|
2. Update the create_bedrock_llm function:
|
|
tools = [consult_answers] + SUGGESTED_TOOLS
|
|
llm_with_tools = llm.bind_tools(tools)
|
|
|
|
3. Update the call_tools function:
|
|
tools_map = {
|
|
"consult_answers": consult_answers,
|
|
**TOOLS_MAP # Add all suggested tools
|
|
}
|
|
|
|
4. Update the system prompt to include descriptions of new tools:
|
|
Your available tools:
|
|
- consult_answers: Get answers for a specific question by shortname
|
|
- get_table_schema: Get column names and types of a table
|
|
- list_available_tables: List all available tables
|
|
- count_table_rows: Count total rows in a table
|
|
- get_column_statistics: Get statistics (min, max, avg) for numeric columns
|
|
- get_distinct_values: Get unique values and frequencies from a column
|
|
- calculate_percentage_distribution: Calculate percentage distribution of values
|
|
- group_by_analysis: Perform GROUP BY analysis on data
|
|
- compare_columns: Compare two columns to find patterns
|
|
- And more...
|
|
"""
|