Source code for graphxplore.GraphDataScience.group_selector

import collections
from typing import Iterable, Dict, List, Tuple
from graphxplore.MetaDataHandling import MetaData
from graphxplore.DataMapping import MetaLattice, AggregatorType
from graphxplore.DataMapping.Conditionals import (LogicOperator, AggregatorOperator, InListOperator, StringOperatorType,
                                                  MetricOperatorType, AndOperator, OrOperator, AtomicOperator,
                                                  AlwaysTrueOperator, NegatedOperator)

[docs] class GroupSelector: """This class generates Cypher statements to select a group of primary keys (e.g. patient IDs) from a Neo4J database based on a :class:`~graphxplore.DataMapping.Conditionals.LogicOperator` object. Variables from inverted foreign table chains can be aggregated, and variables from foreign table chains can be used for singular comparison. Negations, conjunctions and disjunctions can be used as well. :param group_table: The name of the origin table for the group to select :param meta: The metadata of the database :param group_filter: The conditional describing the group, defaults to the tautology (all primary keys of ``group_table`` will be selected) """ def __init__(self, group_table : str, meta : MetaData, group_filter : LogicOperator = AlwaysTrueOperator()): """Constructor method """ if group_table not in meta.get_table_names(): raise AttributeError('Group table "' + group_table + '" not found in metadata') self.group_table = group_table self.meta = meta self.required_children = collections.defaultdict(lambda : collections.defaultdict(list)) self.required_ancestors = collections.defaultdict(lambda : collections.defaultdict(list)) full_lattice = MetaLattice.from_meta_data(self.meta) if not isinstance(group_filter, AlwaysTrueOperator): upward = set(full_lattice.get_relatives(self.group_table)) upward.add(self.group_table) downward = set(full_lattice.get_relatives(self.group_table, upward=False)) queue = [group_filter] while len(queue) > 0: current = queue.pop(0) if isinstance(current, AndOperator) or isinstance(current, OrOperator): queue += current.sub_operators elif isinstance(current, NegatedOperator): queue.append(current.pos_operator) elif isinstance(current, AtomicOperator): if current.table not in self.meta.get_table_names(): raise AttributeError('Filter table "' + current.table + '" not found in metadata') if current.variable not in self.meta.get_variable_names(current.table): raise AttributeError('Filter variable "' + current.variable + '" not found for "' + current.table + '" in metadata') var_info = self.meta.get_variable(current.table, current.variable) if current.data_type != var_info.data_type: raise AttributeError('Filter data type "' + current.data_type + '" does not match data type of variable "' + current.variable + '" for "' + current.table + '" in metadata') if isinstance(current, AggregatorOperator): if current.table not in downward: raise AttributeError('Filter table "' + current.table + '" is marked for aggregation, but has no foreign table chain to ' 'group table "' + group_table + '"') self.required_ancestors[current.table][current.variable].append(current) else: if current.table not in upward: raise AttributeError('group table "' + group_table + '" has no foreign table chain to filter table "' + current.table + '"') self.required_children[current.table][current.variable].append(current) elif isinstance(current, AlwaysTrueOperator): pass else: raise NotImplementedError('Condition not implemented') self.group_filter = group_filter self.children_lattice = full_lattice.get_sub_lattice_whitelist([self.group_table], self.required_children.keys()) # inverted lattice self.ancestor_lattice = MetaLattice(full_lattice.get_ancestor_lattice([self.group_table], self.required_ancestors.keys()).parents)
[docs] def get_cypher_query(self) -> str: """Generates the Cypher query to select the primary keys for the group. :return: Returns the generated query as a string """ query = 'match (x_0:' + self.group_table + ' {name:"' + self.meta.get_primary_key(self.group_table) + '"}) where x_0:Key\n' path_vars = {self.group_table: 'x_0'} attr_vars = collections.defaultdict(dict) path_var_counter = 1 attr_var_counter = 0 for lattice, required_tables in [(self.children_lattice, self.required_children), (self.ancestor_lattice, self.required_ancestors)]: if len(required_tables) == 0: continue shortest_paths = lattice.get_shortest_paths_to_required(self.group_table, required_tables.keys()) shortest_paths_sorted = sorted(shortest_paths.items(), key=lambda x : len(x[1])) for required, path in shortest_paths_sorted: sub_query = '' if required != self.group_table: first_unvisited_idx = len(path) for idx in range(len(path)): if path[idx] not in path_vars: first_unvisited_idx = idx break sub_query += 'match (' + path_vars[path[first_unvisited_idx - 1]] + ')' for idx in range(first_unvisited_idx, len(path)): curr = path[idx] # only have to name required tables and branching points if curr == required or len(lattice.children[curr]) > 1: curr_var = 'x_' + str(path_var_counter) path_var_counter += 1 else: curr_var = '' path_vars[curr] = curr_var curr_label = self.meta.get_label(curr) if curr_label == '': curr_label = curr curr_pk = self.meta.get_primary_key(curr) sub_query += '--(' + curr_var + ':' + curr_label + ' {name:"' + curr_pk + '"})' sub_query += '\n' path_var = path_vars[required] for var, operators in required_tables[required].items(): if var == self.meta.get_primary_key(required): attr_vars[required][var] = path_var continue attr_var = 'y_' + str(attr_var_counter) attr_var_counter += 1 attr_vars[required][var] = attr_var allow_empty_aggregation = True for operator in operators: if not isinstance(operator, AggregatorOperator) or operator.aggregator not in [AggregatorType.Count, AggregatorType.Concatenate, AggregatorType.List]: allow_empty_aggregation = False break if allow_empty_aggregation: sub_query += 'optional ' sub_query += 'match (' + path_var + ')--(' + attr_var + ':Attribute {name:"' + var + '"})\n' query += sub_query return query + self._get_with_where_clause(attr_vars) + 'return id(x_0) as x_0'
def _get_with_where_clause(self, attrs_vars : Dict[str, Dict[str, str]]) -> str: """Generates the last part containing the potential aggregation and condition checking in a where statement :param attrs_vars: The variable names used in the Cypher statement for each combination of table and variable in the metadata :return: Returns the last query part as string """ result = 'with x_0' children_cypher_vars = [] for table, var_dict in self.required_children.items(): for var, operators in var_dict.items(): cypher_var = attrs_vars[table][var] if cypher_var != 'x_0': children_cypher_vars.append(cypher_var) if len(children_cypher_vars) > 0: result += ',' + ','.join(sorted(children_cypher_vars)) agg_vars = {} for table, var_dict in self.required_ancestors.items(): for var, operators in var_dict.items(): var_to_agg = attrs_vars[table][var] + '.value' for operator in operators: lookup_tuple = (table, var, operator.aggregator) if lookup_tuple not in agg_vars: agg_var = 'z_' + str(len(agg_vars)) if operator.aggregator == AggregatorType.List: agg_statement = 'collect(distinct ' + var_to_agg + ')' elif operator.aggregator == AggregatorType.Count: agg_statement = 'count(' + var_to_agg + ')' elif operator.aggregator == AggregatorType.Concatenate: agg_statement = 'apoc.text.join(collect(toString(' + var_to_agg + ')), ";")' elif operator.aggregator == AggregatorType.Sum: agg_statement = 'sum(' + var_to_agg + ')' elif operator.aggregator == AggregatorType.Min: agg_statement = 'min(' + var_to_agg + ')' elif operator.aggregator == AggregatorType.Max: agg_statement = 'max(' + var_to_agg + ')' elif operator.aggregator == AggregatorType.Mean: agg_statement = 'avg(' + var_to_agg + ')' elif operator.aggregator == AggregatorType.Std: agg_statement = 'stDev(' + var_to_agg + ')' elif operator.aggregator == AggregatorType.Median: agg_statement = 'apoc.agg.median(' + var_to_agg + ')' elif operator.aggregator == AggregatorType.Amplitude: agg_statement = 'max(' + var_to_agg + ')-min(' + var_to_agg + ')' else: raise NotImplementedError('Aggregator type not implemented') agg_vars[lookup_tuple] = agg_var result += ',' + agg_statement + ' as ' + agg_var result += '\n' where_condition = self._get_cypher_condition(self.group_filter, attrs_vars, agg_vars) if where_condition != 'true': result += 'where ' + where_condition + '\n' return result @staticmethod def _get_cypher_condition(operator : LogicOperator, attrs_vars : Dict[str, Dict[str, str]], agg_vars : Dict[Tuple[str, str, AggregatorType], str]) -> str: """Recursively generates the condition checking of the Cypher statement as a where clause :param operator: The condition currently checked in the recursion :param attrs_vars: The variable names used in the Cypher statement for each combination of table and variable in the metadata :param agg_vars: The variables introduced for aggregation :return: Returns the where clause as a string """ if isinstance(operator, AlwaysTrueOperator): return 'true' if isinstance(operator, NegatedOperator): return 'not (' + GroupSelector._get_cypher_condition(operator.pos_operator, attrs_vars, agg_vars) + ')' if isinstance(operator, AndOperator) or isinstance(operator, OrOperator): concatenation = ') or (' if isinstance(operator, OrOperator) else ') and (' return '(' + concatenation.join((GroupSelector._get_cypher_condition(sub_operator, attrs_vars, agg_vars) for sub_operator in operator.sub_operators)) + ')' if isinstance(operator, AtomicOperator): if isinstance(operator, AggregatorOperator): to_check = agg_vars[(operator.table, operator.variable, operator.aggregator)] if operator.aggregator == AggregatorType.List: val = '"' + operator.value + '"' if isinstance(operator.value, str) else str(operator.value) return val + ' in ' + to_check else: to_check = attrs_vars[operator.table][operator.variable] + '.value' if isinstance(operator, InListOperator): return 'toString(' + to_check + ') in ["' + '","'.join(operator.ordered_white_list) + '"]' if isinstance(operator.compare, StringOperatorType): comparator = '=' if operator.compare == StringOperatorType.Equals else ' contains ' val = '"' + operator.value + '"' # metric comparison else: comparator = '=' if operator.compare == MetricOperatorType.Equals else operator.compare val = str(operator.value) return to_check + comparator + val raise AttributeError('Cypher condition not implemented for operator type')