diff --git a/src/pg/sql/11_kmeans.sql b/src/pg/sql/11_kmeans.sql index 73e2f1d..87f07ea 100644 --- a/src/pg/sql/11_kmeans.sql +++ b/src/pg/sql/11_kmeans.sql @@ -8,24 +8,46 @@ RETURNS table (cartodb_id integer, cluster_no integer) as $$ $$ language plpythonu; -CREATE OR REPLACE FUNCTION CDB_WeightedMean(query text, weight_column text, category_column text default null ) -RETURNS table (the_geom geometry,class integer ) as $$ -BEGIN -RETURN QUERY - EXECUTE format( $string$ - select ST_SETSRID(st_makepoint(cx, cy),4326) the_geom, class from ( - select - %I as class, - sum(st_x(the_geom)*%I)/sum(%I) cx, - sum(st_y(the_geom)*%I)/sum(%I) cy - from (%s) a - group by %I - ) q - - $string$, category_column, weight_column,weight_column,weight_column,weight_column,query, category_column - ) - using the_geom - RETURN; -END +CREATE OR REPLACE FUNCTION CDB_WeightedMeanS(state Numeric[],the_geom GEOMETRY(Point, 4326), weight NUMERIC) +RETURNS Numeric[] AS +$$ +DECLARE + newX NUMERIC; + newY NUMERIC; + newW NUMERIC; +BEGIN + IF weight IS NULL OR the_geom IS NULL THEN + newX = state[1]; + newY = state[2]; + newW = state[3]; + ELSE + newX = state[1] + ST_X(the_geom)*weight; + newY = state[2] + ST_Y(the_geom)*weight; + newW = state[3] + weight; + END IF; + RETURN Array[newX,newY,newW]; + +END $$ LANGUAGE plpgsql; + +CREATE OR REPLACE FUNCTION CDB_WeightedMeanF(state Numeric[]) +RETURNS GEOMETRY AS +$$ +BEGIN + IF state[3] = 0 THEN + RETURN ST_SetSRID(ST_MakePoint(state[1],state[2]), 4326); + ELSE + RETURN ST_SETSRID(ST_MakePoint(state[1]/state[3], state[2]/state[3]),4326); + END IF; +END +$$ LANGUAGE plpgsql; + +CREATE AGGREGATE CDB_WeightedMean(the_geom geometry(Point, 4326), weight NUMERIC)( + SFUNC = CDB_WeightedMeanS, + FINALFUNC = CDB_WeightedMeanF, + STYPE = Numeric[], + INITCOND = "{0.0,0.0,0.0}" +); + + diff --git a/src/pg/test/expected/05_kmeans_test.out b/src/pg/test/expected/05_kmeans_test.out index 4e6db09..8c6ffa1 100644 --- a/src/pg/test/expected/05_kmeans_test.out +++ b/src/pg/test/expected/05_kmeans_test.out @@ -4,7 +4,7 @@ SELECT count(DISTINCT cluster_no) as clusters from cdb_crankshaft.cdb_kmeans('se clusters 2 (1 row) -SELECT count(*) clusters from cdb_crankshaft.cdb_WeightedMean( 'select *, code::INTEGER as cluster from ppoints' , 'value', 'cluster' ); +SELECT count(*) clusters from (select cdb_crankshaft.CDB_WeightedMean(the_geom, value::NUMERIC), code from ppoints group by code) p; clusters 52 (1 row) diff --git a/src/pg/test/sql/05_kmeans_test.sql b/src/pg/test/sql/05_kmeans_test.sql index a400e5e..2298b85 100644 --- a/src/pg/test/sql/05_kmeans_test.sql +++ b/src/pg/test/sql/05_kmeans_test.sql @@ -3,4 +3,4 @@ SELECT count(DISTINCT cluster_no) as clusters from cdb_crankshaft.cdb_kmeans('select * from ppoints', 2); -SELECT count(*) clusters from cdb_crankshaft.cdb_WeightedMean( 'select *, code::INTEGER as cluster from ppoints' , 'value', 'cluster' ); +SELECT count(*) clusters from (select cdb_crankshaft.CDB_WeightedMean(the_geom, value::NUMERIC), code from ppoints group by code) p;