diff --git a/src/py/crankshaft/crankshaft/segmentation/segmentation.py b/src/py/crankshaft/crankshaft/segmentation/segmentation.py index 319ba21..c3e99fa 100644 --- a/src/py/crankshaft/crankshaft/segmentation/segmentation.py +++ b/src/py/crankshaft/crankshaft/segmentation/segmentation.py @@ -7,11 +7,10 @@ from sklearn.ensemble import GradientBoostingRegressor from sklearn import metrics from sklearn.cross_validation import train_test_split from crankshaft.analysis_data_provider import AnalysisDataProvider -from mock_plpy import MockCursor - # NOTE: added optional param here + class Segmentation(object): """ Add docstring @@ -82,7 +81,7 @@ class Segmentation(object): ''' rowid = [{'ids': [2.9, 4.9, 4, 5, 6]}] ''' - return zip(rowid[0]['id_col'], result, accuracy_array) + return zip(rowid[0]['ids'], result, accuracy_array) def predict_segment(self, model, feature_columns, target_query, feature_means): @@ -101,33 +100,20 @@ class Segmentation(object): "feature_columns": feature_columns} results = [] - cursor = self.data_provider.get_segmentation_predict_data(params) - cursor = MockCursor(cursor) + cursors = self.data_provider.get_segmentation_predict_data(params) ''' - cursor = [{'feature_columns': [{'features': (0.81140362630858487, - 0.65758478086896821, - 0)}]}] - + cursors = [{'features': [[m1[0],m2[0],m3[0]],[m1[1],m2[1],m3[1]], + [m1[2],m2[2],m3[2]]]}] ''' while True: - batch = [] - rows = cursor.fetch(batch_size) + rows = cursors.fetch(batch_size) if not rows: break - for row in rows: - max = len(rows[0]['feature_columns']) - for c in range(max): - batch = np.append(batch, np.row_stack([np.array(row - ['feature_columns'] - [c] - ['features'])]) - .astype(float)) - # batch = np.row_stack([np.array(row['features']) - # for row in rows]).astype(float) - co = len(rows[0]['feature_columns'][0]['features']) - batch = batch.reshape((batch_size, co)) + batch = np.row_stack([np.array(row['features']) + for row in rows]).astype(float) + batch = replace_nan_with_mean(batch, feature_means)[0] prediction = model.predict(batch) results.append(prediction)