diff --git a/src/py/crankshaft/crankshaft/analysis_data_provider.py b/src/py/crankshaft/crankshaft/analysis_data_provider.py index 02131b0..932aff2 100644 --- a/src/py/crankshaft/crankshaft/analysis_data_provider.py +++ b/src/py/crankshaft/crankshaft/analysis_data_provider.py @@ -73,12 +73,12 @@ class AnalysisDataProvider(object): def get_segmentation_model_data(self, params): """ fetch data for Segmentation - params = {"subquery": query, - "target": variable, - "features": feature_columns} + params = {"subquery": query, + "target": variable, + "features": feature_columns} """ columns = ', '.join(['array_agg("{col}") As "{col}"'.format(col=col) - for col in params['feature_columns']]) + for col in params['features']]) query = ''' SELECT array_agg("{target}") As target, diff --git a/src/py/crankshaft/crankshaft/segmentation/segmentation.py b/src/py/crankshaft/crankshaft/segmentation/segmentation.py index 105c2f0..9840ff0 100644 --- a/src/py/crankshaft/crankshaft/segmentation/segmentation.py +++ b/src/py/crankshaft/crankshaft/segmentation/segmentation.py @@ -69,7 +69,7 @@ class Segmentation(object): params = {"subquery": target_query, "id_col": id_col, - "feature_columns": features} + "feature_columns": feature_columns} target, features, target_mean, \ feature_means = self.clean_data(variable, feature_columns, query) @@ -101,6 +101,9 @@ class Segmentation(object): results = [] cursors = self.data_provider.get_segmentation_predict_data(params) + # cursors = [{'': , + # '': }] + # while True: rows = cursors.fetch(batch_size) if not rows: @@ -127,6 +130,14 @@ class Segmentation(object): data = self.data_provider.get_segmentation_model_data(params) + ''' + data: [{'target': [2.9, 4.9, 4, 5, 6]}, + {'feature1': [1,2,3,4]}, {'feature2' : [2,3,4,5]} + ] + ''' + + [{target: [dsdfs]}] + # extract target data from plpy object target = np.array(data[0]['target']) diff --git a/src/py/crankshaft/test/test_segmentation.py b/src/py/crankshaft/test/test_segmentation.py index b6fbb00..c0638fe 100644 --- a/src/py/crankshaft/test/test_segmentation.py +++ b/src/py/crankshaft/test/test_segmentation.py @@ -4,6 +4,21 @@ from helper import plpy, fixture_file from crankshaft.segmentation import Segmentation import json +class RawDataProvider(AnalysisDataProvider): + def __init__(self, raw_data1, raw_data2, raw_data3): + self.raw_data1 = raw_data1 + self.raw_data2 = raw_data2 + self.raw_data3 = raw_data3 + + def get_segmentation_data(self, params): + return self.raw_data1 + + def get_segmentation_predict_data(self, params): + return self.raw_data2 + + def get_segmentation_model_data(self, params): + return self.raw_data3 + class SegmentationTest(unittest.TestCase): """Testing class for Moran's I functions""" @@ -36,19 +51,23 @@ class SegmentationTest(unittest.TestCase): ids = [{'cartodb_ids': range(len(test_data))}] - rows = [{'x1': 0,'x2':0,'x3':0,'y':0,'cartodb_id':0}] + rows = [{'x1': 0, 'x2': 0, 'x3': 0, 'y': 0, 'cartodb_id': 0}] - plpy._define_result('select \* from \(select \* from training\) a limit 1',rows) - plpy._define_result('.*from \(select \* from training\) as a' ,training_data) - plpy._define_result('select array_agg\(cartodb\_id order by cartodb\_id\) as cartodb_ids from \(.*\) a',ids) - plpy._define_result('.*select \* from test.*' ,test_data) + plpy._define_result('select \* from \(select \* from training\) a limit 1', rows) + plpy._define_result('.*from \(select \* from training\) as a', training_data) + plpy._define_result('select array_agg\(cartodb\_id order by cartodb\_id\) as cartodb_ids from \(.*\) a', ids) + plpy._define_result('.*select \* from test.*', test_data) - model_parameters = {'n_estimators': 1200, - 'max_depth': 3, - 'subsample' : 0.5, - 'learning_rate': 0.01, - 'min_samples_leaf': 1} - seg = Segmentation() + model_parameters = {'n_estimators': 1200, + 'max_depth': 3, + 'subsample' : 0.5, + 'learning_rate': 0.01, + 'min_samples_leaf': 1} + data = [{'target': [], + 'x1': [], + 'x2': [], + 'x3': []}] + seg = Segmentation(RawDataProvider(test, train, predict)) ''' self, query, variable, feature_columns, target_query, model_params,