From 62a3bd3107365b63fee2617eeba4f2a6589d45c3 Mon Sep 17 00:00:00 2001 From: Roman Joeres Date: Sat, 16 Mar 2024 00:17:51 +0100 Subject: [PATCH] Bug fix in agglomerative clustering due to new version restrictions --- datasail/cluster/clustering.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/datasail/cluster/clustering.py b/datasail/cluster/clustering.py index 66c4272..d5f4b6e 100644 --- a/datasail/cluster/clustering.py +++ b/datasail/cluster/clustering.py @@ -1,6 +1,7 @@ from typing import Dict, Tuple, List, Union, Optional, Literal import numpy as np +import sklearn from sklearn.cluster import AgglomerativeClustering, SpectralClustering from datasail.cluster.caching import load_from_cache, store_to_cache @@ -194,11 +195,18 @@ def additional_clustering( ) else: cluster_matrix = np.array(dataset.cluster_distance, dtype=float) - ca = AgglomerativeClustering( - n_clusters=n_clusters, - affinity="precomputed", - linkage=linkage, - ) + if sklearn.__version__ < '1.4': + ca = AgglomerativeClustering( + n_clusters=n_clusters, + affinity="precomputed", + linkage=linkage, + ) + else: + ca = AgglomerativeClustering( + n_clusters=n_clusters, + metric="precomputed", + linkage=linkage, + ) LOGGER.info( f"Clustering based on distances. " f"Distances above {np.average(dataset.cluster_distance) * 0.9} cannot end up in same cluster."