Source code for galaxywitness.witness_complex

import math
import os

import multiprocessing as mp
from joblib import Parallel, delayed
from joblib import dump

import numpy as np


from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import floyd_warshall

# gudhi is needed to construct a simplex tree and to plot the persistence diagram.
import gudhi

from sklearn.metrics import pairwise_distances

from galaxywitness.base_complex import BaseComplex

# hard-coded
#MAX_DIST_INIT = 100000
MAX_N_PLOT = 10000
NUMBER_OF_FRAMES = 6


[docs]class WitnessComplex(BaseComplex): """ Main class for handling data about the point cloud and the simlex tree of filtered witness complex :param landmarks: set of landmarks in :math:`\mathbb{R}^d`. :type landmarks: np.array size of *n_landmarks x 3* :param witnesses: set of witnesses in :math:`\mathbb{R}^d`. :type witnesses: np.array size of *n_witnesses x 3* :param landmarks_idxs: indices of landmarks in witnesses array :type landmarks_idxs: np.array[int] """ # __slots__ = [ # 'landmarks', # 'witnesses', # 'distances', # 'distances_isomap', # 'landmarks_idxs', # 'isomap_eps', # 'simplex_tree', # 'simplex_tree_computed', # 'weights', # 'betti' # ] def __init__(self, landmarks, witnesses, landmarks_idxs, n_jobs = -1, isomap_eps = 0): """ Constuctor """ super().__init__(landmarks) self.landmarks = landmarks self.witnesses = witnesses self.landmarks_idxs = landmarks_idxs self.distances = pairwise_distances(witnesses, landmarks, n_jobs = n_jobs) if isomap_eps > 0: #distances = pairwise_distances(witnesses, n_jobs = -1) # todo: optimize def create_large_matrix(): matrix = np.zeros((self.distances.shape[0], self.distances.shape[0])) for i in range(self.distances.shape[0]): for j in range(self.distances.shape[1]): if self.distances[i][j] < isomap_eps: matrix[i][landmarks_idxs[j]] = self.distances[i][j] return matrix def create_small_matrix(matrix): for i in range(self.distances.shape[0]): for j in range(self.distances.shape[1]): self.distances[i][j] = matrix[i][landmarks_idxs[j]] matrix = create_large_matrix() matrix = csr_matrix(matrix) matrix = floyd_warshall(csgraph = matrix, directed = False) self.distances_isomap = matrix create_small_matrix(matrix)
[docs] def compute_simplicial_complex(self, d_max, r_max=None, n_jobs = 1, custom=False): """ Compute custom filtered simplicial complex :param d_max: max dimension of simplicies in the simplex tree :type d_max: int :param r_max: max filtration value :type r_max: float :param n_jobs: number of threads :type n_jobs: int """ if custom: if n_jobs == 1: self.compute_simplicial_complex_single(d_max=d_max, r_max=r_max) else: self.compute_simplicial_complex_parallel(d_max=d_max, r_max=r_max, n_jobs=n_jobs) else: tmp = gudhi.EuclideanStrongWitnessComplex(witnesses=self.witnesses, landmarks=self.landmarks) self.simplex_tree = tmp.create_simplex_tree(max_alpha_square=r_max**2, limit_dimension=d_max) self.simplex_tree_computed = True
######################################################################################### def _update_register_simplex(self, simplicial_complex_temp, i_add, i_dist): simplex_add = [] for e in simplicial_complex_temp: element = e[0] if element[0] != i_add and len(element) == 1: element_copy = element.copy() element_copy.append(i_add) simplex_add.append([element_copy, i_dist]) else: pass return simplex_add def compute_simplicial_complex_single(self, d_max, r_max=None): simplicial_complex = [] simplex_tree = gudhi.SimplexTree() for row_i in range(self.distances.shape[0]): row = self.distances[row_i, :] # sort row by landmarks witnessed sorted_row = sorted([*enumerate(row)], key=lambda x: x[1]) if r_max is not None: sorted_row_new_temp = [] for element in sorted_row: if element[1] < r_max: sorted_row_new_temp.append(element) sorted_row = sorted_row_new_temp simplices_temp = [] for elem in sorted_row: simplices_temp.append([[elem[0]], elem[1]]) simplex_add = self._update_register_simplex(simplices_temp.copy(), elem[0], elem[1]) simplices_temp += simplex_add simplicial_complex += simplices_temp #self.simplicial_complex = simplicial_complex sorted_simplicial_complex = sorted(simplicial_complex, key=lambda x: x[1]) for simplex in sorted_simplicial_complex: simplex_tree.insert(simplex[0], filtration=simplex[1]) self.simplex_tree = simplex_tree #t = time.time() self.simplex_tree.expansion(d_max) #t = time.time() - t self.simplex_tree_computed = True def compute_simplicial_complex_parallel(self, d_max=math.inf, r_max=math.inf, n_jobs=-1): #global process_wc #@delayed #@wrap_non_picklable_objects def process_wc(distances, ind, r_max=r_max): simplicial_complex = [] def update_register_simplex(simplicial_complex, i_add, i_dist): simplex_add = [] for e in simplicial_complex: element = e[0] if element[0] != i_add and len(element) == 1: element_copy = element.copy() element_copy.append(i_add) simplex_add.append([element_copy, i_dist]) else: pass return simplex_add for row_i in range(distances[ind].shape[0]): row = distances[ind][row_i, :] sorted_row = sorted([*enumerate(row)], key=lambda x: x[1]) if r_max is not None: sorted_row_new_temp = [] for element in sorted_row: if element[1] < r_max: sorted_row_new_temp.append(element) sorted_row = sorted_row_new_temp simplices_temp = [] for elem in sorted_row: simplices_temp.append([[elem[0]], elem[1]]) simplex_add = update_register_simplex(simplices_temp.copy(), elem[0], elem[1]) simplices_temp += simplex_add simplicial_complex += simplices_temp return simplicial_complex def combine_results(results): simplicial_complex = [] for result in results: simplicial_complex += result return simplicial_complex simplex_tree = gudhi.SimplexTree() if n_jobs == -1: n_jobs = mp.cpu_count() #mp.set_start_method('fork') #pool = mp.Pool(processes=n_jobs) distances_chunk = np.array_split(self.distances, n_jobs) folder = './joblib_memmap' data_filename_memmap = os.path.join(folder, 'distances_memmap') dump(distances_chunk, data_filename_memmap) # data = load(data_filename_memmap, mmap_mode='r') results = Parallel(n_jobs=n_jobs)(delayed(process_wc)(distances=distances_chunk, ind=i) for i in range(n_jobs)) #pool.map(process_wc, distances_chunk) #pool.close() #pool.join() simplicial_complex = combine_results(results) sorted_simplicial_complex = sorted(simplicial_complex, key=lambda x: x[1]) for simplex in sorted_simplicial_complex: simplex_tree.insert(simplex[0], filtration=simplex[1]) self.simplex_tree = simplex_tree self.simplex_tree.expansion(d_max) self.simplex_tree_computed = True #################################################################################
[docs] def animate_simplex_tree(self, path_to_save): """ Draw animation of filtration (powered by matplotlib) :param path_to_save: place, where we are saving files :type path_to_save: str """ assert self.simplex_tree_computed gen = self.simplex_tree.get_filtration() gen = list(gen) scale = NUMBER_OF_FRAMES/gen[-1][1] for num in range(1, NUMBER_OF_FRAMES + 1): # if self.witnesses.shape[0] <= MAX_N_PLOT: # ax.scatter3D(self.witnesses[:MAX_N_PLOT, 0], # self.witnesses[:MAX_N_PLOT, 1], # self.witnesses[:MAX_N_PLOT, 2], # s = 1, # linewidths = 0.1) self.draw_simplicial_complex(num, num/scale, 'mpl', path_to_save)
[docs] def animate_simplex_tree_plotly(self, path_to_save): """ Draw animation of filtration (powered by plotly) :param path_to_save: place, where we are saving files :type path_to_save: str """ assert self.simplex_tree_computed gen = self.simplex_tree.get_filtration() gen = list(gen) scale = NUMBER_OF_FRAMES/gen[-1][1] for num in range(1, NUMBER_OF_FRAMES + 1): # if self.witnesses.shape[0] <= MAX_N_PLOT: # data.append(go.Scatter3d(x=self.witnesses[:MAX_N_PLOT, 0], # y=self.witnesses[:MAX_N_PLOT, 1], # z=self.witnesses[:MAX_N_PLOT, 2], # mode='markers', # marker=dict(size=1, color='blue'))) self.draw_simplicial_complex(num, num/scale, 'plotly', path_to_save)
# def tomato(self, max_fil_val=10): # """ # ToMATo clustering with automatic choice of number of clusters. # Hence, clustering depends on filtered complex construction and # max value of filtration. # """ # t = Tomato(density_type = 'manual', graph_type='manual') # t.fit(self.get_adjacency_list(max_fil_val), weights=self.density_class.foo(self.points)) # t.n_clusters_ = self.betti[0] # return t