prototype of model writing
This commit is contained in:
parent
8c4057bb7a
commit
c2be340c07
@ -42,6 +42,7 @@ CREATE OR REPLACE FUNCTION
|
||||
query TEXT,
|
||||
variable_name TEXT,
|
||||
target_table TEXT,
|
||||
model_name text DEFAULT NULL,
|
||||
n_estimators INTEGER DEFAULT 1200,
|
||||
max_depth INTEGER DEFAULT 3,
|
||||
subsample DOUBLE PRECISION DEFAULT 0.5,
|
||||
@ -58,24 +59,59 @@ AS $$
|
||||
'learning_rate': learning_rate,
|
||||
'min_samples_leaf': min_samples_leaf
|
||||
}
|
||||
feature_cols = set(plpy.execute('''
|
||||
all_cols = list(plpy.execute('''
|
||||
select * from ({query}) as _w limit 0
|
||||
'''.format(query=query)).colnames()) - set([variable_name, 'cartodb_id', ])
|
||||
'''.format(query=query)).colnames())
|
||||
feature_cols = [a for a in all_cols
|
||||
if a not in [variable_name, 'cartodb_id', ]]
|
||||
return seg.create_and_predict_segment(
|
||||
query,
|
||||
variable_name,
|
||||
feature_cols,
|
||||
target_table,
|
||||
model_params
|
||||
model_params,
|
||||
model_name=model_name
|
||||
)
|
||||
$$ LANGUAGE plpythonu VOLATILE PARALLEL UNSAFE;
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
CDB_RetrieveModelParams(
|
||||
model_name text,
|
||||
param_name text
|
||||
)
|
||||
RETURNS TABLE(param numeric) AS $$
|
||||
|
||||
import pickle
|
||||
from collections import Iterable
|
||||
|
||||
plan = plpy.prepare('''
|
||||
SELECT model FROM model_storage
|
||||
WHERE name = $1;
|
||||
''', ['text', ])
|
||||
|
||||
try:
|
||||
model_encoded = plpy.execute(plan, [model_name, ])
|
||||
except plpy.SPIError as err:
|
||||
plpy.error('ERROR: {}'.format(err))
|
||||
|
||||
model = pickle.loads(
|
||||
model_encoded[0]['model']
|
||||
)
|
||||
|
||||
res = getattr(model, param_name)
|
||||
if not isinstance(res, Iterable):
|
||||
raise Exception('Cannot return `{}` as a table'.format(param_name))
|
||||
return res
|
||||
|
||||
$$ LANGUAGE plpythonu VOLATILE PARALLEL UNSAFE;
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
CDB_CreateAndPredictSegment(
|
||||
query TEXT,
|
||||
variable TEXT,
|
||||
feature_columns TEXT[],
|
||||
target_query TEXT,
|
||||
model_name TEXT DEFAULT NULL,
|
||||
n_estimators INTEGER DEFAULT 1200,
|
||||
max_depth INTEGER DEFAULT 3,
|
||||
subsample DOUBLE PRECISION DEFAULT 0.5,
|
||||
@ -97,6 +133,7 @@ AS $$
|
||||
variable,
|
||||
feature_columns,
|
||||
target_query,
|
||||
model_params
|
||||
model_params,
|
||||
model_name=model_name
|
||||
)
|
||||
$$ LANGUAGE plpythonu VOLATILE PARALLEL UNSAFE;
|
||||
|
@ -2,11 +2,14 @@
|
||||
Segmentation creation and prediction
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import plpy
|
||||
import numpy as np
|
||||
from sklearn.ensemble import GradientBoostingRegressor
|
||||
from sklearn import metrics
|
||||
from sklearn.cross_validation import train_test_split
|
||||
from crankshaft.analysis_data_provider import AnalysisDataProvider
|
||||
from crankshaft import model_storage
|
||||
|
||||
# NOTE: added optional param here
|
||||
|
||||
@ -51,6 +54,7 @@ class Segmentation(object):
|
||||
|
||||
def create_and_predict_segment(self, query, variable, feature_columns,
|
||||
target_query, model_params,
|
||||
model_name=None,
|
||||
id_col='cartodb_id'):
|
||||
"""
|
||||
generate a segment with machine learning
|
||||
@ -70,15 +74,23 @@ class Segmentation(object):
|
||||
(target, features, target_mean,
|
||||
feature_means) = self.clean_data(query, variable, feature_columns)
|
||||
|
||||
model, accuracy = train_model(target, features, model_params, 0.2)
|
||||
model_storage.create_model_table()
|
||||
|
||||
# find model if it exists and is specified
|
||||
if model_name is not None:
|
||||
model = model_storage.get_model(model_name)
|
||||
|
||||
if locals().get('model') is None:
|
||||
model, accuracy = train_model(target, features, model_params, 0.2)
|
||||
|
||||
result = self.predict_segment(model, feature_columns, target_query,
|
||||
feature_means)
|
||||
accuracy_array = [accuracy] * result.shape[0]
|
||||
|
||||
rowid = self.data_provider.get_segmentation_data(params)
|
||||
'''
|
||||
rowid = [{'ids': [2.9, 4.9, 4, 5, 6]}]
|
||||
'''
|
||||
|
||||
# store the model for later use
|
||||
model_storage.set_model(model, model_name, feature_columns)
|
||||
return zip(rowid[0]['ids'], result, accuracy_array)
|
||||
|
||||
def predict_segment(self, model, feature_columns, target_query,
|
||||
|
@ -41,7 +41,7 @@ setup(
|
||||
# The choice of component versions is dictated by what's
|
||||
# provisioned in the production servers.
|
||||
# IMPORTANT NOTE: please don't change this line. Instead issue a ticket to systems for evaluation.
|
||||
install_requires=['joblib==0.8.3', 'numpy==1.6.1', 'scipy==0.14.0', 'pysal==1.14.3', 'scikit-learn==0.14.1'],
|
||||
install_requires=['joblib==0.8.3', 'numpy==1.6.1', 'scipy==0.14.0', 'pysal==1.14.3', 'scikit-learn==0.14.1', 'petname==2.2'],
|
||||
|
||||
requires=['pysal', 'numpy', 'sklearn'],
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user