diff --git a/src/pg/sql/11_kmeans.sql b/src/pg/sql/11_kmeans.sql index f20942f..4985c2f 100644 --- a/src/pg/sql/11_kmeans.sql +++ b/src/pg/sql/11_kmeans.sql @@ -1,21 +1,34 @@ -CREATE OR REPLACE FUNCTION CDB_KMeans(query text, no_clusters integer,no_init integer default 20) -RETURNS table (cartodb_id integer, cluster_no integer) as $$ - - from crankshaft.clustering import kmeans - return kmeans(query,no_clusters,no_init) +-- Spatial k-means clustering -$$ language plpythonu; +CREATE OR REPLACE FUNCTION CDB_KMeans(query text, no_clusters integer, no_init integer default 20) +RETURNS table (cartodb_id integer, cluster_no integer) as $$ + + from crankshaft.clustering import kmeans + return kmeans(query, no_clusters, no_init) + +$$ LANGUAGE plpythonu; + +-- Non-spatial k-means clustering +-- query: sql query to retrieve all the needed data + +CREATE OR REPLACE FUNCTION CDB_KMeansNonspatial(query TEXT, col_names TEXT[], no_clusters INTEGER, id_col TEXT DEFAULT 'cartodb_id') +RETURNS TABLE(rowid BIGINT, cluster_no INTEGER, ) + +from crankshaft.clustering import kmeans_nonspatial +return kmeans_nonspatial(query, colnames, num_clusters, id_col) + +$$ LANGUAGE plpythonu; CREATE OR REPLACE FUNCTION CDB_WeightedMeanS(state Numeric[],the_geom GEOMETRY(Point, 4326), weight NUMERIC) -RETURNS Numeric[] AS +RETURNS Numeric[] AS $$ -DECLARE +DECLARE newX NUMERIC; newY NUMERIC; newW NUMERIC; BEGIN - IF weight IS NULL OR the_geom IS NULL THEN + IF weight IS NULL OR the_geom IS NULL THEN newX = state[1]; newY = state[2]; newW = state[3]; @@ -30,12 +43,12 @@ END $$ LANGUAGE plpgsql; CREATE OR REPLACE FUNCTION CDB_WeightedMeanF(state Numeric[]) -RETURNS GEOMETRY AS +RETURNS GEOMETRY AS $$ BEGIN - IF state[3] = 0 THEN + IF state[3] = 0 THEN RETURN ST_SetSRID(ST_MakePoint(state[1],state[2]), 4326); - ELSE + ELSE RETURN ST_SETSRID(ST_MakePoint(state[1]/state[3], state[2]/state[3]),4326); END IF; END @@ -56,7 +69,7 @@ BEGIN SFUNC = CDB_WeightedMeanS, FINALFUNC = CDB_WeightedMeanF, STYPE = Numeric[], - INITCOND = "{0.0,0.0,0.0}" + INITCOND = "{0.0,0.0,0.0}" ); END IF; END diff --git a/src/py/crankshaft/crankshaft/clustering/kmeans.py b/src/py/crankshaft/crankshaft/clustering/kmeans.py index 4134062..ee2f304 100644 --- a/src/py/crankshaft/crankshaft/clustering/kmeans.py +++ b/src/py/crankshaft/crankshaft/clustering/kmeans.py @@ -1,18 +1,67 @@ from sklearn.cluster import KMeans import plpy -def kmeans(query, no_clusters, no_init=20): - data = plpy.execute('''select array_agg(cartodb_id order by cartodb_id) as ids, - array_agg(ST_X(the_geom) order by cartodb_id) xs, - array_agg(ST_Y(the_geom) order by cartodb_id) ys from ({query}) a - where the_geom is not null - '''.format(query=query)) - xs = data[0]['xs'] - ys = data[0]['ys'] +def kmeans(query, no_clusters, no_init=20): + """ + + """ + full_query = ''' + SELECT array_agg(cartodb_id ORDER BY cartodb_id) as ids, + array_agg(ST_X(the_geom) ORDER BY cartodb_id) xs, + array_agg(ST_Y(the_geom) ORDER BY cartodb_id) + FROM ({query}) As a + WHERE the_geom IS NOT NULL + '''.format(query=query) + try: + data = plpy.execute(full_query) + except plpy.SPIError, err: + plpy.error("KMeans cluster failed: %s" % err) + + xs = data[0]['xs'] + ys = data[0]['ys'] ids = data[0]['ids'] - km = KMeans(n_clusters= no_clusters, n_init=no_init) - labels = km.fit_predict(zip(xs,ys)) - return zip(ids,labels) + km = KMeans(n_clusters=no_clusters, n_init=no_init) + labels = km.fit_predict(zip(xs, ys)) + return zip(ids, labels) + +def kmeans_nonspatial(query, colnames, num_clusters=5, id_col='cartodb_id'): + """ + query (string): A SQL query to retrieve the data required to do the + k-means clustering analysis, like so: + SELECT * FROM iris_flower_data + colnames (list): a list of the column names which contain the data of + interest, like so: ["sepal_width", "petal_width", + "sepal_length", "petal_length"] + num_clusters (int): number of clusters (greater than zero) + id_col (string): name of the input id_column + """ + + id_colname = 'rowids' + + full_query = ''' + SELECT {cols}, array_agg({id_col}) As {id_colname} + FROM ({query}) As a + '''.format(query=query, + id_col=id_col, + id_colname=id_colname, + cols=', '.join(['array_agg({0}) As col{1}'.format(val, idx) + for idx, val in enumerate(colnames)])) + + try: + data = plpy.execute(full_query) + plpy.notice('query: %s' % full_query) + + # fill array with values for kmeans clustering + data = np.array([d[c] for c in d if c != 'id_colname'], + dtype=float).T + except plpy.SPIError, err: + plpy.error('KMeans cluster failed: %s' % err) + + kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(data) + + # zip(ids, labels, means) + return zip(kmeans.labels_, map(str, kmeans.cluster_centers_), + d[0]['rowids'])