Adds initial files]
This commit is contained in:
404
scripts/process_images_batch.py
Executable file
404
scripts/process_images_batch.py
Executable file
@@ -0,0 +1,404 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Batch process images from S3 using AWS Textract.
|
||||
Iterates through folders (prefixes) in an S3 bucket and processes any PDF, PNG, or JPEG files
|
||||
that haven't been processed yet (checking for existing textract output files).
|
||||
Saves both JSON and plain text outputs locally.
|
||||
"""
|
||||
|
||||
import boto3
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
import time
|
||||
from PyPDF2 import PdfReader
|
||||
|
||||
|
||||
def get_s3_client():
|
||||
"""Initialize and return AWS S3 client."""
|
||||
return boto3.client('s3',region_name="us-east-2")
|
||||
|
||||
|
||||
def get_textract_client():
|
||||
"""Initialize and return AWS Textract client."""
|
||||
return boto3.client('textract',region_name="us-east-2")
|
||||
|
||||
|
||||
def get_pdf_page_count(pdf_bytes: bytes) -> int:
|
||||
"""
|
||||
Get the number of pages in a PDF file.
|
||||
|
||||
Args:
|
||||
pdf_bytes: PDF file content as bytes
|
||||
|
||||
Returns:
|
||||
int: Number of pages in the PDF
|
||||
"""
|
||||
try:
|
||||
pdf_reader = PdfReader(io.BytesIO(pdf_bytes))
|
||||
return len(pdf_reader.pages)
|
||||
except Exception as e:
|
||||
print(f" Warning: Could not determine page count: {str(e)}")
|
||||
return 1
|
||||
|
||||
|
||||
def is_already_processed(s3_key: str, output_dir: Path) -> bool:
|
||||
"""
|
||||
Check if an image has already been processed by looking for output file.
|
||||
|
||||
Args:
|
||||
s3_key: S3 object key
|
||||
output_dir: Directory where output files are stored
|
||||
|
||||
Returns:
|
||||
bool: True if output file exists, False otherwise
|
||||
"""
|
||||
filename = Path(s3_key).stem
|
||||
output_file = output_dir / f"{filename}_textract.json"
|
||||
return output_file.exists()
|
||||
|
||||
|
||||
def process_image_from_s3(bucket_name: str, s3_key: str) -> Dict:
|
||||
"""
|
||||
Process an image file from S3 with AWS Textract.
|
||||
Supports PDF, PNG, and JPEG formats.
|
||||
Uses async API (start_document_text_detection) for multi-page PDFs,
|
||||
and sync API (detect_document_text) for single-page PDFs and images.
|
||||
|
||||
Args:
|
||||
bucket_name: S3 bucket name
|
||||
s3_key: S3 object key
|
||||
|
||||
Returns:
|
||||
dict: Textract response containing detected text
|
||||
"""
|
||||
textract = get_textract_client()
|
||||
s3 = get_s3_client()
|
||||
|
||||
try:
|
||||
# Verify the object exists first
|
||||
try:
|
||||
s3.head_object(Bucket=bucket_name, Key=s3_key)
|
||||
except Exception as e:
|
||||
print(f" Error accessing S3 object: {str(e)}")
|
||||
print(f" Bucket: {bucket_name}")
|
||||
print(f" Key: {s3_key}")
|
||||
return None
|
||||
|
||||
file_ext = Path(s3_key).suffix.lower()
|
||||
|
||||
# For images (PNG, JPEG), always use sync API
|
||||
if file_ext in ['.png', '.jpg', '.jpeg']:
|
||||
print(f" Processing image with sync API")
|
||||
response = textract.detect_document_text(
|
||||
Document={
|
||||
'S3Object': {
|
||||
'Bucket': bucket_name,
|
||||
'Name': s3_key
|
||||
}
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
||||
# For PDFs, check page count to decide which API to use
|
||||
if file_ext == '.pdf':
|
||||
# Download PDF to check page count
|
||||
response = s3.get_object(Bucket=bucket_name, Key=s3_key)
|
||||
pdf_bytes = response['Body'].read()
|
||||
page_count = get_pdf_page_count(pdf_bytes)
|
||||
|
||||
print(f" PDF has {page_count} page(s)")
|
||||
|
||||
# Use async API for multi-page PDFs
|
||||
if page_count > 1:
|
||||
print(f" Using async API (start_document_text_detection) for multi-page PDF")
|
||||
response = textract.start_document_text_detection(
|
||||
DocumentLocation={
|
||||
'S3Object': {
|
||||
'Bucket': bucket_name,
|
||||
'Name': s3_key
|
||||
}
|
||||
}
|
||||
)
|
||||
job_id = response['JobId']
|
||||
print(f" Started async job: {job_id}")
|
||||
|
||||
# Wait for job to complete
|
||||
while True:
|
||||
result = textract.get_document_text_detection(JobId=job_id)
|
||||
status = result['JobStatus']
|
||||
|
||||
if status == 'SUCCEEDED':
|
||||
return result
|
||||
elif status == 'FAILED':
|
||||
print(f" Job failed: {result.get('StatusMessage', 'Unknown error')}")
|
||||
return None
|
||||
|
||||
time.sleep(2)
|
||||
else:
|
||||
# Use sync API for single-page PDFs
|
||||
print(f" Using sync API (detect_document_text) for single-page PDF")
|
||||
response = textract.detect_document_text(
|
||||
Document={
|
||||
'S3Object': {
|
||||
'Bucket': bucket_name,
|
||||
'Name': s3_key
|
||||
}
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f" Error processing {s3_key}: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_text_from_response(response: Dict) -> str:
|
||||
"""
|
||||
Extract plain text from Textract response.
|
||||
|
||||
Args:
|
||||
response: Textract API response
|
||||
|
||||
Returns:
|
||||
str: Extracted text
|
||||
"""
|
||||
if not response:
|
||||
return ""
|
||||
|
||||
text_lines = []
|
||||
for block in response.get('Blocks', []):
|
||||
if block['BlockType'] == 'LINE':
|
||||
text_lines.append(block['Text'])
|
||||
|
||||
return '\n'.join(text_lines)
|
||||
|
||||
|
||||
def save_textract_output(s3_key: str, response: Dict, output_dir: Path):
|
||||
"""
|
||||
Save Textract response to JSON file and plain text file locally.
|
||||
|
||||
Args:
|
||||
s3_key: S3 object key
|
||||
response: Textract API response
|
||||
output_dir: Directory to save output files
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = Path(s3_key).stem
|
||||
|
||||
# Extract text
|
||||
extracted_text = extract_text_from_response(response)
|
||||
|
||||
# Save JSON output
|
||||
json_output_file = output_dir / f"{filename}_textract.json"
|
||||
if response:
|
||||
response['extracted_text'] = extracted_text
|
||||
response['source_s3_key'] = s3_key
|
||||
|
||||
with open(json_output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(response, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f" ✓ Saved JSON to: {json_output_file.name}")
|
||||
|
||||
# Save plain text output
|
||||
text_output_file = output_dir / f"{filename}.txt"
|
||||
with open(text_output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(extracted_text)
|
||||
|
||||
print(f" ✓ Saved text to: {text_output_file.name}")
|
||||
|
||||
|
||||
def get_supported_images_from_s3(bucket_name: str, prefix: str) -> List[str]:
|
||||
"""
|
||||
Get list of supported image files in an S3 prefix (folder).
|
||||
Filters out files containing 'script' (case-insensitive).
|
||||
|
||||
Args:
|
||||
bucket_name: S3 bucket name
|
||||
prefix: S3 prefix (folder path)
|
||||
|
||||
Returns:
|
||||
List of S3 keys for supported image files
|
||||
"""
|
||||
s3 = get_s3_client()
|
||||
supported_extensions = {'.pdf', '.png', '.jpg', '.jpeg'}
|
||||
images = []
|
||||
|
||||
# Ensure prefix ends with / if it's not empty
|
||||
if prefix and not prefix.endswith('/'):
|
||||
prefix += '/'
|
||||
|
||||
paginator = s3.get_paginator('list_objects_v2')
|
||||
pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter='/')
|
||||
|
||||
for page in pages:
|
||||
for obj in page.get('Contents', []):
|
||||
key = obj['Key']
|
||||
file_path = Path(key)
|
||||
|
||||
# Check if it's a file (not a folder) and has supported extension
|
||||
if file_path.suffix.lower() in supported_extensions:
|
||||
# Filter out files containing 'script' (case-insensitive)
|
||||
if 'script' not in file_path.name.lower():
|
||||
images.append(key)
|
||||
|
||||
return sorted(images)
|
||||
|
||||
|
||||
def get_folders_from_s3(bucket_name: str, base_prefix: str = '') -> List[str]:
|
||||
"""
|
||||
Get list of folders (prefixes) in S3 bucket.
|
||||
|
||||
Args:
|
||||
bucket_name: S3 bucket name
|
||||
base_prefix: Base prefix to search under
|
||||
|
||||
Returns:
|
||||
List of folder prefixes
|
||||
"""
|
||||
s3 = get_s3_client()
|
||||
folders = []
|
||||
|
||||
# Ensure prefix ends with / if it's not empty
|
||||
if base_prefix and not base_prefix.endswith('/'):
|
||||
base_prefix += '/'
|
||||
|
||||
paginator = s3.get_paginator('list_objects_v2')
|
||||
pages = paginator.paginate(Bucket=bucket_name, Prefix=base_prefix, Delimiter='/')
|
||||
|
||||
for page in pages:
|
||||
for prefix_info in page.get('CommonPrefixes', []):
|
||||
folders.append(prefix_info['Prefix'])
|
||||
|
||||
return folders
|
||||
|
||||
|
||||
def process_folder(bucket_name: str, prefix: str, output_base_dir: Path, skip_existing: bool = True):
|
||||
"""
|
||||
Process all images in an S3 folder (prefix).
|
||||
|
||||
Args:
|
||||
bucket_name: S3 bucket name
|
||||
prefix: S3 prefix (folder path)
|
||||
output_base_dir: Base directory for output files
|
||||
skip_existing: Whether to skip already processed files
|
||||
"""
|
||||
folder_name = prefix.rstrip('/').split('/')[-1] or 'root'
|
||||
output_dir = output_base_dir / folder_name
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"Processing folder: {prefix}")
|
||||
print(f"{'='*80}")
|
||||
|
||||
images = get_supported_images_from_s3(bucket_name, prefix)
|
||||
|
||||
if not images:
|
||||
print(f" No supported images found (PDF, PNG, JPEG)")
|
||||
return
|
||||
|
||||
print(f" Found {len(images)} image(s)")
|
||||
|
||||
processed_count = 0
|
||||
skipped_count = 0
|
||||
error_count = 0
|
||||
|
||||
for s3_key in images:
|
||||
filename = Path(s3_key).name
|
||||
print(f"\n Processing: {filename}")
|
||||
|
||||
# Check if already processed
|
||||
if skip_existing and is_already_processed(s3_key, output_dir):
|
||||
print(f" ⊘ Skipped (already processed)")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# Process with Textract
|
||||
response = process_image_from_s3(bucket_name, s3_key)
|
||||
|
||||
if response:
|
||||
# Save output (both JSON and text)
|
||||
save_textract_output(s3_key, response, output_dir)
|
||||
|
||||
# Print summary
|
||||
num_blocks = len(response.get('Blocks', []))
|
||||
text_length = len(extract_text_from_response(response))
|
||||
print(f" ℹ Extracted {text_length} characters, {num_blocks} blocks")
|
||||
|
||||
processed_count += 1
|
||||
|
||||
# Small delay to avoid rate limiting
|
||||
time.sleep(0.5)
|
||||
else:
|
||||
error_count += 1
|
||||
|
||||
print(f"\n Summary for {folder_name}:")
|
||||
print(f" Processed: {processed_count}")
|
||||
print(f" Skipped: {skipped_count}")
|
||||
print(f" Errors: {error_count}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the script."""
|
||||
# Get bucket name from environment or command line
|
||||
bucket_name = os.environ.get('S3_BUCKET_NAME')
|
||||
base_prefix = os.environ.get('S3_BASE_PREFIX', 'imagens')
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
bucket_name = sys.argv[1]
|
||||
if len(sys.argv) > 2:
|
||||
base_prefix = sys.argv[2]
|
||||
|
||||
if not bucket_name:
|
||||
print("Error: S3 bucket name not provided.")
|
||||
print("\nUsage:")
|
||||
print(" python process_images_batch.py <bucket_name> [base_prefix]")
|
||||
print("\nOr set environment variables:")
|
||||
print(" export S3_BUCKET_NAME=my-bucket")
|
||||
print(" export S3_BASE_PREFIX=imagens")
|
||||
print(" python process_images_batch.py")
|
||||
sys.exit(1)
|
||||
|
||||
# Get output directory
|
||||
script_dir = Path(__file__).parent
|
||||
output_base_dir = script_dir / "textract_output"
|
||||
|
||||
print(f"S3 Bucket: {bucket_name}")
|
||||
print(f"Base prefix: {base_prefix}")
|
||||
print(f"Output directory: {output_base_dir}")
|
||||
|
||||
# Get all folders (prefixes) in the bucket
|
||||
print(f"\nScanning S3 bucket for folders...")
|
||||
folders = get_folders_from_s3(bucket_name, base_prefix)
|
||||
|
||||
if not folders:
|
||||
print(f"\nNo subdirectories found under '{base_prefix}'.")
|
||||
print("Processing files in the base prefix instead...")
|
||||
folders = [base_prefix]
|
||||
else:
|
||||
print(f"\nFound {len(folders)} folder(s) to process")
|
||||
|
||||
# Process each folder
|
||||
total_start = time.time()
|
||||
|
||||
for prefix in folders:
|
||||
try:
|
||||
process_folder(bucket_name, prefix, output_base_dir)
|
||||
except Exception as e:
|
||||
print(f"\nError processing folder {prefix}: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
total_time = time.time() - total_start
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"Batch processing complete!")
|
||||
print(f"Total time: {total_time:.2f} seconds")
|
||||
print(f"{'='*80}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user