Initial commit

This commit is contained in:
2026-05-14 15:29:03 -03:00
parent 82ac556ecc
commit 54bcf081f6
31 changed files with 3132 additions and 518 deletions

View File

@@ -0,0 +1,549 @@
"""
Additional Tool Suggestions for AWS Athena LangGraph Agent
This file contains suggested tools that can be added to your LangGraph agent
to enhance its capabilities with AWS Athena queries and data analysis.
Based on: BDAgent.py
"""
import boto3
import time
from langchain_core.tools import tool
from typing import List, Dict, Any, Optional
# ==============================================
# CONFIGURATION
# ==============================================
WORKGROUP = "iceberg-workgroup"
DATABASE = "dnx_warehouse"
# Initialize Athena client (reuse from your main file)
session = boto3.Session()
athena = session.client("athena", region_name="us-east-1")
# ==============================================
# HELPER FUNCTION
# ==============================================
def execute_athena_query(query: str, database: str = DATABASE, workgroup: str = WORKGROUP) -> Dict[str, Any]:
"""
Helper function to execute Athena queries and wait for results.
Args:
query: SQL query string
database: Database name
workgroup: Athena workgroup
Returns:
Dictionary with query results or error information
"""
try:
print(f"Executing Athena query...")
response = athena.start_query_execution(
QueryString=query,
QueryExecutionContext={"Database": database},
WorkGroup=workgroup
)
query_execution_id = response["QueryExecutionId"]
print(f"QueryExecutionId: {query_execution_id}")
# Wait for query completion
while True:
result = athena.get_query_execution(QueryExecutionId=query_execution_id)
state = result["QueryExecution"]["Status"]["State"]
if state in ["SUCCEEDED", "FAILED", "CANCELLED"]:
print(f"Query state: {state}")
break
print("Waiting for query execution...")
time.sleep(1)
if state == "SUCCEEDED":
output = athena.get_query_results(QueryExecutionId=query_execution_id)
return {
"status": "success",
"results": output["ResultSet"]["Rows"],
"query_execution_id": query_execution_id
}
else:
error_message = result["QueryExecution"]["Status"].get("StateChangeReason", "Unknown error")
return {
"status": "failed",
"error": error_message,
"query_execution_id": query_execution_id
}
except Exception as e:
return {
"status": "error",
"error": str(e)
}
# ==============================================
# SUGGESTED TOOLS
# ==============================================
@tool
def get_table_schema(table_name: str) -> str:
"""
Gets the schema (column names and types) of a specific table.
Args:
table_name: Name of the table to describe
Returns:
Schema information including column names and data types
"""
query = f"DESCRIBE {DATABASE}.{table_name};"
result = execute_athena_query(query)
if result["status"] == "success":
print(f"\n🔧 [TOOL CALLED] get_table_schema for {table_name}")
return result["results"]
else:
return f"Error getting schema: {result.get('error', 'Unknown error')}"
@tool
def list_available_tables() -> str:
"""
Lists all available tables in the database.
Useful when user wants to know what data is available.
Returns:
List of all tables in the database
"""
query = f"SHOW TABLES IN {DATABASE};"
result = execute_athena_query(query)
if result["status"] == "success":
print(f"\n🔧 [TOOL CALLED] list_available_tables")
return result["results"]
else:
return f"Error listing tables: {result.get('error', 'Unknown error')}"
@tool
def count_table_rows(table_name: str) -> str:
"""
Counts the total number of rows in a specific table.
Args:
table_name: Name of the table to count rows from
Returns:
Total number of rows in the table
"""
query = f"SELECT COUNT(*) as total_rows FROM {DATABASE}.{table_name};"
result = execute_athena_query(query)
if result["status"] == "success":
print(f"\n🔧 [TOOL CALLED] count_table_rows for {table_name}")
return result["results"]
else:
return f"Error counting rows: {result.get('error', 'Unknown error')}"
@tool
def get_column_statistics(table_name: str, column_name: str) -> str:
"""
Gets statistical information about a numeric column (min, max, avg, count).
Args:
table_name: Name of the table
column_name: Name of the numeric column to analyze
Returns:
Statistical summary of the column
"""
query = f"""
SELECT
MIN({column_name}) as min_value,
MAX({column_name}) as max_value,
AVG({column_name}) as avg_value,
COUNT({column_name}) as count_values,
COUNT(DISTINCT {column_name}) as distinct_values
FROM {DATABASE}.{table_name};
"""
result = execute_athena_query(query)
if result["status"] == "success":
print(f"\n🔧 [TOOL CALLED] get_column_statistics for {table_name}.{column_name}")
return result["results"]
else:
return f"Error getting statistics: {result.get('error', 'Unknown error')}"
@tool
def get_distinct_values(table_name: str, column_name: str, limit: int = 100) -> str:
"""
Gets distinct values from a specific column, useful for categorical data analysis.
Args:
table_name: Name of the table
column_name: Name of the column to get distinct values from
limit: Maximum number of distinct values to return (default: 100)
Returns:
List of distinct values in the column
"""
query = f"""
SELECT DISTINCT {column_name}, COUNT(*) as frequency
FROM {DATABASE}.{table_name}
GROUP BY {column_name}
ORDER BY frequency DESC
LIMIT {limit};
"""
result = execute_athena_query(query)
if result["status"] == "success":
print(f"\n🔧 [TOOL CALLED] get_distinct_values for {table_name}.{column_name}")
return result["results"]
else:
return f"Error getting distinct values: {result.get('error', 'Unknown error')}"
@tool
def filter_answers_by_condition(table_name: str, column_name: str, condition: str) -> str:
"""
Filters answers based on a specific condition.
Args:
table_name: Name of the table (e.g., 'survey_name_respostas')
column_name: Column to apply the filter on
condition: SQL condition (e.g., "> 5", "= 'Yes'", "IS NOT NULL")
Returns:
Filtered results matching the condition
"""
query = f"""
SELECT *
FROM {DATABASE}.{table_name}
WHERE {column_name} {condition}
LIMIT 1000;
"""
result = execute_athena_query(query)
if result["status"] == "success":
print(f"\n🔧 [TOOL CALLED] filter_answers_by_condition")
return result["results"]
else:
return f"Error filtering data: {result.get('error', 'Unknown error')}"
@tool
def aggregate_survey_responses(survey_name: str, column_name: str, aggregation: str = "COUNT") -> str:
"""
Performs aggregation on survey responses (COUNT, SUM, AVG, etc.).
Args:
survey_name: Name of the survey (will query survey_name_respostas table)
column_name: Column to aggregate
aggregation: Type of aggregation (COUNT, SUM, AVG, MIN, MAX)
Returns:
Aggregated result
"""
table_name = f"{survey_name}_respostas"
query = f"""
SELECT {aggregation}({column_name}) as result
FROM {DATABASE}.{table_name};
"""
result = execute_athena_query(query)
if result["status"] == "success":
print(f"\n🔧 [TOOL CALLED] aggregate_survey_responses - {aggregation} on {column_name}")
return result["results"]
else:
return f"Error performing aggregation: {result.get('error', 'Unknown error')}"
@tool
def group_by_analysis(table_name: str, group_column: str, agg_column: str, agg_function: str = "COUNT") -> str:
"""
Performs GROUP BY analysis to understand distribution of responses.
Args:
table_name: Name of the table
group_column: Column to group by
agg_column: Column to aggregate (use '*' for COUNT(*))
agg_function: Aggregation function (COUNT, SUM, AVG, MIN, MAX)
Returns:
Grouped analysis results
"""
if agg_column == '*':
agg_expr = f"{agg_function}(*)"
else:
agg_expr = f"{agg_function}({agg_column})"
query = f"""
SELECT
{group_column},
{agg_expr} as aggregated_value
FROM {DATABASE}.{table_name}
GROUP BY {group_column}
ORDER BY aggregated_value DESC
LIMIT 100;
"""
result = execute_athena_query(query)
if result["status"] == "success":
print(f"\n🔧 [TOOL CALLED] group_by_analysis")
return result["results"]
else:
return f"Error performing group by analysis: {result.get('error', 'Unknown error')}"
@tool
def calculate_percentage_distribution(table_name: str, column_name: str) -> str:
"""
Calculates percentage distribution of values in a column.
Useful for understanding response distributions.
Args:
table_name: Name of the table
column_name: Column to analyze
Returns:
Percentage distribution of values
"""
query = f"""
SELECT
{column_name},
COUNT(*) as count,
ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2) as percentage
FROM {DATABASE}.{table_name}
GROUP BY {column_name}
ORDER BY count DESC;
"""
result = execute_athena_query(query)
if result["status"] == "success":
print(f"\n🔧 [TOOL CALLED] calculate_percentage_distribution for {column_name}")
return result["results"]
else:
return f"Error calculating distribution: {result.get('error', 'Unknown error')}"
@tool
def compare_columns(table_name: str, column1: str, column2: str) -> str:
"""
Compares two columns to find correlations or patterns in survey responses.
Args:
table_name: Name of the table
column1: First column to compare
column2: Second column to compare
Returns:
Cross-tabulation of the two columns
"""
query = f"""
SELECT
{column1},
{column2},
COUNT(*) as frequency
FROM {DATABASE}.{table_name}
GROUP BY {column1}, {column2}
ORDER BY frequency DESC
LIMIT 100;
"""
result = execute_athena_query(query)
if result["status"] == "success":
print(f"\n🔧 [TOOL CALLED] compare_columns: {column1} vs {column2}")
return result["results"]
else:
return f"Error comparing columns: {result.get('error', 'Unknown error')}"
@tool
def search_text_in_responses(table_name: str, column_name: str, search_term: str) -> str:
"""
Searches for specific text in text-based response columns.
Useful for open-ended survey questions.
Args:
table_name: Name of the table
column_name: Column to search in
search_term: Text to search for
Returns:
Matching responses
"""
query = f"""
SELECT {column_name}, COUNT(*) as occurrences
FROM {DATABASE}.{table_name}
WHERE LOWER({column_name}) LIKE LOWER('%{search_term}%')
GROUP BY {column_name}
LIMIT 100;
"""
result = execute_athena_query(query)
if result["status"] == "success":
print(f"\n🔧 [TOOL CALLED] search_text_in_responses for '{search_term}'")
return result["results"]
else:
return f"Error searching text: {result.get('error', 'Unknown error')}"
@tool
def get_null_count(table_name: str, column_name: str) -> str:
"""
Counts NULL or missing values in a specific column.
Useful for data quality analysis.
Args:
table_name: Name of the table
column_name: Column to check for NULL values
Returns:
Count of NULL values and percentage of total
"""
query = f"""
SELECT
COUNT(*) as total_rows,
COUNT({column_name}) as non_null_count,
COUNT(*) - COUNT({column_name}) as null_count,
ROUND((COUNT(*) - COUNT({column_name})) * 100.0 / COUNT(*), 2) as null_percentage
FROM {DATABASE}.{table_name};
"""
result = execute_athena_query(query)
if result["status"] == "success":
print(f"\n🔧 [TOOL CALLED] get_null_count for {table_name}.{column_name}")
return result["results"]
else:
return f"Error checking null values: {result.get('error', 'Unknown error')}"
@tool
def get_date_range_data(table_name: str, date_column: str, start_date: str, end_date: str) -> str:
"""
Retrieves data within a specific date range.
Args:
table_name: Name of the table
date_column: Name of the date column
start_date: Start date in format 'YYYY-MM-DD'
end_date: End date in format 'YYYY-MM-DD'
Returns:
Data within the specified date range
"""
query = f"""
SELECT *
FROM {DATABASE}.{table_name}
WHERE {date_column} BETWEEN DATE '{start_date}' AND DATE '{end_date}'
LIMIT 1000;
"""
result = execute_athena_query(query)
if result["status"] == "success":
print(f"\n🔧 [TOOL CALLED] get_date_range_data from {start_date} to {end_date}")
return result["results"]
else:
return f"Error getting date range data: {result.get('error', 'Unknown error')}"
@tool
def execute_custom_query(sql_query: str) -> str:
"""
Executes a custom SQL query. USE WITH CAUTION.
Only use this when other specific tools don't meet the requirement.
Args:
sql_query: Complete SQL query to execute
Returns:
Query results
"""
# Add basic validation
if any(keyword in sql_query.upper() for keyword in ["DROP", "DELETE", "INSERT", "UPDATE", "ALTER", "CREATE"]):
return "Error: Write operations are not allowed. Only SELECT queries are permitted."
result = execute_athena_query(sql_query)
if result["status"] == "success":
print(f"\n🔧 [TOOL CALLED] execute_custom_query")
return result["results"]
else:
return f"Error executing query: {result.get('error', 'Unknown error')}"
# ==============================================
# TOOL MAPPING FOR EASY INTEGRATION
# ==============================================
SUGGESTED_TOOLS = [
get_table_schema,
list_available_tables,
count_table_rows,
get_column_statistics,
get_distinct_values,
filter_answers_by_condition,
aggregate_survey_responses,
group_by_analysis,
calculate_percentage_distribution,
compare_columns,
search_text_in_responses,
get_null_count,
get_date_range_data,
execute_custom_query
]
TOOLS_MAP = {
"get_table_schema": get_table_schema,
"list_available_tables": list_available_tables,
"count_table_rows": count_table_rows,
"get_column_statistics": get_column_statistics,
"get_distinct_values": get_distinct_values,
"filter_answers_by_condition": filter_answers_by_condition,
"aggregate_survey_responses": aggregate_survey_responses,
"group_by_analysis": group_by_analysis,
"calculate_percentage_distribution": calculate_percentage_distribution,
"compare_columns": compare_columns,
"search_text_in_responses": search_text_in_responses,
"get_null_count": get_null_count,
"get_date_range_data": get_date_range_data,
"execute_custom_query": execute_custom_query
}
# ==============================================
# INTEGRATION EXAMPLE
# ==============================================
"""
To integrate these tools into your BDAgent.py:
1. Import the tools:
from app.utils.AthenaToolsSuggestions import SUGGESTED_TOOLS, TOOLS_MAP
2. Update the create_bedrock_llm function:
tools = [consult_answers] + SUGGESTED_TOOLS
llm_with_tools = llm.bind_tools(tools)
3. Update the call_tools function:
tools_map = {
"consult_answers": consult_answers,
**TOOLS_MAP # Add all suggested tools
}
4. Update the system prompt to include descriptions of new tools:
Your available tools:
- consult_answers: Get answers for a specific question by shortname
- get_table_schema: Get column names and types of a table
- list_available_tables: List all available tables
- count_table_rows: Count total rows in a table
- get_column_statistics: Get statistics (min, max, avg) for numeric columns
- get_distinct_values: Get unique values and frequencies from a column
- calculate_percentage_distribution: Calculate percentage distribution of values
- group_by_analysis: Perform GROUP BY analysis on data
- compare_columns: Compare two columns to find patterns
- And more...
"""