Files
AI-inovyo-assistende-db/scripts/athena_query_paginated.py
2026-05-14 15:29:03 -03:00

242 lines
7.5 KiB
Python

#!/usr/bin/env python3
"""
AWS Athena Query Script with Pagination
This script executes an Athena query and retrieves all results using pagination
to overcome the 1000 row limit per API call.
"""
import boto3
import time
import csv
import json
from typing import List, Dict, Any, Optional
from datetime import datetime
class AthenaQueryExecutor:
"""Execute Athena queries with automatic pagination and result retrieval."""
def __init__(
self,
database: str,
output_location: str,
region_name: str = 'us-east-1',
max_results_per_page: int = 1000
):
"""
Initialize the Athena query executor.
Args:
database: The Athena database name
output_location: S3 location for query results (e.g., 's3://bucket/path/')
region_name: AWS region name
max_results_per_page: Maximum results per API call (max 1000)
"""
self.client = boto3.client('athena', region_name=region_name)
self.database = database
self.output_location = output_location
self.max_results_per_page = min(max_results_per_page, 1000)
def execute_query(self, query: str, wait: bool = True) -> str:
"""
Execute an Athena query and return the query execution ID.
Args:
query: SQL query string
wait: Whether to wait for query completion
Returns:
Query execution ID
"""
response = self.client.start_query_execution(
QueryString=query,
QueryExecutionContext={'Database': self.database},
ResultConfiguration={'OutputLocation': self.output_location}
)
query_execution_id = response['QueryExecutionId']
print(f"Query submitted. Execution ID: {query_execution_id}")
if wait:
self._wait_for_query_completion(query_execution_id)
return query_execution_id
def _wait_for_query_completion(self, query_execution_id: str, poll_interval: int = 2):
"""
Wait for query to complete execution.
Args:
query_execution_id: The query execution ID
poll_interval: Seconds between status checks
"""
print("Waiting for query to complete...")
while True:
response = self.client.get_query_execution(
QueryExecutionId=query_execution_id
)
state = response['QueryExecution']['Status']['State']
if state == 'SUCCEEDED':
print("Query completed successfully!")
stats = response['QueryExecution']['Statistics']
print(f"Data scanned: {stats.get('DataScannedInBytes', 0) / (1024**3):.2f} GB")
print(f"Execution time: {stats.get('EngineExecutionTimeInMillis', 0) / 1000:.2f} seconds")
break
elif state in ['FAILED', 'CANCELLED']:
reason = response['QueryExecution']['Status'].get('StateChangeReason', 'Unknown')
raise Exception(f"Query {state.lower()}: {reason}")
time.sleep(poll_interval)
def get_all_results(self, query_execution_id: str) -> List[Dict[str, Any]]:
"""
Retrieve all query results using pagination.
Args:
query_execution_id: The query execution ID
Returns:
List of result rows as dictionaries
"""
all_results = []
next_token = None
page_count = 0
print("Fetching results with pagination...")
while True:
page_count += 1
# Build request parameters
params = {
'QueryExecutionId': query_execution_id,
'MaxResults': self.max_results_per_page
}
if next_token:
params['NextToken'] = next_token
# Get results page
response = self.client.get_query_results(**params)
# Extract column names from first page
if page_count == 1:
columns = [col['Name'] for col in response['ResultSet']['ResultSetMetadata']['ColumnInfo']]
# Skip header row in first page
rows = response['ResultSet']['Rows'][1:]
else:
rows = response['ResultSet']['Rows']
# Convert rows to dictionaries
for row in rows:
values = [field.get('VarCharValue', '') for field in row['Data']]
all_results.append(dict(zip(columns, values)))
print(f"Page {page_count}: Retrieved {len(rows)} rows (Total: {len(all_results)})")
# Check if there are more results
next_token = response.get('NextToken')
if not next_token:
break
print(f"\nTotal rows retrieved: {len(all_results)}")
return all_results
def query_and_fetch_all(self, query: str) -> List[Dict[str, Any]]:
"""
Execute query and fetch all results in one call.
Args:
query: SQL query string
Returns:
List of result rows as dictionaries
"""
query_execution_id = self.execute_query(query, wait=True)
return self.get_all_results(query_execution_id)
def export_to_csv(self, results: List[Dict[str, Any]], filename: str):
"""
Export results to CSV file.
Args:
results: List of result dictionaries
filename: Output CSV filename
"""
if not results:
print("No results to export")
return
with open(filename, 'w', newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=results[0].keys())
writer.writeheader()
writer.writerows(results)
print(f"Results exported to {filename}")
def export_to_json(self, results: List[Dict[str, Any]], filename: str):
"""
Export results to JSON file.
Args:
results: List of result dictionaries
filename: Output JSON filename
"""
with open(filename, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"Results exported to {filename}")
def main():
"""Example usage of the AthenaQueryExecutor."""
# Configuration
DATABASE = 'your_database_name'
OUTPUT_LOCATION = 's3://your-bucket/athena-results/'
REGION = 'us-east-1'
# Example query
QUERY = """
SELECT *
FROM your_table
WHERE date >= '2024-01-01'
LIMIT 5000
"""
# Initialize executor
executor = AthenaQueryExecutor(
database=DATABASE,
output_location=OUTPUT_LOCATION,
region_name=REGION
)
# Execute query and fetch all results
try:
results = executor.query_and_fetch_all(QUERY)
# Export results
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
executor.export_to_csv(results, f'athena_results_{timestamp}.csv')
executor.export_to_json(results, f'athena_results_{timestamp}.json')
# Display sample results
if results:
print("\nFirst 5 results:")
for i, row in enumerate(results[:5], 1):
print(f"{i}. {row}")
except Exception as e:
print(f"Error: {e}")
return 1
return 0
if __name__ == '__main__':
exit(main())