diff --git a/src/pg/sql/11_kmeans.sql b/src/pg/sql/11_kmeans.sql index 59fcf59..2db57e0 100644 --- a/src/pg/sql/11_kmeans.sql +++ b/src/pg/sql/11_kmeans.sql @@ -11,7 +11,7 @@ $$ LANGUAGE plpythonu; -- Non-spatial k-means clustering -- query: sql query to retrieve all the needed data -CREATE OR REPLACE FUNCTION CDB_KMeansNonspatial(query TEXT, col_names TEXT[], no_clusters INTEGER, id_col TEXT DEFAULT 'cartodb_id') +CREATE OR REPLACE FUNCTION CDB_KMeansNonspatial(query TEXT, colnames TEXT[], num_clusters INTEGER, id_col TEXT DEFAULT 'cartodb_id') RETURNS TABLE(cluster_label text, cluster_center text, rowid bigint) AS $$ from crankshaft.clustering import kmeans_nonspatial diff --git a/src/py/crankshaft/crankshaft/clustering/kmeans.py b/src/py/crankshaft/crankshaft/clustering/kmeans.py index ee2f304..091e87b 100644 --- a/src/py/crankshaft/crankshaft/clustering/kmeans.py +++ b/src/py/crankshaft/crankshaft/clustering/kmeans.py @@ -38,7 +38,7 @@ def kmeans_nonspatial(query, colnames, num_clusters=5, id_col='cartodb_id'): num_clusters (int): number of clusters (greater than zero) id_col (string): name of the input id_column """ - + import numpy as np id_colname = 'rowids' full_query = ''' @@ -55,13 +55,14 @@ def kmeans_nonspatial(query, colnames, num_clusters=5, id_col='cartodb_id'): plpy.notice('query: %s' % full_query) # fill array with values for kmeans clustering - data = np.array([d[c] for c in d if c != 'id_colname'], - dtype=float).T + cluster_columns = np.array([data[0][c] for c in data.colnames() + if c != 'id_colname'], + dtype=float).T except plpy.SPIError, err: plpy.error('KMeans cluster failed: %s' % err) - kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(data) + kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(cluster_columns) # zip(ids, labels, means) return zip(kmeans.labels_, map(str, kmeans.cluster_centers_), - d[0]['rowids']) + data[0]['rowids'])