import csv import json from io import StringIO from itertools import chain from typing import Any, Dict, List, Optional, Set, TextIO from . import utils from ._typing import T_obj, T_obj_list from .convert import CSV_COLS_REQUIRED, CSV_COLS_TO_PREPEND, Serializer class Container(object): cached_properties = ["_rect_edges", "_curve_edges", "_edges", "_objects"] @property def pages(self) -> Optional[List[Any]]: # pragma: nocover raise NotImplementedError @property def objects(self) -> Dict[str, T_obj_list]: # pragma: nocover raise NotImplementedError def to_dict( self, object_types: Optional[List[str]] = None ) -> Dict[str, Any]: # pragma: nocover raise NotImplementedError def flush_cache(self, properties: Optional[List[str]] = None) -> None: props = self.cached_properties if properties is None else properties for p in props: if hasattr(self, p): delattr(self, p) @property def rects(self) -> T_obj_list: return self.objects.get("rect", []) @property def lines(self) -> T_obj_list: return self.objects.get("line", []) @property def curves(self) -> T_obj_list: return self.objects.get("curve", []) @property def images(self) -> T_obj_list: return self.objects.get("image", []) @property def chars(self) -> T_obj_list: return self.objects.get("char", []) @property def textboxverticals(self) -> T_obj_list: return self.objects.get("textboxvertical", []) @property def textboxhorizontals(self) -> T_obj_list: return self.objects.get("textboxhorizontal", []) @property def textlineverticals(self) -> T_obj_list: return self.objects.get("textlinevertical", []) @property def textlinehorizontals(self) -> T_obj_list: return self.objects.get("textlinehorizontal", []) @property def rect_edges(self) -> T_obj_list: if hasattr(self, "_rect_edges"): return self._rect_edges rect_edges_gen = (utils.rect_to_edges(r) for r in self.rects) self._rect_edges: T_obj_list = list(chain(*rect_edges_gen)) return self._rect_edges @property def curve_edges(self) -> T_obj_list: if hasattr(self, "_curve_edges"): return self._curve_edges curve_edges_gen = (utils.curve_to_edges(r) for r in self.curves) self._curve_edges: T_obj_list = list(chain(*curve_edges_gen)) return self._curve_edges @property def edges(self) -> T_obj_list: if hasattr(self, "_edges"): return self._edges line_edges = list(map(utils.line_to_edge, self.lines)) self._edges: T_obj_list = line_edges + self.rect_edges + self.curve_edges return self._edges @property def horizontal_edges(self) -> T_obj_list: def test(x: T_obj) -> bool: return bool(x["orientation"] == "h") return list(filter(test, self.edges)) @property def vertical_edges(self) -> T_obj_list: def test(x: T_obj) -> bool: return bool(x["orientation"] == "v") return list(filter(test, self.edges)) def to_json( self, stream: Optional[TextIO] = None, object_types: Optional[List[str]] = None, include_attrs: Optional[List[str]] = None, exclude_attrs: Optional[List[str]] = None, precision: Optional[int] = None, indent: Optional[int] = None, ) -> Optional[str]: data = self.to_dict(object_types) serialized = Serializer( precision=precision, include_attrs=include_attrs, exclude_attrs=exclude_attrs, ).serialize(data) if stream is None: return json.dumps(serialized, indent=indent) else: json.dump(serialized, stream, indent=indent) return None def to_csv( self, stream: Optional[TextIO] = None, object_types: Optional[List[str]] = None, precision: Optional[int] = None, include_attrs: Optional[List[str]] = None, exclude_attrs: Optional[List[str]] = None, ) -> Optional[str]: if stream is None: stream = StringIO() to_string = True else: to_string = False if object_types is None: object_types = list(self.objects.keys()) + ["annot"] serialized = [] fields: Set[str] = set() pages = [self] if self.pages is None else self.pages serializer = Serializer( precision=precision, include_attrs=include_attrs, exclude_attrs=exclude_attrs, ) for page in pages: for t in object_types: objs = getattr(page, t + "s") if len(objs): serialized += serializer.serialize(objs) new_keys = [k for k, v in objs[0].items() if type(v) is not dict] fields = fields.union(set(new_keys)) non_req_cols = CSV_COLS_TO_PREPEND + list( sorted(set(fields) - set(CSV_COLS_REQUIRED + CSV_COLS_TO_PREPEND)) ) cols = CSV_COLS_REQUIRED + list(filter(serializer.attr_filter, non_req_cols)) w = csv.DictWriter( stream, fieldnames=cols, extrasaction="ignore", quoting=csv.QUOTE_MINIMAL, escapechar="\\", ) w.writeheader() w.writerows(serialized) if to_string: stream.seek(0) return stream.read() else: return None