diff --git a/src/py/crankshaft/crankshaft/clustering/kmeans.py b/src/py/crankshaft/crankshaft/clustering/kmeans.py index ac0ce4d..84f83f7 100644 --- a/src/py/crankshaft/crankshaft/clustering/kmeans.py +++ b/src/py/crankshaft/crankshaft/clustering/kmeans.py @@ -1,5 +1,6 @@ from sklearn.cluster import KMeans import plpy +import numpy as np def kmeans(query, no_clusters, no_init=20): @@ -39,7 +40,6 @@ def kmeans_nonspatial(query, colnames, num_clusters=5, num_clusters (int): number of clusters (greater than zero) id_col (string): name of the input id_column """ - import numpy as np out_id_colname = 'rowids' # TODO: need a random seed? @@ -54,14 +54,13 @@ def kmeans_nonspatial(query, colnames, num_clusters=5, try: db_resp = plpy.execute(full_query) - plpy.notice('query: %s' % full_query) except plpy.SPIError, err: plpy.error('k-means cluster analysis failed: %s' % err) # fill array with values for kmeans clustering if standarize: cluster_columns = _scale_data( - _extract_columns(db_resp, id_col='cartodb_id')) + _extract_columns(db_resp, id_col=out_id_colname)) else: cluster_columns = _extract_columns(db_resp) @@ -70,7 +69,7 @@ def kmeans_nonspatial(query, colnames, num_clusters=5, kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(cluster_columns) - return zip(kmeans.predict(X), + return zip(kmeans.labels_, map(str, kmeans.cluster_centers_[kmeans.labels_]), db_resp[0][out_id_colname])