#!/usr/bin/env python3 """ AWS Athena Query Script with Pagination This script executes an Athena query and retrieves all results using pagination to overcome the 1000 row limit per API call. """ import boto3 import time import csv import json from typing import List, Dict, Any, Optional from datetime import datetime class AthenaQueryExecutor: """Execute Athena queries with automatic pagination and result retrieval.""" def __init__( self, database: str, output_location: str, region_name: str = 'us-east-1', max_results_per_page: int = 1000 ): """ Initialize the Athena query executor. Args: database: The Athena database name output_location: S3 location for query results (e.g., 's3://bucket/path/') region_name: AWS region name max_results_per_page: Maximum results per API call (max 1000) """ self.client = boto3.client('athena', region_name=region_name) self.database = database self.output_location = output_location self.max_results_per_page = min(max_results_per_page, 1000) def execute_query(self, query: str, wait: bool = True) -> str: """ Execute an Athena query and return the query execution ID. Args: query: SQL query string wait: Whether to wait for query completion Returns: Query execution ID """ response = self.client.start_query_execution( QueryString=query, QueryExecutionContext={'Database': self.database}, ResultConfiguration={'OutputLocation': self.output_location} ) query_execution_id = response['QueryExecutionId'] print(f"Query submitted. Execution ID: {query_execution_id}") if wait: self._wait_for_query_completion(query_execution_id) return query_execution_id def _wait_for_query_completion(self, query_execution_id: str, poll_interval: int = 2): """ Wait for query to complete execution. Args: query_execution_id: The query execution ID poll_interval: Seconds between status checks """ print("Waiting for query to complete...") while True: response = self.client.get_query_execution( QueryExecutionId=query_execution_id ) state = response['QueryExecution']['Status']['State'] if state == 'SUCCEEDED': print("Query completed successfully!") stats = response['QueryExecution']['Statistics'] print(f"Data scanned: {stats.get('DataScannedInBytes', 0) / (1024**3):.2f} GB") print(f"Execution time: {stats.get('EngineExecutionTimeInMillis', 0) / 1000:.2f} seconds") break elif state in ['FAILED', 'CANCELLED']: reason = response['QueryExecution']['Status'].get('StateChangeReason', 'Unknown') raise Exception(f"Query {state.lower()}: {reason}") time.sleep(poll_interval) def get_all_results(self, query_execution_id: str) -> List[Dict[str, Any]]: """ Retrieve all query results using pagination. Args: query_execution_id: The query execution ID Returns: List of result rows as dictionaries """ all_results = [] next_token = None page_count = 0 print("Fetching results with pagination...") while True: page_count += 1 # Build request parameters params = { 'QueryExecutionId': query_execution_id, 'MaxResults': self.max_results_per_page } if next_token: params['NextToken'] = next_token # Get results page response = self.client.get_query_results(**params) # Extract column names from first page if page_count == 1: columns = [col['Name'] for col in response['ResultSet']['ResultSetMetadata']['ColumnInfo']] # Skip header row in first page rows = response['ResultSet']['Rows'][1:] else: rows = response['ResultSet']['Rows'] # Convert rows to dictionaries for row in rows: values = [field.get('VarCharValue', '') for field in row['Data']] all_results.append(dict(zip(columns, values))) print(f"Page {page_count}: Retrieved {len(rows)} rows (Total: {len(all_results)})") # Check if there are more results next_token = response.get('NextToken') if not next_token: break print(f"\nTotal rows retrieved: {len(all_results)}") return all_results def query_and_fetch_all(self, query: str) -> List[Dict[str, Any]]: """ Execute query and fetch all results in one call. Args: query: SQL query string Returns: List of result rows as dictionaries """ query_execution_id = self.execute_query(query, wait=True) return self.get_all_results(query_execution_id) def export_to_csv(self, results: List[Dict[str, Any]], filename: str): """ Export results to CSV file. Args: results: List of result dictionaries filename: Output CSV filename """ if not results: print("No results to export") return with open(filename, 'w', newline='', encoding='utf-8') as f: writer = csv.DictWriter(f, fieldnames=results[0].keys()) writer.writeheader() writer.writerows(results) print(f"Results exported to {filename}") def export_to_json(self, results: List[Dict[str, Any]], filename: str): """ Export results to JSON file. Args: results: List of result dictionaries filename: Output JSON filename """ with open(filename, 'w', encoding='utf-8') as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"Results exported to {filename}") def main(): """Example usage of the AthenaQueryExecutor.""" # Configuration DATABASE = 'your_database_name' OUTPUT_LOCATION = 's3://your-bucket/athena-results/' REGION = 'us-east-1' # Example query QUERY = """ SELECT * FROM your_table WHERE date >= '2024-01-01' LIMIT 5000 """ # Initialize executor executor = AthenaQueryExecutor( database=DATABASE, output_location=OUTPUT_LOCATION, region_name=REGION ) # Execute query and fetch all results try: results = executor.query_and_fetch_all(QUERY) # Export results timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') executor.export_to_csv(results, f'athena_results_{timestamp}.csv') executor.export_to_json(results, f'athena_results_{timestamp}.json') # Display sample results if results: print("\nFirst 5 results:") for i, row in enumerate(results[:5], 1): print(f"{i}. {row}") except Exception as e: print(f"Error: {e}") return 1 return 0 if __name__ == '__main__': exit(main())