fixes syntax errors

This commit is contained in:
Andy Eschbacher 2016-10-12 21:13:51 +00:00
parent c47116571f
commit 361505fca9
2 changed files with 7 additions and 6 deletions

View File

@ -11,7 +11,7 @@ $$ LANGUAGE plpythonu;
-- Non-spatial k-means clustering -- Non-spatial k-means clustering
-- query: sql query to retrieve all the needed data -- query: sql query to retrieve all the needed data
CREATE OR REPLACE FUNCTION CDB_KMeansNonspatial(query TEXT, col_names TEXT[], no_clusters INTEGER, id_col TEXT DEFAULT 'cartodb_id') CREATE OR REPLACE FUNCTION CDB_KMeansNonspatial(query TEXT, colnames TEXT[], num_clusters INTEGER, id_col TEXT DEFAULT 'cartodb_id')
RETURNS TABLE(cluster_label text, cluster_center text, rowid bigint) AS $$ RETURNS TABLE(cluster_label text, cluster_center text, rowid bigint) AS $$
from crankshaft.clustering import kmeans_nonspatial from crankshaft.clustering import kmeans_nonspatial

View File

@ -38,7 +38,7 @@ def kmeans_nonspatial(query, colnames, num_clusters=5, id_col='cartodb_id'):
num_clusters (int): number of clusters (greater than zero) num_clusters (int): number of clusters (greater than zero)
id_col (string): name of the input id_column id_col (string): name of the input id_column
""" """
import numpy as np
id_colname = 'rowids' id_colname = 'rowids'
full_query = ''' full_query = '''
@ -55,13 +55,14 @@ def kmeans_nonspatial(query, colnames, num_clusters=5, id_col='cartodb_id'):
plpy.notice('query: %s' % full_query) plpy.notice('query: %s' % full_query)
# fill array with values for kmeans clustering # fill array with values for kmeans clustering
data = np.array([d[c] for c in d if c != 'id_colname'], cluster_columns = np.array([data[0][c] for c in data.colnames()
dtype=float).T if c != 'id_colname'],
dtype=float).T
except plpy.SPIError, err: except plpy.SPIError, err:
plpy.error('KMeans cluster failed: %s' % err) plpy.error('KMeans cluster failed: %s' % err)
kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(data) kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(cluster_columns)
# zip(ids, labels, means) # zip(ids, labels, means)
return zip(kmeans.labels_, map(str, kmeans.cluster_centers_), return zip(kmeans.labels_, map(str, kmeans.cluster_centers_),
d[0]['rowids']) data[0]['rowids'])