diff --git a/src/py/crankshaft/crankshaft/clustering/kmeans.py b/src/py/crankshaft/crankshaft/clustering/kmeans.py index 2477d80..fe6831f 100644 --- a/src/py/crankshaft/crankshaft/clustering/kmeans.py +++ b/src/py/crankshaft/crankshaft/clustering/kmeans.py @@ -31,7 +31,7 @@ class Kmeans: labels = km.fit_predict(zip(xs, ys)) return zip(ids, labels) - def nonspatial(self, subquery, colnames, num_clusters=5, + def nonspatial(self, subquery, colnames, no_clusters=5, standardize=True, id_col='cartodb_id'): """ Inputs: @@ -43,7 +43,7 @@ class Kmeans: 'petal_width', 'sepal_length', 'petal_length'] - num_clusters (int): number of clusters (greater than zero) + no_clusters (int): number of clusters (greater than zero) id_col (string): name of the input id_column Output: @@ -71,7 +71,7 @@ class Kmeans: else: cluster_columns = _extract_columns(data, len(colnames)) - kmeans = KMeans(n_clusters=num_clusters, + kmeans = KMeans(n_clusters=no_clusters, random_state=0).fit(cluster_columns) centers = [json.dumps(dict(zip(colnames, c)))