mirror of
https://github.com/CartoDB/crankshaft.git
synced 2024-11-01 10:20:48 +08:00
edits to clean up code
This commit is contained in:
parent
ee723aa3dc
commit
9c2f68fcaf
@ -12,6 +12,9 @@ from crankshaft.analysis_data_provider import AnalysisDataProvider
|
||||
# NOTE: added optional param here
|
||||
|
||||
class Segmentation(object):
|
||||
"""
|
||||
Add docstring
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider=None):
|
||||
if data_provider is None:
|
||||
@ -67,8 +70,9 @@ class Segmentation(object):
|
||||
params = {"subquery": target_query,
|
||||
"id_col": id_col}
|
||||
|
||||
target, features, target_mean,
|
||||
feature_means = clean_data(variable, feature_columns, query)
|
||||
target, features, target_mean, \
|
||||
feature_means = self.clean_data(variable, feature_columns, query)
|
||||
|
||||
model, accuracy = train_model(target, features, model_params, 0.2)
|
||||
result = self.predict_segment(model, feature_columns, target_query,
|
||||
feature_means)
|
||||
@ -112,7 +116,10 @@ class Segmentation(object):
|
||||
return np.concatenate(results)
|
||||
|
||||
|
||||
def clean_data(self, query, variable, feature_columns):
|
||||
def clean_data(self, query, variable, feature_columns):
|
||||
"""
|
||||
Add docstring
|
||||
"""
|
||||
params = {"subquery": query,
|
||||
"target": variable,
|
||||
"features": feature_columns}
|
||||
@ -171,7 +178,7 @@ def train_model(target, features, model_params, test_split):
|
||||
@parma test_split: The fraction of the data to be withheld for
|
||||
testing the model / calculating the accuray
|
||||
"""
|
||||
features_train, features_test,
|
||||
features_train, features_test, \
|
||||
target_train, target_test = train_test_split(features, target,
|
||||
test_size=test_split)
|
||||
model = GradientBoostingRegressor(**model_params)
|
||||
|
Loading…
Reference in New Issue
Block a user