Source code for grakel.kernels.weisfeiler_lehman

"""The weisfeiler lehman kernel :cite:`shervashidze2011weisfeiler`."""
# Author: Ioannis Siglidis <y.siglidis@gmail.com>
# License: BSD 3 clause
import collections
import warnings

import numpy as np

from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted
from sklearn.externals import joblib

from grakel.graph import Graph
from grakel.kernels import Kernel
from grakel.kernels.vertex_histogram import VertexHistogram

# Python 2/3 cross-compatibility import
from six import iteritems
from six import itervalues


[docs]class WeisfeilerLehman(Kernel): """Compute the Weisfeiler Lehman Kernel. See :cite:`shervashidze2011weisfeiler`. Parameters ---------- n_iter : int, default=5 The number of iterations. base_graph_kernel : `grakel.kernels.Kernel` or tuple, default=None If tuple it must consist of a valid kernel object and a dictionary of parameters. General parameters concerning normalization, concurrency, .. will be ignored, and the ones of given on `__init__` will be passed in case it is needed. Default `base_graph_kernel` is `VertexHistogram`. Attributes ---------- X : dict Holds a dictionary of fitted subkernel modules for all levels. _nx : number Holds the number of inputs. _n_iter : int Holds the number, of iterations. _base_graph_kernel : function A void function that initializes a base kernel object. _inv_labels : dict An inverse dictionary, used for relabeling on each iteration. """ _graph_format = "dictionary"
[docs] def __init__(self, n_jobs=None, verbose=False, normalize=False, n_iter=5, base_graph_kernel=VertexHistogram): """Initialise a `weisfeiler_lehman` kernel.""" super(WeisfeilerLehman, self).__init__( n_jobs=n_jobs, verbose=verbose, normalize=normalize) self.n_iter = n_iter self.base_graph_kernel = base_graph_kernel self._initialized.update({"n_iter": False, "base_graph_kernel": False}) self._base_graph_kernel = None
[docs] def initialize(self): """Initialize all transformer arguments, needing initialization.""" super(WeisfeilerLehman, self).initialize() if not self._initialized["base_graph_kernel"]: base_graph_kernel = self.base_graph_kernel if base_graph_kernel is None: base_graph_kernel, params = VertexHistogram, dict() elif type(base_graph_kernel) is type and issubclass(base_graph_kernel, Kernel): params = dict() else: try: base_graph_kernel, params = base_graph_kernel except Exception: raise TypeError('Base kernel was not formulated in ' 'the correct way. ' 'Check documentation.') if not (type(base_graph_kernel) is type and issubclass(base_graph_kernel, Kernel)): raise TypeError('The first argument must be a valid ' 'grakel.kernel.kernel Object') if type(params) is not dict: raise ValueError('If the second argument of base ' 'kernel exists, it must be a diction' 'ary between parameters names and ' 'values') params.pop("normalize", None) params["normalize"] = False params["verbose"] = self.verbose params["n_jobs"] = None self._base_graph_kernel = base_graph_kernel self._params = params self._initialized["base_graph_kernel"] = True if not self._initialized["n_iter"]: if type(self.n_iter) is not int or self.n_iter <= 0: raise TypeError("'n_iter' must be a positive integer") self._n_iter = self.n_iter + 1 self._initialized["n_iter"] = True
[docs] def parse_input(self, X): """Parse input for weisfeiler lehman. Parameters ---------- X : iterable For the input to pass the test, we must have: Each element must be an iterable with at most three features and at least one. The first that is obligatory is a valid graph structure (adjacency matrix or edge_dictionary) while the second is node_labels and the third edge_labels (that correspond to the given graph format). A valid input also consists of graph type objects. Returns ------- base_graph_kernel : object Returns base_graph_kernel. """ if self._method_calling not in [1, 2]: raise ValueError('method call must be called either from fit ' + 'or fit-transform') elif hasattr(self, '_X_diag'): # Clean _X_diag value delattr(self, '_X_diag') # Input validation and parsing if not isinstance(X, collections.Iterable): raise TypeError('input must be an iterable\n') else: nx = 0 Gs_ed, L, distinct_values, extras = dict(), dict(), set(), dict() for (idx, x) in enumerate(iter(X)): is_iter = isinstance(x, collections.Iterable) if is_iter: x = list(x) if is_iter and (len(x) == 0 or len(x) >= 2): if len(x) == 0: warnings.warn('Ignoring empty element on index: ' + str(idx)) continue else: if len(x) > 2: extra = tuple() if len(x) > 3: extra = tuple(x[3:]) x = Graph(x[0], x[1], x[2], graph_format=self._graph_format) extra = (x.get_labels(purpose=self._graph_format, label_type="edge", return_none=True), ) + extra else: x = Graph(x[0], x[1], {}, graph_format=self._graph_format) extra = tuple() elif type(x) is Graph: x.desired_format(self._graph_format) el = x.get_labels(purpose=self._graph_format, label_type="edge", return_none=True) if el is None: extra = tuple() else: extra = (el, ) else: raise TypeError('each element of X must be either a ' + 'graph object or a list with at least ' + 'a graph like object and node labels ' + 'dict \n') Gs_ed[nx] = x.get_edge_dictionary() L[nx] = x.get_labels(purpose="dictionary") extras[nx] = extra distinct_values |= set(itervalues(L[nx])) nx += 1 if nx == 0: raise ValueError('parsed input is empty') # Save the number of "fitted" graphs. self._nx = nx # get all the distinct values of current labels WL_labels_inverse = dict() # assign a number to each label label_count = 0 for dv in sorted(list(distinct_values)): WL_labels_inverse[dv] = label_count label_count += 1 # Initalize an inverse dictionary of labels for all iterations self._inv_labels = dict() self._inv_labels[0] = WL_labels_inverse def generate_graphs(label_count, WL_labels_inverse): new_graphs = list() for j in range(nx): new_labels = dict() for k in L[j].keys(): new_labels[k] = WL_labels_inverse[L[j][k]] L[j] = new_labels # add new labels new_graphs.append((Gs_ed[j], new_labels) + extras[j]) yield new_graphs for i in range(1, self._n_iter): label_set, WL_labels_inverse, L_temp = set(), dict(), dict() for j in range(nx): # Find unique labels and sort # them for both graphs # Keep for each node the temporary L_temp[j] = dict() for v in Gs_ed[j].keys(): credential = str(L[j][v]) + "," + \ str(sorted([L[j][n] for n in Gs_ed[j][v].keys()])) L_temp[j][v] = credential label_set.add(credential) label_list = sorted(list(label_set)) for dv in label_list: WL_labels_inverse[dv] = label_count label_count += 1 # Recalculate labels new_graphs = list() for j in range(nx): new_labels = dict() for k in L_temp[j].keys(): new_labels[k] = WL_labels_inverse[L_temp[j][k]] L[j] = new_labels # relabel new_graphs.append((Gs_ed[j], new_labels) + extras[j]) self._inv_labels[i] = WL_labels_inverse yield new_graphs base_graph_kernel = {i: self._base_graph_kernel(**self._params) for i in range(self._n_iter)} if self._parallel is None: if self._method_calling == 1: for (i, g) in enumerate(generate_graphs(label_count, WL_labels_inverse)): base_graph_kernel[i].fit(g) elif self._method_calling == 2: K = np.sum((base_graph_kernel[i].fit_transform(g) for (i, g) in enumerate(generate_graphs(label_count, WL_labels_inverse))), axis=0) else: if self._method_calling == 1: self._parallel(joblib.delayed(efit)(base_graph_kernel[i], g) for (i, g) in enumerate(generate_graphs(label_count, WL_labels_inverse))) elif self._method_calling == 2: K = np.sum(self._parallel(joblib.delayed(efit_transform)(base_graph_kernel[i], g) for (i, g) in enumerate(generate_graphs(label_count, WL_labels_inverse))), axis=0) if self._method_calling == 1: return base_graph_kernel elif self._method_calling == 2: return K, base_graph_kernel
[docs] def fit_transform(self, X, y=None): """Fit and transform, on the same dataset. Parameters ---------- X : iterable Each element must be an iterable with at most three features and at least one. The first that is obligatory is a valid graph structure (adjacency matrix or edge_dictionary) while the second is node_labels and the third edge_labels (that fitting the given graph format). If None the kernel matrix is calculated upon fit data. The test samples. y : Object, default=None Ignored argument, added for the pipeline. Returns ------- K : numpy array, shape = [n_targets, n_input_graphs] corresponding to the kernel matrix, a calculation between all pairs of graphs between target an features """ self._method_calling = 2 self._is_transformed = False self.initialize() if X is None: raise ValueError('transform input cannot be None') else: km, self.X = self.parse_input(X) self._X_diag = np.diagonal(km) if self.normalize: old_settings = np.seterr(divide='ignore') km = np.nan_to_num(np.divide(km, np.sqrt(np.outer(self._X_diag, self._X_diag)))) np.seterr(**old_settings) return km
[docs] def transform(self, X): """Calculate the kernel matrix, between given and fitted dataset. Parameters ---------- X : iterable Each element must be an iterable with at most three features and at least one. The first that is obligatory is a valid graph structure (adjacency matrix or edge_dictionary) while the second is node_labels and the third edge_labels (that fitting the given graph format). If None the kernel matrix is calculated upon fit data. The test samples. Returns ------- K : numpy array, shape = [n_targets, n_input_graphs] corresponding to the kernel matrix, a calculation between all pairs of graphs between target an features """ self._method_calling = 3 # Check is fit had been called check_is_fitted(self, ['X', '_nx', '_inv_labels']) # Input validation and parsing if X is None: raise ValueError('transform input cannot be None') else: if not isinstance(X, collections.Iterable): raise ValueError('input must be an iterable\n') else: nx = 0 distinct_values = set() Gs_ed, L = dict(), dict() for (i, x) in enumerate(iter(X)): is_iter = isinstance(x, collections.Iterable) if is_iter: x = list(x) if is_iter and len(x) in [0, 2, 3]: if len(x) == 0: warnings.warn('Ignoring empty element on index: ' + str(i)) continue elif len(x) in [2, 3]: x = Graph(x[0], x[1], {}, self._graph_format) elif type(x) is Graph: x.desired_format("dictionary") else: raise ValueError('each element of X must have at ' + 'least one and at most 3 elements\n') Gs_ed[nx] = x.get_edge_dictionary() L[nx] = x.get_labels(purpose="dictionary") # Hold all the distinct values distinct_values |= set( v for v in itervalues(L[nx]) if v not in self._inv_labels[0]) nx += 1 if nx == 0: raise ValueError('parsed input is empty') nl = len(self._inv_labels[0]) WL_labels_inverse = {dv: idx for (idx, dv) in enumerate(sorted(list(distinct_values)), nl)} def generate_graphs(WL_labels_inverse, nl): # calculate the kernel matrix for the 0 iteration new_graphs = list() for j in range(nx): new_labels = dict() for (k, v) in iteritems(L[j]): if v in self._inv_labels[0]: new_labels[k] = self._inv_labels[0][v] else: new_labels[k] = WL_labels_inverse[v] L[j] = new_labels # produce the new graphs new_graphs.append([Gs_ed[j], new_labels]) yield new_graphs for i in range(1, self._n_iter): new_graphs = list() L_temp, label_set = dict(), set() nl += len(self._inv_labels[i]) for j in range(nx): # Find unique labels and sort them for both graphs # Keep for each node the temporary L_temp[j] = dict() for v in Gs_ed[j].keys(): credential = str(L[j][v]) + "," + \ str(sorted([L[j][n] for n in Gs_ed[j][v].keys()])) L_temp[j][v] = credential if credential not in self._inv_labels[i]: label_set.add(credential) # Calculate the new label_set WL_labels_inverse = dict() if len(label_set) > 0: for dv in sorted(list(label_set)): idx = len(WL_labels_inverse) + nl WL_labels_inverse[dv] = idx # Recalculate labels new_graphs = list() for j in range(nx): new_labels = dict() for (k, v) in iteritems(L_temp[j]): if v in self._inv_labels[i]: new_labels[k] = self._inv_labels[i][v] else: new_labels[k] = WL_labels_inverse[v] L[j] = new_labels # Create the new graphs with the new labels. new_graphs.append([Gs_ed[j], new_labels]) yield new_graphs if self._parallel is None: # Calculate the kernel matrix without parallelization K = np.sum((self.X[i].transform(g) for (i, g) in enumerate(generate_graphs(WL_labels_inverse, nl))), axis=0) else: # Calculate the kernel marix with parallelization K = np.sum(self._parallel(joblib.delayed(etransform)(self.X[i], g) for (i, g) in enumerate(generate_graphs(WL_labels_inverse, nl))), axis=0) self._is_transformed = True if self.normalize: X_diag, Y_diag = self.diagonal() old_settings = np.seterr(divide='ignore') K = np.nan_to_num(np.divide(K, np.sqrt(np.outer(Y_diag, X_diag)))) np.seterr(**old_settings) return K
[docs] def diagonal(self): """Calculate the kernel matrix diagonal for fitted data. A funtion called on transform on a seperate dataset to apply normalization on the exterior. Parameters ---------- None. Returns ------- X_diag : np.array The diagonal of the kernel matrix, of the fitted data. This consists of kernel calculation for each element with itself. Y_diag : np.array The diagonal of the kernel matrix, of the transformed data. This consists of kernel calculation for each element with itself. """ # Check if fit had been called check_is_fitted(self, ['X']) try: check_is_fitted(self, ['_X_diag']) if self._is_transformed: Y_diag = self.X[0].diagonal()[1] for i in range(1, self._n_iter): Y_diag += self.X[i].diagonal()[1] except NotFittedError: # Calculate diagonal of X if self._is_transformed: X_diag, Y_diag = self.X[0].diagonal() # X_diag is considered a mutable and should not affect the kernel matrix itself. X_diag.flags.writeable = True for i in range(1, self._n_iter): x, y = self.X[i].diagonal() X_diag += x Y_diag += y self._X_diag = X_diag else: # case sub kernel is only fitted X_diag = self.X[0].diagonal() # X_diag is considered a mutable and should not affect the kernel matrix itself. X_diag.flags.writeable = True for i in range(1, self._n_iter): x = self.X[i].diagonal() X_diag += x self._X_diag = X_diag if self._is_transformed: return self._X_diag, Y_diag else: return self._X_diag
def efit(object, data): """Fit an object on data.""" object.fit(data) def efit_transform(object, data): """Fit-Transform an object on data.""" return object.fit_transform(data) def etransform(object, data): """Transform an object on data.""" return object.transform(data)