From 6260a080a698f39a07fb6e35a3bc3eb75e3fd593 Mon Sep 17 00:00:00 2001 From: Andy Eschbacher Date: Thu, 15 Mar 2018 11:41:47 -0400 Subject: [PATCH] adds fuller test suite for segmentation --- src/pg/sql/05_segmentation.sql | 21 ++- src/pg/test/expected/06_segmentation_test.out | 50 ++++++ src/pg/test/fixtures/ml_values.sql | 4 +- src/pg/test/sql/06_segmentation_test.sql | 160 +++++++++++++++--- .../crankshaft/segmentation/segmentation.py | 16 +- 5 files changed, 207 insertions(+), 44 deletions(-) diff --git a/src/pg/sql/05_segmentation.sql b/src/pg/sql/05_segmentation.sql index a93f3e8..4a0cfa0 100644 --- a/src/pg/sql/05_segmentation.sql +++ b/src/pg/sql/05_segmentation.sql @@ -25,19 +25,20 @@ AS $$ def unpack2D(data): dimension = data.pop(0) - a = np.array(data, dtype=float) - return a.reshape(len(a)/dimension, dimension) + a = np.array(data, dtype=np.float64) + return a.reshape(int(len(a)/dimension), int(dimension)) - return seg.create_and_predict_segment_agg(np.array(target, dtype=float), - unpack2D(features), - unpack2D(target_features), - target_ids, - model_params) + return seg.create_and_predict_segment_agg( + np.array(target, dtype=np.float64), + unpack2D(features), + unpack2D(target_features), + target_ids, + model_params) $$ LANGUAGE plpythonu VOLATILE PARALLEL RESTRICTED; CREATE OR REPLACE FUNCTION - CDB_CreateAndPredictSegment ( + CDB_CreateAndPredictSegment( query TEXT, variable_name TEXT, target_table TEXT, @@ -57,9 +58,13 @@ AS $$ 'learning_rate': learning_rate, 'min_samples_leaf': min_samples_leaf } + feature_cols = set(plpy.execute(''' + select * from ({query}) as _w limit 0 + '''.format(query=query)).colnames()) - set([variable_name, 'cartodb_id', ]) return seg.create_and_predict_segment( query, variable_name, + feature_cols, target_table, model_params ) diff --git a/src/pg/test/expected/06_segmentation_test.out b/src/pg/test/expected/06_segmentation_test.out index a4c17a9..227a6c0 100644 --- a/src/pg/test/expected/06_segmentation_test.out +++ b/src/pg/test/expected/06_segmentation_test.out @@ -25,3 +25,53 @@ t t t (20 rows) +_cdb_random_seeds + +(1 row) +within_tolerance +t +t +t +t +t +t +t +t +t +t +t +t +t +t +t +t +t +t +t +t +(20 rows) +_cdb_random_seeds + +(1 row) +within_tolerance +t +t +t +t +t +t +t +t +t +t +t +t +t +t +t +t +t +t +t +t +(20 rows) diff --git a/src/pg/test/fixtures/ml_values.sql b/src/pg/test/fixtures/ml_values.sql index c87a10f..59fffd5 100644 --- a/src/pg/test/fixtures/ml_values.sql +++ b/src/pg/test/fixtures/ml_values.sql @@ -1,7 +1,7 @@ SET client_min_messages TO WARNING; \set ECHO none -CREATE TABLE ml_values (cartodb_id integer, target float, the_geom geometry, x1 float , x2 float, x3 float, class text); -INSERT INTO ml_values(cartodb_id, target,x1,x2,x3, class) VALUES +CREATE TABLE ml_values (cartodb_id integer, target float, the_geom geometry, x1 float, x2 float, x3 float, class text); +INSERT INTO ml_values(cartodb_id, target, x1, x2, x3, class) VALUES (0,1.24382137034,0.811403626309,0.657584780869,0,'train'), (1,1.72727475342,0.447764244847,0.528687533966,1,'train'), (2,3.32104694099,0.62774565606,0.832647155118,2,'train'), diff --git a/src/pg/test/sql/06_segmentation_test.sql b/src/pg/test/sql/06_segmentation_test.sql index 2675422..e2aa51d 100644 --- a/src/pg/test/sql/06_segmentation_test.sql +++ b/src/pg/test/sql/06_segmentation_test.sql @@ -3,31 +3,141 @@ \i test/fixtures/ml_values.sql SELECT cdb_crankshaft._cdb_random_seeds(1234); +-- second version (query, not specifying features) WITH expected AS ( - SELECT generate_series(1000,1020) AS id, unnest(ARRAY[ - 4.5656517130822492, - 1.7928053473230694, - 1.0283378773916563, - 2.6586517814904593, - 2.9699056242935944, - 3.9550646059951347, - 4.1662572444459745, - 3.8126334839264162, - 1.8809821053623488, - 1.6349065129019873, - 3.0391288591472954, - 3.3035970359672553, - 1.5835471589451968, - 3.7530378537263638, - 1.0833589653009252, - 3.8104965452882897, - 2.665217959294802, - 1.5850334252802472, - 3.679401198805563, - 3.5332033186588636 - ]) AS expected LIMIT 20 + SELECT + generate_series(1000, 1020) AS id, + unnest(ARRAY[4.5656517130822492, + 1.7928053473230694, + 1.0283378773916563, + 2.6586517814904593, + 2.9699056242935944, + 3.9550646059951347, + 4.1662572444459745, + 3.8126334839264162, + 1.8809821053623488, + 1.6349065129019873, + 3.0391288591472954, + 3.3035970359672553, + 1.5835471589451968, + 3.7530378537263638, + 1.0833589653009252, + 3.8104965452882897, + 2.665217959294802, + 1.5850334252802472, + 3.679401198805563, + 3.5332033186588636 ]) AS expected + LIMIT 20 +), training as ( + SELECT + array_agg(target)::numeric[] as target, + cdb_crankshaft.CDB_PyAgg(Array[x1, x2, x3]::numeric[]) as features + FROM (SELECT * FROM ml_values ORDER BY cartodb_id asc) as _w + WHERE class = 'train' +), testing As ( + SELECT + cdb_crankshaft.CDB_PyAgg(Array[x1, x2, x3]::numeric[]) as features, + array_agg(cartodb_id)::numeric[] as cartodb_ids + FROM (SELECT * FROM ml_values ORDER BY cartodb_id asc) as _w + WHERE class = 'test' ), prediction AS ( - SELECT cartodb_id::integer id, prediction - FROM cdb_crankshaft.CDB_CreateAndPredictSegment('SELECT target, x1, x2, x3 FROM ml_values WHERE class = $$train$$','target', Array['x1', 'x2', 'x3'], 'SELECT cartodb_id, x1, x2, x3 FROM ml_values WHERE class = $$test$$') + SELECT + * + FROM + cdb_crankshaft.CDB_CreateAndPredictSegment( + (SELECT target FROM training), + (SELECT features FROM training), + (SELECT features FROM testing), + (SELECT cartodb_ids FROM testing) + ) +) +SELECT + abs(e.expected - p.prediction) <= 1e-1 AS within_tolerance +FROM expected e, prediction p +WHERE e.id = p.cartodb_id +LIMIT 20; +SELECT cdb_crankshaft._cdb_random_seeds(1234); + +-- second version (query, not specifying features) +WITH expected AS ( + SELECT + generate_series(1000, 1020) AS id, + unnest(ARRAY[4.5656517130822492, + 1.7928053473230694, + 1.0283378773916563, + 2.6586517814904593, + 2.9699056242935944, + 3.9550646059951347, + 4.1662572444459745, + 3.8126334839264162, + 1.8809821053623488, + 1.6349065129019873, + 3.0391288591472954, + 3.3035970359672553, + 1.5835471589451968, + 3.7530378537263638, + 1.0833589653009252, + 3.8104965452882897, + 2.665217959294802, + 1.5850334252802472, + 3.679401198805563, + 3.5332033186588636 ]) AS expected + LIMIT 20 +), prediction AS ( + SELECT + cartodb_id::integer id, + prediction + FROM cdb_crankshaft.CDB_CreateAndPredictSegment( + 'SELECT target, x1, x2, x3 FROM ml_values WHERE class = $$train$$ ORDER BY cartodb_id asc', + 'target', + 'SELECT cartodb_id, x1, x2, x3 FROM ml_values WHERE class = $$test$$ ORDER BY cartodb_id asc' + ) + LIMIT 20 +) +SELECT + abs(e.expected - p.prediction) <= 1e-1 AS within_tolerance +FROM expected e, prediction p +WHERE e.id = p.id; + +SELECT cdb_crankshaft._cdb_random_seeds(1234); +-- third version (query, specifying features) +WITH expected AS ( + SELECT + generate_series(1000, 1020) AS id, + unnest(ARRAY[4.5656517130822492, + 1.7928053473230694, + 1.0283378773916563, + 2.6586517814904593, + 2.9699056242935944, + 3.9550646059951347, + 4.1662572444459745, + 3.8126334839264162, + 1.8809821053623488, + 1.6349065129019873, + 3.0391288591472954, + 3.3035970359672553, + 1.5835471589451968, + 3.7530378537263638, + 1.0833589653009252, + 3.8104965452882897, + 2.665217959294802, + 1.5850334252802472, + 3.679401198805563, + 3.5332033186588636 ]) AS expected + LIMIT 20 +), prediction AS ( + SELECT + cartodb_id::integer id, + prediction + FROM cdb_crankshaft.CDB_CreateAndPredictSegment( + 'SELECT target, x1, x2, x3 FROM ml_values WHERE class = $$train$$', + 'target', + Array['x1', 'x2', 'x3'], + 'SELECT cartodb_id, x1, x2, x3 FROM ml_values WHERE class = $$test$$' + ) LIMIT 20 -) SELECT abs(e.expected - p.prediction) <= 1e-1 AS within_tolerance FROM expected e, prediction p WHERE e.id = p.id; +) +SELECT + abs(e.expected - p.prediction) <= 1e-1 AS within_tolerance +FROM expected e, prediction p +WHERE e.id = p.id; diff --git a/src/py/crankshaft/crankshaft/segmentation/segmentation.py b/src/py/crankshaft/crankshaft/segmentation/segmentation.py index 613fca6..10b9a84 100644 --- a/src/py/crankshaft/crankshaft/segmentation/segmentation.py +++ b/src/py/crankshaft/crankshaft/segmentation/segmentation.py @@ -29,26 +29,25 @@ class Segmentation(object): straight form the SQL calling the function. Input: - @param target: The 1D array of lenth NSamples containing the + @param target: The 1D array of length NSamples containing the target variable we want the model to predict @param features: The 2D array of size NSamples * NFeatures that - form the imput to the model + form the input to the model @param target_ids: A 1D array of target_ids that will be used to associate the results of the prediction with the rows which they come from @param model_parameters: A dictionary containing parameters for the model. """ - clean_target = replace_nan_with_mean(target) - clean_features = replace_nan_with_mean(features) - target_features = replace_nan_with_mean(target_features) + clean_target, _ = replace_nan_with_mean(target) + clean_features, _ = replace_nan_with_mean(features) + target_features, _ = replace_nan_with_mean(target_features) model, accuracy = train_model(clean_target, clean_features, model_parameters, 0.2) prediction = model.predict(target_features) accuracy_array = [accuracy] * prediction.shape[0] - return zip(target_ids, prediction, - np.full(prediction.shape, accuracy_array)) + return zip(target_ids, prediction, accuracy_array) def create_and_predict_segment(self, query, variable, feature_columns, target_query, model_params, @@ -65,7 +64,6 @@ class Segmentation(object): scikit learn page for [GradientBoostingRegressor] (http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingRegressor.html) """ - params = {"subquery": target_query, "id_col": id_col} @@ -198,7 +196,7 @@ def train_model(target, features, model_params, test_split): Input: @param target: 1D Array of the variable that the model is to be trained to predict - @param features: 2D Array NSamples *NFeatures to use in trining + @param features: 2D Array NSamples *NFeatures to use in training the model @param model_params: A dictionary of model parameters, the full specification can be found on the