diff --git a/src/pg/sql/11_kmeans.sql b/src/pg/sql/11_kmeans.sql index 6a9d1a9..175ab6b 100644 --- a/src/pg/sql/11_kmeans.sql +++ b/src/pg/sql/11_kmeans.sql @@ -15,14 +15,14 @@ CREATE OR REPLACE FUNCTION CDB_KMeansNonspatial( query TEXT, colnames TEXT[], num_clusters INTEGER, - id_col TEXT DEFAULT 'cartodb_id', + id_colname TEXT DEFAULT 'cartodb_id', standarize BOOLEAN DEFAULT true ) -RETURNS TABLE(cluster_label text, cluster_center text, rowid bigint) AS $$ +RETURNS TABLE(cluster_label text, cluster_center json, rowid bigint) AS $$ from crankshaft.clustering import kmeans_nonspatial return kmeans_nonspatial(query, colnames, num_clusters, - id_col, standarize) + id_colname, standarize) $$ LANGUAGE plpythonu; diff --git a/src/py/crankshaft/crankshaft/clustering/kmeans.py b/src/py/crankshaft/crankshaft/clustering/kmeans.py index 6e972e5..5bd7830 100644 --- a/src/py/crankshaft/crankshaft/clustering/kmeans.py +++ b/src/py/crankshaft/crankshaft/clustering/kmeans.py @@ -40,6 +40,7 @@ def kmeans_nonspatial(query, colnames, num_clusters=5, num_clusters (int): number of clusters (greater than zero) id_col (string): name of the input id_column """ + import json out_id_colname = 'rowids' # TODO: need a random seed? @@ -60,7 +61,7 @@ def kmeans_nonspatial(query, colnames, num_clusters=5, # fill array with values for k-means clustering if standarize: cluster_columns = _scale_data( - _extract_columns(db_resp, id_col=out_id_colname)) + _extract_columns(db_resp, out_id_colname)) else: cluster_columns = _extract_columns(db_resp, id_col=out_id_colname) @@ -69,8 +70,9 @@ def kmeans_nonspatial(query, colnames, num_clusters=5, kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(cluster_columns) + centers = [json.dumps(dict(zip(colnames, c))) for c in kmeans.cluster_centers_[kmeans.labels_]] return zip(kmeans.labels_, - map(str, kmeans.cluster_centers_[kmeans.labels_]), + centers, db_resp[0][out_id_colname])