Initial commit
This commit is contained in:
736
label/infra/handler.py
Normal file
736
label/infra/handler.py
Normal file
@@ -0,0 +1,736 @@
|
||||
import boto3
|
||||
import json
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image, ImageDraw
|
||||
import numpy as np
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
import re
|
||||
from pdf2image import convert_from_bytes
|
||||
|
||||
# Configuration
|
||||
REGION = 'us-east-1'
|
||||
CUSTOM_LABELS_PROJECT_ARN = 'modelid'
|
||||
CONFIDENCE_THRESHOLD = 80
|
||||
|
||||
class InMemoryDiagramProcessor:
|
||||
"""Process diagrams entirely in memory for Lambda"""
|
||||
|
||||
def __init__(self, region=REGION, custom_labels_arn=CUSTOM_LABELS_PROJECT_ARN):
|
||||
self.textract_client = boto3.client('textract', region_name=region)
|
||||
self.rekognition_client = boto3.client('rekognition', region_name=region)
|
||||
self.custom_labels_arn = custom_labels_arn
|
||||
self.region = region
|
||||
|
||||
def segment_image(self, img, grid_size=(5, 5), overlap_percent=10):
|
||||
"""
|
||||
Segment PIL Image into grid with overlap (in-memory)
|
||||
Returns list of (PIL Image, position_info) tuples
|
||||
"""
|
||||
img_width, img_height = img.size
|
||||
rows, cols = grid_size
|
||||
|
||||
overlap_factor = overlap_percent / 100.0
|
||||
segment_width = img_width / cols
|
||||
segment_height = img_height / rows
|
||||
|
||||
step_width = segment_width * (1 - overlap_factor)
|
||||
step_height = segment_height * (1 - overlap_factor)
|
||||
|
||||
segments = []
|
||||
|
||||
for row in range(rows):
|
||||
for col in range(cols):
|
||||
left = int(col * step_width)
|
||||
top = int(row * step_height)
|
||||
right = int(min(left + segment_width, img_width))
|
||||
bottom = int(min(top + segment_height, img_height))
|
||||
|
||||
segment = img.crop((left, top, right, bottom))
|
||||
|
||||
position_info = {
|
||||
'row': row,
|
||||
'col': col,
|
||||
'left': left,
|
||||
'top': top,
|
||||
'right': right,
|
||||
'bottom': bottom,
|
||||
'width': right - left,
|
||||
'height': bottom - top
|
||||
}
|
||||
|
||||
segments.append((segment, position_info))
|
||||
|
||||
return segments
|
||||
|
||||
def pil_to_bytes(self, pil_image):
|
||||
"""Convert PIL Image to bytes for AWS API calls"""
|
||||
buffer = BytesIO()
|
||||
pil_image.save(buffer, format='PNG')
|
||||
return buffer.getvalue()
|
||||
|
||||
def detect_text_segment(self, segment_image):
|
||||
"""Detect text in PIL Image segment using Textract"""
|
||||
image_bytes = self.pil_to_bytes(segment_image)
|
||||
|
||||
result = self.textract_client.detect_document_text(
|
||||
Document={'Bytes': image_bytes}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def clean_text_from_segment(self, segment_image, textract_data,
|
||||
shrink_percent=8.5, keep_regex_list=None, min_confidence=80):
|
||||
"""Remove text from PIL Image segment (in-memory)"""
|
||||
compiled_patterns = []
|
||||
if keep_regex_list:
|
||||
for pattern in keep_regex_list:
|
||||
try:
|
||||
compiled_patterns.append(re.compile(pattern))
|
||||
except re.error:
|
||||
pass
|
||||
|
||||
img = segment_image.copy()
|
||||
width, height = img.size
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
words_removed = 0
|
||||
words_kept = 0
|
||||
|
||||
for block in textract_data['Blocks']:
|
||||
if block['BlockType'] == 'WORD':
|
||||
text = block['Text']
|
||||
confidence = block['Confidence']
|
||||
|
||||
should_keep = False
|
||||
|
||||
if confidence < min_confidence:
|
||||
should_keep = True
|
||||
words_kept += 1
|
||||
|
||||
if compiled_patterns:
|
||||
for pattern in compiled_patterns:
|
||||
if pattern.match(text):
|
||||
should_keep = True
|
||||
words_kept += 1
|
||||
break
|
||||
|
||||
if should_keep:
|
||||
continue
|
||||
|
||||
bbox = block['Geometry']['BoundingBox']
|
||||
left = int(bbox['Left'] * width)
|
||||
top = int(bbox['Top'] * height)
|
||||
box_width = int(bbox['Width'] * width)
|
||||
box_height = int(bbox['Height'] * height)
|
||||
|
||||
if shrink_percent > 0:
|
||||
shrink_factor = shrink_percent / 100
|
||||
width_reduction = int(box_width * shrink_factor / 2)
|
||||
height_reduction = int(box_height * shrink_factor / 2)
|
||||
|
||||
left += width_reduction
|
||||
top += height_reduction
|
||||
box_width -= width_reduction * 2
|
||||
box_height -= height_reduction * 2
|
||||
|
||||
draw.rectangle(
|
||||
[(left, top), (left + box_width, top + box_height)],
|
||||
fill='white'
|
||||
)
|
||||
words_removed += 1
|
||||
|
||||
return img, {'words_removed': words_removed, 'words_kept': words_kept}
|
||||
|
||||
def recognize_objects_segment(self, segment_image, min_confidence=CONFIDENCE_THRESHOLD):
|
||||
"""Recognize objects in PIL Image using Custom Labels"""
|
||||
image_bytes = self.pil_to_bytes(segment_image)
|
||||
|
||||
try:
|
||||
response = self.rekognition_client.detect_custom_labels(
|
||||
ProjectVersionArn=self.custom_labels_arn,
|
||||
Image={'Bytes': image_bytes},
|
||||
MinConfidence=min_confidence
|
||||
)
|
||||
|
||||
return {
|
||||
'custom_labels': response.get('CustomLabels', []),
|
||||
'success': True
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
'custom_labels': [],
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def calculate_iou(self, box1, box2):
|
||||
"""Calculate IoU between two bounding boxes"""
|
||||
x_left = max(box1['left'], box2['left'])
|
||||
y_top = max(box1['top'], box2['top'])
|
||||
x_right = min(box1['right'], box2['right'])
|
||||
y_bottom = min(box1['bottom'], box2['bottom'])
|
||||
|
||||
if x_right < x_left or y_bottom < y_top:
|
||||
return 0.0
|
||||
|
||||
intersection_area = (x_right - x_left) * (y_bottom - y_top)
|
||||
|
||||
box1_area = (box1['right'] - box1['left']) * (box1['bottom'] - box1['top'])
|
||||
box2_area = (box2['right'] - box2['left']) * (box2['bottom'] - box2['top'])
|
||||
union_area = box1_area + box2_area - intersection_area
|
||||
|
||||
if union_area == 0:
|
||||
return 0.0
|
||||
|
||||
return intersection_area / union_area
|
||||
|
||||
def merge_bounding_boxes(self, boxes):
|
||||
"""Merge multiple bounding boxes into one"""
|
||||
if not boxes:
|
||||
return None
|
||||
|
||||
return {
|
||||
'left': min(box['left'] for box in boxes),
|
||||
'top': min(box['top'] for box in boxes),
|
||||
'right': max(box['right'] for box in boxes),
|
||||
'bottom': max(box['bottom'] for box in boxes)
|
||||
}
|
||||
|
||||
def deduplicate_detections(self, all_detections, iou_threshold=0.3):
|
||||
"""Remove duplicate detections using NMS"""
|
||||
if not all_detections:
|
||||
return []
|
||||
|
||||
detections_by_label = {}
|
||||
for det in all_detections:
|
||||
label = det['Name']
|
||||
if label not in detections_by_label:
|
||||
detections_by_label[label] = []
|
||||
detections_by_label[label].append(det)
|
||||
|
||||
deduplicated = []
|
||||
|
||||
for label, detections in detections_by_label.items():
|
||||
detections = sorted(detections, key=lambda x: x['Confidence'], reverse=True)
|
||||
|
||||
groups = []
|
||||
used = set()
|
||||
|
||||
for i, det in enumerate(detections):
|
||||
if i in used:
|
||||
continue
|
||||
|
||||
group = [det]
|
||||
used.add(i)
|
||||
|
||||
for j, other_det in enumerate(detections):
|
||||
if j in used or j == i:
|
||||
continue
|
||||
|
||||
iou = self.calculate_iou(det['global_bbox'], other_det['global_bbox'])
|
||||
|
||||
if iou > iou_threshold:
|
||||
group.append(other_det)
|
||||
used.add(j)
|
||||
|
||||
groups.append(group)
|
||||
|
||||
for group in groups:
|
||||
if len(group) == 1:
|
||||
deduplicated.append(group[0])
|
||||
else:
|
||||
merged_bbox = self.merge_bounding_boxes([d['global_bbox'] for d in group])
|
||||
merged_bbox['width'] = merged_bbox['right'] - merged_bbox['left']
|
||||
merged_bbox['height'] = merged_bbox['bottom'] - merged_bbox['top']
|
||||
|
||||
avg_confidence = sum(d['Confidence'] for d in group) / len(group)
|
||||
|
||||
merged_detection = {
|
||||
'Name': label,
|
||||
'Confidence': avg_confidence,
|
||||
'global_bbox': merged_bbox,
|
||||
'merged_from': len(group)
|
||||
}
|
||||
|
||||
deduplicated.append(merged_detection)
|
||||
|
||||
return deduplicated
|
||||
|
||||
def deduplicate_text_detections(self, all_text_detections, iou_threshold=0.5):
|
||||
"""Remove duplicate text detections"""
|
||||
if not all_text_detections:
|
||||
return []
|
||||
|
||||
all_text_detections = sorted(all_text_detections, key=lambda x: x['confidence'], reverse=True)
|
||||
|
||||
deduplicated = []
|
||||
used = set()
|
||||
|
||||
for i, text_det in enumerate(all_text_detections):
|
||||
if i in used:
|
||||
continue
|
||||
|
||||
group = [text_det]
|
||||
used.add(i)
|
||||
|
||||
for j, other_det in enumerate(all_text_detections):
|
||||
if j in used or j == i:
|
||||
continue
|
||||
|
||||
if text_det['text'].lower() == other_det['text'].lower():
|
||||
iou = self.calculate_iou(text_det['global_bbox'], other_det['global_bbox'])
|
||||
|
||||
if iou > iou_threshold:
|
||||
group.append(other_det)
|
||||
used.add(j)
|
||||
|
||||
deduplicated.append(text_det)
|
||||
|
||||
return deduplicated
|
||||
|
||||
def get_bbox_center(self, bbox):
|
||||
"""Get center point of bounding box"""
|
||||
center_x = bbox['left'] + bbox['width'] / 2
|
||||
center_y = bbox['top'] + bbox['height'] / 2
|
||||
return (center_x, center_y)
|
||||
|
||||
def calculate_distance(self, center1, center2):
|
||||
"""Calculate Euclidean distance"""
|
||||
return np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
|
||||
|
||||
def match_objects_to_text_by_type(self, objects, all_text_detections, max_distance=200):
|
||||
"""Match objects to text based on object type"""
|
||||
VM_LABEL_OBJECTS = ['globo', 'gaveta', 'retencao', 'espera']
|
||||
TWO_LABEL_OBJECTS = ['sis_con_dist', 'instrumento_local']
|
||||
|
||||
vm_label_objects = []
|
||||
two_label_objects = []
|
||||
single_label_objects = []
|
||||
|
||||
for obj in objects:
|
||||
obj_name = obj['Name'].lower()
|
||||
if obj_name in VM_LABEL_OBJECTS:
|
||||
vm_label_objects.append(obj)
|
||||
elif obj_name in TWO_LABEL_OBJECTS:
|
||||
two_label_objects.append(obj)
|
||||
else:
|
||||
single_label_objects.append(obj)
|
||||
|
||||
vm_pattern = re.compile(r'VM-\d{4}')
|
||||
vm_texts = [t for t in all_text_detections if vm_pattern.search(t['text'])]
|
||||
other_texts = [t for t in all_text_detections if not vm_pattern.search(t['text'])]
|
||||
|
||||
all_matches = []
|
||||
all_unmatched_objects = []
|
||||
all_unmatched_texts = []
|
||||
used_texts = set()
|
||||
|
||||
# Part 1: Match VM-#### objects using Hungarian algorithm
|
||||
if vm_label_objects and vm_texts:
|
||||
n_objects = len(vm_label_objects)
|
||||
n_texts = len(vm_texts)
|
||||
|
||||
max_dim = max(n_objects, n_texts)
|
||||
cost_matrix = np.full((max_dim, max_dim), 1e10)
|
||||
|
||||
for i, obj in enumerate(vm_label_objects):
|
||||
obj_center = self.get_bbox_center(obj['global_bbox'])
|
||||
|
||||
for j, text_data in enumerate(vm_texts):
|
||||
text_center = self.get_bbox_center(text_data['global_bbox'])
|
||||
distance = self.calculate_distance(obj_center, text_center)
|
||||
|
||||
if max_distance and distance > max_distance:
|
||||
cost_matrix[i, j] = 1e10
|
||||
else:
|
||||
cost_matrix[i, j] = distance
|
||||
|
||||
row_indices, col_indices = linear_sum_assignment(cost_matrix)
|
||||
|
||||
matched_obj_indices = set()
|
||||
matched_text_indices = set()
|
||||
|
||||
for obj_idx, text_idx in zip(row_indices, col_indices):
|
||||
if (obj_idx >= n_objects or text_idx >= n_texts or
|
||||
cost_matrix[obj_idx, text_idx] >= 1e10):
|
||||
continue
|
||||
|
||||
distance = cost_matrix[obj_idx, text_idx]
|
||||
|
||||
match = {
|
||||
'object_name': vm_label_objects[obj_idx]['Name'],
|
||||
'object_bbox': vm_label_objects[obj_idx]['global_bbox'],
|
||||
'object_confidence': vm_label_objects[obj_idx]['Confidence'],
|
||||
'text': vm_texts[text_idx]['text'],
|
||||
'text_bbox': vm_texts[text_idx]['global_bbox'],
|
||||
'text_confidence': vm_texts[text_idx]['confidence'],
|
||||
'distance': distance,
|
||||
'match_type': 'vm_label'
|
||||
}
|
||||
|
||||
all_matches.append(match)
|
||||
matched_obj_indices.add(obj_idx)
|
||||
matched_text_indices.add(text_idx)
|
||||
|
||||
all_unmatched_objects.extend([vm_label_objects[i] for i in range(n_objects)
|
||||
if i not in matched_obj_indices])
|
||||
all_unmatched_texts.extend([vm_texts[j] for j in range(n_texts)
|
||||
if j not in matched_text_indices])
|
||||
|
||||
# Part 2: Match two-label objects
|
||||
for obj in two_label_objects:
|
||||
obj_bbox = obj['global_bbox']
|
||||
obj_center_x = obj_bbox['left'] + obj_bbox['width'] / 2
|
||||
obj_center_y = obj_bbox['top'] + obj_bbox['height'] / 2
|
||||
|
||||
texts_inside = []
|
||||
for text_data in other_texts:
|
||||
if id(text_data) in used_texts:
|
||||
continue
|
||||
|
||||
text_bbox = text_data['global_bbox']
|
||||
text_center_x = text_bbox['left'] + text_bbox['width'] / 2
|
||||
text_center_y = text_bbox['top'] + text_bbox['height'] / 2
|
||||
|
||||
if (obj_bbox['left'] <= text_center_x <= obj_bbox['right'] and
|
||||
obj_bbox['top'] <= text_center_y <= obj_bbox['bottom']):
|
||||
|
||||
distance_to_center = self.calculate_distance(
|
||||
(obj_center_x, obj_center_y),
|
||||
(text_center_x, text_center_y)
|
||||
)
|
||||
|
||||
texts_inside.append({
|
||||
'text_data': text_data,
|
||||
'distance_to_center': distance_to_center,
|
||||
'y_position': text_center_y
|
||||
})
|
||||
|
||||
if len(texts_inside) >= 2:
|
||||
texts_inside.sort(key=lambda t: t['distance_to_center'])
|
||||
closest_two = texts_inside[:2]
|
||||
closest_two.sort(key=lambda t: t['y_position'])
|
||||
|
||||
top_text = closest_two[0]['text_data']
|
||||
bottom_text = closest_two[1]['text_data']
|
||||
|
||||
match = {
|
||||
'object_name': obj['Name'],
|
||||
'object_bbox': obj_bbox,
|
||||
'object_confidence': obj['Confidence'],
|
||||
'text': f"{top_text['text']} / {bottom_text['text']}",
|
||||
'text_top': top_text['text'],
|
||||
'text_bottom': bottom_text['text'],
|
||||
'text_bbox_top': top_text['global_bbox'],
|
||||
'text_bbox_bottom': bottom_text['global_bbox'],
|
||||
'text_confidence_top': top_text['confidence'],
|
||||
'text_confidence_bottom': bottom_text['confidence'],
|
||||
'distance': 0,
|
||||
'match_type': 'two_labels'
|
||||
}
|
||||
|
||||
all_matches.append(match)
|
||||
used_texts.add(id(top_text))
|
||||
used_texts.add(id(bottom_text))
|
||||
else:
|
||||
all_unmatched_objects.append(obj)
|
||||
|
||||
# Part 3: Match single-label objects
|
||||
for obj in single_label_objects:
|
||||
obj_bbox = obj['global_bbox']
|
||||
obj_center_x = obj_bbox['left'] + obj_bbox['width'] / 2
|
||||
obj_center_y = obj_bbox['top'] + obj_bbox['height'] / 2
|
||||
|
||||
texts_inside = []
|
||||
for text_data in other_texts:
|
||||
if id(text_data) in used_texts:
|
||||
continue
|
||||
|
||||
text_bbox = text_data['global_bbox']
|
||||
text_center_x = text_bbox['left'] + text_bbox['width'] / 2
|
||||
text_center_y = text_bbox['top'] + text_bbox['height'] / 2
|
||||
|
||||
if (obj_bbox['left'] <= text_center_x <= obj_bbox['right'] and
|
||||
obj_bbox['top'] <= text_center_y <= obj_bbox['bottom']):
|
||||
texts_inside.append(text_data)
|
||||
|
||||
if texts_inside:
|
||||
closest_text = min(texts_inside, key=lambda t: self.calculate_distance(
|
||||
(obj_center_x, obj_center_y),
|
||||
(t['global_bbox']['left'] + t['global_bbox']['width'] / 2,
|
||||
t['global_bbox']['top'] + t['global_bbox']['height'] / 2)
|
||||
))
|
||||
|
||||
text_center_x = closest_text['global_bbox']['left'] + closest_text['global_bbox']['width'] / 2
|
||||
text_center_y = closest_text['global_bbox']['top'] + closest_text['global_bbox']['height'] / 2
|
||||
distance_to_center = self.calculate_distance(
|
||||
(obj_center_x, obj_center_y),
|
||||
(text_center_x, text_center_y)
|
||||
)
|
||||
|
||||
match = {
|
||||
'object_name': obj['Name'],
|
||||
'object_bbox': obj_bbox,
|
||||
'object_confidence': obj['Confidence'],
|
||||
'text': closest_text['text'],
|
||||
'text_bbox': closest_text['global_bbox'],
|
||||
'text_confidence': closest_text['confidence'],
|
||||
'distance': distance_to_center,
|
||||
'match_type': 'single_label'
|
||||
}
|
||||
|
||||
all_matches.append(match)
|
||||
used_texts.add(id(closest_text))
|
||||
else:
|
||||
all_unmatched_objects.append(obj)
|
||||
|
||||
for text_data in other_texts:
|
||||
if id(text_data) not in used_texts:
|
||||
all_unmatched_texts.append(text_data)
|
||||
|
||||
return {
|
||||
'matches': all_matches,
|
||||
'unmatched_objects': all_unmatched_objects,
|
||||
'unmatched_texts': all_unmatched_texts,
|
||||
'n_objects': len(objects),
|
||||
'n_texts': len(all_text_detections),
|
||||
'matching_rate': len(all_matches) / len(objects) if objects else 0
|
||||
}
|
||||
|
||||
def process_diagram_inmemory(self, pil_image, grid_size=(5, 5), overlap_percent=10,
|
||||
keep_regex_list=None, min_confidence=80,
|
||||
custom_labels_confidence=80, iou_threshold=0.3,
|
||||
matching_max_distance=200):
|
||||
"""
|
||||
Complete in-memory pipeline
|
||||
Returns only the matches
|
||||
"""
|
||||
img_width, img_height = pil_image.size
|
||||
|
||||
# Step 1: Segment
|
||||
segments = self.segment_image(pil_image, grid_size, overlap_percent)
|
||||
|
||||
all_global_detections = []
|
||||
all_text_detections = []
|
||||
|
||||
# Step 2-4: Process each segment
|
||||
for segment_image, position_info in segments:
|
||||
# Detect text
|
||||
textract_data = self.detect_text_segment(segment_image)
|
||||
|
||||
# Extract text with global coordinates
|
||||
for block in textract_data['Blocks']:
|
||||
if block['BlockType'] == 'WORD':
|
||||
bbox = block['Geometry']['BoundingBox']
|
||||
|
||||
seg_left = position_info['left']
|
||||
seg_top = position_info['top']
|
||||
seg_width = position_info['width']
|
||||
seg_height = position_info['height']
|
||||
|
||||
global_left = seg_left + int(bbox['Left'] * seg_width)
|
||||
global_top = seg_top + int(bbox['Top'] * seg_height)
|
||||
global_width = int(bbox['Width'] * seg_width)
|
||||
global_height = int(bbox['Height'] * seg_height)
|
||||
|
||||
all_text_detections.append({
|
||||
'text': block['Text'],
|
||||
'confidence': block['Confidence'],
|
||||
'global_bbox': {
|
||||
'left': global_left,
|
||||
'top': global_top,
|
||||
'right': global_left + global_width,
|
||||
'bottom': global_top + global_height,
|
||||
'width': global_width,
|
||||
'height': global_height
|
||||
}
|
||||
})
|
||||
|
||||
# Clean text
|
||||
cleaned_image, _ = self.clean_text_from_segment(
|
||||
segment_image, textract_data,
|
||||
keep_regex_list=keep_regex_list, min_confidence=min_confidence
|
||||
)
|
||||
|
||||
# Recognize objects
|
||||
detection_results = self.recognize_objects_segment(
|
||||
cleaned_image, min_confidence=custom_labels_confidence
|
||||
)
|
||||
|
||||
if detection_results['success']:
|
||||
labels = detection_results['custom_labels']
|
||||
|
||||
for label in labels:
|
||||
if 'Geometry' in label and 'BoundingBox' in label['Geometry']:
|
||||
bbox = label['Geometry']['BoundingBox']
|
||||
|
||||
seg_left = position_info['left']
|
||||
seg_top = position_info['top']
|
||||
seg_width = position_info['width']
|
||||
seg_height = position_info['height']
|
||||
|
||||
global_left = seg_left + int(bbox['Left'] * seg_width)
|
||||
global_top = seg_top + int(bbox['Top'] * seg_height)
|
||||
global_width = int(bbox['Width'] * seg_width)
|
||||
global_height = int(bbox['Height'] * seg_height)
|
||||
|
||||
global_detection = {
|
||||
'Name': label['Name'],
|
||||
'Confidence': label['Confidence'],
|
||||
'global_bbox': {
|
||||
'left': global_left,
|
||||
'top': global_top,
|
||||
'right': global_left + global_width,
|
||||
'bottom': global_top + global_height,
|
||||
'width': global_width,
|
||||
'height': global_height
|
||||
}
|
||||
}
|
||||
|
||||
all_global_detections.append(global_detection)
|
||||
|
||||
# Step 5: Deduplicate
|
||||
deduplicated_detections = self.deduplicate_detections(
|
||||
all_global_detections, iou_threshold=iou_threshold
|
||||
)
|
||||
|
||||
deduplicated_text = self.deduplicate_text_detections(
|
||||
all_text_detections, iou_threshold=0.5
|
||||
)
|
||||
|
||||
# Step 6: Match objects to text
|
||||
matching_results = self.match_objects_to_text_by_type(
|
||||
objects=deduplicated_detections,
|
||||
all_text_detections=deduplicated_text,
|
||||
max_distance=matching_max_distance
|
||||
)
|
||||
|
||||
return matching_results
|
||||
|
||||
|
||||
# ==================== LAMBDA HANDLER ====================
|
||||
|
||||
def lambda_handler(event, context):
|
||||
"""
|
||||
AWS Lambda handler function
|
||||
|
||||
Expected event formats:
|
||||
1. PDF as base64 in body:
|
||||
{
|
||||
"pdf_base64": "<base64-encoded-pdf>",
|
||||
"config": {
|
||||
"grid_size": [5, 5],
|
||||
"overlap_percent": 10,
|
||||
...
|
||||
}
|
||||
}
|
||||
|
||||
2. PDF in S3:
|
||||
{
|
||||
"s3_bucket": "bucket-name",
|
||||
"s3_key": "path/to/file.pdf",
|
||||
"config": {...}
|
||||
}
|
||||
"""
|
||||
|
||||
try:
|
||||
# Parse event
|
||||
if isinstance(event.get('body'), str):
|
||||
body = json.loads(event['body'])
|
||||
else:
|
||||
body = event
|
||||
|
||||
# Extract configuration
|
||||
config = body.get('config', {})
|
||||
grid_size = tuple(config.get('grid_size', [5, 5]))
|
||||
overlap_percent = config.get('overlap_percent', 10)
|
||||
keep_regex_list = config.get('keep_regex_list', [r'\+', r'.*[Xx].*', r'\*', r'\\'])
|
||||
min_confidence = config.get('min_confidence', 80)
|
||||
custom_labels_confidence = config.get('custom_labels_confidence', 60)
|
||||
iou_threshold = config.get('iou_threshold', 0.3)
|
||||
matching_max_distance = config.get('matching_max_distance', 200)
|
||||
custom_labels_arn = config.get('custom_labels_arn', CUSTOM_LABELS_PROJECT_ARN)
|
||||
dpi = config.get('dpi', 200)
|
||||
|
||||
# Get PDF bytes
|
||||
if 'pdf_base64' in body:
|
||||
# PDF provided as base64 in request
|
||||
pdf_bytes = base64.b64decode(body['pdf_base64'])
|
||||
elif 's3_bucket' in body and 's3_key' in body:
|
||||
# PDF in S3
|
||||
s3_client = boto3.client('s3')
|
||||
response = s3_client.get_object(
|
||||
Bucket=body['s3_bucket'],
|
||||
Key=body['s3_key']
|
||||
)
|
||||
pdf_bytes = response['Body'].read()
|
||||
else:
|
||||
return {
|
||||
'statusCode': 400,
|
||||
'body': json.dumps({
|
||||
'error': 'Must provide either pdf_base64 or s3_bucket/s3_key'
|
||||
})
|
||||
}
|
||||
|
||||
# Convert PDF to image (first page only, or specify page)
|
||||
page_num = config.get('page', 0) # 0-indexed
|
||||
images = convert_from_bytes(pdf_bytes, dpi=dpi, first_page=page_num+1, last_page=page_num+1)
|
||||
|
||||
if not images:
|
||||
return {
|
||||
'statusCode': 400,
|
||||
'body': json.dumps({
|
||||
'error': 'Could not convert PDF to image'
|
||||
})
|
||||
}
|
||||
|
||||
diagram_image = images[0]
|
||||
|
||||
# Initialize processor
|
||||
processor = InMemoryDiagramProcessor(
|
||||
region=REGION,
|
||||
custom_labels_arn=custom_labels_arn
|
||||
)
|
||||
|
||||
# Process diagram
|
||||
matching_results = processor.process_diagram_inmemory(
|
||||
pil_image=diagram_image,
|
||||
grid_size=grid_size,
|
||||
overlap_percent=overlap_percent,
|
||||
keep_regex_list=keep_regex_list,
|
||||
min_confidence=min_confidence,
|
||||
custom_labels_confidence=custom_labels_confidence,
|
||||
iou_threshold=iou_threshold,
|
||||
matching_max_distance=matching_max_distance
|
||||
)
|
||||
|
||||
# Return only matches
|
||||
return {
|
||||
'statusCode': 200,
|
||||
'headers': {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
'body': json.dumps({
|
||||
'matches': matching_results['matches'],
|
||||
'summary': {
|
||||
'total_matches': len(matching_results['matches']),
|
||||
'unmatched_objects': len(matching_results['unmatched_objects']),
|
||||
'unmatched_texts': len(matching_results['unmatched_texts']),
|
||||
'matching_rate': matching_results['matching_rate']
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing diagram: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
return {
|
||||
'statusCode': 500,
|
||||
'body': json.dumps({
|
||||
'error': str(e),
|
||||
'error_type': type(e).__name__
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user