import itertools
import math
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from typing import Optional, Dict, Tuple, Union
import collections
from enum import Enum
from graphxplore.GraphDataScience import GroupSelector
from graphxplore.Basis import GraphDatabaseUtils, GraphType
from graphxplore.MetaDataHandling import MetaData, VariableType, VariableInfo
from graphxplore.DataMapping import MetaLattice
[docs]
class HistogramYScaleType(str, Enum):
Count = 'Count'
Fraction = 'Fraction'
[docs]
class DashboardBuilder:
"""This class generates plots univariate and bivariate distributions by querying a
:class:`~graphxplore.Basis.BaseGraph.BaseGraph` stored in a Neo4J database. The plots are generated using the
plotly package. Additionally, subgroups of ``main_table`` primary keys can be defined to jointly plot and compare
distributions of groups.
:param meta: The metadata of the :class:`~graphxplore.Basis.BaseGraph.BaseGraph`
:param main_table: The origin table of primary keys used for the plot
:param base_graph_database: The name of the :class:`BaseGraph` Neo4J database
:param full_table_group: If ``True``, all primary keys of ``main_table`` are used as a group. Defaults to
``True``
:param groups: Dictionary of name and :class:`~graphxplore.GraphDataScience.GroupSelector` for the defined
subgroups. Must have ``main_table`` as their group table. Defaults to None
:param address: The address of the Neo4J DBMS
:param auth: User and password of the Neo4J DBMS
"""
def __init__(self, meta: MetaData, main_table: str, base_graph_database: str, full_table_group: bool = True,
groups: Optional[Dict[str, GroupSelector]] = None, address: str = GraphDatabaseUtils.get_neo4j_address(),
auth: Tuple[str, str] = ('neo4j', '')):
"""Constructor method
"""
if main_table not in meta.get_table_names():
raise AttributeError('Main table "' + main_table + '" of dashboard not in specified metadata')
self.meta = meta
self.main_table = main_table
if not full_table_group and (groups is None or len(groups) == 0):
raise AttributeError('You have to either specify at least one group, or set the flag "full_table_group" so '
'at least one group will be present')
self.full_table_group = full_table_group
if groups is not None:
for group_name, group_selector in groups.items():
if group_selector.group_table != main_table:
raise AttributeError('Group table of group "' + group_name
+ '" does not match main table of dashboard builder')
if group_selector.meta.to_dict() != self.meta.to_dict():
raise AttributeError('Metadata of group "' + group_name
+ '" does not match metadata of dashboard builder')
self.groups = groups
else:
self.groups = {}
if self.full_table_group:
self.groups['All of table "' + self.main_table + '"'] = GroupSelector(self.main_table, self.meta)
if GraphDatabaseUtils.check_graph_type_of_db(base_graph_database, address, auth) != GraphType.Base:
raise AttributeError('Database "' + base_graph_database + '" does not contain a base graph')
self.base_graph_database = base_graph_database
self.address = address
self.auth = auth
self.lattice = MetaLattice.from_meta_data(self.meta)
self.group_size = None
self.group_ids = None
[docs]
def get_variable_dist_plot(self, table: str, variable: str,
y_scale_type: Optional[HistogramYScaleType] = None) -> go.Figure:
"""Generates a :class:`plotly.graph_objects.Figure` for the univariate distribution of ``variable``. If
``variable`` is metric, a plot of multiple histograms (one for each group) is generated. If ``variable`` is
categorical, multiple pie charts are generated and combined into one plot. All necessary data is queried from
the Neo4J database
:param table: the table of the variable
:param variable: The variable for the distribution plot
:param y_scale_type: The y-scale type. If group sizes are very imbalanced, ``HistogramYScaleType.Fraction``
should be preferred
:return: Returns the plotted figure which can e.g. be used in streamlit or notebooks
"""
if table != self.main_table and table not in self.lattice.get_relatives(self.main_table):
raise AttributeError('"' + table + '" is not a foreign table (or foreign table of foreign table...) of "'
+ self.main_table + '"')
var_info = self.meta.get_variable(table, variable)
if var_info.variable_type not in [VariableType.Metric, VariableType.Categorical]:
raise AttributeError('Can only plot distribution for metric and categorical variables')
if var_info.variable_type == VariableType.Metric:
if y_scale_type is None:
raise AttributeError('For histogram plots of metric variables, the y-scale type must be specified')
dist_data = self._query_and_transform_dist_data(var_info)
if var_info.variable_type == VariableType.Metric:
hist_norm = None if y_scale_type == HistogramYScaleType.Count else 'probability'
return px.histogram(
dist_data, x=variable, color='group', barmode='overlay', histnorm=hist_norm, marginal='box')
else:
group_indices, nof_rows, nof_cols = self._get_subplot_indices()
specs = [[{'type': 'domain'} for i in range(nof_cols)] for j in range(nof_rows)]
fig = make_subplots(
rows=nof_rows, cols=nof_cols, specs=specs,
subplot_titles=tuple([group + ' (' + str(count) + ')' for group, count in self.group_size.items()]))
for group in self.groups.keys():
row, col = group_indices[group]
labels = list(dist_data[group].keys())
values = list(dist_data[group].values())
fig.add_trace(go.Pie(labels=labels, values=values, name=group), row, col)
fig.update_traces(textinfo='label+value')
fig.update_layout(legend_title_text='Categories of ' + variable)
return fig
[docs]
def get_correlation_plot(self, first_table: str, first_var: str, second_table: str, second_var: str) -> go.Figure:
"""Generates a :class:`plotly.graph_objects.Figure` for the bivariate distribution of ``first_variable`` and
``second_variable``. For two metric variables a scatter plot is generated, for a pair of metric and categorical
variables multiple box plots are generated, and for two categorical variables stacked bar plots are used. All
necessary data is queried from
the Neo4J database
:param first_table: The table of ``first_var``
:param first_var: The first variable for the distribution
:param second_table: The table of ``second_var``
:param second_var: The second variable for the distribution
:return: Returns the plotted figure which can e.g. be used in streamlit or notebooks
"""
children = self.lattice.get_relatives(self.main_table)
for table in [first_table, second_table]:
if table != self.main_table and table not in children:
raise AttributeError(
'"' + table + '" is not a foreign table (or foreign table of foreign table...) of "'
+ self.main_table + '"')
first_var_info = self.meta.get_variable(first_table, first_var)
second_var_info = self.meta.get_variable(second_table, second_var)
for var_info in [first_var_info, second_var_info]:
if var_info.variable_type not in [VariableType.Metric, VariableType.Categorical]:
raise AttributeError('Can only plot correlation for metric and categorical variables')
dist_data = self._query_and_transform_dist_data((first_var_info, second_var_info))
if first_var_info.variable_type == VariableType.Categorical and second_var_info.variable_type == VariableType.Categorical:
color_dict = {}
color_iter = itertools.cycle(px.colors.qualitative.Plotly + px.colors.qualitative.Pastel1)
for second_val in dist_data.keys():
color_dict[second_val] = next(color_iter)
group_indices, nof_rows, nof_cols = self._get_subplot_indices()
fig = make_subplots(
rows=nof_rows, cols=nof_cols, x_title=first_var,
subplot_titles=tuple([group + ' (' + str(count) + ')' for group, count in self.group_size.items()]))
first_group = True
for group in self.groups.keys():
row, col = group_indices[group]
for second_val, data in dist_data.items():
x_values, y_values = zip(*data[group].items())
fig.add_trace(go.Bar(x=x_values, y=y_values, name=second_val,
hoverinfo='name+y+x', marker_color=color_dict[second_val],
legendgroup=str(second_val), showlegend=True if first_group else False), row,
col)
first_group = False
fig.update_xaxes(type='category')
fig.update_layout(barmode='stack', yaxis_title='Count',
legend_title_text='Categories of ' + second_var)
return fig
else:
if first_var_info.variable_type == VariableType.Metric and second_var_info.variable_type == VariableType.Metric:
return px.scatter(dist_data, x=first_var, y=second_var, color='group',
marginal_x='histogram', marginal_y='histogram')
elif first_var_info.variable_type == VariableType.Metric:
return px.box(dist_data, x=second_var, y=first_var, color='group')
else:
return px.box(dist_data, x=first_var, y=second_var, color='group')
def _get_subplot_indices(self) -> Tuple[Dict[str, Tuple[int, int]], int, int]:
"""Generates alignment and indices of subplots for the groups
:return: Returns a dictionary with row and column index for each group, and number of rows and columns
"""
nof_cols = min(len(self.groups), 4)
nof_rows = math.ceil(len(self.groups) / 4)
result = {}
# plotly subplots are 1-indexed
curr_row = 1
curr_col = 1
for group in self.groups.keys():
result[group] = (curr_row, curr_col)
if curr_col == nof_cols:
curr_row += 1
curr_col = 1
else:
curr_col += 1
return result, nof_rows, nof_cols
def _get_cypher_query(self, var_info: Union[VariableInfo, Tuple[VariableInfo, VariableInfo]]) -> str:
"""Generates the Neo4J Cypher query to retrieve the data for the univariate or bivariate distribution
:param var_info: Either one variable info for univariate, or two infos for bivariate distributions
:return: Returns the query as string
"""
if isinstance(var_info, VariableInfo):
shortest_path = self.lattice.get_shortest_paths_to_required(
self.main_table, [var_info.table])[var_info.table]
query = 'match ' + '--'.join(('(x_' + str(i) + ':' + shortest_path[i] + ')'
for i in range(len(shortest_path))))
query += ('--(y:' + var_info.table + ' {name:"' + var_info.name
+ '"}) where x_0:Key return y.value as val, id(x_0) as member_id')
return query
else:
first_info, second_info = var_info
shortest_paths = self.lattice.get_shortest_paths_to_required(
self.main_table, [first_info.table, second_info.table])
first_shortest = shortest_paths[first_info.table]
second_shortest = shortest_paths[second_info.table]
last_common_idx = 0
for path_idx in range(min(len(first_shortest), len(second_shortest))):
if first_shortest[path_idx] == second_shortest[path_idx]:
last_common_idx = path_idx
else:
break
query = 'match ' + '--'.join(('(x_' + str(i) + ':' + first_shortest[i] + ')'
for i in range(len(first_shortest))))
query += ('--(y_0:' + first_info.table + ' {name:"' + first_info.name + '"}) where x_0:Key ')
query += ' match ' + '--'.join(('(' + ('x' if i == last_common_idx else 'z')
+ '_' + str(i) + (':' + second_shortest[i] if i > last_common_idx else '')
+ ')'
for i in range(last_common_idx, len(second_shortest))))
query += '--(y_1:' + second_info.table + ' {name:"' + second_info.name + '"}) '
query += 'return y_0.value as first_val, y_1.value as second_val, id(x_0) as member_id'
return query
def _query_and_transform_dist_data(self, var_info: Union[VariableInfo, Tuple[VariableInfo, VariableInfo]]) -> Dict:
"""Queries the Neo4J database for the distribution data and transforms it into the suitable format for plotly
:param var_info: Either one variable info for univariate, or two infos for bivariate distributions
:return: Returns the transformed data as dictionary of different formats depending on the variable types and
quantities
"""
if self.group_size is None:
self._query_group_members()
query = self._get_cypher_query(var_info)
records = GraphDatabaseUtils.execute_query(
query=query, database=self.base_graph_database, address=self.address, auth=self.auth)
if isinstance(var_info, VariableInfo):
if var_info.variable_type == VariableType.Metric:
result = {var_info.name: [], 'group': []}
for record in records:
member_id = record['member_id']
var_val = record['val']
for group in self.group_ids[member_id]:
result[var_info.name].append(var_val)
result['group'].append(group + ' (' + str(self.group_size[group]) + ')')
return result
else:
result = {group: collections.defaultdict(int) for group in self.groups.keys()}
for record in records:
member_id = record['member_id']
var_val = record['val']
for group in self.group_ids[member_id]:
result[group][var_val] += 1
return result
else:
first_info, second_info = var_info
if first_info.variable_type == VariableType.Categorical and second_info.variable_type == VariableType.Categorical:
all_first_vals = set()
triplets = collections.defaultdict(
lambda: collections.defaultdict(lambda: collections.defaultdict(int)))
for record in records:
member_id = record['member_id']
first_val = record['first_val']
all_first_vals.add(first_val)
second_val = record['second_val']
for group in self.group_ids[member_id]:
triplets[second_val][group][first_val] += 1
# fill up missing data with zeros and order
result = {}
first_val_list = list(all_first_vals)
for second_val, data in triplets.items():
filled_data = {}
for group in self.groups.keys():
group_data = {first_val : 0 if group not in data or first_val not in data[group]
else data[group][first_val]
for first_val in first_val_list}
filled_data[group] = group_data
result[second_val] = filled_data
return result
else:
result = {first_info.name: [], second_info.name: [], 'group': []}
for record in records:
member_id = record['member_id']
first_var_val = record['first_val']
second_var_val = record['second_val']
for group in self.group_ids[member_id]:
result[first_info.name].append(first_var_val)
result[second_info.name].append(second_var_val)
result['group'].append(group + ' (' + str(self.group_size[group]) + ')')
return result
def _query_group_members(self):
"""Queries the Neo4J node IDs for all groups and stores them in the object
"""
group_ids = collections.defaultdict(list)
group_size = {}
for group_name, group_selector in self.groups.items():
records = GraphDatabaseUtils.execute_query(
query=group_selector.get_cypher_query(), database=self.base_graph_database, address=self.address,
auth=self.auth)
group_size[group_name] = len(records)
for record in records:
group_ids[record['x_0']].append(group_name)
self.group_size = group_size
self.group_ids = group_ids