diff --git a/src/py/crankshaft/crankshaft/space_time_dynamics/markov.py b/src/py/crankshaft/crankshaft/space_time_dynamics/markov.py index 911d558..db2be6f 100644 --- a/src/py/crankshaft/crankshaft/space_time_dynamics/markov.py +++ b/src/py/crankshaft/crankshaft/space_time_dynamics/markov.py @@ -54,7 +54,7 @@ def spatial_markov_trend(subquery, time_cols, num_time_per_bin, permutations, ge ## rebin time data if num_time_per_bin > 1: ## rebin - t_data = rebin_data(t_data, num_time_per_bin) + t_data = rebin_data(t_data, int(num_time_per_bin)) sp_markov_result = ps.Spatial_Markov(t_data, weights, k=7, fixed=False) @@ -68,7 +68,7 @@ def spatial_markov_trend(subquery, time_cols, num_time_per_bin, permutations, ge prob_dist = get_prob_dist(lag_classes, sp_markov_result.classes) ## find the ups and down and overall distribution of each cell - trend, trend_up, trend_down, volatility = get_prob_stats(prob_dist) + trend_up, trend_down, trend, volatility = get_prob_stats(prob_dist) ## output the results @@ -127,12 +127,29 @@ def get_prob_dist(transition_matrix, lag_indices, unit_indices): return np.array([transition_matrix[(lag_indices[i], unit_indices[i])] for i in range(len(lag_indices))]) def get_prob_stats(prob_dist, unit_indices): -# trend, trend_up, trend_down, volatility = get_prob_stats(prob_dist) + """ + get the statistics of the probability distributions - trend_up = np.array([prob_dist[:, i:].sum() for i in unit_indices]) - trend_down = np.array([prob_dist[:, :i].sum() for i in unit_indices]) - trend = trend_up - trend_down + Outputs: + @param trend_up ndarray(float): sum of probabilities for upward + movement (relative to the unit index of that prob) + @param trend_down ndarray(float): sum of probabilities for downard + movement (relative to the unit index of that prob) + @param trend ndarray(float): difference of upward and downward + movements + """ + + num_elements = len(prob_dist) + trend_up = np.empty(num_elements) + trend_down = np.empty(num_elements) + trend = np.empty(num_elements) + + for i in range(num_elements): + trend_up[i] = prob_dist[i, (unit_indices[i]+1):].sum() + trend_down[i] = prob_dist[i, :unit_indices[i]].sum() + trend[i] = (trend_up[i] - trend_down[i]) / prob_dist[i, unit_indices[i]] + + ## calculate volatility of distribution volatility = prob_dist.std(axis=1) - return trend_up, trend_down, trend, volatility diff --git a/src/py/crankshaft/test/test_space_time_dynamics.py b/src/py/crankshaft/test/test_space_time_dynamics.py index dd7b8b0..c35aea5 100644 --- a/src/py/crankshaft/test/test_space_time_dynamics.py +++ b/src/py/crankshaft/test/test_space_time_dynamics.py @@ -28,9 +28,9 @@ class SpaceTimeTests(unittest.TestCase): "num_ngbrs": 321} self.neighbors_data = json.loads(open(fixture_file('neighbors.json')).read()) self.moran_data = json.loads(open(fixture_file('moran.json')).read()) - + self.time_data = np.array([i * np.ones(10, dtype=float) for i in range(10)]).T - + self.transition_matrix = p = np.array([ [[ 0.96341463, 0.0304878 , 0.00609756, 0. , 0. ], [ 0.06040268, 0.83221477, 0.10738255, 0. , 0. ], @@ -61,7 +61,7 @@ class SpaceTimeTests(unittest.TestCase): # def test_spatial_markov(self): # """Test Spatial Markov.""" - # + # # ans = "SELECT i.\"cartodb_id\" As id, " \ # "i.\"dec_2013\"::numeric As attr1, " \ # "i.\"jan_2014\"::numeric As attr2, " \ @@ -80,7 +80,7 @@ class SpaceTimeTests(unittest.TestCase): # "i.\"jan_2014\" IS NOT NULL AND " \ # "i.\"feb_2014\" IS NOT NULL " \ # "ORDER BY i.\"cartodb_id\" ASC;" - # + # # subquery = self.params['subquery'] # time_cols = self.params['time_cols'] # num_time_per_bin = 1 @@ -89,23 +89,23 @@ class SpaceTimeTests(unittest.TestCase): # id_col = self.params['id_col'] # w_type = 'knn' # num_ngbrs = self.params['num_ngbrs'] - # + # # self.assertEqual(std.spatial_markov(subquery, time_cols, num_time_per_bin, permutations, geom_col, id_col, w_type, num_ngbrs), ans) - + def test_rebin_data(self): """Test rebin_data""" - ## sample in double the time (even case since 10 % 2 = 0): - ## (0+1)/2, (2+3)/2, (4+5)/2, (6+7)/2, (8+9)/2 + ## sample in double the time (even case since 10 % 2 = 0): + ## (0+1)/2, (2+3)/2, (4+5)/2, (6+7)/2, (8+9)/2 ## = 0.5, 2.5, 4.5, 6.5, 8.5 - ans_even = np.array([(i + 0.5) * np.ones(10, dtype=float) + ans_even = np.array([(i + 0.5) * np.ones(10, dtype=float) for i in range(0, 10, 2)]).T - + self.assertTrue(np.array_equal(std.rebin_data(self.time_data, 2), ans_even)) ## sample in triple the time (uneven since 10 % 3 = 1): - ## (0+1+2)/3, (3+4+5)/3, (6+7+8)/3, (9)/1 + ## (0+1+2)/3, (3+4+5)/3, (6+7+8)/3, (9)/1 ## = 1, 4, 7, 9 - ans_odd = np.array([i * np.ones(10, dtype=float) + ans_odd = np.array([i * np.ones(10, dtype=float) for i in (1, 4, 7, 9)]).T self.assertTrue(np.array_equal(std.rebin_data(self.time_data, 3), ans_odd)) def test_get_prob_dist(self): @@ -119,8 +119,31 @@ class SpaceTimeTests(unittest.TestCase): [ 0. , 0. , 0. , 0.02352941, 0.97647059] ]) result = std.get_prob_dist(self.transition_matrix, lag_indices, unit_indices) - + self.assertTrue(np.array_equal(result, answer)) - - - + + def test_get_prob_stats(self): + """Test get_prob_stats""" + + probs = np.array([ + [ 0.0754717 , 0.88207547, 0.04245283, 0. , 0. ], + [ 0. , 0. , 0.09411765, 0.87058824, 0.03529412], + [ 0.0049505 , 0.09405941, 0.77722772, 0.11881188, 0.0049505 ], + [ 0. , 0. , 0. , 0.02352941, 0.97647059] + ]) + unit_indices = np.array([1, 3, 2, 4]) + answer_up = np.array([0.04245283, 0.03529412, 0.12376238, 0.]) + answer_down = np.array([0.0754717, 0.09411765, 0.0990099, 0.02352941]) + answer_trend = np.array([-0.03301887 / 0.88207547, -0.05882353 / 0.87058824, 0.02475248 / 0.77722772, -0.02352941 / 0.97647059]) + answer_volatility = np.array([ 0.34221495, 0.33705421, 0.29226542, 0.38834223]) + + result = std.get_prob_stats(probs, unit_indices) + result_up = result[0] + result_down = result[1] + result_trend = result[2] + result_volatility = result[3] + + self.assertTrue(np.allclose(result_up, answer_up)) + self.assertTrue(np.allclose(result_down, answer_down)) + self.assertTrue(np.allclose(result_trend, answer_trend)) + self.assertTrue(np.allclose(result_volatility, answer_volatility))