279 lines
8.2 KiB
Python
279 lines
8.2 KiB
Python
import itertools
|
|
from operator import itemgetter
|
|
from typing import Dict, Iterable, Optional
|
|
|
|
from .._typing import T_bbox, T_num, T_obj, T_obj_list
|
|
from .clustering import cluster_objects
|
|
|
|
|
|
def objects_to_rect(objects: Iterable[T_obj]) -> Dict[str, T_num]:
|
|
"""
|
|
Given an iterable of objects, return the smallest rectangle (i.e. a
|
|
dict with "x0", "top", "x1", and "bottom" keys) that contains them
|
|
all.
|
|
"""
|
|
return bbox_to_rect(objects_to_bbox(objects))
|
|
|
|
|
|
def objects_to_bbox(objects: Iterable[T_obj]) -> T_bbox:
|
|
"""
|
|
Given an iterable of objects, return the smallest bounding box that
|
|
contains them all.
|
|
"""
|
|
return merge_bboxes(map(bbox_getter, objects))
|
|
|
|
|
|
bbox_getter = itemgetter("x0", "top", "x1", "bottom")
|
|
|
|
|
|
def obj_to_bbox(obj: T_obj) -> T_bbox:
|
|
"""
|
|
Return the bounding box for an object.
|
|
"""
|
|
bbox: T_bbox = bbox_getter(obj)
|
|
return bbox
|
|
|
|
|
|
def bbox_to_rect(bbox: T_bbox) -> Dict[str, T_num]:
|
|
"""
|
|
Return the rectangle (i.e a dict with keys "x0", "top", "x1",
|
|
"bottom") for an object.
|
|
"""
|
|
return {"x0": bbox[0], "top": bbox[1], "x1": bbox[2], "bottom": bbox[3]}
|
|
|
|
|
|
def merge_bboxes(bboxes: Iterable[T_bbox]) -> T_bbox:
|
|
"""
|
|
Given an iterable of bounding boxes, return the smallest bounding box
|
|
that contains them all.
|
|
"""
|
|
x0, top, x1, bottom = zip(*bboxes)
|
|
return (min(x0), min(top), max(x1), max(bottom))
|
|
|
|
|
|
def get_bbox_overlap(a: T_bbox, b: T_bbox) -> Optional[T_bbox]:
|
|
a_left, a_top, a_right, a_bottom = a
|
|
b_left, b_top, b_right, b_bottom = b
|
|
o_left = max(a_left, b_left)
|
|
o_right = min(a_right, b_right)
|
|
o_bottom = min(a_bottom, b_bottom)
|
|
o_top = max(a_top, b_top)
|
|
o_width = o_right - o_left
|
|
o_height = o_bottom - o_top
|
|
if o_height >= 0 and o_width >= 0 and o_height + o_width > 0:
|
|
return (o_left, o_top, o_right, o_bottom)
|
|
else:
|
|
return None
|
|
|
|
|
|
def calculate_area(bbox: T_bbox) -> T_num:
|
|
left, top, right, bottom = bbox
|
|
if left > right or top > bottom:
|
|
raise ValueError(f"{bbox} has a negative width or height.")
|
|
return (right - left) * (bottom - top)
|
|
|
|
|
|
def clip_obj(obj: T_obj, bbox: T_bbox) -> Optional[T_obj]:
|
|
overlap = get_bbox_overlap(obj_to_bbox(obj), bbox)
|
|
if overlap is None:
|
|
return None
|
|
|
|
dims = bbox_to_rect(overlap)
|
|
copy = dict(obj)
|
|
|
|
for attr in ["x0", "top", "x1", "bottom"]:
|
|
copy[attr] = dims[attr]
|
|
|
|
diff = dims["top"] - obj["top"]
|
|
if "doctop" in copy:
|
|
copy["doctop"] = obj["doctop"] + diff
|
|
copy["width"] = copy["x1"] - copy["x0"]
|
|
copy["height"] = copy["bottom"] - copy["top"]
|
|
|
|
return copy
|
|
|
|
|
|
def intersects_bbox(objs: Iterable[T_obj], bbox: T_bbox) -> T_obj_list:
|
|
"""
|
|
Filters objs to only those intersecting the bbox
|
|
"""
|
|
return [obj for obj in objs if get_bbox_overlap(obj_to_bbox(obj), bbox) is not None]
|
|
|
|
|
|
def within_bbox(objs: Iterable[T_obj], bbox: T_bbox) -> T_obj_list:
|
|
"""
|
|
Filters objs to only those fully within the bbox
|
|
"""
|
|
return [
|
|
obj
|
|
for obj in objs
|
|
if get_bbox_overlap(obj_to_bbox(obj), bbox) == obj_to_bbox(obj)
|
|
]
|
|
|
|
|
|
def outside_bbox(objs: Iterable[T_obj], bbox: T_bbox) -> T_obj_list:
|
|
"""
|
|
Filters objs to only those fully outside the bbox
|
|
"""
|
|
return [obj for obj in objs if get_bbox_overlap(obj_to_bbox(obj), bbox) is None]
|
|
|
|
|
|
def crop_to_bbox(objs: Iterable[T_obj], bbox: T_bbox) -> T_obj_list:
|
|
"""
|
|
Filters objs to only those intersecting the bbox,
|
|
and crops the extent of the objects to the bbox.
|
|
"""
|
|
return list(filter(None, (clip_obj(obj, bbox) for obj in objs)))
|
|
|
|
|
|
def move_object(obj: T_obj, axis: str, value: T_num) -> T_obj:
|
|
assert axis in ("h", "v")
|
|
if axis == "h":
|
|
new_items = [
|
|
("x0", obj["x0"] + value),
|
|
("x1", obj["x1"] + value),
|
|
]
|
|
if axis == "v":
|
|
new_items = [
|
|
("top", obj["top"] + value),
|
|
("bottom", obj["bottom"] + value),
|
|
]
|
|
if "doctop" in obj:
|
|
new_items += [("doctop", obj["doctop"] + value)]
|
|
if "y0" in obj:
|
|
new_items += [
|
|
("y0", obj["y0"] - value),
|
|
("y1", obj["y1"] - value),
|
|
]
|
|
return obj.__class__(tuple(obj.items()) + tuple(new_items))
|
|
|
|
|
|
def snap_objects(objs: Iterable[T_obj], attr: str, tolerance: T_num) -> T_obj_list:
|
|
axis = {"x0": "h", "x1": "h", "top": "v", "bottom": "v"}[attr]
|
|
list_objs = list(objs)
|
|
clusters = cluster_objects(list_objs, itemgetter(attr), tolerance)
|
|
avgs = [sum(map(itemgetter(attr), cluster)) / len(cluster) for cluster in clusters]
|
|
snapped_clusters = [
|
|
[move_object(obj, axis, avg - obj[attr]) for obj in cluster]
|
|
for cluster, avg in zip(clusters, avgs)
|
|
]
|
|
return list(itertools.chain(*snapped_clusters))
|
|
|
|
|
|
def resize_object(obj: T_obj, key: str, value: T_num) -> T_obj:
|
|
assert key in ("x0", "x1", "top", "bottom")
|
|
old_value = obj[key]
|
|
diff = value - old_value
|
|
new_items = [
|
|
(key, value),
|
|
]
|
|
if key == "x0":
|
|
assert value <= obj["x1"]
|
|
new_items.append(("width", obj["x1"] - value))
|
|
elif key == "x1":
|
|
assert value >= obj["x0"]
|
|
new_items.append(("width", value - obj["x0"]))
|
|
elif key == "top":
|
|
assert value <= obj["bottom"]
|
|
new_items.append(("doctop", obj["doctop"] + diff))
|
|
new_items.append(("height", obj["height"] - diff))
|
|
if "y1" in obj:
|
|
new_items.append(("y1", obj["y1"] - diff))
|
|
elif key == "bottom":
|
|
assert value >= obj["top"]
|
|
new_items.append(("height", obj["height"] + diff))
|
|
if "y0" in obj:
|
|
new_items.append(("y0", obj["y0"] - diff))
|
|
return obj.__class__(tuple(obj.items()) + tuple(new_items))
|
|
|
|
|
|
def curve_to_edges(curve: T_obj) -> T_obj_list:
|
|
point_pairs = zip(curve["pts"], curve["pts"][1:])
|
|
return [
|
|
{
|
|
"object_type": "curve_edge",
|
|
"x0": min(p0[0], p1[0]),
|
|
"x1": max(p0[0], p1[0]),
|
|
"top": min(p0[1], p1[1]),
|
|
"doctop": min(p0[1], p1[1]) + (curve["doctop"] - curve["top"]),
|
|
"bottom": max(p0[1], p1[1]),
|
|
"width": abs(p0[0] - p1[0]),
|
|
"height": abs(p0[1] - p1[1]),
|
|
"orientation": "v" if p0[0] == p1[0] else ("h" if p0[1] == p1[1] else None),
|
|
}
|
|
for p0, p1 in point_pairs
|
|
]
|
|
|
|
|
|
def rect_to_edges(rect: T_obj) -> T_obj_list:
|
|
top, bottom, left, right = [dict(rect) for x in range(4)]
|
|
top.update(
|
|
{
|
|
"object_type": "rect_edge",
|
|
"height": 0,
|
|
"y0": rect["y1"],
|
|
"bottom": rect["top"],
|
|
"orientation": "h",
|
|
}
|
|
)
|
|
bottom.update(
|
|
{
|
|
"object_type": "rect_edge",
|
|
"height": 0,
|
|
"y1": rect["y0"],
|
|
"top": rect["top"] + rect["height"],
|
|
"doctop": rect["doctop"] + rect["height"],
|
|
"orientation": "h",
|
|
}
|
|
)
|
|
left.update(
|
|
{
|
|
"object_type": "rect_edge",
|
|
"width": 0,
|
|
"x1": rect["x0"],
|
|
"orientation": "v",
|
|
}
|
|
)
|
|
right.update(
|
|
{
|
|
"object_type": "rect_edge",
|
|
"width": 0,
|
|
"x0": rect["x1"],
|
|
"orientation": "v",
|
|
}
|
|
)
|
|
return [top, bottom, left, right]
|
|
|
|
|
|
def line_to_edge(line: T_obj) -> T_obj:
|
|
edge = dict(line)
|
|
edge["orientation"] = "h" if (line["top"] == line["bottom"]) else "v"
|
|
return edge
|
|
|
|
|
|
def obj_to_edges(obj: T_obj) -> T_obj_list:
|
|
t = obj["object_type"]
|
|
if "_edge" in t:
|
|
return [obj]
|
|
elif t == "line":
|
|
return [line_to_edge(obj)]
|
|
else:
|
|
return {"rect": rect_to_edges, "curve": curve_to_edges}[t](obj)
|
|
|
|
|
|
def filter_edges(
|
|
edges: Iterable[T_obj],
|
|
orientation: Optional[str] = None,
|
|
edge_type: Optional[str] = None,
|
|
min_length: T_num = 1,
|
|
) -> T_obj_list:
|
|
if orientation not in ("v", "h", None):
|
|
raise ValueError("Orientation must be 'v' or 'h'")
|
|
|
|
def test(e: T_obj) -> bool:
|
|
dim = "height" if e["orientation"] == "v" else "width"
|
|
et_correct = e["object_type"] == edge_type if edge_type is not None else True
|
|
orient_correct = orientation is None or e["orientation"] == orientation
|
|
return bool(et_correct and orient_correct and (e[dim] >= min_length))
|
|
|
|
return list(filter(test, edges))
|