mirror of
https://github.com/CartoDB/crankshaft.git
synced 2024-11-01 10:20:48 +08:00
adding passing tests
This commit is contained in:
parent
cfb40ddecd
commit
42e760b5d1
@ -54,7 +54,7 @@ def spatial_markov_trend(subquery, time_cols, num_time_per_bin, permutations, ge
|
||||
## rebin time data
|
||||
if num_time_per_bin > 1:
|
||||
## rebin
|
||||
t_data = rebin_data(t_data, num_time_per_bin)
|
||||
t_data = rebin_data(t_data, int(num_time_per_bin))
|
||||
|
||||
sp_markov_result = ps.Spatial_Markov(t_data, weights, k=7, fixed=False)
|
||||
|
||||
@ -68,7 +68,7 @@ def spatial_markov_trend(subquery, time_cols, num_time_per_bin, permutations, ge
|
||||
prob_dist = get_prob_dist(lag_classes, sp_markov_result.classes)
|
||||
|
||||
## find the ups and down and overall distribution of each cell
|
||||
trend, trend_up, trend_down, volatility = get_prob_stats(prob_dist)
|
||||
trend_up, trend_down, trend, volatility = get_prob_stats(prob_dist)
|
||||
|
||||
## output the results
|
||||
|
||||
@ -127,12 +127,29 @@ def get_prob_dist(transition_matrix, lag_indices, unit_indices):
|
||||
return np.array([transition_matrix[(lag_indices[i], unit_indices[i])] for i in range(len(lag_indices))])
|
||||
|
||||
def get_prob_stats(prob_dist, unit_indices):
|
||||
# trend, trend_up, trend_down, volatility = get_prob_stats(prob_dist)
|
||||
"""
|
||||
get the statistics of the probability distributions
|
||||
|
||||
trend_up = np.array([prob_dist[:, i:].sum() for i in unit_indices])
|
||||
trend_down = np.array([prob_dist[:, :i].sum() for i in unit_indices])
|
||||
trend = trend_up - trend_down
|
||||
Outputs:
|
||||
@param trend_up ndarray(float): sum of probabilities for upward
|
||||
movement (relative to the unit index of that prob)
|
||||
@param trend_down ndarray(float): sum of probabilities for downard
|
||||
movement (relative to the unit index of that prob)
|
||||
@param trend ndarray(float): difference of upward and downward
|
||||
movements
|
||||
"""
|
||||
|
||||
num_elements = len(prob_dist)
|
||||
trend_up = np.empty(num_elements)
|
||||
trend_down = np.empty(num_elements)
|
||||
trend = np.empty(num_elements)
|
||||
|
||||
for i in range(num_elements):
|
||||
trend_up[i] = prob_dist[i, (unit_indices[i]+1):].sum()
|
||||
trend_down[i] = prob_dist[i, :unit_indices[i]].sum()
|
||||
trend[i] = (trend_up[i] - trend_down[i]) / prob_dist[i, unit_indices[i]]
|
||||
|
||||
## calculate volatility of distribution
|
||||
volatility = prob_dist.std(axis=1)
|
||||
|
||||
|
||||
return trend_up, trend_down, trend, volatility
|
||||
|
@ -122,5 +122,28 @@ class SpaceTimeTests(unittest.TestCase):
|
||||
|
||||
self.assertTrue(np.array_equal(result, answer))
|
||||
|
||||
def test_get_prob_stats(self):
|
||||
"""Test get_prob_stats"""
|
||||
|
||||
probs = np.array([
|
||||
[ 0.0754717 , 0.88207547, 0.04245283, 0. , 0. ],
|
||||
[ 0. , 0. , 0.09411765, 0.87058824, 0.03529412],
|
||||
[ 0.0049505 , 0.09405941, 0.77722772, 0.11881188, 0.0049505 ],
|
||||
[ 0. , 0. , 0. , 0.02352941, 0.97647059]
|
||||
])
|
||||
unit_indices = np.array([1, 3, 2, 4])
|
||||
answer_up = np.array([0.04245283, 0.03529412, 0.12376238, 0.])
|
||||
answer_down = np.array([0.0754717, 0.09411765, 0.0990099, 0.02352941])
|
||||
answer_trend = np.array([-0.03301887 / 0.88207547, -0.05882353 / 0.87058824, 0.02475248 / 0.77722772, -0.02352941 / 0.97647059])
|
||||
answer_volatility = np.array([ 0.34221495, 0.33705421, 0.29226542, 0.38834223])
|
||||
|
||||
result = std.get_prob_stats(probs, unit_indices)
|
||||
result_up = result[0]
|
||||
result_down = result[1]
|
||||
result_trend = result[2]
|
||||
result_volatility = result[3]
|
||||
|
||||
self.assertTrue(np.allclose(result_up, answer_up))
|
||||
self.assertTrue(np.allclose(result_down, answer_down))
|
||||
self.assertTrue(np.allclose(result_trend, answer_trend))
|
||||
self.assertTrue(np.allclose(result_volatility, answer_volatility))
|
||||
|
Loading…
Reference in New Issue
Block a user