diff --git a/src/pg/sql/11_kmeans.sql b/src/pg/sql/11_kmeans.sql index 175ab6b..c9ae131 100644 --- a/src/pg/sql/11_kmeans.sql +++ b/src/pg/sql/11_kmeans.sql @@ -18,7 +18,7 @@ CREATE OR REPLACE FUNCTION CDB_KMeansNonspatial( id_colname TEXT DEFAULT 'cartodb_id', standarize BOOLEAN DEFAULT true ) -RETURNS TABLE(cluster_label text, cluster_center json, rowid bigint) AS $$ +RETURNS TABLE(cluster_label text, cluster_center json, silhouettes numeric, rowid bigint) AS $$ from crankshaft.clustering import kmeans_nonspatial return kmeans_nonspatial(query, colnames, num_clusters, diff --git a/src/py/crankshaft/crankshaft/clustering/kmeans.py b/src/py/crankshaft/crankshaft/clustering/kmeans.py index 52139d1..d070bf9 100644 --- a/src/py/crankshaft/crankshaft/clustering/kmeans.py +++ b/src/py/crankshaft/crankshaft/clustering/kmeans.py @@ -76,11 +76,12 @@ def kmeans_nonspatial(query, colnames, num_clusters=5, for c in kmeans.cluster_centers_[kmeans.labels_]] silhouettes = metrics.silhouette_samples(cluster_columns, - labels, + kmeans.labels_, metric='sqeuclidean') return zip(kmeans.labels_, centers, + silhouettes, db_resp[0][out_id_colname])