Source code for galaxywitness.base_complex

from abc import abstractmethod
from collections import defaultdict

import numpy as np
import gudhi
from gudhi.clustering.tomato import Tomato

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib import colors
import plotly.graph_objects as go

from galaxywitness.manual_density import ManualDensity

MAX_N_PLOT = 10000
NUMBER_OF_FRAMES = 6


[docs]class BaseComplex: """ Base class for any type of complexes """ # __slots__ = [ # 'points', # 'witnesses', # 'distances', # 'distances_isomap', # 'points_idxs', # 'isomap_eps', # 'simplex_tree', # 'simplex_tree_computed', # 'weights', # 'betti' # ] def __init__(self, points=None): """ Constuctor """ self.simplex_tree = None self.points = points self.betti = None self.simplex_tree_computed = False self.density_class = ManualDensity() # self.graph_type = 'knn' @abstractmethod def compute_simplicial_complex(self, *args): pass
[docs] def external_simplex_tree(self, simplex_tree): """ Load external filtered simplicial complex (as simplex tree) to DescendantComplex instance :param simplex_tree: external simplex tree :type simplex_tree: gudhi.SimplexTree """ # TODO diversify restructured text for descendant classes self.simplex_tree = simplex_tree self.simplex_tree_computed = True
[docs] def get_persistence_betti(self, dim, magnitudes): """ Computation of persistence betti numbers :param dim: max dimension of betti numbers :type dim: int :param magnitudes: levels of significance :type magnitude: list[float] :return: list of persistence betti numbers for dimensions 0...dim :rtype: np.array """ assert self.simplex_tree_computed self.simplex_tree.compute_persistence() betti = np.zeros(dim, dtype=int) for j in range(dim): pers = self.simplex_tree.persistence_intervals_in_dimension(j) for e in pers: if e[1] - e[0] >= magnitudes[j]: betti[j] += 1 self.betti = betti return betti
[docs] def get_diagram(self, show=False, path_to_save=None): """ Draw persistent diagram :param show: show diagram? (Optional) :type show: bool :param path_to_save: place, where we are saving files :type path_to_save: str """ assert self.simplex_tree_computed _, ax = plt.subplots() diag = self.simplex_tree.persistence() gudhi.plot_persistence_diagram(diag, axes=ax, legend=True) if path_to_save is not None: plt.savefig(path_to_save + '/diagram.png', dpi=200) if show: plt.show() plt.close()
[docs] def get_barcode(self, show=False, path_to_save=None): """ Draw barcode :param show: show barcode? (Optional) :type show: bool :param path_to_save: place, where we are saving files :type path_to_save: str """ assert self.simplex_tree_computed _, ax = plt.subplots() diag = self.simplex_tree.persistence() gudhi.plot_persistence_barcode(diag, axes=ax, legend=True) if path_to_save is not None: plt.savefig(path_to_save + '/barcode.png', dpi=200) if show: plt.show() plt.close()
[docs] def draw_simplicial_complex(self, num, filtration_val, backend, path_to_save=None): """ Draw simplicial complex with filtration value filtration_val :param num: number of step :type num: int :param filtration_val: filtration value :type filtration_val: float :param backend: backend for drawing :type backend: str :param path_to_save: place, where we are saving files :type path_to_save: str """ assert self.simplex_tree_computed data = [] if backend == 'mpl': fig = plt.figure() ax = fig.add_subplot(projection="3d") ax.scatter3D(self.points[:MAX_N_PLOT, 0], self.points[:MAX_N_PLOT, 1], self.points[:MAX_N_PLOT, 2], s=1, linewidths=1, color='C1') ax.set_xlabel('X, Mpc') ax.set_ylabel('Y, Mpc') ax.set_zlabel('Z, Mpc') elif backend == 'plotly': data.append(go.Scatter3d(x=self.points[:MAX_N_PLOT, 0], y=self.points[:MAX_N_PLOT, 1], z=self.points[:MAX_N_PLOT, 2], mode='markers', marker = dict(size=1, color='blue'))) gen = self.simplex_tree.get_filtration() for edge in gen: if edge[1] < filtration_val: if len(edge[0]) == 2: x = [self.points[edge[0][0]][0], self.points[edge[0][1]][0]] y = [self.points[edge[0][0]][1], self.points[edge[0][1]][1]] z = [self.points[edge[0][0]][2], self.points[edge[0][1]][2]] if backend == 'mpl': ax.plot(x, y, z, color=colors.rgb2hex(np.random.rand(3))) elif backend == 'plotly': data.append(go.Scatter3d( x=x, y=y, z=z, marker=dict(size=1, color='blue'), line=dict(color=colors.rgb2hex(np.random.rand(3)), width=3))) if len(edge[0]) == 3: x = [self.points[edge[0][0]][0], self.points[edge[0][1]][0], self.points[edge[0][2]][0]] y = [self.points[edge[0][0]][1], self.points[edge[0][1]][1], self.points[edge[0][2]][1]] z = [self.points[edge[0][0]][2], self.points[edge[0][1]][2], self.points[edge[0][2]][2]] verts = [list(zip(x, y, z))] if backend == 'mpl': poly = Poly3DCollection(verts) poly.set_color(colors.rgb2hex(np.random.rand(3))) ax.add_collection3d(poly) elif backend == 'plotly': data.append(go.Mesh3d(x=x, y=y, z=z, color=colors.rgb2hex(np.random.rand(3)), opacity=0.5)) if backend == 'mpl': ax.set_title(f"Animation of alpha filtration: picture #{num} of {NUMBER_OF_FRAMES}") if path_to_save is not None: plt.savefig(path_to_save + f"/picture{num}.png", dpi=200) plt.show() elif backend == 'plotly': fig = go.Figure(data=data) fig.update_layout( title=f"Animation of alpha filtration: picture #{num} of {NUMBER_OF_FRAMES}", scene=dict( xaxis_title='X, Mpc', yaxis_title='Y, Mpc', zaxis_title='Z, Mpc')) if path_to_save is not None: fig.write_image(path_to_save + f"/picture{num}.pdf") fig.show()
[docs] def get_adjacency_list(self, max_fil_val): """ Get adjacency list for vertices in 1-skeleton of filtrated simplicial complex """ assert self.simplex_tree_computed graph = defaultdict(list) for edge, val in self.simplex_tree.get_skeleton(1): if val > max_fil_val or len(edge) != 2: continue graph[edge[0]] += [edge[1]] graph[edge[1]] += [edge[0]] adj_list = [] for vertex in range(self.points.shape[0]): adj_list.append(graph[vertex]) return adj_list
[docs] @abstractmethod def animate_simplex_tree(self, path_to_save): """ Draw animation of filtrated simplicial complex (powered by matplotlib) :param path_to_save: place, where we are saving files :type path_to_save: str """
[docs] @abstractmethod def animate_simplex_tree_plotly(self, path_to_save): """ Draw animation of filtrated simplicial complex (powered by plotly) :param path_to_save: place, where we are saving files :type path_to_save: str """
[docs] def tomato(self, max_fil_val=7.5): """ ToMATo clustering with automatic choice of number of clusters. Hence, clustering depends on filtered complex construction and max value of filtration. """ assert self.simplex_tree_computed tomato_clustering = Tomato(graph_type='manual', density_type='manual') tomato_clustering.fit(self.get_adjacency_list(max_fil_val), weights=self.density_class.dtm_density(self.points)) tomato_clustering.n_clusters_ = self.betti[0] return tomato_clustering