mirror of
https://github.com/CartoDB/crankshaft.git
synced 2024-11-01 10:20:48 +08:00
bug fixes and adding of internal docs
This commit is contained in:
parent
3dad9c6044
commit
c6f64ad2f4
@ -11,20 +11,25 @@ $$ LANGUAGE plpythonu;
|
|||||||
|
|
||||||
-- Non-spatial k-means clustering
|
-- Non-spatial k-means clustering
|
||||||
-- query: sql query to retrieve all the needed data
|
-- query: sql query to retrieve all the needed data
|
||||||
|
-- colnames: text array of column names for doing the clustering analysis
|
||||||
|
-- standardize: whether to scale variables to a mean of zero and a standard
|
||||||
|
-- deviation of 1
|
||||||
|
-- id_colname: name of the id column
|
||||||
|
|
||||||
CREATE OR REPLACE FUNCTION CDB_KMeansNonspatial(
|
CREATE OR REPLACE FUNCTION CDB_KMeansNonspatial(
|
||||||
query TEXT,
|
query TEXT,
|
||||||
colnames TEXT[],
|
colnames TEXT[],
|
||||||
num_clusters INTEGER,
|
num_clusters INTEGER,
|
||||||
id_colname TEXT DEFAULT 'cartodb_id',
|
standardize BOOLEAN DEFAULT true,
|
||||||
standarize BOOLEAN DEFAULT true
|
id_colname TEXT DEFAULT 'cartodb_id'
|
||||||
)
|
)
|
||||||
RETURNS TABLE(cluster_label text, cluster_center json, silhouettes numeric, rowid bigint) AS $$
|
RETURNS TABLE(cluster_label text, cluster_center json, silhouettes numeric, rowid bigint) AS $$
|
||||||
|
|
||||||
from crankshaft.clustering import Kmeans
|
from crankshaft.clustering import Kmeans
|
||||||
kmeans = Kmeans()
|
kmeans = Kmeans()
|
||||||
return kmeans.nonspatial(query, colnames, num_clusters,
|
return kmeans.nonspatial(query, colnames, num_clusters,
|
||||||
id_colname, standarize)
|
standardize=standardize,
|
||||||
|
id_col=id_colname)
|
||||||
$$ LANGUAGE plpythonu;
|
$$ LANGUAGE plpythonu;
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,18 +45,34 @@ class AnalysisDataProvider:
|
|||||||
return pu.empty_zipped_array(2)
|
return pu.empty_zipped_array(2)
|
||||||
|
|
||||||
def get_nonspatial_kmeans(self, params):
|
def get_nonspatial_kmeans(self, params):
|
||||||
"""fetch data for non-spatial kmeans"""
|
"""
|
||||||
|
Fetch data for non-spatial k-means.
|
||||||
|
|
||||||
|
Inputs - a dict (params) with the following keys:
|
||||||
|
colnames: a (text) list of column names (e.g.,
|
||||||
|
`['andy', 'cookie']`)
|
||||||
|
id_col: the name of the id column (e.g., `'cartodb_id'`)
|
||||||
|
subquery: the subquery for exposing the data (e.g.,
|
||||||
|
SELECT * FROM favorite_things)
|
||||||
|
Output:
|
||||||
|
A SQL query for packaging the data for consumption within
|
||||||
|
`KMeans().nonspatial`. Format will be a list of length one,
|
||||||
|
with the first element a dict with keys ('rowid', 'attr1',
|
||||||
|
'attr2', ...)
|
||||||
|
"""
|
||||||
agg_cols = ', '.join(['array_agg({0}) As arr_col{1}'.format(val, idx+1)
|
agg_cols = ', '.join(['array_agg({0}) As arr_col{1}'.format(val, idx+1)
|
||||||
for idx, val in enumerate(params['colnames'])])
|
for idx, val in enumerate(params['colnames'])])
|
||||||
print agg_cols
|
|
||||||
query = '''
|
query = '''
|
||||||
SELECT {cols}, array_agg({id_col}) As rowid
|
SELECT {cols}, array_agg({id_col}) As rowid
|
||||||
FROM ({subquery}) As a
|
FROM ({subquery}) As a
|
||||||
'''.format(subquery=params['subquery'],
|
'''.format(subquery=params['subquery'],
|
||||||
id_col=params['id_col'],
|
id_col=params['id_col'],
|
||||||
cols=agg_cols)
|
cols=agg_cols).strip()
|
||||||
try:
|
try:
|
||||||
data = plpy.execute(query)
|
data = plpy.execute(query)
|
||||||
|
if len(data) == 0:
|
||||||
|
plpy.error('No non-null-valued data to analyze. Check the '
|
||||||
|
'rows and columns of all of the inputs')
|
||||||
return data
|
return data
|
||||||
except plpy.SPIError, err:
|
except plpy.SPIError, err:
|
||||||
plpy.error('Analysis failed: %s' % err)
|
plpy.error('Analysis failed: %s' % err)
|
||||||
@ -71,6 +87,9 @@ class AnalysisDataProvider:
|
|||||||
"WHERE {geom_col} IS NOT NULL").format(**params)
|
"WHERE {geom_col} IS NOT NULL").format(**params)
|
||||||
try:
|
try:
|
||||||
data = plpy.execute(query)
|
data = plpy.execute(query)
|
||||||
|
if len(data) == 0:
|
||||||
|
plpy.error('No non-null-valued data to analyze. Check the '
|
||||||
|
'rows and columns of all of the inputs')
|
||||||
return data
|
return data
|
||||||
except plpy.SPIError, err:
|
except plpy.SPIError, err:
|
||||||
plpy.error('Analysis failed: %s' % err)
|
plpy.error('Analysis failed: %s' % err)
|
||||||
|
@ -32,40 +32,45 @@ class Kmeans:
|
|||||||
return zip(ids, labels)
|
return zip(ids, labels)
|
||||||
|
|
||||||
def nonspatial(self, subquery, colnames, num_clusters=5,
|
def nonspatial(self, subquery, colnames, num_clusters=5,
|
||||||
id_col='cartodb_id', standarize=True):
|
standardize=True, id_col='cartodb_id'):
|
||||||
"""
|
"""
|
||||||
|
Inputs:
|
||||||
query (string): A SQL query to retrieve the data required to do the
|
query (string): A SQL query to retrieve the data required to do the
|
||||||
k-means clustering analysis, like so:
|
k-means clustering analysis, like so:
|
||||||
SELECT * FROM iris_flower_data
|
SELECT * FROM iris_flower_data
|
||||||
colnames (list): a list of the column names which contain the data
|
colnames (list): a list of the column names which contain the data
|
||||||
of interest, like so: ["sepal_width",
|
of interest, like so: ['sepal_width',
|
||||||
"petal_width",
|
'petal_width',
|
||||||
"sepal_length",
|
'sepal_length',
|
||||||
"petal_length"]
|
'petal_length']
|
||||||
num_clusters (int): number of clusters (greater than zero)
|
num_clusters (int): number of clusters (greater than zero)
|
||||||
id_col (string): name of the input id_column
|
id_col (string): name of the input id_column
|
||||||
|
|
||||||
|
Output:
|
||||||
|
A list of tuples with the following columns:
|
||||||
|
cluster labels: a label for the cluster that the row belongs to
|
||||||
|
centers: center of the cluster that this row belongs to
|
||||||
|
silhouettes: silhouette measure for this value
|
||||||
|
rowid: row that these values belong to (corresponds to the value in
|
||||||
|
`id_col`)
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
|
||||||
out_id_colname = 'rowids'
|
|
||||||
# TODO: need a random seed?
|
# TODO: need a random seed?
|
||||||
params = {"cols": colnames,
|
params = {"colnames": colnames,
|
||||||
"subquery": subquery,
|
"subquery": subquery,
|
||||||
"id_col": id_col}
|
"id_col": id_col}
|
||||||
|
|
||||||
data = self.data_provider.get_nonspatial_kmeans(params, standarize)
|
data = self.data_provider.get_nonspatial_kmeans(params)
|
||||||
|
|
||||||
# fill array with values for k-means clustering
|
# fill array with values for k-means clustering
|
||||||
if standarize:
|
if standardize:
|
||||||
cluster_columns = _scale_data(
|
cluster_columns = _scale_data(
|
||||||
_extract_columns(data, len(colnames)))
|
_extract_columns(data, len(colnames)))
|
||||||
else:
|
else:
|
||||||
cluster_columns = _extract_columns(data, len(colnames))
|
cluster_columns = _extract_columns(data, len(colnames))
|
||||||
|
|
||||||
print str(cluster_columns)
|
|
||||||
# TODO: decide on optimal parameters for most cases
|
|
||||||
# Are there ways of deciding parameters based on inputs?
|
|
||||||
kmeans = KMeans(n_clusters=num_clusters,
|
kmeans = KMeans(n_clusters=num_clusters,
|
||||||
random_state=0).fit(cluster_columns)
|
random_state=0).fit(cluster_columns)
|
||||||
|
|
||||||
@ -79,7 +84,7 @@ class Kmeans:
|
|||||||
return zip(kmeans.labels_,
|
return zip(kmeans.labels_,
|
||||||
centers,
|
centers,
|
||||||
silhouettes,
|
silhouettes,
|
||||||
data[0][out_id_colname])
|
data[0]['rowid'])
|
||||||
|
|
||||||
|
|
||||||
# -- Preprocessing steps
|
# -- Preprocessing steps
|
||||||
@ -102,4 +107,5 @@ def _scale_data(features):
|
|||||||
features (numpy matrix): features of dimension (n_features, n_samples)
|
features (numpy matrix): features of dimension (n_features, n_samples)
|
||||||
"""
|
"""
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
return StandardScaler().fit_transform(features)
|
scaler = StandardScaler()
|
||||||
|
return scaler.fit_transform(features)
|
||||||
|
@ -19,7 +19,7 @@ class FakeDataProvider(AnalysisDataProvider):
|
|||||||
def get_spatial_kmeans(self, query):
|
def get_spatial_kmeans(self, query):
|
||||||
return self.mocked_result
|
return self.mocked_result
|
||||||
|
|
||||||
def get_nonspatial_kmeans(self, query, standarize):
|
def get_nonspatial_kmeans(self, query):
|
||||||
return self.mocked_result
|
return self.mocked_result
|
||||||
|
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ class KMeansNonspatialTest(unittest.TestCase):
|
|||||||
# http://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn-cluster-kmeans
|
# http://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn-cluster-kmeans
|
||||||
data_raw = [OrderedDict([("arr_col1", [1, 1, 1, 4, 4, 4]),
|
data_raw = [OrderedDict([("arr_col1", [1, 1, 1, 4, 4, 4]),
|
||||||
("arr_col2", [2, 4, 0, 2, 4, 0]),
|
("arr_col2", [2, 4, 0, 2, 4, 0]),
|
||||||
("rowids", [1, 2, 3, 4, 5, 6])])]
|
("rowid", [1, 2, 3, 4, 5, 6])])]
|
||||||
|
|
||||||
random_seeds.set_random_seeds(1234)
|
random_seeds.set_random_seeds(1234)
|
||||||
kmeans = Kmeans(FakeDataProvider(data_raw))
|
kmeans = Kmeans(FakeDataProvider(data_raw))
|
||||||
|
Loading…
Reference in New Issue
Block a user