mirror of
https://github.com/CartoDB/crankshaft.git
synced 2024-11-01 10:20:48 +08:00
adds psql connector
This commit is contained in:
parent
8b061bac72
commit
5be5a48894
@ -1,35 +1,59 @@
|
||||
import numpy as np
|
||||
from base.gwr import GWR
|
||||
from base.sel_bw import Sel_BW
|
||||
from gwr.base.gwr import GWR
|
||||
from gwr.base.sel_bw import Sel_BW
|
||||
|
||||
def gwr(subquery, dep_var, ind_vars, fixed=False, kernel='bisquare'):
|
||||
|
||||
def gwr(subquery, dep_var, ind_vars,
|
||||
fixed=False, kernel='bisquare'):
|
||||
"""
|
||||
subquery: 'select * from interesting_table'
|
||||
subquery: 'select * from demographics'
|
||||
dep_var: 'pctbachelor'
|
||||
ind_vars: ['intercept', 'pctpov', 'pctrural', 'pctblack']
|
||||
fixed: False (kNN) or True ('distance')
|
||||
kernel: 'bisquare' (default), or 'exponential', 'gaussian'
|
||||
"""
|
||||
|
||||
query_result = subquery
|
||||
rowid = np.array(query_result[0]['rowid'])
|
||||
# query_result = subquery
|
||||
# rowid = np.array(query_result[0]['rowid'])
|
||||
params = {'geom_col': 'the_geom',
|
||||
'id_col': 'cartodb_id',
|
||||
'subquery': subquery,
|
||||
'dep_var': dep_var,
|
||||
'ind_vars': ind_vars}
|
||||
|
||||
try:
|
||||
query_result = plpy.execute(pu.gwr_query(params))
|
||||
except plpy.SPIError, err:
|
||||
plpy.error('Analysis failed: %s' % err)
|
||||
|
||||
# TODO: should x, y be centroids? point on surface?
|
||||
# lat, long coordinates
|
||||
x = np.array(query_result[0]['x'])
|
||||
y = np.array(query_result[0]['y'])
|
||||
coords = zip(x, y)
|
||||
|
||||
Y = query_result[0]['dep'].reshape((-1, 1))
|
||||
# extract dependent variable
|
||||
Y = query_result[0]['dep_var'].reshape((-1, 1))
|
||||
|
||||
n = Y.shape[0]
|
||||
k = len(ind_vars)
|
||||
X = np.zeros((n, k))
|
||||
|
||||
for attr in range(0, k):
|
||||
attr_name = 'attr' + str(attr+1)
|
||||
X[:, attr] = np.array(query_result[0][attr_name]).flatten()
|
||||
attr_name = 'attr' + str(attr + 1)
|
||||
X[:, attr] = np.array(
|
||||
query_result[0][attr_name]).flatten()
|
||||
|
||||
bw = Sel_BW(coords, Y, X, fixed=fixed, kernel=kernel).search()
|
||||
model = GWR(coords, Y, X, bw, fixed=fixed, kernel=kernel).fit()
|
||||
# calculate bandwidth
|
||||
bw = Sel_BW(coords, Y, X,
|
||||
fixed=fixed, kernel=kernel).search()
|
||||
model = GWR(coords, Y, X, bw,
|
||||
fixed=fixed, kernel=kernel).fit()
|
||||
|
||||
# TODO: iterate from 0, n-1 and fill objects like this, for a
|
||||
# column called coeffs:
|
||||
# {'pctrural': ..., 'pctpov': ..., ...}
|
||||
# Follow the same structure for other outputs
|
||||
coefficients = model.params.reshape((-1,))
|
||||
t_vals = model.tvalues.reshape((-1,))
|
||||
stand_errs = model.bse.reshape((-1))
|
||||
@ -39,4 +63,5 @@ def gwr(subquery, dep_var, ind_vars, fixed=False, kernel='bisquare'):
|
||||
rowid = np.tile(rowid, k+1).reshape((-1,))
|
||||
var_name = np.tile(ind_vars, k+1).reshape((-1,))
|
||||
|
||||
return zip(coefficients, stand_errs, t_vals, predicted, residuals, r_squared, rowid, var_name)
|
||||
return zip(coefficients, stand_errs, t_vals, predicted,
|
||||
residuals, r_squared, rowid)
|
||||
|
Loading…
Reference in New Issue
Block a user