Initial commit
This commit is contained in:
549
scripts/AthenaToolsSuggestions.py
Normal file
549
scripts/AthenaToolsSuggestions.py
Normal file
@@ -0,0 +1,549 @@
|
||||
"""
|
||||
Additional Tool Suggestions for AWS Athena LangGraph Agent
|
||||
|
||||
This file contains suggested tools that can be added to your LangGraph agent
|
||||
to enhance its capabilities with AWS Athena queries and data analysis.
|
||||
|
||||
Based on: BDAgent.py
|
||||
"""
|
||||
|
||||
import boto3
|
||||
import time
|
||||
from langchain_core.tools import tool
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
|
||||
# ==============================================
|
||||
# CONFIGURATION
|
||||
# ==============================================
|
||||
WORKGROUP = "iceberg-workgroup"
|
||||
DATABASE = "dnx_warehouse"
|
||||
|
||||
# Initialize Athena client (reuse from your main file)
|
||||
session = boto3.Session()
|
||||
athena = session.client("athena", region_name="us-east-1")
|
||||
|
||||
|
||||
# ==============================================
|
||||
# HELPER FUNCTION
|
||||
# ==============================================
|
||||
def execute_athena_query(query: str, database: str = DATABASE, workgroup: str = WORKGROUP) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper function to execute Athena queries and wait for results.
|
||||
|
||||
Args:
|
||||
query: SQL query string
|
||||
database: Database name
|
||||
workgroup: Athena workgroup
|
||||
|
||||
Returns:
|
||||
Dictionary with query results or error information
|
||||
"""
|
||||
try:
|
||||
print(f"Executing Athena query...")
|
||||
response = athena.start_query_execution(
|
||||
QueryString=query,
|
||||
QueryExecutionContext={"Database": database},
|
||||
WorkGroup=workgroup
|
||||
)
|
||||
|
||||
query_execution_id = response["QueryExecutionId"]
|
||||
print(f"QueryExecutionId: {query_execution_id}")
|
||||
|
||||
# Wait for query completion
|
||||
while True:
|
||||
result = athena.get_query_execution(QueryExecutionId=query_execution_id)
|
||||
state = result["QueryExecution"]["Status"]["State"]
|
||||
|
||||
if state in ["SUCCEEDED", "FAILED", "CANCELLED"]:
|
||||
print(f"Query state: {state}")
|
||||
break
|
||||
|
||||
print("Waiting for query execution...")
|
||||
time.sleep(1)
|
||||
|
||||
if state == "SUCCEEDED":
|
||||
output = athena.get_query_results(QueryExecutionId=query_execution_id)
|
||||
return {
|
||||
"status": "success",
|
||||
"results": output["ResultSet"]["Rows"],
|
||||
"query_execution_id": query_execution_id
|
||||
}
|
||||
else:
|
||||
error_message = result["QueryExecution"]["Status"].get("StateChangeReason", "Unknown error")
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_message,
|
||||
"query_execution_id": query_execution_id
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# ==============================================
|
||||
# SUGGESTED TOOLS
|
||||
# ==============================================
|
||||
|
||||
@tool
|
||||
def get_table_schema(table_name: str) -> str:
|
||||
"""
|
||||
Gets the schema (column names and types) of a specific table.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table to describe
|
||||
|
||||
Returns:
|
||||
Schema information including column names and data types
|
||||
"""
|
||||
query = f"DESCRIBE {DATABASE}.{table_name};"
|
||||
result = execute_athena_query(query)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n🔧 [TOOL CALLED] get_table_schema for {table_name}")
|
||||
return result["results"]
|
||||
else:
|
||||
return f"Error getting schema: {result.get('error', 'Unknown error')}"
|
||||
|
||||
|
||||
@tool
|
||||
def list_available_tables() -> str:
|
||||
"""
|
||||
Lists all available tables in the database.
|
||||
Useful when user wants to know what data is available.
|
||||
|
||||
Returns:
|
||||
List of all tables in the database
|
||||
"""
|
||||
query = f"SHOW TABLES IN {DATABASE};"
|
||||
result = execute_athena_query(query)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n🔧 [TOOL CALLED] list_available_tables")
|
||||
return result["results"]
|
||||
else:
|
||||
return f"Error listing tables: {result.get('error', 'Unknown error')}"
|
||||
|
||||
|
||||
@tool
|
||||
def count_table_rows(table_name: str) -> str:
|
||||
"""
|
||||
Counts the total number of rows in a specific table.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table to count rows from
|
||||
|
||||
Returns:
|
||||
Total number of rows in the table
|
||||
"""
|
||||
query = f"SELECT COUNT(*) as total_rows FROM {DATABASE}.{table_name};"
|
||||
result = execute_athena_query(query)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n🔧 [TOOL CALLED] count_table_rows for {table_name}")
|
||||
return result["results"]
|
||||
else:
|
||||
return f"Error counting rows: {result.get('error', 'Unknown error')}"
|
||||
|
||||
|
||||
@tool
|
||||
def get_column_statistics(table_name: str, column_name: str) -> str:
|
||||
"""
|
||||
Gets statistical information about a numeric column (min, max, avg, count).
|
||||
|
||||
Args:
|
||||
table_name: Name of the table
|
||||
column_name: Name of the numeric column to analyze
|
||||
|
||||
Returns:
|
||||
Statistical summary of the column
|
||||
"""
|
||||
query = f"""
|
||||
SELECT
|
||||
MIN({column_name}) as min_value,
|
||||
MAX({column_name}) as max_value,
|
||||
AVG({column_name}) as avg_value,
|
||||
COUNT({column_name}) as count_values,
|
||||
COUNT(DISTINCT {column_name}) as distinct_values
|
||||
FROM {DATABASE}.{table_name};
|
||||
"""
|
||||
result = execute_athena_query(query)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n🔧 [TOOL CALLED] get_column_statistics for {table_name}.{column_name}")
|
||||
return result["results"]
|
||||
else:
|
||||
return f"Error getting statistics: {result.get('error', 'Unknown error')}"
|
||||
|
||||
|
||||
@tool
|
||||
def get_distinct_values(table_name: str, column_name: str, limit: int = 100) -> str:
|
||||
"""
|
||||
Gets distinct values from a specific column, useful for categorical data analysis.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table
|
||||
column_name: Name of the column to get distinct values from
|
||||
limit: Maximum number of distinct values to return (default: 100)
|
||||
|
||||
Returns:
|
||||
List of distinct values in the column
|
||||
"""
|
||||
query = f"""
|
||||
SELECT DISTINCT {column_name}, COUNT(*) as frequency
|
||||
FROM {DATABASE}.{table_name}
|
||||
GROUP BY {column_name}
|
||||
ORDER BY frequency DESC
|
||||
LIMIT {limit};
|
||||
"""
|
||||
result = execute_athena_query(query)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n🔧 [TOOL CALLED] get_distinct_values for {table_name}.{column_name}")
|
||||
return result["results"]
|
||||
else:
|
||||
return f"Error getting distinct values: {result.get('error', 'Unknown error')}"
|
||||
|
||||
|
||||
@tool
|
||||
def filter_answers_by_condition(table_name: str, column_name: str, condition: str) -> str:
|
||||
"""
|
||||
Filters answers based on a specific condition.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table (e.g., 'survey_name_respostas')
|
||||
column_name: Column to apply the filter on
|
||||
condition: SQL condition (e.g., "> 5", "= 'Yes'", "IS NOT NULL")
|
||||
|
||||
Returns:
|
||||
Filtered results matching the condition
|
||||
"""
|
||||
query = f"""
|
||||
SELECT *
|
||||
FROM {DATABASE}.{table_name}
|
||||
WHERE {column_name} {condition}
|
||||
LIMIT 1000;
|
||||
"""
|
||||
result = execute_athena_query(query)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n🔧 [TOOL CALLED] filter_answers_by_condition")
|
||||
return result["results"]
|
||||
else:
|
||||
return f"Error filtering data: {result.get('error', 'Unknown error')}"
|
||||
|
||||
|
||||
@tool
|
||||
def aggregate_survey_responses(survey_name: str, column_name: str, aggregation: str = "COUNT") -> str:
|
||||
"""
|
||||
Performs aggregation on survey responses (COUNT, SUM, AVG, etc.).
|
||||
|
||||
Args:
|
||||
survey_name: Name of the survey (will query survey_name_respostas table)
|
||||
column_name: Column to aggregate
|
||||
aggregation: Type of aggregation (COUNT, SUM, AVG, MIN, MAX)
|
||||
|
||||
Returns:
|
||||
Aggregated result
|
||||
"""
|
||||
table_name = f"{survey_name}_respostas"
|
||||
query = f"""
|
||||
SELECT {aggregation}({column_name}) as result
|
||||
FROM {DATABASE}.{table_name};
|
||||
"""
|
||||
result = execute_athena_query(query)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n🔧 [TOOL CALLED] aggregate_survey_responses - {aggregation} on {column_name}")
|
||||
return result["results"]
|
||||
else:
|
||||
return f"Error performing aggregation: {result.get('error', 'Unknown error')}"
|
||||
|
||||
|
||||
@tool
|
||||
def group_by_analysis(table_name: str, group_column: str, agg_column: str, agg_function: str = "COUNT") -> str:
|
||||
"""
|
||||
Performs GROUP BY analysis to understand distribution of responses.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table
|
||||
group_column: Column to group by
|
||||
agg_column: Column to aggregate (use '*' for COUNT(*))
|
||||
agg_function: Aggregation function (COUNT, SUM, AVG, MIN, MAX)
|
||||
|
||||
Returns:
|
||||
Grouped analysis results
|
||||
"""
|
||||
if agg_column == '*':
|
||||
agg_expr = f"{agg_function}(*)"
|
||||
else:
|
||||
agg_expr = f"{agg_function}({agg_column})"
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
{group_column},
|
||||
{agg_expr} as aggregated_value
|
||||
FROM {DATABASE}.{table_name}
|
||||
GROUP BY {group_column}
|
||||
ORDER BY aggregated_value DESC
|
||||
LIMIT 100;
|
||||
"""
|
||||
result = execute_athena_query(query)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n🔧 [TOOL CALLED] group_by_analysis")
|
||||
return result["results"]
|
||||
else:
|
||||
return f"Error performing group by analysis: {result.get('error', 'Unknown error')}"
|
||||
|
||||
|
||||
@tool
|
||||
def calculate_percentage_distribution(table_name: str, column_name: str) -> str:
|
||||
"""
|
||||
Calculates percentage distribution of values in a column.
|
||||
Useful for understanding response distributions.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table
|
||||
column_name: Column to analyze
|
||||
|
||||
Returns:
|
||||
Percentage distribution of values
|
||||
"""
|
||||
query = f"""
|
||||
SELECT
|
||||
{column_name},
|
||||
COUNT(*) as count,
|
||||
ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2) as percentage
|
||||
FROM {DATABASE}.{table_name}
|
||||
GROUP BY {column_name}
|
||||
ORDER BY count DESC;
|
||||
"""
|
||||
result = execute_athena_query(query)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n🔧 [TOOL CALLED] calculate_percentage_distribution for {column_name}")
|
||||
return result["results"]
|
||||
else:
|
||||
return f"Error calculating distribution: {result.get('error', 'Unknown error')}"
|
||||
|
||||
|
||||
@tool
|
||||
def compare_columns(table_name: str, column1: str, column2: str) -> str:
|
||||
"""
|
||||
Compares two columns to find correlations or patterns in survey responses.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table
|
||||
column1: First column to compare
|
||||
column2: Second column to compare
|
||||
|
||||
Returns:
|
||||
Cross-tabulation of the two columns
|
||||
"""
|
||||
query = f"""
|
||||
SELECT
|
||||
{column1},
|
||||
{column2},
|
||||
COUNT(*) as frequency
|
||||
FROM {DATABASE}.{table_name}
|
||||
GROUP BY {column1}, {column2}
|
||||
ORDER BY frequency DESC
|
||||
LIMIT 100;
|
||||
"""
|
||||
result = execute_athena_query(query)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n🔧 [TOOL CALLED] compare_columns: {column1} vs {column2}")
|
||||
return result["results"]
|
||||
else:
|
||||
return f"Error comparing columns: {result.get('error', 'Unknown error')}"
|
||||
|
||||
|
||||
@tool
|
||||
def search_text_in_responses(table_name: str, column_name: str, search_term: str) -> str:
|
||||
"""
|
||||
Searches for specific text in text-based response columns.
|
||||
Useful for open-ended survey questions.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table
|
||||
column_name: Column to search in
|
||||
search_term: Text to search for
|
||||
|
||||
Returns:
|
||||
Matching responses
|
||||
"""
|
||||
query = f"""
|
||||
SELECT {column_name}, COUNT(*) as occurrences
|
||||
FROM {DATABASE}.{table_name}
|
||||
WHERE LOWER({column_name}) LIKE LOWER('%{search_term}%')
|
||||
GROUP BY {column_name}
|
||||
LIMIT 100;
|
||||
"""
|
||||
result = execute_athena_query(query)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n🔧 [TOOL CALLED] search_text_in_responses for '{search_term}'")
|
||||
return result["results"]
|
||||
else:
|
||||
return f"Error searching text: {result.get('error', 'Unknown error')}"
|
||||
|
||||
|
||||
@tool
|
||||
def get_null_count(table_name: str, column_name: str) -> str:
|
||||
"""
|
||||
Counts NULL or missing values in a specific column.
|
||||
Useful for data quality analysis.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table
|
||||
column_name: Column to check for NULL values
|
||||
|
||||
Returns:
|
||||
Count of NULL values and percentage of total
|
||||
"""
|
||||
query = f"""
|
||||
SELECT
|
||||
COUNT(*) as total_rows,
|
||||
COUNT({column_name}) as non_null_count,
|
||||
COUNT(*) - COUNT({column_name}) as null_count,
|
||||
ROUND((COUNT(*) - COUNT({column_name})) * 100.0 / COUNT(*), 2) as null_percentage
|
||||
FROM {DATABASE}.{table_name};
|
||||
"""
|
||||
result = execute_athena_query(query)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n🔧 [TOOL CALLED] get_null_count for {table_name}.{column_name}")
|
||||
return result["results"]
|
||||
else:
|
||||
return f"Error checking null values: {result.get('error', 'Unknown error')}"
|
||||
|
||||
|
||||
@tool
|
||||
def get_date_range_data(table_name: str, date_column: str, start_date: str, end_date: str) -> str:
|
||||
"""
|
||||
Retrieves data within a specific date range.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table
|
||||
date_column: Name of the date column
|
||||
start_date: Start date in format 'YYYY-MM-DD'
|
||||
end_date: End date in format 'YYYY-MM-DD'
|
||||
|
||||
Returns:
|
||||
Data within the specified date range
|
||||
"""
|
||||
query = f"""
|
||||
SELECT *
|
||||
FROM {DATABASE}.{table_name}
|
||||
WHERE {date_column} BETWEEN DATE '{start_date}' AND DATE '{end_date}'
|
||||
LIMIT 1000;
|
||||
"""
|
||||
result = execute_athena_query(query)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n🔧 [TOOL CALLED] get_date_range_data from {start_date} to {end_date}")
|
||||
return result["results"]
|
||||
else:
|
||||
return f"Error getting date range data: {result.get('error', 'Unknown error')}"
|
||||
|
||||
|
||||
@tool
|
||||
def execute_custom_query(sql_query: str) -> str:
|
||||
"""
|
||||
Executes a custom SQL query. USE WITH CAUTION.
|
||||
Only use this when other specific tools don't meet the requirement.
|
||||
|
||||
Args:
|
||||
sql_query: Complete SQL query to execute
|
||||
|
||||
Returns:
|
||||
Query results
|
||||
"""
|
||||
# Add basic validation
|
||||
if any(keyword in sql_query.upper() for keyword in ["DROP", "DELETE", "INSERT", "UPDATE", "ALTER", "CREATE"]):
|
||||
return "Error: Write operations are not allowed. Only SELECT queries are permitted."
|
||||
|
||||
result = execute_athena_query(sql_query)
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n🔧 [TOOL CALLED] execute_custom_query")
|
||||
return result["results"]
|
||||
else:
|
||||
return f"Error executing query: {result.get('error', 'Unknown error')}"
|
||||
|
||||
|
||||
# ==============================================
|
||||
# TOOL MAPPING FOR EASY INTEGRATION
|
||||
# ==============================================
|
||||
|
||||
SUGGESTED_TOOLS = [
|
||||
get_table_schema,
|
||||
list_available_tables,
|
||||
count_table_rows,
|
||||
get_column_statistics,
|
||||
get_distinct_values,
|
||||
filter_answers_by_condition,
|
||||
aggregate_survey_responses,
|
||||
group_by_analysis,
|
||||
calculate_percentage_distribution,
|
||||
compare_columns,
|
||||
search_text_in_responses,
|
||||
get_null_count,
|
||||
get_date_range_data,
|
||||
execute_custom_query
|
||||
]
|
||||
|
||||
TOOLS_MAP = {
|
||||
"get_table_schema": get_table_schema,
|
||||
"list_available_tables": list_available_tables,
|
||||
"count_table_rows": count_table_rows,
|
||||
"get_column_statistics": get_column_statistics,
|
||||
"get_distinct_values": get_distinct_values,
|
||||
"filter_answers_by_condition": filter_answers_by_condition,
|
||||
"aggregate_survey_responses": aggregate_survey_responses,
|
||||
"group_by_analysis": group_by_analysis,
|
||||
"calculate_percentage_distribution": calculate_percentage_distribution,
|
||||
"compare_columns": compare_columns,
|
||||
"search_text_in_responses": search_text_in_responses,
|
||||
"get_null_count": get_null_count,
|
||||
"get_date_range_data": get_date_range_data,
|
||||
"execute_custom_query": execute_custom_query
|
||||
}
|
||||
|
||||
|
||||
# ==============================================
|
||||
# INTEGRATION EXAMPLE
|
||||
# ==============================================
|
||||
"""
|
||||
To integrate these tools into your BDAgent.py:
|
||||
|
||||
1. Import the tools:
|
||||
from app.utils.AthenaToolsSuggestions import SUGGESTED_TOOLS, TOOLS_MAP
|
||||
|
||||
2. Update the create_bedrock_llm function:
|
||||
tools = [consult_answers] + SUGGESTED_TOOLS
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
|
||||
3. Update the call_tools function:
|
||||
tools_map = {
|
||||
"consult_answers": consult_answers,
|
||||
**TOOLS_MAP # Add all suggested tools
|
||||
}
|
||||
|
||||
4. Update the system prompt to include descriptions of new tools:
|
||||
Your available tools:
|
||||
- consult_answers: Get answers for a specific question by shortname
|
||||
- get_table_schema: Get column names and types of a table
|
||||
- list_available_tables: List all available tables
|
||||
- count_table_rows: Count total rows in a table
|
||||
- get_column_statistics: Get statistics (min, max, avg) for numeric columns
|
||||
- get_distinct_values: Get unique values and frequencies from a column
|
||||
- calculate_percentage_distribution: Calculate percentage distribution of values
|
||||
- group_by_analysis: Perform GROUP BY analysis on data
|
||||
- compare_columns: Compare two columns to find patterns
|
||||
- And more...
|
||||
"""
|
||||
241
scripts/athena_query_paginated.py
Normal file
241
scripts/athena_query_paginated.py
Normal file
@@ -0,0 +1,241 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
AWS Athena Query Script with Pagination
|
||||
|
||||
This script executes an Athena query and retrieves all results using pagination
|
||||
to overcome the 1000 row limit per API call.
|
||||
"""
|
||||
|
||||
import boto3
|
||||
import time
|
||||
import csv
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class AthenaQueryExecutor:
|
||||
"""Execute Athena queries with automatic pagination and result retrieval."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database: str,
|
||||
output_location: str,
|
||||
region_name: str = 'us-east-1',
|
||||
max_results_per_page: int = 1000
|
||||
):
|
||||
"""
|
||||
Initialize the Athena query executor.
|
||||
|
||||
Args:
|
||||
database: The Athena database name
|
||||
output_location: S3 location for query results (e.g., 's3://bucket/path/')
|
||||
region_name: AWS region name
|
||||
max_results_per_page: Maximum results per API call (max 1000)
|
||||
"""
|
||||
self.client = boto3.client('athena', region_name=region_name)
|
||||
self.database = database
|
||||
self.output_location = output_location
|
||||
self.max_results_per_page = min(max_results_per_page, 1000)
|
||||
|
||||
def execute_query(self, query: str, wait: bool = True) -> str:
|
||||
"""
|
||||
Execute an Athena query and return the query execution ID.
|
||||
|
||||
Args:
|
||||
query: SQL query string
|
||||
wait: Whether to wait for query completion
|
||||
|
||||
Returns:
|
||||
Query execution ID
|
||||
"""
|
||||
response = self.client.start_query_execution(
|
||||
QueryString=query,
|
||||
QueryExecutionContext={'Database': self.database},
|
||||
ResultConfiguration={'OutputLocation': self.output_location}
|
||||
)
|
||||
|
||||
query_execution_id = response['QueryExecutionId']
|
||||
print(f"Query submitted. Execution ID: {query_execution_id}")
|
||||
|
||||
if wait:
|
||||
self._wait_for_query_completion(query_execution_id)
|
||||
|
||||
return query_execution_id
|
||||
|
||||
def _wait_for_query_completion(self, query_execution_id: str, poll_interval: int = 2):
|
||||
"""
|
||||
Wait for query to complete execution.
|
||||
|
||||
Args:
|
||||
query_execution_id: The query execution ID
|
||||
poll_interval: Seconds between status checks
|
||||
"""
|
||||
print("Waiting for query to complete...")
|
||||
|
||||
while True:
|
||||
response = self.client.get_query_execution(
|
||||
QueryExecutionId=query_execution_id
|
||||
)
|
||||
|
||||
state = response['QueryExecution']['Status']['State']
|
||||
|
||||
if state == 'SUCCEEDED':
|
||||
print("Query completed successfully!")
|
||||
stats = response['QueryExecution']['Statistics']
|
||||
print(f"Data scanned: {stats.get('DataScannedInBytes', 0) / (1024**3):.2f} GB")
|
||||
print(f"Execution time: {stats.get('EngineExecutionTimeInMillis', 0) / 1000:.2f} seconds")
|
||||
break
|
||||
elif state in ['FAILED', 'CANCELLED']:
|
||||
reason = response['QueryExecution']['Status'].get('StateChangeReason', 'Unknown')
|
||||
raise Exception(f"Query {state.lower()}: {reason}")
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
def get_all_results(self, query_execution_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve all query results using pagination.
|
||||
|
||||
Args:
|
||||
query_execution_id: The query execution ID
|
||||
|
||||
Returns:
|
||||
List of result rows as dictionaries
|
||||
"""
|
||||
all_results = []
|
||||
next_token = None
|
||||
page_count = 0
|
||||
|
||||
print("Fetching results with pagination...")
|
||||
|
||||
while True:
|
||||
page_count += 1
|
||||
|
||||
# Build request parameters
|
||||
params = {
|
||||
'QueryExecutionId': query_execution_id,
|
||||
'MaxResults': self.max_results_per_page
|
||||
}
|
||||
|
||||
if next_token:
|
||||
params['NextToken'] = next_token
|
||||
|
||||
# Get results page
|
||||
response = self.client.get_query_results(**params)
|
||||
|
||||
# Extract column names from first page
|
||||
if page_count == 1:
|
||||
columns = [col['Name'] for col in response['ResultSet']['ResultSetMetadata']['ColumnInfo']]
|
||||
# Skip header row in first page
|
||||
rows = response['ResultSet']['Rows'][1:]
|
||||
else:
|
||||
rows = response['ResultSet']['Rows']
|
||||
|
||||
# Convert rows to dictionaries
|
||||
for row in rows:
|
||||
values = [field.get('VarCharValue', '') for field in row['Data']]
|
||||
all_results.append(dict(zip(columns, values)))
|
||||
|
||||
print(f"Page {page_count}: Retrieved {len(rows)} rows (Total: {len(all_results)})")
|
||||
|
||||
# Check if there are more results
|
||||
next_token = response.get('NextToken')
|
||||
if not next_token:
|
||||
break
|
||||
|
||||
print(f"\nTotal rows retrieved: {len(all_results)}")
|
||||
return all_results
|
||||
|
||||
def query_and_fetch_all(self, query: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute query and fetch all results in one call.
|
||||
|
||||
Args:
|
||||
query: SQL query string
|
||||
|
||||
Returns:
|
||||
List of result rows as dictionaries
|
||||
"""
|
||||
query_execution_id = self.execute_query(query, wait=True)
|
||||
return self.get_all_results(query_execution_id)
|
||||
|
||||
def export_to_csv(self, results: List[Dict[str, Any]], filename: str):
|
||||
"""
|
||||
Export results to CSV file.
|
||||
|
||||
Args:
|
||||
results: List of result dictionaries
|
||||
filename: Output CSV filename
|
||||
"""
|
||||
if not results:
|
||||
print("No results to export")
|
||||
return
|
||||
|
||||
with open(filename, 'w', newline='', encoding='utf-8') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=results[0].keys())
|
||||
writer.writeheader()
|
||||
writer.writerows(results)
|
||||
|
||||
print(f"Results exported to {filename}")
|
||||
|
||||
def export_to_json(self, results: List[Dict[str, Any]], filename: str):
|
||||
"""
|
||||
Export results to JSON file.
|
||||
|
||||
Args:
|
||||
results: List of result dictionaries
|
||||
filename: Output JSON filename
|
||||
"""
|
||||
with open(filename, 'w', encoding='utf-8') as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"Results exported to {filename}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Example usage of the AthenaQueryExecutor."""
|
||||
|
||||
# Configuration
|
||||
DATABASE = 'your_database_name'
|
||||
OUTPUT_LOCATION = 's3://your-bucket/athena-results/'
|
||||
REGION = 'us-east-1'
|
||||
|
||||
# Example query
|
||||
QUERY = """
|
||||
SELECT *
|
||||
FROM your_table
|
||||
WHERE date >= '2024-01-01'
|
||||
LIMIT 5000
|
||||
"""
|
||||
|
||||
# Initialize executor
|
||||
executor = AthenaQueryExecutor(
|
||||
database=DATABASE,
|
||||
output_location=OUTPUT_LOCATION,
|
||||
region_name=REGION
|
||||
)
|
||||
|
||||
# Execute query and fetch all results
|
||||
try:
|
||||
results = executor.query_and_fetch_all(QUERY)
|
||||
|
||||
# Export results
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
executor.export_to_csv(results, f'athena_results_{timestamp}.csv')
|
||||
executor.export_to_json(results, f'athena_results_{timestamp}.json')
|
||||
|
||||
# Display sample results
|
||||
if results:
|
||||
print("\nFirst 5 results:")
|
||||
for i, row in enumerate(results[:5], 1):
|
||||
print(f"{i}. {row}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(main())
|
||||
191
scripts/dynamodb_read_table.py
Normal file
191
scripts/dynamodb_read_table.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
DynamoDB Table Reader Script
|
||||
|
||||
This script connects to AWS DynamoDB and reads all entries from a specified table.
|
||||
Outputs data in XML format with <period> tags containing the context XML content.
|
||||
|
||||
Usage:
|
||||
from dynamodb_read_table import read_table_as_xml
|
||||
xml_content = read_table_as_xml("my-table-name")
|
||||
"""
|
||||
|
||||
import re
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
|
||||
def clean_context_xml(context: str) -> str:
|
||||
"""
|
||||
Remove XML declaration and <relatorio> tags from context content.
|
||||
|
||||
Args:
|
||||
context: Raw XML content from DynamoDB
|
||||
|
||||
Returns:
|
||||
Cleaned XML content without declaration and relatorio tags
|
||||
"""
|
||||
# Remove XML declaration (e.g., <?xml version="1.0" encoding="UTF-8"?>)
|
||||
context = re.sub(r'<\?xml[^?]*\?>\s*', '', context)
|
||||
|
||||
# Remove opening <relatorio> tag (with any attributes)
|
||||
context = re.sub(r'<relatorio[^>]*>\s*', '', context)
|
||||
|
||||
# Remove closing </relatorio> tag
|
||||
context = re.sub(r'\s*</relatorio>', '', context)
|
||||
|
||||
return context.strip()
|
||||
|
||||
|
||||
def format_items_to_xml(items: list) -> str:
|
||||
"""
|
||||
Format all DynamoDB items to XML format.
|
||||
|
||||
Each item's 'period' field becomes a <period> tag,
|
||||
and the 'context' field's cleaned XML content is placed inside it.
|
||||
|
||||
Args:
|
||||
items: List of DynamoDB items
|
||||
|
||||
Returns:
|
||||
Complete XML formatted string with all items
|
||||
"""
|
||||
xml_parts = []
|
||||
|
||||
for item in items:
|
||||
period = item.get("period", "unknown")
|
||||
context = item.get("context", "")
|
||||
|
||||
# Clean the context XML
|
||||
cleaned_context = clean_context_xml(context)
|
||||
|
||||
xml_parts.append(f"<{period}>")
|
||||
xml_parts.append(cleaned_context)
|
||||
xml_parts.append(f"</{period}>")
|
||||
xml_parts.append("") # Empty line between entries
|
||||
|
||||
return "\n".join(xml_parts)
|
||||
|
||||
|
||||
def get_dynamodb_client(region_name: str = "us-east-1"):
|
||||
"""Create and return a DynamoDB client."""
|
||||
session = boto3.Session()
|
||||
return session.client("dynamodb", region_name=region_name)
|
||||
|
||||
|
||||
def get_dynamodb_resource(region_name: str = "us-east-1"):
|
||||
"""Create and return a DynamoDB resource for higher-level operations."""
|
||||
session = boto3.Session()
|
||||
return session.resource("dynamodb", region_name=region_name)
|
||||
|
||||
|
||||
def scan_table(table_name: str, region_name: str = "us-east-1") -> list:
|
||||
"""
|
||||
Scan a DynamoDB table and return all items.
|
||||
|
||||
Uses pagination to handle tables larger than 1MB response limit.
|
||||
|
||||
Args:
|
||||
table_name: Name of the DynamoDB table to scan
|
||||
region_name: AWS region where the table is located
|
||||
|
||||
Returns:
|
||||
List of all items in the table
|
||||
"""
|
||||
dynamodb = get_dynamodb_resource(region_name)
|
||||
table = dynamodb.Table(table_name)
|
||||
|
||||
items = []
|
||||
last_evaluated_key = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
if last_evaluated_key:
|
||||
response = table.scan(ExclusiveStartKey=last_evaluated_key)
|
||||
else:
|
||||
response = table.scan()
|
||||
|
||||
items.extend(response.get("Items", []))
|
||||
|
||||
last_evaluated_key = response.get("LastEvaluatedKey")
|
||||
if not last_evaluated_key:
|
||||
break
|
||||
|
||||
print(f"Successfully scanned {len(items)} items from table '{table_name}'")
|
||||
return items
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
error_message = e.response["Error"]["Message"]
|
||||
print(f"Error scanning table: {error_code} - {error_message}")
|
||||
raise
|
||||
|
||||
|
||||
def list_tables(region_name: str = "us-east-1") -> list:
|
||||
"""List all DynamoDB tables in the specified region."""
|
||||
client = get_dynamodb_client(region_name)
|
||||
|
||||
tables = []
|
||||
last_evaluated_table_name = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
if last_evaluated_table_name:
|
||||
response = client.list_tables(ExclusiveStartTableName=last_evaluated_table_name)
|
||||
else:
|
||||
response = client.list_tables()
|
||||
|
||||
tables.extend(response.get("TableNames", []))
|
||||
|
||||
last_evaluated_table_name = response.get("LastEvaluatedTableName")
|
||||
if not last_evaluated_table_name:
|
||||
break
|
||||
|
||||
return tables
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
error_message = e.response["Error"]["Message"]
|
||||
print(f"Error listing tables: {error_code} - {error_message}")
|
||||
raise
|
||||
|
||||
|
||||
def get_table_info(table_name: str, region_name: str = "us-east-1") -> dict:
|
||||
"""Get metadata information about a DynamoDB table."""
|
||||
client = get_dynamodb_client(region_name)
|
||||
|
||||
try:
|
||||
response = client.describe_table(TableName=table_name)
|
||||
table_info = response.get("Table", {})
|
||||
|
||||
return {
|
||||
"TableName": table_info.get("TableName"),
|
||||
"TableStatus": table_info.get("TableStatus"),
|
||||
"ItemCount": table_info.get("ItemCount"),
|
||||
"TableSizeBytes": table_info.get("TableSizeBytes"),
|
||||
"KeySchema": table_info.get("KeySchema"),
|
||||
"AttributeDefinitions": table_info.get("AttributeDefinitions"),
|
||||
"CreationDateTime": str(table_info.get("CreationDateTime")),
|
||||
}
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
error_message = e.response["Error"]["Message"]
|
||||
print(f"Error describing table: {error_code} - {error_message}")
|
||||
raise
|
||||
|
||||
|
||||
def read_table_as_xml(table_name: str, region_name: str = "us-east-1") -> str:
|
||||
"""
|
||||
Read all entries from a DynamoDB table and return as XML string.
|
||||
|
||||
Args:
|
||||
table_name: Name of the DynamoDB table to read
|
||||
region_name: AWS region where the table is located (default: us-east-1)
|
||||
|
||||
Returns:
|
||||
XML formatted string with all items wrapped in <period> tags
|
||||
"""
|
||||
items = scan_table(table_name, region_name)
|
||||
return format_items_to_xml(items)
|
||||
if __name__=="__main__":
|
||||
print(read_table_as_xml("poc_dnx_monthly_summary","us-east-1"))
|
||||
29
scripts/secretsmanager.py
Normal file
29
scripts/secretsmanager.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from botocore.exceptions import ClientError
|
||||
import json
|
||||
import boto3
|
||||
from langfuse import Langfuse
|
||||
def get_secret():
|
||||
|
||||
secret_name = "assistente-db-secrets-manager"
|
||||
region_name = "us-east-1"
|
||||
|
||||
# Create a Secrets Manager client
|
||||
session = boto3.session.Session()
|
||||
client = session.client(
|
||||
service_name='secretsmanager',
|
||||
region_name=region_name
|
||||
)
|
||||
|
||||
try:
|
||||
get_secret_value_response = client.get_secret_value(
|
||||
SecretId=secret_name
|
||||
)
|
||||
except ClientError as e:
|
||||
# For a list of exceptions thrown, see
|
||||
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
|
||||
raise e
|
||||
|
||||
secret = get_secret_value_response['SecretString']
|
||||
return secret
|
||||
secrets=json.loads(get_secret())['LANGFUSE-SECRET-KEY']
|
||||
print(secrets)
|
||||
322
scripts/simple_agent_with_langfuse.py
Normal file
322
scripts/simple_agent_with_langfuse.py
Normal file
@@ -0,0 +1,322 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple LangGraph Agent with Langfuse Integration
|
||||
Demonstrates basic agent functionality with Langfuse observability
|
||||
"""
|
||||
|
||||
import json
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
from typing import TypedDict, Annotated
|
||||
import operator
|
||||
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langchain_aws import ChatBedrock
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage
|
||||
|
||||
from langfuse import Langfuse
|
||||
from langfuse.langchain import CallbackHandler
|
||||
|
||||
|
||||
# ==============================================
|
||||
# SECRETS MANAGER
|
||||
# ==============================================
|
||||
|
||||
def get_secret():
|
||||
"""Fetch secrets from AWS Secrets Manager"""
|
||||
secret_name = "assistente-db-secrets-manager"
|
||||
region_name = "us-east-1"
|
||||
|
||||
session = boto3.session.Session()
|
||||
client = session.client(
|
||||
service_name='secretsmanager',
|
||||
region_name=region_name
|
||||
)
|
||||
|
||||
try:
|
||||
get_secret_value_response = client.get_secret_value(
|
||||
SecretId=secret_name
|
||||
)
|
||||
except ClientError as e:
|
||||
raise e
|
||||
|
||||
secret = get_secret_value_response['SecretString']
|
||||
return json.loads(secret)
|
||||
|
||||
|
||||
# ==============================================
|
||||
# INITIALIZE LANGFUSE
|
||||
# ==============================================
|
||||
|
||||
print("Initializing Langfuse...")
|
||||
secrets = get_secret()
|
||||
|
||||
langfuse = Langfuse(
|
||||
public_key=secrets['LANGFUSE-PUBLIC-KEY'],
|
||||
secret_key=secrets['LANGFUSE-SECRET-KEY'],
|
||||
host="http://98.92.98.83:3000"
|
||||
)
|
||||
|
||||
print(f"✓ Langfuse initialized successfully")
|
||||
print(f" Host: http://98.92.98.83:3000")
|
||||
|
||||
|
||||
# ==============================================
|
||||
# DEFINE TOOLS
|
||||
# ==============================================
|
||||
|
||||
@tool
|
||||
def add_numbers(a: int, b: int) -> int:
|
||||
"""
|
||||
Add two numbers together.
|
||||
|
||||
Args:
|
||||
a: First number
|
||||
b: Second number
|
||||
|
||||
Returns:
|
||||
The sum of a and b
|
||||
"""
|
||||
print(f"🔧 [TOOL] Adding {a} + {b}")
|
||||
return a + b
|
||||
|
||||
|
||||
@tool
|
||||
def multiply_numbers(a: int, b: int) -> int:
|
||||
"""
|
||||
Multiply two numbers together.
|
||||
|
||||
Args:
|
||||
a: First number
|
||||
b: Second number
|
||||
|
||||
Returns:
|
||||
The product of a and b
|
||||
"""
|
||||
print(f"🔧 [TOOL] Multiplying {a} * {b}")
|
||||
return a * b
|
||||
|
||||
|
||||
# ==============================================
|
||||
# AGENT STATE
|
||||
# ==============================================
|
||||
|
||||
class AgentState(TypedDict):
|
||||
messages: Annotated[list, operator.add]
|
||||
current_step: str
|
||||
|
||||
|
||||
# ==============================================
|
||||
# AGENT NODES
|
||||
# ==============================================
|
||||
|
||||
def create_bedrock_llm(inference_profile_arn: str, region: str = "us-east-1"):
|
||||
"""Create a ChatBedrock instance with tools"""
|
||||
llm = ChatBedrock(
|
||||
model_id=inference_profile_arn,
|
||||
region_name=region,
|
||||
model_kwargs={
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
provider="anthropic"
|
||||
)
|
||||
|
||||
# Bind tools to the LLM
|
||||
tools = [add_numbers, multiply_numbers]
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
|
||||
return llm_with_tools
|
||||
|
||||
|
||||
def call_model(state: AgentState, llm) -> AgentState:
|
||||
"""Call the LLM with Langfuse callback"""
|
||||
print(f"[MODEL] Calling Bedrock with Langfuse tracing...")
|
||||
|
||||
messages = state["messages"]
|
||||
|
||||
# Create Langfuse callback handler
|
||||
langfuse_handler = CallbackHandler()
|
||||
config = {
|
||||
"configurable": {"thread_id": "simple_agent_demo"},
|
||||
"callbacks": [langfuse_handler]
|
||||
}
|
||||
|
||||
response = llm.invoke(messages, config=config)
|
||||
state["current_step"] = "model_called"
|
||||
return {"messages": [response]}
|
||||
|
||||
|
||||
def call_tools(state: AgentState) -> AgentState:
|
||||
"""Execute any tool calls from the LLM response"""
|
||||
print(f"[TOOLS] Checking for tool calls...")
|
||||
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
|
||||
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
|
||||
print(f"[TOOLS] Found {len(last_message.tool_calls)} tool call(s)")
|
||||
|
||||
tool_messages = []
|
||||
tools_map = {
|
||||
"add_numbers": add_numbers,
|
||||
"multiply_numbers": multiply_numbers
|
||||
}
|
||||
|
||||
for tool_call in last_message.tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
|
||||
print(f"[TOOLS] Executing: {tool_name}({tool_args})")
|
||||
|
||||
tool_func = tools_map[tool_name]
|
||||
result = tool_func.invoke(tool_args)
|
||||
|
||||
tool_message = ToolMessage(
|
||||
content=str(result),
|
||||
tool_call_id=tool_call["id"]
|
||||
)
|
||||
tool_messages.append(tool_message)
|
||||
|
||||
state["current_step"] = "tools_executed"
|
||||
return {"messages": tool_messages}
|
||||
else:
|
||||
print(f"[TOOLS] No tool calls found")
|
||||
state["current_step"] = "no_tools"
|
||||
return {"messages": []}
|
||||
|
||||
|
||||
def should_continue(state: AgentState) -> str:
|
||||
"""Determine if we should continue to tools or end"""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
|
||||
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
|
||||
print("[ROUTER] Routing to tools...")
|
||||
return "tools"
|
||||
|
||||
print("[ROUTER] No more tool calls, ending...")
|
||||
return "end"
|
||||
|
||||
|
||||
# ==============================================
|
||||
# CREATE AGENT
|
||||
# ==============================================
|
||||
|
||||
def create_agent(inference_profile_arn: str, region: str = "us-east-1"):
|
||||
"""Create a LangGraph agent with Langfuse observability"""
|
||||
|
||||
llm = create_bedrock_llm(inference_profile_arn, region)
|
||||
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Add nodes
|
||||
workflow.add_node("model", lambda state: call_model(state, llm))
|
||||
workflow.add_node("tools", call_tools)
|
||||
|
||||
# Define workflow
|
||||
workflow.set_entry_point("model")
|
||||
|
||||
workflow.add_conditional_edges(
|
||||
"model",
|
||||
should_continue,
|
||||
{
|
||||
"tools": "tools",
|
||||
"end": END
|
||||
}
|
||||
)
|
||||
|
||||
workflow.add_edge("tools", "model")
|
||||
|
||||
app = workflow.compile()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ==============================================
|
||||
# MAIN EXECUTION
|
||||
# ==============================================
|
||||
|
||||
def main():
|
||||
"""Main execution function"""
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Simple LangGraph Agent with Langfuse")
|
||||
print("=" * 60)
|
||||
|
||||
# Configuration
|
||||
INFERENCE_PROFILE_ARN = "arn:aws:bedrock:us-east-1:305427701314:application-inference-profile/b3umwd5jpd0u"
|
||||
REGION = "us-east-1"
|
||||
|
||||
# System prompt
|
||||
SYSTEM_PROMPT = """You are a helpful math assistant with access to calculation tools.
|
||||
|
||||
Your available tools:
|
||||
- add_numbers: Add two numbers
|
||||
- multiply_numbers: Multiply two numbers
|
||||
|
||||
When a user asks you to perform calculations:
|
||||
1. Break down the problem into steps
|
||||
2. Use the appropriate tools
|
||||
3. Show your reasoning clearly
|
||||
4. Provide the final answer
|
||||
|
||||
Always use the tools rather than calculating in your head."""
|
||||
|
||||
print(f"\nUsing inference profile: {INFERENCE_PROFILE_ARN}")
|
||||
print(f"Region: {REGION}")
|
||||
print("=" * 60)
|
||||
|
||||
# Create the agent
|
||||
agent = create_agent(INFERENCE_PROFILE_ARN, REGION)
|
||||
|
||||
# Example query
|
||||
user_query = "What is (10 + 5) * 3?"
|
||||
|
||||
# Initialize state
|
||||
initial_state = {
|
||||
"messages": [
|
||||
SystemMessage(content=SYSTEM_PROMPT),
|
||||
HumanMessage(content=user_query)
|
||||
],
|
||||
"current_step": "init"
|
||||
}
|
||||
|
||||
print(f"\nUser Query: {user_query}\n")
|
||||
print("-" * 60)
|
||||
|
||||
# Run the agent
|
||||
final_state = agent.invoke(initial_state)
|
||||
|
||||
# Display results
|
||||
print("-" * 60)
|
||||
print("\n[FINAL RESULT]\n")
|
||||
|
||||
for i, msg in enumerate(final_state["messages"], 1):
|
||||
if isinstance(msg, SystemMessage):
|
||||
print(f"{i}. System: [System prompt configured]")
|
||||
elif isinstance(msg, HumanMessage):
|
||||
print(f"{i}. User: {msg.content}")
|
||||
elif isinstance(msg, AIMessage):
|
||||
if hasattr(msg, 'tool_calls') and msg.tool_calls:
|
||||
print(f"{i}. AI: [Calling tools...]")
|
||||
else:
|
||||
print(f"{i}. AI: {msg.content}")
|
||||
elif isinstance(msg, ToolMessage):
|
||||
print(f"{i}. Tool Result: {msg.content}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"Agent completed successfully!")
|
||||
|
||||
# Flush Langfuse data
|
||||
print("\nFlushing data to Langfuse...")
|
||||
langfuse.flush()
|
||||
print("✓ Data sent to Langfuse")
|
||||
|
||||
print("\nView your traces at: http://98.92.98.83:3000")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
12
scripts/test_api.sh
Executable file
12
scripts/test_api.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
URL="http://alb-assistente-analitico-7e352f9-1039635730.us-east-1.elb.amazonaws.com:8000"
|
||||
|
||||
curl -X POST "${URL}/agent" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"query": "Qual o NPS de dezembro 2025?",
|
||||
"history": "",
|
||||
"model": "anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"base": "bacio_transacional_loja_app"
|
||||
}'
|
||||
133
scripts/test_langfuse_connection.py
Normal file
133
scripts/test_langfuse_connection.py
Normal file
@@ -0,0 +1,133 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to diagnose Langfuse connectivity issues
|
||||
"""
|
||||
import json
|
||||
import sys
|
||||
|
||||
print("=" * 60)
|
||||
print("Langfuse Connection Diagnostic Tool")
|
||||
print("=" * 60)
|
||||
|
||||
# Test 1: Check if required modules are installed
|
||||
print("\n1. Checking required modules...")
|
||||
try:
|
||||
import boto3
|
||||
print(" ✓ boto3 installed")
|
||||
except ImportError:
|
||||
print(" ✗ boto3 NOT installed")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
from langfuse import Langfuse
|
||||
print(" ✓ langfuse installed")
|
||||
except ImportError:
|
||||
print(" ✗ langfuse NOT installed")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
from botocore.exceptions import ClientError
|
||||
print(" ✓ botocore installed")
|
||||
except ImportError:
|
||||
print(" ✗ botocore NOT installed")
|
||||
sys.exit(1)
|
||||
|
||||
# Test 2: Fetch secrets from AWS Secrets Manager
|
||||
print("\n2. Fetching secrets from AWS Secrets Manager...")
|
||||
secrets = None
|
||||
try:
|
||||
secret_name = "assistente-db-secrets-manager"
|
||||
region_name = "us-east-1"
|
||||
|
||||
session = boto3.session.Session()
|
||||
client = session.client(
|
||||
service_name='secretsmanager',
|
||||
region_name=region_name
|
||||
)
|
||||
|
||||
get_secret_value_response = client.get_secret_value(SecretId=secret_name)
|
||||
secret = get_secret_value_response['SecretString']
|
||||
secrets = json.loads(secret)
|
||||
|
||||
print(f" ✓ Successfully fetched secrets")
|
||||
print(f" ✓ Found keys: {list(secrets.keys())}")
|
||||
|
||||
# Check for Langfuse keys
|
||||
has_public_key = 'LANGFUSE-PUBLIC-KEY' in secrets
|
||||
has_secret_key = 'LANGFUSE-SECRET-KEY' in secrets
|
||||
|
||||
print(f" {'✓' if has_public_key else '✗'} LANGFUSE-PUBLIC-KEY {'found' if has_public_key else 'MISSING'}")
|
||||
print(f" {'✓' if has_secret_key else '✗'} LANGFUSE-SECRET-KEY {'found' if has_secret_key else 'MISSING'}")
|
||||
|
||||
if not has_public_key or not has_secret_key:
|
||||
print("\n ⚠ Langfuse credentials not found in Secrets Manager!")
|
||||
secrets = None
|
||||
|
||||
except ClientError as e:
|
||||
print(f" ⚠ Error fetching secrets: {e}")
|
||||
print(" Will try with manual input...")
|
||||
secrets = None
|
||||
except Exception as e:
|
||||
print(f" ⚠ Unexpected error: {e}")
|
||||
print(" Will try with manual input...")
|
||||
secrets = None
|
||||
|
||||
# If AWS Secrets Manager failed, try manual input
|
||||
if secrets is None:
|
||||
print("\n Please provide Langfuse credentials manually:")
|
||||
public_key = input(" Enter LANGFUSE-PUBLIC-KEY: ").strip()
|
||||
secret_key = input(" Enter LANGFUSE-SECRET-KEY: ").strip()
|
||||
|
||||
if not public_key or not secret_key:
|
||||
print(" ✗ Both keys are required!")
|
||||
sys.exit(1)
|
||||
|
||||
secrets = {
|
||||
'LANGFUSE-PUBLIC-KEY': public_key,
|
||||
'LANGFUSE-SECRET-KEY': secret_key
|
||||
}
|
||||
print(" ✓ Manual credentials provided")
|
||||
|
||||
# Test 3: Initialize Langfuse client
|
||||
print("\n3. Initializing Langfuse client...")
|
||||
try:
|
||||
host = "http://98.92.98.83:3000"
|
||||
langfuse = Langfuse(
|
||||
public_key=secrets['LANGFUSE-PUBLIC-KEY'],
|
||||
secret_key=secrets['LANGFUSE-SECRET-KEY'],
|
||||
host=host
|
||||
)
|
||||
print(f" ✓ Langfuse client initialized with host: {host}")
|
||||
except Exception as e:
|
||||
print(f" ✗ Error initializing Langfuse: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Test 4: Test API connectivity
|
||||
print("\n4. Testing Langfuse API connectivity...")
|
||||
try:
|
||||
# Create a simple test trace
|
||||
trace = langfuse.trace(
|
||||
name="connection_test",
|
||||
user_id="diagnostic_script"
|
||||
)
|
||||
print(" ✓ Test trace created")
|
||||
|
||||
# Try to flush
|
||||
langfuse.flush()
|
||||
print(" ✓ Data flushed to Langfuse")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("SUCCESS: Langfuse connection is working properly!")
|
||||
print("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ Error communicating with Langfuse: {e}")
|
||||
print(f"\n Error details: {type(e).__name__}: {str(e)}")
|
||||
|
||||
# Additional diagnostic info
|
||||
print("\n Troubleshooting tips:")
|
||||
print(" - Verify the Langfuse server is running")
|
||||
print(" - Check if the host URL is correct")
|
||||
print(" - Verify your API keys are valid in Langfuse UI")
|
||||
print(" - Check network connectivity to the Langfuse server")
|
||||
sys.exit(1)
|
||||
53
scripts/teste.py
Normal file
53
scripts/teste.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import boto3
|
||||
import time
|
||||
|
||||
WORKGROUP = "iceberg-workgroup"
|
||||
DATABASE = "dnx_warehouse"
|
||||
|
||||
session = boto3.Session()
|
||||
athena = session.client("athena", region_name="us-east-1")
|
||||
|
||||
# ==============================================
|
||||
# QUERY
|
||||
# ==============================================
|
||||
|
||||
QUERY = """
|
||||
SELECT title,shortname from AwsDataCatalog.dnx_warehouse.bacio_transacional_loja_app_pesquisa;
|
||||
"""
|
||||
|
||||
print("Executando query no Athena...")
|
||||
response = athena.start_query_execution(
|
||||
QueryString=QUERY,
|
||||
QueryExecutionContext={"Database": DATABASE},
|
||||
WorkGroup=WORKGROUP
|
||||
)
|
||||
|
||||
query_execution_id = response["QueryExecutionId"]
|
||||
print(f"QueryExecutionId: {query_execution_id}")
|
||||
|
||||
# ==============================================
|
||||
# AGUARDAR RESULTADO
|
||||
# ==============================================
|
||||
|
||||
while True:
|
||||
result = athena.get_query_execution(QueryExecutionId=query_execution_id)
|
||||
state = result["QueryExecution"]["Status"]["State"]
|
||||
|
||||
if state in ["SUCCEEDED", "FAILED", "CANCELLED"]:
|
||||
print("Estado final:", state)
|
||||
break
|
||||
|
||||
print("Aguardando execução...")
|
||||
time.sleep(1)
|
||||
|
||||
# ==============================================
|
||||
# RESULTADO
|
||||
# ==============================================
|
||||
|
||||
if state == "SUCCEEDED":
|
||||
output = athena.get_query_results(QueryExecutionId=query_execution_id)
|
||||
print("\nResultados:")
|
||||
for row in output["ResultSet"]["Rows"]:
|
||||
print([col.get("VarCharValue", "") for col in row["Data"]][0])
|
||||
else:
|
||||
print("Erro ao executar a query.")
|
||||
209
scripts/utils/dynamodb_read_table.py
Normal file
209
scripts/utils/dynamodb_read_table.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
DynamoDB Table Reader Script
|
||||
|
||||
This script connects to AWS DynamoDB and reads all entries from a specified table.
|
||||
Outputs data in XML format with <period> tags containing the context XML content.
|
||||
|
||||
Usage:
|
||||
from dynamodb_read_table import read_table_as_xml
|
||||
xml_content = read_table_as_xml("my-table-name")
|
||||
"""
|
||||
|
||||
import re
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
|
||||
def clean_context_xml(context: str) -> str:
|
||||
"""
|
||||
Remove XML declaration and <relatorio> tags from context content.
|
||||
|
||||
Args:
|
||||
context: Raw XML content from DynamoDB
|
||||
|
||||
Returns:
|
||||
Cleaned XML content without declaration and relatorio tags
|
||||
"""
|
||||
# Remove XML declaration (e.g., <?xml version="1.0" encoding="UTF-8"?>)
|
||||
context = re.sub(r'<\?xml[^?]*\?>\s*', '', context)
|
||||
|
||||
# Remove opening <relatorio> tag (with any attributes)
|
||||
context = re.sub(r'<relatorio[^>]*>\s*', '', context)
|
||||
|
||||
# Remove closing </relatorio> tag
|
||||
context = re.sub(r'\s*</relatorio>', '', context)
|
||||
|
||||
return context.strip()
|
||||
|
||||
|
||||
def remove_xml_declaration(content: str) -> str:
|
||||
"""
|
||||
Remove only the XML declaration from content.
|
||||
|
||||
Args:
|
||||
content: Raw XML content
|
||||
|
||||
Returns:
|
||||
Content without XML declaration (keeps relatorio tags)
|
||||
"""
|
||||
content = re.sub(r'<\?xml[^?]*\?>\s*', '', content)
|
||||
return content.strip()
|
||||
|
||||
|
||||
def format_items_to_xml(items: list) -> str:
|
||||
"""
|
||||
Format all DynamoDB items to XML format.
|
||||
|
||||
Each item's 'period' field becomes a <period> tag,
|
||||
and the 'context' and 'dados_consolidados' fields are placed inside it.
|
||||
|
||||
Args:
|
||||
items: List of DynamoDB items
|
||||
|
||||
Returns:
|
||||
Complete XML formatted string with all items
|
||||
"""
|
||||
xml_parts = []
|
||||
|
||||
for item in items:
|
||||
period = item.get("period", "unknown")
|
||||
context = item.get("context", "")
|
||||
dados_consolidados = item.get("dados_consolidados", "")
|
||||
|
||||
# Clean the XML content
|
||||
cleaned_context = clean_context_xml(context)
|
||||
cleaned_dados = remove_xml_declaration(dados_consolidados)
|
||||
|
||||
xml_parts.append(f"<{period}>")
|
||||
xml_parts.append(cleaned_context)
|
||||
if cleaned_dados:
|
||||
xml_parts.append(cleaned_dados)
|
||||
xml_parts.append(f"</{period}>")
|
||||
xml_parts.append("") # Empty line between entries
|
||||
|
||||
return "\n".join(xml_parts)
|
||||
|
||||
|
||||
def get_dynamodb_client(region_name: str = "us-east-1"):
|
||||
"""Create and return a DynamoDB client."""
|
||||
session = boto3.Session()
|
||||
return session.client("dynamodb", region_name=region_name)
|
||||
|
||||
|
||||
def get_dynamodb_resource(region_name: str = "us-east-1"):
|
||||
"""Create and return a DynamoDB resource for higher-level operations."""
|
||||
session = boto3.Session()
|
||||
return session.resource("dynamodb", region_name=region_name)
|
||||
|
||||
|
||||
def scan_table(table_name: str, region_name: str = "us-east-1") -> list:
|
||||
"""
|
||||
Scan a DynamoDB table and return all items.
|
||||
|
||||
Uses pagination to handle tables larger than 1MB response limit.
|
||||
|
||||
Args:
|
||||
table_name: Name of the DynamoDB table to scan
|
||||
region_name: AWS region where the table is located
|
||||
|
||||
Returns:
|
||||
List of all items in the table
|
||||
"""
|
||||
dynamodb = get_dynamodb_resource(region_name)
|
||||
table = dynamodb.Table(table_name)
|
||||
|
||||
items = []
|
||||
last_evaluated_key = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
if last_evaluated_key:
|
||||
response = table.scan(ExclusiveStartKey=last_evaluated_key)
|
||||
else:
|
||||
response = table.scan()
|
||||
|
||||
items.extend(response.get("Items", []))
|
||||
|
||||
last_evaluated_key = response.get("LastEvaluatedKey")
|
||||
if not last_evaluated_key:
|
||||
break
|
||||
|
||||
print(f"Successfully scanned {len(items)} items from table '{table_name}'")
|
||||
return items
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
error_message = e.response["Error"]["Message"]
|
||||
print(f"Error scanning table: {error_code} - {error_message}")
|
||||
raise
|
||||
|
||||
|
||||
def list_tables(region_name: str = "us-east-1") -> list:
|
||||
"""List all DynamoDB tables in the specified region."""
|
||||
client = get_dynamodb_client(region_name)
|
||||
|
||||
tables = []
|
||||
last_evaluated_table_name = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
if last_evaluated_table_name:
|
||||
response = client.list_tables(ExclusiveStartTableName=last_evaluated_table_name)
|
||||
else:
|
||||
response = client.list_tables()
|
||||
|
||||
tables.extend(response.get("TableNames", []))
|
||||
|
||||
last_evaluated_table_name = response.get("LastEvaluatedTableName")
|
||||
if not last_evaluated_table_name:
|
||||
break
|
||||
|
||||
return tables
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
error_message = e.response["Error"]["Message"]
|
||||
print(f"Error listing tables: {error_code} - {error_message}")
|
||||
raise
|
||||
|
||||
|
||||
def get_table_info(table_name: str, region_name: str = "us-east-1") -> dict:
|
||||
"""Get metadata information about a DynamoDB table."""
|
||||
client = get_dynamodb_client(region_name)
|
||||
|
||||
try:
|
||||
response = client.describe_table(TableName=table_name)
|
||||
table_info = response.get("Table", {})
|
||||
|
||||
return {
|
||||
"TableName": table_info.get("TableName"),
|
||||
"TableStatus": table_info.get("TableStatus"),
|
||||
"ItemCount": table_info.get("ItemCount"),
|
||||
"TableSizeBytes": table_info.get("TableSizeBytes"),
|
||||
"KeySchema": table_info.get("KeySchema"),
|
||||
"AttributeDefinitions": table_info.get("AttributeDefinitions"),
|
||||
"CreationDateTime": str(table_info.get("CreationDateTime")),
|
||||
}
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
error_message = e.response["Error"]["Message"]
|
||||
print(f"Error describing table: {error_code} - {error_message}")
|
||||
raise
|
||||
|
||||
|
||||
def read_table_as_xml(table_name: str, region_name: str = "us-east-1") -> str:
|
||||
"""
|
||||
Read all entries from a DynamoDB table and return as XML string.
|
||||
|
||||
Args:
|
||||
table_name: Name of the DynamoDB table to read
|
||||
region_name: AWS region where the table is located (default: us-east-1)
|
||||
|
||||
Returns:
|
||||
XML formatted string with all items wrapped in <period> tags
|
||||
"""
|
||||
items = scan_table(table_name, region_name)
|
||||
return format_items_to_xml(items)
|
||||
if __name__=="__main__":
|
||||
print(read_table_as_xml("poc_dnx_monthly_summary","us-east-1"))
|
||||
Reference in New Issue
Block a user