mirror of
https://github.com/CartoDB/crankshaft.git
synced 2024-11-01 10:20:48 +08:00
Merge pull request #171 from CartoDB/update-segmentation
Refactors segmentation to add analysis provider support
This commit is contained in:
commit
8c4057bb7a
@ -15,7 +15,8 @@ AS $$
|
||||
import numpy as np
|
||||
import plpy
|
||||
|
||||
from crankshaft.segmentation import create_and_predict_segment_agg
|
||||
from crankshaft.segmentation import Segmentation
|
||||
seg = Segmentation()
|
||||
model_params = {'n_estimators': n_estimators,
|
||||
'max_depth': max_depth,
|
||||
'subsample': subsample,
|
||||
@ -24,19 +25,20 @@ AS $$
|
||||
|
||||
def unpack2D(data):
|
||||
dimension = data.pop(0)
|
||||
a = np.array(data, dtype=float)
|
||||
return a.reshape(len(a)/dimension, dimension)
|
||||
a = np.array(data, dtype=np.float64)
|
||||
return a.reshape(int(len(a)/dimension), int(dimension))
|
||||
|
||||
return create_and_predict_segment_agg(np.array(target, dtype=float),
|
||||
unpack2D(features),
|
||||
unpack2D(target_features),
|
||||
target_ids,
|
||||
model_params)
|
||||
return seg.create_and_predict_segment_agg(
|
||||
np.array(target, dtype=np.float64),
|
||||
unpack2D(features),
|
||||
unpack2D(target_features),
|
||||
target_ids,
|
||||
model_params)
|
||||
|
||||
$$ LANGUAGE plpythonu VOLATILE PARALLEL RESTRICTED;
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
CDB_CreateAndPredictSegment (
|
||||
CDB_CreateAndPredictSegment(
|
||||
query TEXT,
|
||||
variable_name TEXT,
|
||||
target_table TEXT,
|
||||
@ -47,7 +49,54 @@ CREATE OR REPLACE FUNCTION
|
||||
min_samples_leaf INTEGER DEFAULT 1)
|
||||
RETURNS TABLE (cartodb_id TEXT, prediction NUMERIC, accuracy NUMERIC)
|
||||
AS $$
|
||||
from crankshaft.segmentation import create_and_predict_segment
|
||||
model_params = {'n_estimators': n_estimators, 'max_depth':max_depth, 'subsample' : subsample, 'learning_rate': learning_rate, 'min_samples_leaf' : min_samples_leaf}
|
||||
return create_and_predict_segment(query,variable_name,target_table, model_params)
|
||||
from crankshaft.segmentation import Segmentation
|
||||
seg = Segmentation()
|
||||
model_params = {
|
||||
'n_estimators': n_estimators,
|
||||
'max_depth': max_depth,
|
||||
'subsample': subsample,
|
||||
'learning_rate': learning_rate,
|
||||
'min_samples_leaf': min_samples_leaf
|
||||
}
|
||||
feature_cols = set(plpy.execute('''
|
||||
select * from ({query}) as _w limit 0
|
||||
'''.format(query=query)).colnames()) - set([variable_name, 'cartodb_id', ])
|
||||
return seg.create_and_predict_segment(
|
||||
query,
|
||||
variable_name,
|
||||
feature_cols,
|
||||
target_table,
|
||||
model_params
|
||||
)
|
||||
$$ LANGUAGE plpythonu VOLATILE PARALLEL UNSAFE;
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
CDB_CreateAndPredictSegment(
|
||||
query TEXT,
|
||||
variable TEXT,
|
||||
feature_columns TEXT[],
|
||||
target_query TEXT,
|
||||
n_estimators INTEGER DEFAULT 1200,
|
||||
max_depth INTEGER DEFAULT 3,
|
||||
subsample DOUBLE PRECISION DEFAULT 0.5,
|
||||
learning_rate DOUBLE PRECISION DEFAULT 0.01,
|
||||
min_samples_leaf INTEGER DEFAULT 1)
|
||||
RETURNS TABLE (cartodb_id TEXT, prediction NUMERIC, accuracy NUMERIC)
|
||||
AS $$
|
||||
from crankshaft.segmentation import Segmentation
|
||||
seg = Segmentation()
|
||||
model_params = {
|
||||
'n_estimators': n_estimators,
|
||||
'max_depth': max_depth,
|
||||
'subsample': subsample,
|
||||
'learning_rate': learning_rate,
|
||||
'min_samples_leaf': min_samples_leaf
|
||||
}
|
||||
return seg.create_and_predict_segment(
|
||||
query,
|
||||
variable,
|
||||
feature_columns,
|
||||
target_query,
|
||||
model_params
|
||||
)
|
||||
$$ LANGUAGE plpythonu VOLATILE PARALLEL UNSAFE;
|
||||
|
@ -25,3 +25,53 @@ t
|
||||
t
|
||||
t
|
||||
(20 rows)
|
||||
_cdb_random_seeds
|
||||
|
||||
(1 row)
|
||||
within_tolerance
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
(20 rows)
|
||||
_cdb_random_seeds
|
||||
|
||||
(1 row)
|
||||
within_tolerance
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
t
|
||||
(20 rows)
|
||||
|
4
src/pg/test/fixtures/ml_values.sql
vendored
4
src/pg/test/fixtures/ml_values.sql
vendored
@ -1,7 +1,7 @@
|
||||
SET client_min_messages TO WARNING;
|
||||
\set ECHO none
|
||||
CREATE TABLE ml_values (cartodb_id integer, target float, the_geom geometry, x1 float , x2 float, x3 float, class text);
|
||||
INSERT INTO ml_values(cartodb_id, target,x1,x2,x3, class) VALUES
|
||||
CREATE TABLE ml_values (cartodb_id integer, target float, the_geom geometry, x1 float, x2 float, x3 float, class text);
|
||||
INSERT INTO ml_values(cartodb_id, target, x1, x2, x3, class) VALUES
|
||||
(0,1.24382137034,0.811403626309,0.657584780869,0,'train'),
|
||||
(1,1.72727475342,0.447764244847,0.528687533966,1,'train'),
|
||||
(2,3.32104694099,0.62774565606,0.832647155118,2,'train'),
|
||||
|
@ -3,31 +3,141 @@
|
||||
\i test/fixtures/ml_values.sql
|
||||
SELECT cdb_crankshaft._cdb_random_seeds(1234);
|
||||
|
||||
-- second version (query, not specifying features)
|
||||
WITH expected AS (
|
||||
SELECT generate_series(1000,1020) AS id, unnest(ARRAY[
|
||||
4.5656517130822492,
|
||||
1.7928053473230694,
|
||||
1.0283378773916563,
|
||||
2.6586517814904593,
|
||||
2.9699056242935944,
|
||||
3.9550646059951347,
|
||||
4.1662572444459745,
|
||||
3.8126334839264162,
|
||||
1.8809821053623488,
|
||||
1.6349065129019873,
|
||||
3.0391288591472954,
|
||||
3.3035970359672553,
|
||||
1.5835471589451968,
|
||||
3.7530378537263638,
|
||||
1.0833589653009252,
|
||||
3.8104965452882897,
|
||||
2.665217959294802,
|
||||
1.5850334252802472,
|
||||
3.679401198805563,
|
||||
3.5332033186588636
|
||||
]) AS expected LIMIT 20
|
||||
SELECT
|
||||
generate_series(1000, 1020) AS id,
|
||||
unnest(ARRAY[4.5656517130822492,
|
||||
1.7928053473230694,
|
||||
1.0283378773916563,
|
||||
2.6586517814904593,
|
||||
2.9699056242935944,
|
||||
3.9550646059951347,
|
||||
4.1662572444459745,
|
||||
3.8126334839264162,
|
||||
1.8809821053623488,
|
||||
1.6349065129019873,
|
||||
3.0391288591472954,
|
||||
3.3035970359672553,
|
||||
1.5835471589451968,
|
||||
3.7530378537263638,
|
||||
1.0833589653009252,
|
||||
3.8104965452882897,
|
||||
2.665217959294802,
|
||||
1.5850334252802472,
|
||||
3.679401198805563,
|
||||
3.5332033186588636 ]) AS expected
|
||||
LIMIT 20
|
||||
), training as (
|
||||
SELECT
|
||||
array_agg(target)::numeric[] as target,
|
||||
cdb_crankshaft.CDB_PyAgg(Array[x1, x2, x3]::numeric[]) as features
|
||||
FROM (SELECT * FROM ml_values ORDER BY cartodb_id asc) as _w
|
||||
WHERE class = 'train'
|
||||
), testing As (
|
||||
SELECT
|
||||
cdb_crankshaft.CDB_PyAgg(Array[x1, x2, x3]::numeric[]) as features,
|
||||
array_agg(cartodb_id)::numeric[] as cartodb_ids
|
||||
FROM (SELECT * FROM ml_values ORDER BY cartodb_id asc) as _w
|
||||
WHERE class = 'test'
|
||||
), prediction AS (
|
||||
SELECT cartodb_id::integer id, prediction
|
||||
FROM cdb_crankshaft.CDB_CreateAndPredictSegment('SELECT target, x1, x2, x3 FROM ml_values WHERE class = $$train$$','target','SELECT cartodb_id, target, x1, x2, x3 FROM ml_values WHERE class = $$test$$')
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
cdb_crankshaft.CDB_CreateAndPredictSegment(
|
||||
(SELECT target FROM training),
|
||||
(SELECT features FROM training),
|
||||
(SELECT features FROM testing),
|
||||
(SELECT cartodb_ids FROM testing)
|
||||
)
|
||||
)
|
||||
SELECT
|
||||
abs(e.expected - p.prediction) <= 1e-1 AS within_tolerance
|
||||
FROM expected e, prediction p
|
||||
WHERE e.id = p.cartodb_id
|
||||
LIMIT 20;
|
||||
SELECT cdb_crankshaft._cdb_random_seeds(1234);
|
||||
|
||||
-- second version (query, not specifying features)
|
||||
WITH expected AS (
|
||||
SELECT
|
||||
generate_series(1000, 1020) AS id,
|
||||
unnest(ARRAY[4.5656517130822492,
|
||||
1.7928053473230694,
|
||||
1.0283378773916563,
|
||||
2.6586517814904593,
|
||||
2.9699056242935944,
|
||||
3.9550646059951347,
|
||||
4.1662572444459745,
|
||||
3.8126334839264162,
|
||||
1.8809821053623488,
|
||||
1.6349065129019873,
|
||||
3.0391288591472954,
|
||||
3.3035970359672553,
|
||||
1.5835471589451968,
|
||||
3.7530378537263638,
|
||||
1.0833589653009252,
|
||||
3.8104965452882897,
|
||||
2.665217959294802,
|
||||
1.5850334252802472,
|
||||
3.679401198805563,
|
||||
3.5332033186588636 ]) AS expected
|
||||
LIMIT 20
|
||||
), prediction AS (
|
||||
SELECT
|
||||
cartodb_id::integer id,
|
||||
prediction
|
||||
FROM cdb_crankshaft.CDB_CreateAndPredictSegment(
|
||||
'SELECT target, x1, x2, x3 FROM ml_values WHERE class = $$train$$ ORDER BY cartodb_id asc',
|
||||
'target',
|
||||
'SELECT cartodb_id, x1, x2, x3 FROM ml_values WHERE class = $$test$$ ORDER BY cartodb_id asc'
|
||||
)
|
||||
LIMIT 20
|
||||
) SELECT abs(e.expected - p.prediction) <= 1e-9 AS within_tolerance FROM expected e, prediction p WHERE e.id = p.id;
|
||||
)
|
||||
SELECT
|
||||
abs(e.expected - p.prediction) <= 1e-1 AS within_tolerance
|
||||
FROM expected e, prediction p
|
||||
WHERE e.id = p.id;
|
||||
|
||||
SELECT cdb_crankshaft._cdb_random_seeds(1234);
|
||||
-- third version (query, specifying features)
|
||||
WITH expected AS (
|
||||
SELECT
|
||||
generate_series(1000, 1020) AS id,
|
||||
unnest(ARRAY[4.5656517130822492,
|
||||
1.7928053473230694,
|
||||
1.0283378773916563,
|
||||
2.6586517814904593,
|
||||
2.9699056242935944,
|
||||
3.9550646059951347,
|
||||
4.1662572444459745,
|
||||
3.8126334839264162,
|
||||
1.8809821053623488,
|
||||
1.6349065129019873,
|
||||
3.0391288591472954,
|
||||
3.3035970359672553,
|
||||
1.5835471589451968,
|
||||
3.7530378537263638,
|
||||
1.0833589653009252,
|
||||
3.8104965452882897,
|
||||
2.665217959294802,
|
||||
1.5850334252802472,
|
||||
3.679401198805563,
|
||||
3.5332033186588636 ]) AS expected
|
||||
LIMIT 20
|
||||
), prediction AS (
|
||||
SELECT
|
||||
cartodb_id::integer id,
|
||||
prediction
|
||||
FROM cdb_crankshaft.CDB_CreateAndPredictSegment(
|
||||
'SELECT target, x1, x2, x3 FROM ml_values WHERE class = $$train$$',
|
||||
'target',
|
||||
Array['x1', 'x2', 'x3'],
|
||||
'SELECT cartodb_id, x1, x2, x3 FROM ml_values WHERE class = $$test$$'
|
||||
)
|
||||
LIMIT 20
|
||||
)
|
||||
SELECT
|
||||
abs(e.expected - p.prediction) <= 1e-1 AS within_tolerance
|
||||
FROM expected e, prediction p
|
||||
WHERE e.id = p.id;
|
||||
|
@ -16,7 +16,7 @@ def verify_data(func):
|
||||
plpy.error(NULL_VALUE_ERROR)
|
||||
else:
|
||||
return data
|
||||
except Exception as err:
|
||||
except plpy.SPIError as err:
|
||||
plpy.error('Analysis failed: {}'.format(err))
|
||||
|
||||
return []
|
||||
@ -25,26 +25,27 @@ def verify_data(func):
|
||||
|
||||
|
||||
class AnalysisDataProvider(object):
|
||||
"""Data fetching class for pl/python functions"""
|
||||
@verify_data
|
||||
def get_getis(self, w_type, params):
|
||||
def get_getis(self, w_type, params): # pylint: disable=no-self-use
|
||||
"""fetch data for getis ord's g"""
|
||||
query = pu.construct_neighbor_query(w_type, params)
|
||||
return plpy.execute(query)
|
||||
|
||||
@verify_data
|
||||
def get_markov(self, w_type, params):
|
||||
def get_markov(self, w_type, params): # pylint: disable=no-self-use
|
||||
"""fetch data for spatial markov"""
|
||||
query = pu.construct_neighbor_query(w_type, params)
|
||||
return plpy.execute(query)
|
||||
|
||||
@verify_data
|
||||
def get_moran(self, w_type, params):
|
||||
def get_moran(self, w_type, params): # pylint: disable=no-self-use
|
||||
"""fetch data for moran's i analyses"""
|
||||
query = pu.construct_neighbor_query(w_type, params)
|
||||
return plpy.execute(query)
|
||||
|
||||
@verify_data
|
||||
def get_nonspatial_kmeans(self, params):
|
||||
def get_nonspatial_kmeans(self, params): # pylint: disable=no-self-use
|
||||
"""
|
||||
Fetch data for non-spatial k-means.
|
||||
|
||||
@ -73,7 +74,57 @@ class AnalysisDataProvider(object):
|
||||
return plpy.execute(query)
|
||||
|
||||
@verify_data
|
||||
def get_spatial_kmeans(self, params):
|
||||
def get_segmentation_model_data(self, params): # pylint: disable=R0201
|
||||
"""
|
||||
fetch data for Segmentation
|
||||
params = {"subquery": query,
|
||||
"target": variable,
|
||||
"features": feature_columns}
|
||||
"""
|
||||
columns = ', '.join(['array_agg("{col}") As "{col}"'.format(col=col)
|
||||
for col in params['features']])
|
||||
query = '''
|
||||
SELECT
|
||||
array_agg("{target}") As target,
|
||||
{columns}
|
||||
FROM ({subquery}) As q
|
||||
'''.format(subquery=params['subquery'],
|
||||
target=params['target'],
|
||||
columns=columns)
|
||||
return plpy.execute(query)
|
||||
|
||||
@verify_data
|
||||
def get_segmentation_data(self, params): # pylint: disable=no-self-use
|
||||
"""
|
||||
params = {"subquery": target_query,
|
||||
"id_col": id_col}
|
||||
"""
|
||||
query = '''
|
||||
SELECT
|
||||
array_agg("{id_col}" ORDER BY "{id_col}") as "ids"
|
||||
FROM ({subquery}) as q
|
||||
'''.format(**params)
|
||||
return plpy.execute(query)
|
||||
|
||||
@verify_data
|
||||
def get_segmentation_predict_data(self, params): # pylint: disable=R0201
|
||||
"""
|
||||
fetch data for Segmentation
|
||||
params = {"subquery": target_query,
|
||||
"feature_columns": feature_columns}
|
||||
"""
|
||||
joined_features = ', '.join(['"{}"::numeric'.format(a)
|
||||
for a in params['feature_columns']])
|
||||
query = '''
|
||||
SELECT
|
||||
Array[{joined_features}] As features
|
||||
FROM ({subquery}) as q
|
||||
'''.format(subquery=params['subquery'],
|
||||
joined_features=joined_features)
|
||||
return plpy.cursor(query)
|
||||
|
||||
@verify_data
|
||||
def get_spatial_kmeans(self, params): # pylint: disable=no-self-use
|
||||
"""fetch data for spatial kmeans"""
|
||||
query = '''
|
||||
SELECT
|
||||
@ -86,13 +137,13 @@ class AnalysisDataProvider(object):
|
||||
return plpy.execute(query)
|
||||
|
||||
@verify_data
|
||||
def get_gwr(self, params):
|
||||
def get_gwr(self, params): # pylint: disable=no-self-use
|
||||
"""fetch data for gwr analysis"""
|
||||
query = pu.gwr_query(params)
|
||||
return plpy.execute(query)
|
||||
|
||||
@verify_data
|
||||
def get_gwr_predict(self, params):
|
||||
def get_gwr_predict(self, params): # pylint: disable=no-self-use
|
||||
"""fetch data for gwr predict"""
|
||||
query = pu.gwr_predict_query(params)
|
||||
return plpy.execute(query)
|
||||
|
@ -1 +1,2 @@
|
||||
"""Import all functions from for segmentation"""
|
||||
from segmentation import *
|
||||
|
@ -2,175 +2,227 @@
|
||||
Segmentation creation and prediction
|
||||
"""
|
||||
|
||||
import sklearn
|
||||
import numpy as np
|
||||
import plpy
|
||||
from sklearn.ensemble import GradientBoostingRegressor
|
||||
from sklearn import metrics
|
||||
from sklearn.cross_validation import train_test_split
|
||||
from crankshaft.analysis_data_provider import AnalysisDataProvider
|
||||
|
||||
# Lower level functions
|
||||
#----------------------
|
||||
# NOTE: added optional param here
|
||||
|
||||
def replace_nan_with_mean(array):
|
||||
|
||||
class Segmentation(object):
|
||||
"""
|
||||
Add docstring
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider=None):
|
||||
if data_provider is None:
|
||||
self.data_provider = AnalysisDataProvider()
|
||||
else:
|
||||
self.data_provider = data_provider
|
||||
|
||||
def create_and_predict_segment_agg(self, target, features, target_features,
|
||||
target_ids, model_parameters):
|
||||
"""
|
||||
Version of create_and_predict_segment that works on arrays that come
|
||||
straight form the SQL calling the function.
|
||||
|
||||
Input:
|
||||
@param target: The 1D array of length NSamples containing the
|
||||
target variable we want the model to predict
|
||||
@param features: The 2D array of size NSamples * NFeatures that
|
||||
form the input to the model
|
||||
@param target_ids: A 1D array of target_ids that will be used
|
||||
to associate the results of the prediction with the rows which
|
||||
they come from
|
||||
@param model_parameters: A dictionary containing parameters for
|
||||
the model.
|
||||
"""
|
||||
clean_target, _ = replace_nan_with_mean(target)
|
||||
clean_features, _ = replace_nan_with_mean(features)
|
||||
target_features, _ = replace_nan_with_mean(target_features)
|
||||
|
||||
model, accuracy = train_model(clean_target, clean_features,
|
||||
model_parameters, 0.2)
|
||||
prediction = model.predict(target_features)
|
||||
accuracy_array = [accuracy] * prediction.shape[0]
|
||||
return zip(target_ids, prediction, accuracy_array)
|
||||
|
||||
def create_and_predict_segment(self, query, variable, feature_columns,
|
||||
target_query, model_params,
|
||||
id_col='cartodb_id'):
|
||||
"""
|
||||
generate a segment with machine learning
|
||||
Stuart Lynn
|
||||
@param query: subquery that data is pulled from for packaging
|
||||
@param variable: name of the target variable
|
||||
@param feature_columns: list of column names
|
||||
@target_query: The query to run to obtain the data to predict
|
||||
@param model_params: A dictionary of model parameters, the full
|
||||
specification can be found on the
|
||||
scikit learn page for [GradientBoostingRegressor]
|
||||
(http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingRegressor.html)
|
||||
"""
|
||||
params = {"subquery": target_query,
|
||||
"id_col": id_col}
|
||||
|
||||
(target, features, target_mean,
|
||||
feature_means) = self.clean_data(query, variable, feature_columns)
|
||||
|
||||
model, accuracy = train_model(target, features, model_params, 0.2)
|
||||
result = self.predict_segment(model, feature_columns, target_query,
|
||||
feature_means)
|
||||
accuracy_array = [accuracy] * result.shape[0]
|
||||
|
||||
rowid = self.data_provider.get_segmentation_data(params)
|
||||
'''
|
||||
rowid = [{'ids': [2.9, 4.9, 4, 5, 6]}]
|
||||
'''
|
||||
return zip(rowid[0]['ids'], result, accuracy_array)
|
||||
|
||||
def predict_segment(self, model, feature_columns, target_query,
|
||||
feature_means):
|
||||
"""
|
||||
Use the provided model to predict the values for the new feature set
|
||||
Input:
|
||||
@param model: The pretrained model
|
||||
@features_col: A list of features to use in the
|
||||
model prediction (list of column names)
|
||||
@target_query: The query to run to obtain the data to predict
|
||||
on and the cartodb_ids associated with it.
|
||||
"""
|
||||
|
||||
batch_size = 1000
|
||||
params = {"subquery": target_query,
|
||||
"feature_columns": feature_columns}
|
||||
|
||||
results = []
|
||||
cursors = self.data_provider.get_segmentation_predict_data(params)
|
||||
'''
|
||||
cursors = [{'features': [[m1[0],m2[0],m3[0]],[m1[1],m2[1],m3[1]],
|
||||
[m1[2],m2[2],m3[2]]]}]
|
||||
'''
|
||||
|
||||
while True:
|
||||
rows = cursors.fetch(batch_size)
|
||||
if not rows:
|
||||
break
|
||||
batch = np.row_stack([np.array(row['features'])
|
||||
for row in rows]).astype(float)
|
||||
|
||||
batch = replace_nan_with_mean(batch, feature_means)[0]
|
||||
prediction = model.predict(batch)
|
||||
results.append(prediction)
|
||||
|
||||
# NOTE: we removed the cartodb_ids calculation in here
|
||||
return np.concatenate(results)
|
||||
|
||||
def clean_data(self, query, variable, feature_columns):
|
||||
"""
|
||||
Add docstring
|
||||
"""
|
||||
params = {"subquery": query,
|
||||
"target": variable,
|
||||
"features": feature_columns}
|
||||
|
||||
data = self.data_provider.get_segmentation_model_data(params)
|
||||
|
||||
'''
|
||||
data = [{'target': [2.9, 4.9, 4, 5, 6],
|
||||
'feature1': [1,2,3,4], 'feature2' : [2,3,4,5]}]
|
||||
'''
|
||||
|
||||
# extract target data from data_provider object
|
||||
target = np.array(data[0]['target'], dtype=float)
|
||||
|
||||
# put n feature data arrays into an n x m array of arrays
|
||||
features = np.column_stack([np.array(data[0][col])
|
||||
for col in feature_columns]).astype(float)
|
||||
|
||||
features, feature_means = replace_nan_with_mean(features)
|
||||
target, target_mean = replace_nan_with_mean(target)
|
||||
return target, features, target_mean, feature_means
|
||||
|
||||
|
||||
def replace_nan_with_mean(array, means=None):
|
||||
"""
|
||||
Input:
|
||||
@param array: an array of floats which may have null-valued entries
|
||||
@param array: an array of floats which may have null-valued
|
||||
entries
|
||||
Output:
|
||||
array with nans filled in with the mean of the dataset
|
||||
"""
|
||||
|
||||
# returns an array of rows and column indices
|
||||
indices = np.where(np.isnan(array))
|
||||
nanvals = np.isnan(array)
|
||||
indices = np.where(nanvals)
|
||||
|
||||
# iterate through entries which have nan values
|
||||
for row, col in zip(*indices):
|
||||
array[row, col] = np.mean(array[~np.isnan(array[:, col]), col])
|
||||
def loops(array, axis):
|
||||
try:
|
||||
return np.shape(array)[axis]
|
||||
except IndexError:
|
||||
return 1
|
||||
ran = loops(array, 1)
|
||||
|
||||
return array
|
||||
if means is None:
|
||||
means = {}
|
||||
|
||||
def get_data(variable, feature_columns, query):
|
||||
"""
|
||||
Fetch data from the database, clean, and package into
|
||||
numpy arrays
|
||||
Input:
|
||||
@param variable: name of the target variable
|
||||
@param feature_columns: list of column names
|
||||
@param query: subquery that data is pulled from for the packaging
|
||||
Output:
|
||||
prepared data, packaged into NumPy arrays
|
||||
"""
|
||||
if ran == 1:
|
||||
array = np.array(array)
|
||||
means[0] = np.mean(array[~np.isnan(array)])
|
||||
for row in zip(*indices):
|
||||
array[row] = means[0]
|
||||
else:
|
||||
for col in range(ran):
|
||||
means[col] = np.mean(array[~np.isnan(array[:, col]), col])
|
||||
for row, col in zip(*indices):
|
||||
array[row, col] = means[col]
|
||||
else:
|
||||
if ran == 1:
|
||||
for row in zip(*indices):
|
||||
array[row] = means[0]
|
||||
else:
|
||||
for row, col in zip(*indices):
|
||||
array[row, col] = means[col]
|
||||
|
||||
columns = ','.join(['array_agg("{col}") As "{col}"'.format(col=col) for col in feature_columns])
|
||||
|
||||
try:
|
||||
data = plpy.execute('''SELECT array_agg("{variable}") As target, {columns} FROM ({query}) As a'''.format(
|
||||
variable=variable,
|
||||
columns=columns,
|
||||
query=query))
|
||||
except Exception, e:
|
||||
plpy.error('Failed to access data to build segmentation model: %s' % e)
|
||||
|
||||
# extract target data from plpy object
|
||||
target = np.array(data[0]['target'])
|
||||
|
||||
# put n feature data arrays into an n x m array of arrays
|
||||
features = np.column_stack([np.array(data[0][col], dtype=float) for col in feature_columns])
|
||||
|
||||
return replace_nan_with_mean(target), replace_nan_with_mean(features)
|
||||
|
||||
# High level interface
|
||||
# --------------------
|
||||
|
||||
def create_and_predict_segment_agg(target, features, target_features, target_ids, model_parameters):
|
||||
"""
|
||||
Version of create_and_predict_segment that works on arrays that come stright form the SQL calling
|
||||
the function.
|
||||
|
||||
Input:
|
||||
@param target: The 1D array of lenth NSamples containing the target variable we want the model to predict
|
||||
@param features: Thw 2D array of size NSamples * NFeatures that form the imput to the model
|
||||
@param target_ids: A 1D array of target_ids that will be used to associate the results of the prediction with the rows which they come from
|
||||
@param model_parameters: A dictionary containing parameters for the model.
|
||||
"""
|
||||
|
||||
clean_target = replace_nan_with_mean(target)
|
||||
clean_features = replace_nan_with_mean(features)
|
||||
target_features = replace_nan_with_mean(target_features)
|
||||
|
||||
model, accuracy = train_model(clean_target, clean_features, model_parameters, 0.2)
|
||||
prediction = model.predict(target_features)
|
||||
accuracy_array = [accuracy]*prediction.shape[0]
|
||||
return zip(target_ids, prediction, np.full(prediction.shape, accuracy_array))
|
||||
|
||||
|
||||
|
||||
def create_and_predict_segment(query, variable, target_query, model_params):
|
||||
"""
|
||||
generate a segment with machine learning
|
||||
Stuart Lynn
|
||||
"""
|
||||
|
||||
## fetch column names
|
||||
try:
|
||||
columns = plpy.execute('SELECT * FROM ({query}) As a LIMIT 1 '.format(query=query))[0].keys()
|
||||
except Exception, e:
|
||||
plpy.error('Failed to build segmentation model: %s' % e)
|
||||
|
||||
## extract column names to be used in building the segmentation model
|
||||
feature_columns = set(columns) - set([variable, 'cartodb_id', 'the_geom', 'the_geom_webmercator'])
|
||||
## get data from database
|
||||
target, features = get_data(variable, feature_columns, query)
|
||||
|
||||
model, accuracy = train_model(target, features, model_params, 0.2)
|
||||
cartodb_ids, result = predict_segment(model, feature_columns, target_query)
|
||||
accuracy_array = [accuracy]*result.shape[0]
|
||||
return zip(cartodb_ids, result, accuracy_array)
|
||||
return array, means
|
||||
|
||||
|
||||
def train_model(target, features, model_params, test_split):
|
||||
"""
|
||||
Train the Gradient Boosting model on the provided data and calculate the accuracy of the model
|
||||
Train the Gradient Boosting model on the provided data to calculate
|
||||
the accuracy of the model
|
||||
Input:
|
||||
@param target: 1D Array of the variable that the model is to be trianed to predict
|
||||
@param features: 2D Array NSamples * NFeatures to use in trining the model
|
||||
@param model_params: A dictionary of model parameters, the full specification can be found on the
|
||||
scikit learn page for [GradientBoostingRegressor](http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingRegressor.html)
|
||||
@parma test_split: The fraction of the data to be withheld for testing the model / calculating the accuray
|
||||
@param target: 1D Array of the variable that the model is to be
|
||||
trained to predict
|
||||
@param features: 2D Array NSamples *NFeatures to use in training
|
||||
the model
|
||||
@param model_params: A dictionary of model parameters, the full
|
||||
specification can be found on the
|
||||
scikit learn page for [GradientBoostingRegressor]
|
||||
(http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingRegressor.html)
|
||||
@parma test_split: The fraction of the data to be withheld for
|
||||
testing the model / calculating the accuray
|
||||
"""
|
||||
features_train, features_test, target_train, target_test = train_test_split(features, target, test_size=test_split)
|
||||
features_train, features_test, \
|
||||
target_train, target_test = train_test_split(features, target,
|
||||
test_size=test_split)
|
||||
model = GradientBoostingRegressor(**model_params)
|
||||
model.fit(features_train, target_train)
|
||||
accuracy = calculate_model_accuracy(model, features, target)
|
||||
accuracy = calculate_model_accuracy(model, features_test, target_test)
|
||||
return model, accuracy
|
||||
|
||||
def calculate_model_accuracy(model, features, target):
|
||||
|
||||
def calculate_model_accuracy(model, features_test, target_test):
|
||||
"""
|
||||
Calculate the mean squared error of the model prediction
|
||||
Input:
|
||||
@param model: model trained from input features
|
||||
@param features: features to make a prediction from
|
||||
@param target: target to compare prediction to
|
||||
@param features_test: test features set to make prediction from
|
||||
@param target_test: test target set to compare predictions to
|
||||
Output:
|
||||
mean squared error of the model prection compared to the target
|
||||
mean squared error of the model prection compared target_test
|
||||
"""
|
||||
prediction = model.predict(features)
|
||||
return metrics.mean_squared_error(prediction, target)
|
||||
|
||||
def predict_segment(model, features, target_query):
|
||||
"""
|
||||
Use the provided model to predict the values for the new feature set
|
||||
Input:
|
||||
@param model: The pretrained model
|
||||
@features: A list of features to use in the model prediction (list of column names)
|
||||
@target_query: The query to run to obtain the data to predict on and the cartdb_ids associated with it.
|
||||
"""
|
||||
|
||||
batch_size = 1000
|
||||
joined_features = ','.join(['"{0}"::numeric'.format(a) for a in features])
|
||||
|
||||
try:
|
||||
cursor = plpy.cursor('SELECT Array[{joined_features}] As features FROM ({target_query}) As a'.format(
|
||||
joined_features=joined_features,
|
||||
target_query=target_query))
|
||||
except Exception, e:
|
||||
plpy.error('Failed to build segmentation model: %s' % e)
|
||||
|
||||
results = []
|
||||
|
||||
while True:
|
||||
rows = cursor.fetch(batch_size)
|
||||
if not rows:
|
||||
break
|
||||
batch = np.row_stack([np.array(row['features'], dtype=float) for row in rows])
|
||||
|
||||
#Need to fix this. Should be global mean. This will cause weird effects
|
||||
batch = replace_nan_with_mean(batch)
|
||||
prediction = model.predict(batch)
|
||||
results.append(prediction)
|
||||
|
||||
try:
|
||||
cartodb_ids = plpy.execute('''SELECT array_agg(cartodb_id ORDER BY cartodb_id) As cartodb_ids FROM ({0}) As a'''.format(target_query))[0]['cartodb_ids']
|
||||
except Exception, e:
|
||||
plpy.error('Failed to build segmentation model: %s' % e)
|
||||
|
||||
return cartodb_ids, np.concatenate(results)
|
||||
prediction = model.predict(features_test)
|
||||
return metrics.mean_squared_error(prediction, target_test)
|
||||
|
1
src/py/crankshaft/test/fixtures/data.json
vendored
Normal file
1
src/py/crankshaft/test/fixtures/data.json
vendored
Normal file
@ -0,0 +1 @@
|
||||
[{"ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100]}]
|
1
src/py/crankshaft/test/fixtures/model_data.json
vendored
Normal file
1
src/py/crankshaft/test/fixtures/model_data.json
vendored
Normal file
File diff suppressed because one or more lines are too long
1
src/py/crankshaft/test/fixtures/predict_data.json
vendored
Normal file
1
src/py/crankshaft/test/fixtures/predict_data.json
vendored
Normal file
@ -0,0 +1 @@
|
||||
[{"features": [[0.97297640975099997, 0.48162847641900003], [0.94720885324100002, 0.92519926071899994], [0.8264217730079999, 0.19415235826499999], [0.40411132589500004, 0.38843702575499994], [0.33854978708899996, 0.13416364950200002], [0.35582490007299999, 0.096314795897899999], [0.68616157039699999, 0.41675745974799999], [0.23344213791599999, 0.71210113960199994], [0.187353852663, 0.35711991569799995], [0.056479941924700003, 0.80824517339399993], [0.75088916614400003, 0.572151234131], [0.50246103346500004, 0.49712099904000001], [0.33471066946899997, 0.14859628011499998], [0.60793888599400003, 0.87417901532800002], [0.42749238417400004, 0.097680579671199988], [0.17386041095400001, 0.950866317121], [0.69179991520299999, 0.62516476948499999], [0.84292065094699997, 0.19294979300599999], [0.797120458074, 0.058631100303900001], [0.39566713420500005, 0.96256889448799998], [0.41760069426200003, 0.16947610752799999], [0.353538060524, 0.89931759966399993], [0.84031337913499993, 0.74075899320899996], [0.251836934939, 0.63771637374599999], [0.26998589843100002, 0.62860482510299998], [0.22862387681599999, 0.55551316083899993], [0.154559223986, 0.42489947463699995], [0.88445238717300001, 0.041340049733599997], [0.34388085383, 0.79776848695500002], [0.026095950094300002, 0.53555632848900003], [0.22821389194000002, 0.67315914298199997], [0.35382259735100002, 0.073131088591399995], [0.11108504124299999, 0.58760350502699998], [0.30541724734000003, 0.45383730649300003], [0.63908476061200004, 0.299226707285], [0.060675331022100001, 0.024030363590099999], [0.37411573949100002, 0.48261926695399998], [0.68008712032199992, 0.74278227822500009], [0.81078283291600006, 0.73578148610100003], [0.11804084458900001, 0.67352047988600006], [0.23648198865299999, 0.54946520524499998], [0.56246138984399996, 0.96654913930600006], [0.76249437673899989, 0.450702223969], [0.92400286800699993, 0.56661809273999997], [0.413103712525, 0.36844168088399998], [0.29401694488200003, 0.32987052741599998], [0.57119587292700003, 0.49035651293100002], [0.74037242300799999, 0.28066938607500003], [0.32431146912199998, 0.85648642227799998], [0.61177259413700003, 0.26440014588299998], [0.38144483824199998, 0.229178471927], [0.61478912278999998, 0.0332792237179], [0.39506149161100002, 0.81640329154900004], [0.92309519151199992, 0.66076039597499991], [0.737615452201, 0.235135236961], [0.64368138068500003, 0.40983272801299997], [0.96011821941400011, 0.48294852537400002], [0.81397312427699997, 0.694266791868], [0.16472588926500001, 0.79136948682200003], [0.62538739162000001, 0.58352242713799995], [0.586709961429, 0.52040796275799994], [0.30920667095499998, 0.54516843627099998], [0.83584993804700003, 0.49695224123699999], [0.28690881649200001, 0.99925119035900001], [0.26984583321200001, 0.940321403748], [0.87338723457800005, 0.80176187934499998], [0.95559172429499994, 0.45685424792700002], [0.39529067978400001, 0.89633782936100004], [0.98180058338499998, 0.36730602102700005], [0.50137731568599997, 0.92606654021300006], [0.72742655604899997, 0.376662449392], [0.16354554153799999, 0.12541796540399999], [0.88408208085500006, 0.10330853879799999], [0.43795633263400002, 0.35816882957900004], [0.61596499625299994, 0.31988646331699999], [0.295636219571, 0.63494760383299997], [0.57552353033299997, 0.012257362386], [0.79858186865700009, 0.225066238365], [0.55429278557100004, 0.73526463041500001], [0.447685806932, 0.67143491554699997], [0.42497690916399999, 0.182660253854], [0.492227688665, 0.16444651805500002], [0.46338713581500002, 0.46654784851499997], [0.55861373285899996, 0.73855313091300001], [0.147442147025, 0.15347305926800001], [0.87376257594500006, 0.54099499795700001], [0.38871958895900005, 0.94920731516299994], [0.37621131464300001, 0.335776604315], [0.59968417891600001, 0.33715395376199997], [0.54422177453599996, 0.598089524373], [0.82236256657000006, 0.44986426296600002], [0.638234177239, 0.48084368437299996], [0.50381001662400005, 0.300645579637], [0.71373630162799995, 0.61474740630800007], [0.039538912615400004, 0.60759494735999997], [0.62109308806700003, 0.26068279551199997], [0.080795357754100003, 0.40753672692800003], [0.61108858759999996, 0.79972473220100004], [0.67134808431199999, 0.10437712573499999], [0.10547807725199999, 0.0058468954790699993]]}]
|
1
src/py/crankshaft/test/fixtures/segmentation_result.json
vendored
Normal file
1
src/py/crankshaft/test/fixtures/segmentation_result.json
vendored
Normal file
@ -0,0 +1 @@
|
||||
[[4.6399276705817796, 0.0052868236922298225], [5.115554441401355, 0.0052868236922298225], [3.9279922238303424, 0.0052868236922298225], [3.3819641948267578, 0.0052868236922298225], [2.9132843041389509, 0.0052868236922298225], [2.876066696867833, 0.0052868236922298225], [4.0106272888112651, 0.0052868236922298225], [3.5783652270475974, 0.0052868236922298225], [2.9165716286821199, 0.0052868236922298225], [3.4108311334783568, 0.0052868236922298225], [4.3202132937804372, 0.0052868236922298225], [3.7479855400737048, 0.0052868236922298225], [2.9370765208742595, 0.0052868236922298225], [4.4630858731319449, 0.0052868236922298225], [2.9921697215186938, 0.0052868236922298225], [3.7783567974677217, 0.0052868236922298225], [4.2514291487926652, 0.0052868236922298225], [3.9658039808720535, 0.0052868236922298225], [3.723696295039459, 0.0052868236922298225], [4.2305764993690955, 0.0052868236922298225], [3.1241034993855421, 0.0052868236922298225], [4.0343877737948652, 0.0052868236922298225], [4.7864094703726359, 0.0052868236922298225], [3.4423141823770624, 0.0052868236922298225], [3.424225241703863, 0.0052868236922298225], [3.309201541170641, 0.0052868236922298225], [3.037867375630356, 0.0052868236922298225], [3.8380172470256544, 0.0052868236922298225], [3.8840548342704815, 0.0052868236922298225], [2.8781306594987903, 0.0052868236922298225], [3.4874554940106037, 0.0052868236922298225], [2.8254928573623284, 0.0052868236922298225], [3.0980811019970185, 0.0052868236922298225], [3.3153313414322114, 0.0052868236922298225], [3.7254807947737478, 0.0052868236922298225], [2.2352532389466111, 0.0052868236922298225], [3.398793991587584, 0.0052868236922298225], [4.393489711684496, 0.0052868236922298225], [4.6820658816158236, 0.0052868236922298225], [3.2930725801147198, 0.0052868236922298225], [3.3013108011535843, 0.0052868236922298225], [4.5169704979664962, 0.0052868236922298225], [4.2356395759837682, 0.0052868236922298225], [4.685867240919821, 0.0052868236922298225], [3.3666476683180364, 0.0052868236922298225], [3.1633810641520688, 0.0052868236922298225], [3.9284828602074846, 0.0052868236922298225], [3.8813794254923417, 0.0052868236922298225], [3.9767682468020018, 0.0052868236922298225], [3.6296971637437938, 0.0052868236922298225], [3.2336758867109574, 0.0052868236922298225], [3.3438434216857305, 0.0052868236922298225], [4.059745940545219, 0.0052868236922298225], [4.8003413624883429, 0.0052868236922298225], [3.8343150532526087, 0.0052868236922298225], [3.8884993452951977, 0.0052868236922298225], [4.5967216279010819, 0.0052868236922298225], [4.6317641832280811, 0.0052868236922298225], [3.5805166062443643, 0.0052868236922298225], [4.1049176867051367, 0.0052868236922298225], [3.9515389747788823, 0.0052868236922298225], [3.4250648002120125, 0.0052868236922298225], [4.4759157545508605, 0.0052868236922298225], [4.0134207861425963, 0.0052868236922298225], [3.8799241476802888, 0.0052868236922298225], [4.9781411173602796, 0.0052868236922298225], [4.5230126868924323, 0.0052868236922298225], [4.1529682867170568, 0.0052868236922298225], [4.4754108304977711, 0.0052868236922298225], [4.3132882554878655, 0.0052868236922298225], [4.0547786635287659, 0.0052868236922298225], [2.5688836012215037, 0.0052868236922298225], [3.889152819366271, 0.0052868236922298225], [3.3884811287288952, 0.0052868236922298225], [3.8286491083541225, 0.0052868236922298225], [3.4842580970352057, 0.0052868236922298225], [3.2207170727086329, 0.0052868236922298225], [3.9452244740355038, 0.0052868236922298225], [4.2400946327715978, 0.0052868236922298225], [3.8398869646230049, 0.0052868236922298225], [3.1242158541684319, 0.0052868236922298225], [3.2123888635213436, 0.0052868236922298225], [3.5900402737995578, 0.0052868236922298225], [4.2464905311370957, 0.0052868236922298225], [2.5886568078161565, 0.0052868236922298225], [4.6008521636045012, 0.0052868236922298225], [4.2038409929353815, 0.0052868236922298225], [3.3327313501720157, 0.0052868236922298225], [3.7948100469546913, 0.0052868236922298225], [4.0382728370257404, 0.0052868236922298225], [4.3126973580418575, 0.0052868236922298225], [3.976738340646583, 0.0052868236922298225], [3.4720389796281514, 0.0052868236922298225], [4.3014283833530316, 0.0052868236922298225], [3.0187012207036723, 0.0052868236922298225], [3.6486981350943344, 0.0052868236922298225], [2.8338354315095078, 0.0052868236922298225], [4.3507896147137961, 0.0052868236922298225], [3.4753809797796484, 0.0052868236922298225], [2.2399367208816638, 0.0052868236922298225]]
|
1
src/py/crankshaft/test/fixtures/true_result.json
vendored
Normal file
1
src/py/crankshaft/test/fixtures/true_result.json
vendored
Normal file
@ -0,0 +1 @@
|
||||
[[[4.4227215674645395]], [[5.2712118012993789]], [[3.6279373760418334]], [[3.38304104035302]], [[2.7761519796383083]], [[2.7263669419052903]], [[3.862757275091802]], [[3.7743654860778144]], [[2.9952706103894648]], [[3.7012102596745233]], [[4.2706362174772199]], [[3.7479335482775493]], [[2.7992585644337975]], [[4.6602663596480252]], [[2.8365997356035244]], [[4.1625232506719607]], [[4.288029411774362]], [[3.6502805624336396]], [[3.312942887719065]], [[4.5186384902849328]], [[2.9653532564494514]], [[4.3289422901142238]], [[4.7419880551200571]], [[3.6531881499003931]], [[3.6621884978514769]], [[3.4539621369025717]], [[3.0816377852518206]], [[3.4093586802263656]], [[4.1113582546549052]], [[3.1102565821185824]], [[3.6886391238733465]], [[2.6769960732095788]], [[3.3418345719183726]], [[3.3658004839965203]], [[3.5570805554883793]], [[2.1390737237132882]], [[3.5264121431452518]], [[4.5056952369329686]], [[4.6877372215758752]], [[3.5241022266554354]], [[3.4536533934696991]], [[4.7767903633790905]], [[4.0451460130466712]], [[4.5192404874918441]], [[3.3565389305543119]], [[3.1007664721556902]], [[3.837506835252591]], [[3.6718974066615448]], [[4.1994400482374701]], [[3.4464591829709863]], [[3.0305242012162878]], [[2.988742131620918]], [[4.2253988205149868]], [[4.7061635792179537]], [[3.5766936522234265]], [[3.7851875270538882]], [[4.4060743798682109]], [[4.6094932701511038]], [[3.8298278075415855]], [[4.1051259417055608]], [[3.9208808676586342]], [[3.5541468789732118]], [[4.2476793895442491]], [[4.4288656054562781]], [[4.285411557315129]], [[4.9136046105564342]], [[4.3470960822962557]], [[4.3856116783980914]], [[4.2073129171306984]], [[4.6041990539557842]], [[3.8444647328578898]], [[2.4961542431159094]], [[3.5327401988792424]], [[3.3732721581082883]], [[3.5637204210138624]], [[3.713349537021855]], [[2.8878000202718845]], [[3.6480052797146962]], [[4.3019684391870783]], [[4.0143985414914329]], [[3.0027858714530842]], [[3.0672345691071476]], [[3.6281764007528063]], [[4.315026861113993]], [[2.5281093390733806]], [[4.3926338598315251]], [[4.4814940137640589]], [[3.2358701805945751]], [[3.5738341758988197]], [[4.0125117105508474]], [[4.1332723757858041]], [[3.9190386346055655]], [[3.3570061842111683]], [[4.3000992650570122]], [[3.2744982636432503]], [[3.4530052231252344]], [[2.9362664904878524]], [[4.5160823458017774]], [[3.2157763779380728]], [[2.1699109068357223]]]
|
@ -43,6 +43,7 @@ class MockPlPy:
|
||||
self.infos.append(msg)
|
||||
|
||||
def error(self, msg):
|
||||
self.infos.append(msg)
|
||||
self.notices.append(msg)
|
||||
|
||||
def cursor(self, query):
|
||||
|
@ -1,64 +1,139 @@
|
||||
"""Tests for segmentation functionality"""
|
||||
import unittest
|
||||
import numpy as np
|
||||
from helper import plpy, fixture_file
|
||||
import crankshaft.segmentation as segmentation
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from crankshaft.analysis_data_provider import AnalysisDataProvider
|
||||
from crankshaft.segmentation import Segmentation
|
||||
from helper import fixture_file
|
||||
from mock_plpy import MockCursor
|
||||
|
||||
|
||||
class RawDataProvider(AnalysisDataProvider):
|
||||
"""Data Provider to overwrite the default SQL provider"""
|
||||
def __init__(self, data, model, predict):
|
||||
self.data = data
|
||||
self.model = model
|
||||
self.predict = predict
|
||||
|
||||
def get_segmentation_data(self, params): # pylint: disable=unused-argument
|
||||
"""return data"""
|
||||
return self.data
|
||||
|
||||
def get_segmentation_model_data(self, params): # pylint: disable=W0613
|
||||
"""return model data"""
|
||||
return self.model
|
||||
|
||||
def get_segmentation_predict_data(self, params): # pylint: disable=W0613
|
||||
"""return predict data"""
|
||||
return self.predict
|
||||
|
||||
|
||||
class SegmentationTest(unittest.TestCase):
|
||||
"""Testing class for Moran's I functions"""
|
||||
"""Testing class for Segmentation functions"""
|
||||
|
||||
def setUp(self):
|
||||
plpy._reset()
|
||||
|
||||
def generate_random_data(self,n_samples,random_state, row_type=False):
|
||||
x1 = random_state.uniform(size=n_samples)
|
||||
x2 = random_state.uniform(size=n_samples)
|
||||
x3 = random_state.randint(0, 4, size=n_samples)
|
||||
|
||||
y = x1+x2*x2+x3
|
||||
cartodb_id = range(len(x1))
|
||||
|
||||
if row_type:
|
||||
return [ {'features': vals} for vals in zip(x1,x2,x3)], y
|
||||
else:
|
||||
return [dict( zip(['x1','x2','x3','target', 'cartodb_id'],[x1,x2,x3,y,cartodb_id]))]
|
||||
self.params = {
|
||||
"query": 'SELECT * FROM segmentation_data',
|
||||
"variable": 'price',
|
||||
"feature_columns": ['m1', 'm2', 'm3', 'm4', 'm5', 'm6'],
|
||||
"target_query": 'SELECT * FROM segmentation_result',
|
||||
"id_col": 'cartodb_id',
|
||||
"model_params": {
|
||||
'n_estimators': 1200,
|
||||
'max_depth': 3,
|
||||
'subsample': 0.5,
|
||||
'learning_rate': 0.01,
|
||||
'min_samples_leaf': 1
|
||||
}
|
||||
}
|
||||
self.model_data = json.loads(
|
||||
open(fixture_file('model_data.json')).read())
|
||||
self.data = json.loads(
|
||||
open(fixture_file('data.json')).read())
|
||||
self.predict_data = json.loads(
|
||||
open(fixture_file('predict_data.json')).read())
|
||||
self.result_seg = json.loads(
|
||||
open(fixture_file('segmentation_result.json')).read())
|
||||
self.true_result = json.loads(
|
||||
open(fixture_file('true_result.json')).read())
|
||||
|
||||
def test_replace_nan_with_mean(self):
|
||||
"""test segmentation.test_replace_nan_with_mean"""
|
||||
from crankshaft.segmentation import replace_nan_with_mean
|
||||
test_array = np.array([1.2, np.nan, 3.2, np.nan, np.nan])
|
||||
result = replace_nan_with_mean(test_array, means=None)[0]
|
||||
expectation = np.array([1.2, 2.2, 3.2, 2.2, 2.2], dtype=float)
|
||||
self.assertItemsEqual(result, expectation)
|
||||
|
||||
def test_create_and_predict_segment(self):
|
||||
n_samples = 1000
|
||||
"""test segmentation.test_create_and_predict"""
|
||||
from crankshaft.segmentation import replace_nan_with_mean
|
||||
results = []
|
||||
feature_columns = ['m1', 'm2']
|
||||
feat = np.column_stack([np.array(self.model_data[0][col])
|
||||
for col in feature_columns]).astype(float)
|
||||
feature_means = replace_nan_with_mean(feat)[1]
|
||||
|
||||
random_state_train = np.random.RandomState(13)
|
||||
random_state_test = np.random.RandomState(134)
|
||||
training_data = self.generate_random_data(n_samples, random_state_train)
|
||||
test_data, test_y = self.generate_random_data(n_samples, random_state_test, row_type=True)
|
||||
# data_model is of the form:
|
||||
# [OrderedDict([('target', target),
|
||||
# ('features', feat),
|
||||
# ('target_mean', target_mean),
|
||||
# ('feature_means', feature_means),
|
||||
# ('feature_columns', feature_columns)])]
|
||||
data_model = self.model_data
|
||||
cursor = self.predict_data
|
||||
batch = []
|
||||
|
||||
batches = np.row_stack([np.array(row['features'])
|
||||
for row in cursor]).astype(float)
|
||||
batches = replace_nan_with_mean(batches, feature_means)[0]
|
||||
batch.append(batches)
|
||||
|
||||
ids = [{'cartodb_ids': range(len(test_data))}]
|
||||
rows = [{'x1': 0,'x2':0,'x3':0,'y':0,'cartodb_id':0}]
|
||||
data_predict = [OrderedDict([('features', d['features']),
|
||||
('batch', batch)])
|
||||
for d in self.predict_data]
|
||||
data_predict = MockCursor(data_predict)
|
||||
|
||||
plpy._define_result('select \* from \(select \* from training\) a limit 1',rows)
|
||||
plpy._define_result('.*from \(select \* from training\) as a' ,training_data)
|
||||
plpy._define_result('select array_agg\(cartodb\_id order by cartodb\_id\) as cartodb_ids from \(.*\) a',ids)
|
||||
plpy._define_result('.*select \* from test.*' ,test_data)
|
||||
model_parameters = {
|
||||
'n_estimators': 1200,
|
||||
'max_depth': 3,
|
||||
'subsample': 0.5,
|
||||
'learning_rate': 0.01,
|
||||
'min_samples_leaf': 1
|
||||
}
|
||||
data = [OrderedDict([('ids', d['ids'])])
|
||||
for d in self.data]
|
||||
|
||||
model_parameters = {'n_estimators': 1200,
|
||||
'max_depth': 3,
|
||||
'subsample' : 0.5,
|
||||
'learning_rate': 0.01,
|
||||
'min_samples_leaf': 1}
|
||||
seg = Segmentation(RawDataProvider(data, data_model,
|
||||
data_predict))
|
||||
|
||||
result = segmentation.create_and_predict_segment(
|
||||
'select * from training',
|
||||
'target',
|
||||
'select * from test',
|
||||
model_parameters)
|
||||
result = seg.create_and_predict_segment(
|
||||
'SELECT * FROM segmentation_test',
|
||||
'x_value',
|
||||
['m1', 'm2'],
|
||||
'SELECT * FROM segmentation_result',
|
||||
model_parameters,
|
||||
id_col='cartodb_id')
|
||||
results = [(row[1], row[2]) for row in result]
|
||||
zipped_values = zip(results, self.result_seg)
|
||||
pre_res = [r[0] for r in self.true_result]
|
||||
acc_res = [r[1] for r in self.result_seg]
|
||||
|
||||
prediction = [r[1] for r in result]
|
||||
# test values
|
||||
for (res_pre, _), (exp_pre, _) in zipped_values:
|
||||
diff = abs(res_pre - exp_pre) / np.mean([res_pre, exp_pre])
|
||||
self.assertTrue(diff <= 0.05, msg='diff: {}'.format(diff))
|
||||
diff = abs(res_pre - exp_pre) / np.mean([res_pre, exp_pre])
|
||||
self.assertTrue(diff <= 0.05, msg='diff: {}'.format(diff))
|
||||
prediction = [r[0] for r in results]
|
||||
|
||||
accuracy =np.sqrt(np.mean( np.square( np.array(prediction) - np.array(test_y))))
|
||||
accuracy = np.sqrt(np.mean(
|
||||
(np.array(prediction) - np.array(pre_res))**2
|
||||
))
|
||||
|
||||
self.assertEqual(len(result),len(test_data))
|
||||
self.assertTrue( result[0][2] < 0.01)
|
||||
self.assertTrue( accuracy < 0.5*np.mean(test_y) )
|
||||
self.assertEqual(len(results), len(self.result_seg))
|
||||
self.assertTrue(accuracy < 0.3 * np.mean(pre_res))
|
||||
self.assertTrue(results[0][1] < 0.01)
|
||||
|
Loading…
Reference in New Issue
Block a user