Source code for regparser.tree.struct

import hashlib
import re
from json import JSONEncoder

import six
from lxml import etree

from regparser.tree.depth.markers import MARKERLESS


[docs]class Node(object): APPENDIX = u'appendix' INTERP = u'interp' REGTEXT = u'regtext' SUBPART = u'subpart' EMPTYPART = u'emptypart' EXTRACT = u'extract' NOTE = u'note' INTERP_MARK = 'Interp' MARKERLESS_REGEX = re.compile(r'p\d+') def __init__(self, text='', children=None, label=None, title=None, node_type=REGTEXT, source_xml=None, tagged_text=''): if children is None: children = [] if label is None: label = [] self.text = six.text_type(text) # defensive copy self.children = list(children) self.label = [str(l) for l in label if l != ''] title = six.text_type(title or '') self.title = title or None self.node_type = node_type self.source_xml = source_xml self.tagged_text = tagged_text def __repr__(self): text = ("Node(text={0}, children={1}, label={2}, title={3}, " "node_type={4})") return text.format( repr(self.text), repr(self.children), repr(self.label), repr(self.title), repr(self.node_type) ) def __lt__(self, other): return repr(self) < repr(other) def __eq__(self, other): return repr(self) == repr(other) @property def cfr_part(self): if self.label: return self.label[0]
[docs] def label_id(self): return '-'.join(self.label)
[docs] def depth(self): """Inspect the label and type to determine the node's depth""" second = (self.label[1:2] or [""])[0] second_is_digit = second[:1].isdigit() is_interp = self.INTERP_MARK in self.label is_root = len(self.label) <= 1 if self.node_type in (self.SUBPART, self.EMPTYPART): # Subparts all on the same level return 2 elif not second_is_digit or is_root or is_interp: return len(self.label) else: # Add one for the subpart level return len(self.label) + 1
@classmethod
[docs] def is_markerless_label(cls, label): if not label: return None return (cls.MARKERLESS_REGEX.match(label[-1]) or label[-1] == MARKERLESS)
[docs] def is_markerless(self): return bool(self.is_markerless_label(self.label))
[docs] def is_section(self): """Sections are contained within subparts/subject groups. They are not part of the appendix""" return len(self.label) == 2 and self.label[1][:1].isdigit()
[docs] def walk(self, fn): """See walk(node, fn)""" return walk(self, fn)
[docs]class NodeEncoder(JSONEncoder): """Custom JSON encoder to handle Node objects"""
[docs] def default(self, obj): if isinstance(obj, Node): fields = dict(obj.__dict__) if obj.title is None: del fields['title'] for field in ('tagged_text', 'source_xml', 'child_labels'): if field in fields: del fields[field] return fields return super(NodeEncoder, self).default(obj)
[docs]class FullNodeEncoder(JSONEncoder): """Encodes Nodes into JSON, not losing any of the fields""" FIELDS = {'text', 'children', 'label', 'title', 'node_type', 'source_xml', 'tagged_text'}
[docs] def default(self, obj): if isinstance(obj, Node): result = {field: getattr(obj, field, None) for field in self.FIELDS} if obj.source_xml is not None: result['source_xml'] = etree.tounicode(obj.source_xml) return result return super(FullNodeEncoder, self).default(obj)
[docs]def full_node_decode_hook(d): """Convert a JSON object into a full Node""" if set(d.keys()) == FullNodeEncoder.FIELDS: params = dict(d) node = Node(**params) if node.source_xml: node.source_xml = etree.fromstring(node.source_xml) return node return d
[docs]def frozen_node_decode_hook(d): """Convert a JSON object into a FrozenNode""" if set(d.keys()) == FullNodeEncoder.FIELDS: params = dict(d) del params['source_xml'] fresh = FrozenNode(**params) return fresh.prototype() return d
[docs]def walk(node, fn): """Perform fn for every node in the tree. Pre-order traversal. fn must be a function that accepts a root node.""" result = fn(node) if result is not None: results = [result] else: results = [] for child in node.children: results += walk(child, fn) return results
[docs]def filter_walk(node, fn): """Perform fn on the label for every node in the tree and return a list of nodes on which the function returns truthy.""" return walk(node, lambda n: n if fn(n.label) else None)
[docs]def find_first(root, predicate): """Walk the tree and find the first node which matches the predicate""" response = walk(root, lambda n: n if predicate(n) else None) if response: return response[0]
[docs]def find(root, label): """Search through the tree to find the node with this label.""" if isinstance(label, Node): label = label.label_id() return find_first(root, lambda n: n.label_id() == label)
[docs]def find_parent(root, label): """Search through the tree to find the _parent_ or a node with this label.""" if isinstance(label, Node): label = label.label_id() def has_child(n): return any(c.label_id() == label for c in n.children) return find_first(root, has_child)
[docs]def merge_duplicates(nodes): """Given a list of nodes with the same-length label, merge any duplicates (by combining their children)""" found_pair = None for lidx, lhs in enumerate(nodes): for ridx, rhs in enumerate(nodes[lidx + 1:], lidx + 1): if lhs.label == rhs.label: found_pair = (lidx, ridx) if found_pair: lidx, ridx = found_pair lhs, rhs = nodes[lidx], nodes[ridx] lhs.children.extend(rhs.children) return merge_duplicates(nodes[:ridx] + nodes[ridx + 1:]) else: return nodes
[docs]def treeify(nodes): """Given a list of nodes, convert those nodes into the appropriate tree structure based on their labels. This assumes that all nodes will fall under a set of 'root' nodes, which have the min-length label.""" if not nodes: return nodes min_len, with_min = len(nodes[0].label), [] for node in nodes: if len(node.label) == min_len: with_min.append(node) elif len(node.label) < min_len: min_len = len(node.label) with_min = [node] with_min = merge_duplicates(with_min) roots = [] for root in with_min: label = root.label if root.label[-1] == Node.INTERP_MARK: label = root.label[:-1] def is_child(node): return node.label[:len(label)] == label children = [n for n in nodes if n.label != root.label and is_child(n)] root.children = root.children + treeify(children) roots.append(root) return roots
[docs]class FrozenNode(object): """Immutable interface for nodes. No guarantees about internal state.""" _pool = {} # collection of all FrozenNodes, keyed by hash def __init__(self, text='', children=(), label=(), title='', node_type=Node.REGTEXT, tagged_text=''): self._text = text or '' self._children = tuple(children) self._label = tuple(label) self._title = title or '' self._node_type = node_type self._tagged_text = tagged_text or '' self._child_labels = tuple(c.label_id for c in self.children) self._label_id = '-'.join(self.label) self._hash = self._generate_hash() if self.hash not in FrozenNode._pool: FrozenNode._pool[self.hash] = self @property def text(self): return self._text @property def children(self): return self._children @property def label(self): return self._label @property def title(self): return self._title @property def node_type(self): return self._node_type @property def tagged_text(self): return self._tagged_text @property def hash(self): return self._hash @property def label_id(self): return self._label_id @property def child_labels(self): return self._child_labels def _generate_hash(self): """Called during instantiation. Digests all fields""" hasher = hashlib.sha256() hasher.update(self.text.encode('utf-8')) hasher.update(self.tagged_text.encode('utf-8')) hasher.update(self.title.encode('utf-8')) hasher.update(self.label_id.encode('utf-8')) hasher.update(self.node_type.encode('utf-8')) for child in self.children: hasher.update(child.hash.encode('utf-8')) return hasher.hexdigest() def __hash__(self): """As the hash property is already distinctive, re-use it""" return hash(self.hash) def __eq__(self, other): """We define equality as having the same fields except for children. Instead of recursively inspecting them, we compare only their hash (this is a Merkle tree)""" return (other.__class__ == self.__class__ and self.hash == other.hash and # Compare the fields to limit the effect of hash collisions self.text == other.text and self.title == other.title and self.node_type == other.node_type and self.tagged_text == other.tagged_text and self.label_id == other.label_id and [c.hash for c in self.children] == [c.hash for c in other.children]) @staticmethod
[docs] def from_node(node): """Convert a struct.Node (or similar) into a struct.FrozenNode. This also checks if this node has already been instantiated. If so, it returns the instantiated version (i.e. only one of each identical node exists in memory)""" children = [FrozenNode.from_node(n) for n in node.children] fresh = FrozenNode(text=node.text, children=children, label=node.label, title=node.title or '', node_type=node.node_type, tagged_text=node.tagged_text) return fresh.prototype()
# @todo - seems like something we could implement via __new__?
[docs] def prototype(self): """When we instantiate a FrozenNode, we add it to _pool if we've not seen an identical FrozenNode before. If we have, we want to work with that previously seen version instead. This method returns the _first_ FrozenNode with identical fields""" return FrozenNode._pool[self.hash] # note this may not be self
[docs] def clone(self, **kwargs): """Implement a namedtuple `_replace` style functionality, copying all fields that aren't explicitly replaced.""" for field in ('text', 'children', 'label', 'title', 'node_type', 'tagged_text'): kwargs[field] = kwargs.get(field, getattr(self, field)) fresh = FrozenNode(**kwargs) return fresh.prototype()