Files
AI-coodex-rekog-image-labeling/label/infra/handler.py
2026-05-14 14:07:04 -03:00

736 lines
28 KiB
Python

import boto3
import json
import base64
from io import BytesIO
from PIL import Image, ImageDraw
import numpy as np
from scipy.optimize import linear_sum_assignment
import re
from pdf2image import convert_from_bytes
# Configuration
REGION = 'us-east-1'
CUSTOM_LABELS_PROJECT_ARN = 'modelid'
CONFIDENCE_THRESHOLD = 80
class InMemoryDiagramProcessor:
"""Process diagrams entirely in memory for Lambda"""
def __init__(self, region=REGION, custom_labels_arn=CUSTOM_LABELS_PROJECT_ARN):
self.textract_client = boto3.client('textract', region_name=region)
self.rekognition_client = boto3.client('rekognition', region_name=region)
self.custom_labels_arn = custom_labels_arn
self.region = region
def segment_image(self, img, grid_size=(5, 5), overlap_percent=10):
"""
Segment PIL Image into grid with overlap (in-memory)
Returns list of (PIL Image, position_info) tuples
"""
img_width, img_height = img.size
rows, cols = grid_size
overlap_factor = overlap_percent / 100.0
segment_width = img_width / cols
segment_height = img_height / rows
step_width = segment_width * (1 - overlap_factor)
step_height = segment_height * (1 - overlap_factor)
segments = []
for row in range(rows):
for col in range(cols):
left = int(col * step_width)
top = int(row * step_height)
right = int(min(left + segment_width, img_width))
bottom = int(min(top + segment_height, img_height))
segment = img.crop((left, top, right, bottom))
position_info = {
'row': row,
'col': col,
'left': left,
'top': top,
'right': right,
'bottom': bottom,
'width': right - left,
'height': bottom - top
}
segments.append((segment, position_info))
return segments
def pil_to_bytes(self, pil_image):
"""Convert PIL Image to bytes for AWS API calls"""
buffer = BytesIO()
pil_image.save(buffer, format='PNG')
return buffer.getvalue()
def detect_text_segment(self, segment_image):
"""Detect text in PIL Image segment using Textract"""
image_bytes = self.pil_to_bytes(segment_image)
result = self.textract_client.detect_document_text(
Document={'Bytes': image_bytes}
)
return result
def clean_text_from_segment(self, segment_image, textract_data,
shrink_percent=8.5, keep_regex_list=None, min_confidence=80):
"""Remove text from PIL Image segment (in-memory)"""
compiled_patterns = []
if keep_regex_list:
for pattern in keep_regex_list:
try:
compiled_patterns.append(re.compile(pattern))
except re.error:
pass
img = segment_image.copy()
width, height = img.size
draw = ImageDraw.Draw(img)
words_removed = 0
words_kept = 0
for block in textract_data['Blocks']:
if block['BlockType'] == 'WORD':
text = block['Text']
confidence = block['Confidence']
should_keep = False
if confidence < min_confidence:
should_keep = True
words_kept += 1
if compiled_patterns:
for pattern in compiled_patterns:
if pattern.match(text):
should_keep = True
words_kept += 1
break
if should_keep:
continue
bbox = block['Geometry']['BoundingBox']
left = int(bbox['Left'] * width)
top = int(bbox['Top'] * height)
box_width = int(bbox['Width'] * width)
box_height = int(bbox['Height'] * height)
if shrink_percent > 0:
shrink_factor = shrink_percent / 100
width_reduction = int(box_width * shrink_factor / 2)
height_reduction = int(box_height * shrink_factor / 2)
left += width_reduction
top += height_reduction
box_width -= width_reduction * 2
box_height -= height_reduction * 2
draw.rectangle(
[(left, top), (left + box_width, top + box_height)],
fill='white'
)
words_removed += 1
return img, {'words_removed': words_removed, 'words_kept': words_kept}
def recognize_objects_segment(self, segment_image, min_confidence=CONFIDENCE_THRESHOLD):
"""Recognize objects in PIL Image using Custom Labels"""
image_bytes = self.pil_to_bytes(segment_image)
try:
response = self.rekognition_client.detect_custom_labels(
ProjectVersionArn=self.custom_labels_arn,
Image={'Bytes': image_bytes},
MinConfidence=min_confidence
)
return {
'custom_labels': response.get('CustomLabels', []),
'success': True
}
except Exception as e:
return {
'custom_labels': [],
'success': False,
'error': str(e)
}
def calculate_iou(self, box1, box2):
"""Calculate IoU between two bounding boxes"""
x_left = max(box1['left'], box2['left'])
y_top = max(box1['top'], box2['top'])
x_right = min(box1['right'], box2['right'])
y_bottom = min(box1['bottom'], box2['bottom'])
if x_right < x_left or y_bottom < y_top:
return 0.0
intersection_area = (x_right - x_left) * (y_bottom - y_top)
box1_area = (box1['right'] - box1['left']) * (box1['bottom'] - box1['top'])
box2_area = (box2['right'] - box2['left']) * (box2['bottom'] - box2['top'])
union_area = box1_area + box2_area - intersection_area
if union_area == 0:
return 0.0
return intersection_area / union_area
def merge_bounding_boxes(self, boxes):
"""Merge multiple bounding boxes into one"""
if not boxes:
return None
return {
'left': min(box['left'] for box in boxes),
'top': min(box['top'] for box in boxes),
'right': max(box['right'] for box in boxes),
'bottom': max(box['bottom'] for box in boxes)
}
def deduplicate_detections(self, all_detections, iou_threshold=0.3):
"""Remove duplicate detections using NMS"""
if not all_detections:
return []
detections_by_label = {}
for det in all_detections:
label = det['Name']
if label not in detections_by_label:
detections_by_label[label] = []
detections_by_label[label].append(det)
deduplicated = []
for label, detections in detections_by_label.items():
detections = sorted(detections, key=lambda x: x['Confidence'], reverse=True)
groups = []
used = set()
for i, det in enumerate(detections):
if i in used:
continue
group = [det]
used.add(i)
for j, other_det in enumerate(detections):
if j in used or j == i:
continue
iou = self.calculate_iou(det['global_bbox'], other_det['global_bbox'])
if iou > iou_threshold:
group.append(other_det)
used.add(j)
groups.append(group)
for group in groups:
if len(group) == 1:
deduplicated.append(group[0])
else:
merged_bbox = self.merge_bounding_boxes([d['global_bbox'] for d in group])
merged_bbox['width'] = merged_bbox['right'] - merged_bbox['left']
merged_bbox['height'] = merged_bbox['bottom'] - merged_bbox['top']
avg_confidence = sum(d['Confidence'] for d in group) / len(group)
merged_detection = {
'Name': label,
'Confidence': avg_confidence,
'global_bbox': merged_bbox,
'merged_from': len(group)
}
deduplicated.append(merged_detection)
return deduplicated
def deduplicate_text_detections(self, all_text_detections, iou_threshold=0.5):
"""Remove duplicate text detections"""
if not all_text_detections:
return []
all_text_detections = sorted(all_text_detections, key=lambda x: x['confidence'], reverse=True)
deduplicated = []
used = set()
for i, text_det in enumerate(all_text_detections):
if i in used:
continue
group = [text_det]
used.add(i)
for j, other_det in enumerate(all_text_detections):
if j in used or j == i:
continue
if text_det['text'].lower() == other_det['text'].lower():
iou = self.calculate_iou(text_det['global_bbox'], other_det['global_bbox'])
if iou > iou_threshold:
group.append(other_det)
used.add(j)
deduplicated.append(text_det)
return deduplicated
def get_bbox_center(self, bbox):
"""Get center point of bounding box"""
center_x = bbox['left'] + bbox['width'] / 2
center_y = bbox['top'] + bbox['height'] / 2
return (center_x, center_y)
def calculate_distance(self, center1, center2):
"""Calculate Euclidean distance"""
return np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
def match_objects_to_text_by_type(self, objects, all_text_detections, max_distance=200):
"""Match objects to text based on object type"""
VM_LABEL_OBJECTS = ['globo', 'gaveta', 'retencao', 'espera']
TWO_LABEL_OBJECTS = ['sis_con_dist', 'instrumento_local']
vm_label_objects = []
two_label_objects = []
single_label_objects = []
for obj in objects:
obj_name = obj['Name'].lower()
if obj_name in VM_LABEL_OBJECTS:
vm_label_objects.append(obj)
elif obj_name in TWO_LABEL_OBJECTS:
two_label_objects.append(obj)
else:
single_label_objects.append(obj)
vm_pattern = re.compile(r'VM-\d{4}')
vm_texts = [t for t in all_text_detections if vm_pattern.search(t['text'])]
other_texts = [t for t in all_text_detections if not vm_pattern.search(t['text'])]
all_matches = []
all_unmatched_objects = []
all_unmatched_texts = []
used_texts = set()
# Part 1: Match VM-#### objects using Hungarian algorithm
if vm_label_objects and vm_texts:
n_objects = len(vm_label_objects)
n_texts = len(vm_texts)
max_dim = max(n_objects, n_texts)
cost_matrix = np.full((max_dim, max_dim), 1e10)
for i, obj in enumerate(vm_label_objects):
obj_center = self.get_bbox_center(obj['global_bbox'])
for j, text_data in enumerate(vm_texts):
text_center = self.get_bbox_center(text_data['global_bbox'])
distance = self.calculate_distance(obj_center, text_center)
if max_distance and distance > max_distance:
cost_matrix[i, j] = 1e10
else:
cost_matrix[i, j] = distance
row_indices, col_indices = linear_sum_assignment(cost_matrix)
matched_obj_indices = set()
matched_text_indices = set()
for obj_idx, text_idx in zip(row_indices, col_indices):
if (obj_idx >= n_objects or text_idx >= n_texts or
cost_matrix[obj_idx, text_idx] >= 1e10):
continue
distance = cost_matrix[obj_idx, text_idx]
match = {
'object_name': vm_label_objects[obj_idx]['Name'],
'object_bbox': vm_label_objects[obj_idx]['global_bbox'],
'object_confidence': vm_label_objects[obj_idx]['Confidence'],
'text': vm_texts[text_idx]['text'],
'text_bbox': vm_texts[text_idx]['global_bbox'],
'text_confidence': vm_texts[text_idx]['confidence'],
'distance': distance,
'match_type': 'vm_label'
}
all_matches.append(match)
matched_obj_indices.add(obj_idx)
matched_text_indices.add(text_idx)
all_unmatched_objects.extend([vm_label_objects[i] for i in range(n_objects)
if i not in matched_obj_indices])
all_unmatched_texts.extend([vm_texts[j] for j in range(n_texts)
if j not in matched_text_indices])
# Part 2: Match two-label objects
for obj in two_label_objects:
obj_bbox = obj['global_bbox']
obj_center_x = obj_bbox['left'] + obj_bbox['width'] / 2
obj_center_y = obj_bbox['top'] + obj_bbox['height'] / 2
texts_inside = []
for text_data in other_texts:
if id(text_data) in used_texts:
continue
text_bbox = text_data['global_bbox']
text_center_x = text_bbox['left'] + text_bbox['width'] / 2
text_center_y = text_bbox['top'] + text_bbox['height'] / 2
if (obj_bbox['left'] <= text_center_x <= obj_bbox['right'] and
obj_bbox['top'] <= text_center_y <= obj_bbox['bottom']):
distance_to_center = self.calculate_distance(
(obj_center_x, obj_center_y),
(text_center_x, text_center_y)
)
texts_inside.append({
'text_data': text_data,
'distance_to_center': distance_to_center,
'y_position': text_center_y
})
if len(texts_inside) >= 2:
texts_inside.sort(key=lambda t: t['distance_to_center'])
closest_two = texts_inside[:2]
closest_two.sort(key=lambda t: t['y_position'])
top_text = closest_two[0]['text_data']
bottom_text = closest_two[1]['text_data']
match = {
'object_name': obj['Name'],
'object_bbox': obj_bbox,
'object_confidence': obj['Confidence'],
'text': f"{top_text['text']} / {bottom_text['text']}",
'text_top': top_text['text'],
'text_bottom': bottom_text['text'],
'text_bbox_top': top_text['global_bbox'],
'text_bbox_bottom': bottom_text['global_bbox'],
'text_confidence_top': top_text['confidence'],
'text_confidence_bottom': bottom_text['confidence'],
'distance': 0,
'match_type': 'two_labels'
}
all_matches.append(match)
used_texts.add(id(top_text))
used_texts.add(id(bottom_text))
else:
all_unmatched_objects.append(obj)
# Part 3: Match single-label objects
for obj in single_label_objects:
obj_bbox = obj['global_bbox']
obj_center_x = obj_bbox['left'] + obj_bbox['width'] / 2
obj_center_y = obj_bbox['top'] + obj_bbox['height'] / 2
texts_inside = []
for text_data in other_texts:
if id(text_data) in used_texts:
continue
text_bbox = text_data['global_bbox']
text_center_x = text_bbox['left'] + text_bbox['width'] / 2
text_center_y = text_bbox['top'] + text_bbox['height'] / 2
if (obj_bbox['left'] <= text_center_x <= obj_bbox['right'] and
obj_bbox['top'] <= text_center_y <= obj_bbox['bottom']):
texts_inside.append(text_data)
if texts_inside:
closest_text = min(texts_inside, key=lambda t: self.calculate_distance(
(obj_center_x, obj_center_y),
(t['global_bbox']['left'] + t['global_bbox']['width'] / 2,
t['global_bbox']['top'] + t['global_bbox']['height'] / 2)
))
text_center_x = closest_text['global_bbox']['left'] + closest_text['global_bbox']['width'] / 2
text_center_y = closest_text['global_bbox']['top'] + closest_text['global_bbox']['height'] / 2
distance_to_center = self.calculate_distance(
(obj_center_x, obj_center_y),
(text_center_x, text_center_y)
)
match = {
'object_name': obj['Name'],
'object_bbox': obj_bbox,
'object_confidence': obj['Confidence'],
'text': closest_text['text'],
'text_bbox': closest_text['global_bbox'],
'text_confidence': closest_text['confidence'],
'distance': distance_to_center,
'match_type': 'single_label'
}
all_matches.append(match)
used_texts.add(id(closest_text))
else:
all_unmatched_objects.append(obj)
for text_data in other_texts:
if id(text_data) not in used_texts:
all_unmatched_texts.append(text_data)
return {
'matches': all_matches,
'unmatched_objects': all_unmatched_objects,
'unmatched_texts': all_unmatched_texts,
'n_objects': len(objects),
'n_texts': len(all_text_detections),
'matching_rate': len(all_matches) / len(objects) if objects else 0
}
def process_diagram_inmemory(self, pil_image, grid_size=(5, 5), overlap_percent=10,
keep_regex_list=None, min_confidence=80,
custom_labels_confidence=80, iou_threshold=0.3,
matching_max_distance=200):
"""
Complete in-memory pipeline
Returns only the matches
"""
img_width, img_height = pil_image.size
# Step 1: Segment
segments = self.segment_image(pil_image, grid_size, overlap_percent)
all_global_detections = []
all_text_detections = []
# Step 2-4: Process each segment
for segment_image, position_info in segments:
# Detect text
textract_data = self.detect_text_segment(segment_image)
# Extract text with global coordinates
for block in textract_data['Blocks']:
if block['BlockType'] == 'WORD':
bbox = block['Geometry']['BoundingBox']
seg_left = position_info['left']
seg_top = position_info['top']
seg_width = position_info['width']
seg_height = position_info['height']
global_left = seg_left + int(bbox['Left'] * seg_width)
global_top = seg_top + int(bbox['Top'] * seg_height)
global_width = int(bbox['Width'] * seg_width)
global_height = int(bbox['Height'] * seg_height)
all_text_detections.append({
'text': block['Text'],
'confidence': block['Confidence'],
'global_bbox': {
'left': global_left,
'top': global_top,
'right': global_left + global_width,
'bottom': global_top + global_height,
'width': global_width,
'height': global_height
}
})
# Clean text
cleaned_image, _ = self.clean_text_from_segment(
segment_image, textract_data,
keep_regex_list=keep_regex_list, min_confidence=min_confidence
)
# Recognize objects
detection_results = self.recognize_objects_segment(
cleaned_image, min_confidence=custom_labels_confidence
)
if detection_results['success']:
labels = detection_results['custom_labels']
for label in labels:
if 'Geometry' in label and 'BoundingBox' in label['Geometry']:
bbox = label['Geometry']['BoundingBox']
seg_left = position_info['left']
seg_top = position_info['top']
seg_width = position_info['width']
seg_height = position_info['height']
global_left = seg_left + int(bbox['Left'] * seg_width)
global_top = seg_top + int(bbox['Top'] * seg_height)
global_width = int(bbox['Width'] * seg_width)
global_height = int(bbox['Height'] * seg_height)
global_detection = {
'Name': label['Name'],
'Confidence': label['Confidence'],
'global_bbox': {
'left': global_left,
'top': global_top,
'right': global_left + global_width,
'bottom': global_top + global_height,
'width': global_width,
'height': global_height
}
}
all_global_detections.append(global_detection)
# Step 5: Deduplicate
deduplicated_detections = self.deduplicate_detections(
all_global_detections, iou_threshold=iou_threshold
)
deduplicated_text = self.deduplicate_text_detections(
all_text_detections, iou_threshold=0.5
)
# Step 6: Match objects to text
matching_results = self.match_objects_to_text_by_type(
objects=deduplicated_detections,
all_text_detections=deduplicated_text,
max_distance=matching_max_distance
)
return matching_results
# ==================== LAMBDA HANDLER ====================
def lambda_handler(event, context):
"""
AWS Lambda handler function
Expected event formats:
1. PDF as base64 in body:
{
"pdf_base64": "<base64-encoded-pdf>",
"config": {
"grid_size": [5, 5],
"overlap_percent": 10,
...
}
}
2. PDF in S3:
{
"s3_bucket": "bucket-name",
"s3_key": "path/to/file.pdf",
"config": {...}
}
"""
try:
# Parse event
if isinstance(event.get('body'), str):
body = json.loads(event['body'])
else:
body = event
# Extract configuration
config = body.get('config', {})
grid_size = tuple(config.get('grid_size', [5, 5]))
overlap_percent = config.get('overlap_percent', 10)
keep_regex_list = config.get('keep_regex_list', [r'\+', r'.*[Xx].*', r'\*', r'\\'])
min_confidence = config.get('min_confidence', 80)
custom_labels_confidence = config.get('custom_labels_confidence', 60)
iou_threshold = config.get('iou_threshold', 0.3)
matching_max_distance = config.get('matching_max_distance', 200)
custom_labels_arn = config.get('custom_labels_arn', CUSTOM_LABELS_PROJECT_ARN)
dpi = config.get('dpi', 200)
# Get PDF bytes
if 'pdf_base64' in body:
# PDF provided as base64 in request
pdf_bytes = base64.b64decode(body['pdf_base64'])
elif 's3_bucket' in body and 's3_key' in body:
# PDF in S3
s3_client = boto3.client('s3')
response = s3_client.get_object(
Bucket=body['s3_bucket'],
Key=body['s3_key']
)
pdf_bytes = response['Body'].read()
else:
return {
'statusCode': 400,
'body': json.dumps({
'error': 'Must provide either pdf_base64 or s3_bucket/s3_key'
})
}
# Convert PDF to image (first page only, or specify page)
page_num = config.get('page', 0) # 0-indexed
images = convert_from_bytes(pdf_bytes, dpi=dpi, first_page=page_num+1, last_page=page_num+1)
if not images:
return {
'statusCode': 400,
'body': json.dumps({
'error': 'Could not convert PDF to image'
})
}
diagram_image = images[0]
# Initialize processor
processor = InMemoryDiagramProcessor(
region=REGION,
custom_labels_arn=custom_labels_arn
)
# Process diagram
matching_results = processor.process_diagram_inmemory(
pil_image=diagram_image,
grid_size=grid_size,
overlap_percent=overlap_percent,
keep_regex_list=keep_regex_list,
min_confidence=min_confidence,
custom_labels_confidence=custom_labels_confidence,
iou_threshold=iou_threshold,
matching_max_distance=matching_max_distance
)
# Return only matches
return {
'statusCode': 200,
'headers': {
'Content-Type': 'application/json'
},
'body': json.dumps({
'matches': matching_results['matches'],
'summary': {
'total_matches': len(matching_results['matches']),
'unmatched_objects': len(matching_results['unmatched_objects']),
'unmatched_texts': len(matching_results['unmatched_texts']),
'matching_rate': matching_results['matching_rate']
}
})
}
except Exception as e:
print(f"Error processing diagram: {str(e)}")
import traceback
traceback.print_exc()
return {
'statusCode': 500,
'body': json.dumps({
'error': str(e),
'error_type': type(e).__name__
})
}