prototype of model writing

This commit is contained in:
Andy Eschbacher 2018-03-16 16:21:00 -04:00
parent 8c4057bb7a
commit c2be340c07
3 changed files with 58 additions and 9 deletions

View File

@ -42,6 +42,7 @@ CREATE OR REPLACE FUNCTION
query TEXT,
variable_name TEXT,
target_table TEXT,
model_name text DEFAULT NULL,
n_estimators INTEGER DEFAULT 1200,
max_depth INTEGER DEFAULT 3,
subsample DOUBLE PRECISION DEFAULT 0.5,
@ -58,24 +59,59 @@ AS $$
'learning_rate': learning_rate,
'min_samples_leaf': min_samples_leaf
}
feature_cols = set(plpy.execute('''
all_cols = list(plpy.execute('''
select * from ({query}) as _w limit 0
'''.format(query=query)).colnames()) - set([variable_name, 'cartodb_id', ])
'''.format(query=query)).colnames())
feature_cols = [a for a in all_cols
if a not in [variable_name, 'cartodb_id', ]]
return seg.create_and_predict_segment(
query,
variable_name,
feature_cols,
target_table,
model_params
model_params,
model_name=model_name
)
$$ LANGUAGE plpythonu VOLATILE PARALLEL UNSAFE;
CREATE OR REPLACE FUNCTION
CDB_RetrieveModelParams(
model_name text,
param_name text
)
RETURNS TABLE(param numeric) AS $$
import pickle
from collections import Iterable
plan = plpy.prepare('''
SELECT model FROM model_storage
WHERE name = $1;
''', ['text', ])
try:
model_encoded = plpy.execute(plan, [model_name, ])
except plpy.SPIError as err:
plpy.error('ERROR: {}'.format(err))
model = pickle.loads(
model_encoded[0]['model']
)
res = getattr(model, param_name)
if not isinstance(res, Iterable):
raise Exception('Cannot return `{}` as a table'.format(param_name))
return res
$$ LANGUAGE plpythonu VOLATILE PARALLEL UNSAFE;
CREATE OR REPLACE FUNCTION
CDB_CreateAndPredictSegment(
query TEXT,
variable TEXT,
feature_columns TEXT[],
target_query TEXT,
model_name TEXT DEFAULT NULL,
n_estimators INTEGER DEFAULT 1200,
max_depth INTEGER DEFAULT 3,
subsample DOUBLE PRECISION DEFAULT 0.5,
@ -97,6 +133,7 @@ AS $$
variable,
feature_columns,
target_query,
model_params
model_params,
model_name=model_name
)
$$ LANGUAGE plpythonu VOLATILE PARALLEL UNSAFE;

View File

@ -2,11 +2,14 @@
Segmentation creation and prediction
"""
import pickle
import plpy
import numpy as np
from sklearn.ensemble import GradientBoostingRegressor
from sklearn import metrics
from sklearn.cross_validation import train_test_split
from crankshaft.analysis_data_provider import AnalysisDataProvider
from crankshaft import model_storage
# NOTE: added optional param here
@ -51,6 +54,7 @@ class Segmentation(object):
def create_and_predict_segment(self, query, variable, feature_columns,
target_query, model_params,
model_name=None,
id_col='cartodb_id'):
"""
generate a segment with machine learning
@ -70,15 +74,23 @@ class Segmentation(object):
(target, features, target_mean,
feature_means) = self.clean_data(query, variable, feature_columns)
model, accuracy = train_model(target, features, model_params, 0.2)
model_storage.create_model_table()
# find model if it exists and is specified
if model_name is not None:
model = model_storage.get_model(model_name)
if locals().get('model') is None:
model, accuracy = train_model(target, features, model_params, 0.2)
result = self.predict_segment(model, feature_columns, target_query,
feature_means)
accuracy_array = [accuracy] * result.shape[0]
rowid = self.data_provider.get_segmentation_data(params)
'''
rowid = [{'ids': [2.9, 4.9, 4, 5, 6]}]
'''
# store the model for later use
model_storage.set_model(model, model_name, feature_columns)
return zip(rowid[0]['ids'], result, accuracy_array)
def predict_segment(self, model, feature_columns, target_query,

View File

@ -41,7 +41,7 @@ setup(
# The choice of component versions is dictated by what's
# provisioned in the production servers.
# IMPORTANT NOTE: please don't change this line. Instead issue a ticket to systems for evaluation.
install_requires=['joblib==0.8.3', 'numpy==1.6.1', 'scipy==0.14.0', 'pysal==1.14.3', 'scikit-learn==0.14.1'],
install_requires=['joblib==0.8.3', 'numpy==1.6.1', 'scipy==0.14.0', 'pysal==1.14.3', 'scikit-learn==0.14.1', 'petname==2.2'],
requires=['pysal', 'numpy', 'sklearn'],