test changes

This commit is contained in:
Ubuntu 2017-01-31 20:48:56 +00:00
parent 959747c623
commit cbd95fa0a2
3 changed files with 46 additions and 16 deletions

View File

@ -78,7 +78,7 @@ class AnalysisDataProvider(object):
"features": feature_columns} "features": feature_columns}
""" """
columns = ', '.join(['array_agg("{col}") As "{col}"'.format(col=col) columns = ', '.join(['array_agg("{col}") As "{col}"'.format(col=col)
for col in params['feature_columns']]) for col in params['features']])
query = ''' query = '''
SELECT SELECT
array_agg("{target}") As target, array_agg("{target}") As target,

View File

@ -69,7 +69,7 @@ class Segmentation(object):
params = {"subquery": target_query, params = {"subquery": target_query,
"id_col": id_col, "id_col": id_col,
"feature_columns": features} "feature_columns": feature_columns}
target, features, target_mean, \ target, features, target_mean, \
feature_means = self.clean_data(variable, feature_columns, query) feature_means = self.clean_data(variable, feature_columns, query)
@ -101,6 +101,9 @@ class Segmentation(object):
results = [] results = []
cursors = self.data_provider.get_segmentation_predict_data(params) cursors = self.data_provider.get_segmentation_predict_data(params)
# cursors = [{'': ,
# '': }]
#
while True: while True:
rows = cursors.fetch(batch_size) rows = cursors.fetch(batch_size)
if not rows: if not rows:
@ -127,6 +130,14 @@ class Segmentation(object):
data = self.data_provider.get_segmentation_model_data(params) data = self.data_provider.get_segmentation_model_data(params)
'''
data: [{'target': [2.9, 4.9, 4, 5, 6]},
{'feature1': [1,2,3,4]}, {'feature2' : [2,3,4,5]}
]
'''
[{target: [dsdfs]}]
# extract target data from plpy object # extract target data from plpy object
target = np.array(data[0]['target']) target = np.array(data[0]['target'])

View File

@ -4,6 +4,21 @@ from helper import plpy, fixture_file
from crankshaft.segmentation import Segmentation from crankshaft.segmentation import Segmentation
import json import json
class RawDataProvider(AnalysisDataProvider):
def __init__(self, raw_data1, raw_data2, raw_data3):
self.raw_data1 = raw_data1
self.raw_data2 = raw_data2
self.raw_data3 = raw_data3
def get_segmentation_data(self, params):
return self.raw_data1
def get_segmentation_predict_data(self, params):
return self.raw_data2
def get_segmentation_model_data(self, params):
return self.raw_data3
class SegmentationTest(unittest.TestCase): class SegmentationTest(unittest.TestCase):
"""Testing class for Moran's I functions""" """Testing class for Moran's I functions"""
@ -48,7 +63,11 @@ class SegmentationTest(unittest.TestCase):
'subsample' : 0.5, 'subsample' : 0.5,
'learning_rate': 0.01, 'learning_rate': 0.01,
'min_samples_leaf': 1} 'min_samples_leaf': 1}
seg = Segmentation() data = [{'target': [],
'x1': [],
'x2': [],
'x3': []}]
seg = Segmentation(RawDataProvider(test, train, predict))
''' '''
self, query, variable, feature_columns, self, query, variable, feature_columns,
target_query, model_params, target_query, model_params,