From 224fbc2fc5f6ebe1cc671da2635dbe2d73b43288 Mon Sep 17 00:00:00 2001 From: Andy Eschbacher Date: Sat, 19 Nov 2016 09:05:35 +0000 Subject: [PATCH] move to class based markov --- src/pg/sql/11_markov.sql | 5 +- .../crankshaft/space_time_dynamics/markov.py | 177 ++++--- .../test/test_space_time_dynamics.py | 501 ++++++++++-------- 3 files changed, 366 insertions(+), 317 deletions(-) diff --git a/src/pg/sql/11_markov.sql b/src/pg/sql/11_markov.sql index 1124abd..da02c66 100644 --- a/src/pg/sql/11_markov.sql +++ b/src/pg/sql/11_markov.sql @@ -22,10 +22,11 @@ CREATE OR REPLACE FUNCTION RETURNS TABLE (trend NUMERIC, trend_up NUMERIC, trend_down NUMERIC, volatility NUMERIC, rowid INT) AS $$ - from crankshaft.space_time_dynamics import spatial_markov_trend + from crankshaft.space_time_dynamics import Markov + markov = Markov() ## TODO: use named parameters or a dictionary - return spatial_markov_trend(subquery, time_cols, num_classes, w_type, num_ngbrs, permutations, geom_col, id_col) + return markov.spatial_trend(subquery, time_cols, num_classes, w_type, num_ngbrs, permutations, geom_col, id_col) $$ LANGUAGE plpythonu; -- input table format: identical to above but in a predictable format diff --git a/src/py/crankshaft/crankshaft/space_time_dynamics/markov.py b/src/py/crankshaft/crankshaft/space_time_dynamics/markov.py index ae788d7..7984e0c 100644 --- a/src/py/crankshaft/crankshaft/space_time_dynamics/markov.py +++ b/src/py/crankshaft/crankshaft/space_time_dynamics/markov.py @@ -8,92 +8,104 @@ import pysal as ps import plpy import crankshaft.pysal_utils as pu -def spatial_markov_trend(subquery, time_cols, num_classes=7, - w_type='knn', num_ngbrs=5, permutations=0, - geom_col='the_geom', id_col='cartodb_id'): - """ - Predict the trends of a unit based on: - 1. history of its transitions to different classes (e.g., 1st quantile -> 2nd quantile) - 2. average class of its neighbors - Inputs: - @param subquery string: e.g., SELECT the_geom, cartodb_id, - interesting_time_column FROM table_name - @param time_cols list of strings: list of strings of column names - @param num_classes (optional): number of classes to break distribution - of values into. Currently uses quantile bins. - @param w_type string (optional): weight type ('knn' or 'queen') - @param num_ngbrs int (optional): number of neighbors (if knn type) - @param permutations int (optional): number of permutations for test - stats - @param geom_col string (optional): name of column which contains the - geometries - @param id_col string (optional): name of column which has the ids of - the table +class QueryRunner: + def get_result(self, query): + try: + data = plpy.execute(query) - Outputs: - @param trend_up float: probablity that a geom will move to a higher - class - @param trend_down float: probablity that a geom will move to a lower - class - @param trend float: (trend_up - trend_down) / trend_static - @param volatility float: a measure of the volatility based on - probability stddev(prob array) - """ + if len(data) == 0: + return zip([None], [None], [None], [None], [None]) - if len(time_cols) < 2: - plpy.error('More than one time column needs to be passed') + return data + except plpy.SPIError, err: + plpy.error('Analysis failed: %s' % err) - qvals = {"id_col": id_col, - "time_cols": time_cols, - "geom_col": geom_col, - "subquery": subquery, - "num_ngbrs": num_ngbrs} - try: - query_result = plpy.execute( - pu.construct_neighbor_query(w_type, qvals) - ) - if len(query_result) == 0: - return zip([None], [None], [None], [None], [None]) - except plpy.SPIError, e: - plpy.debug('Query failed with exception %s: %s' % (err, pu.construct_neighbor_query(w_type, qvals))) - plpy.error('Analysis failed: %s' % e) - return zip([None], [None], [None], [None], [None]) +class Markov: + def __init__(self, query_runner=None): + if query_runner is None: + self.query_runner = QueryRunner() + else: + self.query_runner = query_runner - ## build weight - weights = pu.get_weight(query_result, w_type) - weights.transform = 'r' + def spatial_trend(self, subquery, time_cols, num_classes=7, + w_type='knn', num_ngbrs=5, permutations=0, + geom_col='the_geom', id_col='cartodb_id'): + """ + Predict the trends of a unit based on: + 1. history of its transitions to different classes (e.g., 1st + quantile -> 2nd quantile) + 2. average class of its neighbors - ## prep time data - t_data = get_time_data(query_result, time_cols) + Inputs: + @param subquery string: e.g., SELECT the_geom, cartodb_id, + interesting_time_column FROM table_name + @param time_cols list of strings: list of strings of column names + @param num_classes (optional): number of classes to break + distribution of values into. Currently uses quantile bins. + @param w_type string (optional): weight type ('knn' or 'queen') + @param num_ngbrs int (optional): number of neighbors (if knn type) + @param permutations int (optional): number of permutations for test + stats + @param geom_col string (optional): name of column which contains + the geometries + @param id_col string (optional): name of column which has the ids + of the table - plpy.debug('shape of t_data %d, %d' % t_data.shape) - plpy.debug('number of weight objects: %d, %d' % (weights.sparse).shape) - plpy.debug('first num elements: %f' % t_data[0, 0]) + Outputs: + @param trend_up float: probablity that a geom will move to a higher + class + @param trend_down float: probablity that a geom will move to a + lower class + @param trend float: (trend_up - trend_down) / trend_static + @param volatility float: a measure of the volatility based on + probability stddev(prob array) + """ - sp_markov_result = ps.Spatial_Markov(t_data, - weights, - k=num_classes, - fixed=False, - permutations=permutations) + if len(time_cols) < 2: + plpy.error('More than one time column needs to be passed') - ## get lag classes - lag_classes = ps.Quantiles( - ps.lag_spatial(weights, t_data[:, -1]), - k=num_classes).yb + qvals = {"id_col": id_col, + "time_cols": time_cols, + "geom_col": geom_col, + "subquery": subquery, + "num_ngbrs": num_ngbrs} - ## look up probablity distribution for each unit according to class and lag class - prob_dist = get_prob_dist(sp_markov_result.P, - lag_classes, - sp_markov_result.classes[:, -1]) + query = pu.construct_neighbor_query(w_type, qvals) - ## find the ups and down and overall distribution of each cell - trend_up, trend_down, trend, volatility = get_prob_stats(prob_dist, - sp_markov_result.classes[:, -1]) + query_result = self.query_runner.get_result(query) + + # build weight + weights = pu.get_weight(query_result, w_type) + weights.transform = 'r' + + # prep time data + t_data = get_time_data(query_result, time_cols) + + sp_markov_result = ps.Spatial_Markov(t_data, + weights, + k=num_classes, + fixed=False, + permutations=permutations) + + # get lag classes + lag_classes = ps.Quantiles( + ps.lag_spatial(weights, t_data[:, -1]), + k=num_classes).yb + + # look up probablity distribution for each unit according to class and + # lag class + prob_dist = get_prob_dist(sp_markov_result.P, + lag_classes, + sp_markov_result.classes[:, -1]) + + # find the ups and down and overall distribution of each cell + trend_up, trend_down, trend, volatility = get_prob_stats(prob_dist, sp_markov_result.classes[:, -1]) + + # output the results + return zip(trend, trend_up, trend_down, volatility, weights.id_order) - ## output the results - return zip(trend, trend_up, trend_down, volatility, weights.id_order) def get_time_data(markov_data, time_cols): """ @@ -103,7 +115,8 @@ def get_time_data(markov_data, time_cols): return np.array([[x['attr' + str(i)] for x in markov_data] for i in range(1, num_attrs+1)], dtype=float).transpose() -## not currently used + +# not currently used def rebin_data(time_data, num_time_per_bin): """ Convert an n x l matrix into an (n/m) x l matrix where the values are @@ -131,14 +144,16 @@ def rebin_data(time_data, num_time_per_bin): """ if time_data.shape[1] % num_time_per_bin == 0: - ## if fit is perfect, then use it + # if fit is perfect, then use it n_max = time_data.shape[1] / num_time_per_bin else: - ## fit remainders into an additional column + # fit remainders into an additional column n_max = time_data.shape[1] / num_time_per_bin + 1 - return np.array([time_data[:, num_time_per_bin * i:num_time_per_bin * (i+1)].mean(axis=1) - for i in range(n_max)]).T + return np.array( + [time_data[:, num_time_per_bin * i:num_time_per_bin * (i+1)].mean(axis=1) + for i in range(n_max)]).T + def get_prob_dist(transition_matrix, lag_indices, unit_indices): """ @@ -157,6 +172,7 @@ def get_prob_dist(transition_matrix, lag_indices, unit_indices): return np.array([transition_matrix[(lag_indices[i], unit_indices[i])] for i in range(len(lag_indices))]) + def get_prob_stats(prob_dist, unit_indices): """ get the statistics of the probability distributions @@ -179,11 +195,12 @@ def get_prob_stats(prob_dist, unit_indices): trend_up[i] = prob_dist[i, (unit_indices[i]+1):].sum() trend_down[i] = prob_dist[i, :unit_indices[i]].sum() if prob_dist[i, unit_indices[i]] > 0.0: - trend[i] = (trend_up[i] - trend_down[i]) / prob_dist[i, unit_indices[i]] + trend[i] = (trend_up[i] - trend_down[i]) / ( + prob_dist[i, unit_indices[i]]) else: trend[i] = None - ## calculate volatility of distribution + # calculate volatility of distribution volatility = prob_dist.std(axis=1) return trend_up, trend_down, trend, volatility diff --git a/src/py/crankshaft/test/test_space_time_dynamics.py b/src/py/crankshaft/test/test_space_time_dynamics.py index 54ffc9d..21f3afc 100644 --- a/src/py/crankshaft/test/test_space_time_dynamics.py +++ b/src/py/crankshaft/test/test_space_time_dynamics.py @@ -9,81 +9,100 @@ import unittest # # import sys # sys.modules['plpy'] = plpy -from helper import plpy, fixture_file +from helper import fixture_file +from crankshaft.space_time_dynamics import Markov import crankshaft.space_time_dynamics as std from crankshaft import random_seeds +from crankshaft.clustering import QueryRunner import json + +class FakeQueryRunner(QueryRunner): + def __init__(self, data): + self.mock_result = data + + def get_result(self, query): + return self.mock_result + + class SpaceTimeTests(unittest.TestCase): """Testing class for Markov Functions.""" def setUp(self): - plpy._reset() + # plpy._reset() self.params = {"id_col": "cartodb_id", "time_cols": ['dec_2013', 'jan_2014', 'feb_2014'], "subquery": "SELECT * FROM a_list", "geom_col": "the_geom", "num_ngbrs": 321} - self.neighbors_data = json.loads(open(fixture_file('neighbors_markov.json')).read()) + self.neighbors_data = json.loads( + open(fixture_file('neighbors_markov.json')).read()) self.markov_data = json.loads(open(fixture_file('markov.json')).read()) - self.time_data = np.array([i * np.ones(10, dtype=float) for i in range(10)]).T + self.time_data = np.array([i * np.ones(10, dtype=float) + for i in range(10)]).T self.transition_matrix = np.array([ - [[ 0.96341463, 0.0304878 , 0.00609756, 0. , 0. ], - [ 0.06040268, 0.83221477, 0.10738255, 0. , 0. ], - [ 0. , 0.14 , 0.74 , 0.12 , 0. ], - [ 0. , 0.03571429, 0.32142857, 0.57142857, 0.07142857], - [ 0. , 0. , 0. , 0.16666667, 0.83333333]], - [[ 0.79831933, 0.16806723, 0.03361345, 0. , 0. ], - [ 0.0754717 , 0.88207547, 0.04245283, 0. , 0. ], - [ 0.00537634, 0.06989247, 0.8655914 , 0.05913978, 0. ], - [ 0. , 0. , 0.06372549, 0.90196078, 0.03431373], - [ 0. , 0. , 0. , 0.19444444, 0.80555556]], - [[ 0.84693878, 0.15306122, 0. , 0. , 0. ], - [ 0.08133971, 0.78947368, 0.1291866 , 0. , 0. ], - [ 0.00518135, 0.0984456 , 0.79274611, 0.0984456 , 0.00518135], - [ 0. , 0. , 0.09411765, 0.87058824, 0.03529412], - [ 0. , 0. , 0. , 0.10204082, 0.89795918]], - [[ 0.8852459 , 0.09836066, 0. , 0.01639344, 0. ], - [ 0.03875969, 0.81395349, 0.13953488, 0. , 0.00775194], - [ 0.0049505 , 0.09405941, 0.77722772, 0.11881188, 0.0049505 ], - [ 0. , 0.02339181, 0.12865497, 0.75438596, 0.09356725], - [ 0. , 0. , 0. , 0.09661836, 0.90338164]], - [[ 0.33333333, 0.66666667, 0. , 0. , 0. ], - [ 0.0483871 , 0.77419355, 0.16129032, 0.01612903, 0. ], - [ 0.01149425, 0.16091954, 0.74712644, 0.08045977, 0. ], - [ 0. , 0.01036269, 0.06217617, 0.89637306, 0.03108808], - [ 0. , 0. , 0. , 0.02352941, 0.97647059]]] + [[0.96341463, 0.0304878, 0.00609756, 0., 0.], + [0.06040268, 0.83221477, 0.10738255, 0., 0.], + [0., 0.14, 0.74, 0.12, 0.], + [0., 0.03571429, 0.32142857, 0.57142857, 0.07142857], + [0., 0., 0., 0.16666667, 0.83333333]], + [[0.79831933, 0.16806723, 0.03361345, 0., 0.], + [0.0754717, 0.88207547, 0.04245283, 0., 0.], + [0.00537634, 0.06989247, 0.8655914, 0.05913978, 0.], + [0., 0., 0.06372549, 0.90196078, 0.03431373], + [0., 0., 0., 0.19444444, 0.80555556]], + [[0.84693878, 0.15306122, 0., 0., 0.], + [0.08133971, 0.78947368, 0.1291866, 0., 0.], + [0.00518135, 0.0984456, 0.79274611, 0.0984456, 0.00518135], + [0., 0., 0.09411765, 0.87058824, 0.03529412], + [0., 0., 0., 0.10204082, 0.89795918]], + [[0.8852459, 0.09836066, 0., 0.01639344, 0.], + [0.03875969, 0.81395349, 0.13953488, 0., 0.00775194], + [0.0049505, 0.09405941, 0.77722772, 0.11881188, 0.0049505], + [0., 0.02339181, 0.12865497, 0.75438596, 0.09356725], + [0., 0., 0., 0.09661836, 0.90338164]], + [[0.33333333, 0.66666667, 0., 0., 0.], + [0.0483871, 0.77419355, 0.16129032, 0.01612903, 0.], + [0.01149425, 0.16091954, 0.74712644, 0.08045977, 0.], + [0., 0.01036269, 0.06217617, 0.89637306, 0.03108808], + [0., 0., 0., 0.02352941, 0.97647059]]] ) def test_spatial_markov(self): """Test Spatial Markov.""" - data = [ { 'id': d['id'], - 'attr1': d['y1995'], - 'attr2': d['y1996'], - 'attr3': d['y1997'], - 'attr4': d['y1998'], - 'attr5': d['y1999'], - 'attr6': d['y2000'], - 'attr7': d['y2001'], - 'attr8': d['y2002'], - 'attr9': d['y2003'], - 'attr10': d['y2004'], - 'attr11': d['y2005'], - 'attr12': d['y2006'], - 'attr13': d['y2007'], - 'attr14': d['y2008'], - 'attr15': d['y2009'], - 'neighbors': d['neighbors'] } for d in self.neighbors_data] - print(str(data[0])) - plpy._define_result('select', data) + data = [{'id': d['id'], + 'attr1': d['y1995'], + 'attr2': d['y1996'], + 'attr3': d['y1997'], + 'attr4': d['y1998'], + 'attr5': d['y1999'], + 'attr6': d['y2000'], + 'attr7': d['y2001'], + 'attr8': d['y2002'], + 'attr9': d['y2003'], + 'attr10': d['y2004'], + 'attr11': d['y2005'], + 'attr12': d['y2006'], + 'attr13': d['y2007'], + 'attr14': d['y2008'], + 'attr15': d['y2009'], + 'neighbors': d['neighbors']} for d in self.neighbors_data] + # print(str(data[0])) + markov = Markov(FakeQueryRunner(data)) random_seeds.set_random_seeds(1234) - result = std.spatial_markov_trend('subquery', ['y1995', 'y1996', 'y1997', 'y1998', 'y1999', 'y2000', 'y2001', 'y2002', 'y2003', 'y2004', 'y2005', 'y2006', 'y2007', 'y2008', 'y2009'], 5, 'knn', 5, 0, 'the_geom', 'cartodb_id') + result = markov.spatial_trend('subquery', + ['y1995', 'y1996', 'y1997', 'y1998', + 'y1999', 'y2000', 'y2001', 'y2002', + 'y2003', 'y2004', 'y2005', 'y2006', + 'y2007', 'y2008', 'y2009'], + 5, 'knn', 5, 0, 'the_geom', + 'cartodb_id') - self.assertTrue(result != None) + self.assertTrue(result is not None) result = [(row[0], row[1], row[2], row[3], row[4]) for row in result] print result[0] expected = self.markov_data @@ -94,173 +113,178 @@ class SpaceTimeTests(unittest.TestCase): def test_get_time_data(self): """Test get_time_data""" - data = [ { 'attr1': d['y1995'], - 'attr2': d['y1996'], - 'attr3': d['y1997'], - 'attr4': d['y1998'], - 'attr5': d['y1999'], - 'attr6': d['y2000'], - 'attr7': d['y2001'], - 'attr8': d['y2002'], - 'attr9': d['y2003'], - 'attr10': d['y2004'], - 'attr11': d['y2005'], - 'attr12': d['y2006'], - 'attr13': d['y2007'], - 'attr14': d['y2008'], - 'attr15': d['y2009'] } for d in self.neighbors_data] + data = [{'attr1': d['y1995'], + 'attr2': d['y1996'], + 'attr3': d['y1997'], + 'attr4': d['y1998'], + 'attr5': d['y1999'], + 'attr6': d['y2000'], + 'attr7': d['y2001'], + 'attr8': d['y2002'], + 'attr9': d['y2003'], + 'attr10': d['y2004'], + 'attr11': d['y2005'], + 'attr12': d['y2006'], + 'attr13': d['y2007'], + 'attr14': d['y2008'], + 'attr15': d['y2009']} for d in self.neighbors_data] - result = std.get_time_data(data, ['y1995', 'y1996', 'y1997', 'y1998', 'y1999', 'y2000', 'y2001', 'y2002', 'y2003', 'y2004', 'y2005', 'y2006', 'y2007', 'y2008', 'y2009']) + result = std.get_time_data(data, ['y1995', 'y1996', 'y1997', 'y1998', + 'y1999', 'y2000', 'y2001', 'y2002', + 'y2003', 'y2004', 'y2005', 'y2006', + 'y2007', 'y2008', 'y2009']) - ## expected was prepared from PySAL example: - ### f = ps.open(ps.examples.get_path("usjoin.csv")) - ### pci = np.array([f.by_col[str(y)] for y in range(1995, 2010)]).transpose() - ### rpci = pci / (pci.mean(axis = 0)) + # expected was prepared from PySAL example: + # f = ps.open(ps.examples.get_path("usjoin.csv")) + # pci = np.array([f.by_col[str(y)] + # for y in range(1995, 2010)]).transpose() + # rpci = pci / (pci.mean(axis = 0)) - expected = np.array([[ 0.87654416, 0.863147, 0.85637567, 0.84811668, 0.8446154, 0.83271652 - , 0.83786314, 0.85012593, 0.85509656, 0.86416612, 0.87119375, 0.86302631 - , 0.86148267, 0.86252252, 0.86746356], - [ 0.9188951, 0.91757931, 0.92333258, 0.92517289, 0.92552388, 0.90746978 - , 0.89830489, 0.89431991, 0.88924794, 0.89815176, 0.91832091, 0.91706054 - , 0.90139505, 0.87897455, 0.86216858], - [ 0.82591007, 0.82548596, 0.81989793, 0.81503235, 0.81731522, 0.78964559 - , 0.80584442, 0.8084998, 0.82258551, 0.82668196, 0.82373724, 0.81814804 - , 0.83675961, 0.83574199, 0.84647177], - [ 1.09088176, 1.08537689, 1.08456418, 1.08415404, 1.09898841, 1.14506948 - , 1.12151133, 1.11160697, 1.10888621, 1.11399806, 1.12168029, 1.13164797 - , 1.12958508, 1.11371818, 1.09936775], - [ 1.10731446, 1.11373944, 1.13283638, 1.14472559, 1.15910025, 1.16898201 - , 1.17212488, 1.14752303, 1.11843284, 1.11024964, 1.11943471, 1.11736468 - , 1.10863242, 1.09642516, 1.07762337], - [ 1.42269757, 1.42118434, 1.44273502, 1.43577571, 1.44400684, 1.44184737 - , 1.44782832, 1.41978227, 1.39092208, 1.4059372, 1.40788646, 1.44052766 - , 1.45241216, 1.43306098, 1.4174431 ], - [ 1.13073885, 1.13110513, 1.11074708, 1.13364636, 1.13088149, 1.10888138 - , 1.11856629, 1.13062931, 1.11944984, 1.12446239, 1.11671008, 1.10880034 - , 1.08401709, 1.06959206, 1.07875225], - [ 1.04706124, 1.04516831, 1.04253372, 1.03239987, 1.02072545, 0.99854316 - , 0.9880258, 0.99669587, 0.99327676, 1.01400905, 1.03176742, 1.040511 - , 1.01749645, 0.9936394, 0.98279746], - [ 0.98996986, 1.00143564, 0.99491, 1.00188408, 1.00455845, 0.99127006 - , 0.97925917, 0.9683482, 0.95335147, 0.93694787, 0.94308213, 0.92232874 - , 0.91284091, 0.89689833, 0.88928858], - [ 0.87418391, 0.86416601, 0.84425695, 0.8404494, 0.83903044, 0.8578708 - , 0.86036185, 0.86107306, 0.8500772, 0.86981998, 0.86837929, 0.87204141 - , 0.86633032, 0.84946077, 0.83287146], - [ 1.14196118, 1.14660262, 1.14892712, 1.14909594, 1.14436624, 1.14450183 - , 1.12349752, 1.12596664, 1.12213996, 1.1119989, 1.10257792, 1.10491258 - , 1.11059842, 1.10509795, 1.10020097], - [ 0.97282463, 0.96700147, 0.96252588, 0.9653878, 0.96057687, 0.95831051 - , 0.94480909, 0.94804195, 0.95430286, 0.94103989, 0.92122519, 0.91010201 - , 0.89280392, 0.89298243, 0.89165385], - [ 0.94325468, 0.96436902, 0.96455242, 0.95243009, 0.94117647, 0.9480927 - , 0.93539182, 0.95388718, 0.94597005, 0.96918424, 0.94781281, 0.93466815 - , 0.94281559, 0.96520315, 0.96715441], - [ 0.97478408, 0.98169225, 0.98712809, 0.98474769, 0.98559897, 0.98687073 - , 0.99237486, 0.98209969, 0.9877653, 0.97399471, 0.96910087, 0.98416665 - , 0.98423613, 0.99823861, 0.99545704], - [ 0.85570269, 0.85575915, 0.85986132, 0.85693406, 0.8538012, 0.86191535 - , 0.84981451, 0.85472102, 0.84564835, 0.83998883, 0.83478547, 0.82803648 - , 0.8198736, 0.82265395, 0.8399404 ], - [ 0.87022047, 0.85996258, 0.85961813, 0.85689572, 0.83947136, 0.82785597 - , 0.86008789, 0.86776298, 0.86720209, 0.8676334, 0.89179317, 0.94202108 - , 0.9422231, 0.93902708, 0.94479184], - [ 0.90134907, 0.90407738, 0.90403991, 0.90201769, 0.90399238, 0.90906632 - , 0.92693339, 0.93695966, 0.94242697, 0.94338265, 0.91981796, 0.91108804 - , 0.90543476, 0.91737138, 0.94793657], - [ 1.1977611, 1.18222564, 1.18439158, 1.18267865, 1.19286723, 1.20172869 - , 1.21328691, 1.22624778, 1.22397075, 1.23857042, 1.24419893, 1.23929384 - , 1.23418676, 1.23626739, 1.26754398], - [ 1.24919678, 1.25754773, 1.26991161, 1.28020651, 1.30625667, 1.34790023 - , 1.34399863, 1.32575181, 1.30795492, 1.30544841, 1.30303302, 1.32107766 - , 1.32936244, 1.33001241, 1.33288462], - [ 1.06768004, 1.03799276, 1.03637303, 1.02768449, 1.03296093, 1.05059016 - , 1.03405057, 1.02747623, 1.03162734, 0.9961416, 0.97356208, 0.94241549 - , 0.92754547, 0.92549227, 0.92138102], - [ 1.09475614, 1.11526796, 1.11654299, 1.13103948, 1.13143264, 1.13889622 - , 1.12442212, 1.13367018, 1.13982256, 1.14029944, 1.11979401, 1.10905389 - , 1.10577769, 1.11166825, 1.09985155], - [ 0.76530058, 0.76612841, 0.76542451, 0.76722683, 0.76014284, 0.74480073 - , 0.76098396, 0.76156903, 0.76651952, 0.76533288, 0.78205934, 0.76842416 - , 0.77487118, 0.77768683, 0.78801192], - [ 0.98391336, 0.98075816, 0.98295341, 0.97386015, 0.96913803, 0.97370819 - , 0.96419154, 0.97209861, 0.97441313, 0.96356162, 0.94745352, 0.93965462 - , 0.93069645, 0.94020973, 0.94358232], - [ 0.83561828, 0.82298088, 0.81738502, 0.81748588, 0.80904801, 0.80071489 - , 0.83358256, 0.83451613, 0.85175032, 0.85954307, 0.86790024, 0.87170334 - , 0.87863799, 0.87497981, 0.87888675], - [ 0.98845573, 1.02092428, 0.99665283, 0.99141823, 0.99386619, 0.98733195 - , 0.99644997, 0.99669587, 1.02559097, 1.01116651, 0.99988024, 0.97906749 - , 0.99323123, 1.00204939, 0.99602148], - [ 1.14930913, 1.15241949, 1.14300962, 1.14265542, 1.13984683, 1.08312397 - , 1.05192626, 1.04230892, 1.05577278, 1.08569751, 1.12443486, 1.08891079 - , 1.08603695, 1.05997314, 1.02160943], - [ 1.11368269, 1.1057147, 1.11893431, 1.13778669, 1.1432272, 1.18257029 - , 1.16226243, 1.16009196, 1.14467789, 1.14820235, 1.12386598, 1.12680236 - , 1.12357937, 1.1159258, 1.12570828], - [ 1.30379431, 1.30752186, 1.31206366, 1.31532267, 1.30625667, 1.31210239 - , 1.29989156, 1.29203193, 1.27183516, 1.26830786, 1.2617743, 1.28656675 - , 1.29734097, 1.29390205, 1.29345446], - [ 0.83953719, 0.82701448, 0.82006005, 0.81188876, 0.80294864, 0.78772975 - , 0.82848011, 0.8259679, 0.82435705, 0.83108634, 0.84373784, 0.83891093 - , 0.84349247, 0.85637272, 0.86539395], - [ 1.23450087, 1.2426022, 1.23537935, 1.23581293, 1.24522626, 1.2256767 - , 1.21126648, 1.19377804, 1.18355337, 1.19674434, 1.21536573, 1.23653297 - , 1.27962009, 1.27968392, 1.25907738], - [ 0.9769662, 0.97400719, 0.98035944, 0.97581531, 0.95543282, 0.96480308 - , 0.94686376, 0.93679073, 0.92540049, 0.92988835, 0.93442917, 0.92100464 - , 0.91475304, 0.90249622, 0.9021363 ], - [ 0.84986886, 0.8986851, 0.84295997, 0.87280534, 0.85659368, 0.88937573 - , 0.894401, 0.90448993, 0.95495898, 0.92698333, 0.94745352, 0.92562488 - , 0.96635366, 1.02520312, 1.0394296 ], - [ 1.01922808, 1.00258203, 1.00974428, 1.00303417, 0.99765073, 1.00759019 - , 0.99192968, 0.99747298, 0.99550759, 0.97583768, 0.9610168, 0.94779638 - , 0.93759089, 0.93353431, 0.94121705], - [ 0.86367411, 0.85558932, 0.85544346, 0.85103025, 0.84336613, 0.83434854 - , 0.85813595, 0.84667961, 0.84374558, 0.85951183, 0.87194227, 0.89455097 - , 0.88283929, 0.90349491, 0.90600675], - [ 1.00947534, 1.00411055, 1.00698819, 0.99513687, 0.99291086, 1.00581626 - , 0.98850522, 0.99291168, 0.98983209, 0.97511924, 0.96134615, 0.96382634 - , 0.95011401, 0.9434686, 0.94637765], - [ 1.05712571, 1.05459419, 1.05753012, 1.04880786, 1.05103857, 1.04800023 - , 1.03024941, 1.04200483, 1.0402554, 1.03296979, 1.02191682, 1.02476275 - , 1.02347523, 1.02517684, 1.04359571], - [ 1.07084189, 1.06669497, 1.07937623, 1.07387988, 1.0794043, 1.0531801 - , 1.07452771, 1.09383478, 1.1052447, 1.10322136, 1.09167939, 1.08772756 - , 1.08859544, 1.09177338, 1.1096083 ], - [ 0.86719222, 0.86628896, 0.86675156, 0.86425632, 0.86511809, 0.86287327 - , 0.85169796, 0.85411285, 0.84886336, 0.84517414, 0.84843858, 0.84488343 - , 0.83374329, 0.82812044, 0.82878599], - [ 0.88389211, 0.92288667, 0.90282398, 0.91229186, 0.92023286, 0.92652175 - , 0.94278865, 0.93682452, 0.98655146, 0.992237, 0.9798497, 0.93869677 - , 0.96947771, 1.00362626, 0.98102351], - [ 0.97082064, 0.95320233, 0.94534081, 0.94215593, 0.93967, 0.93092109 - , 0.92662519, 0.93412152, 0.93501274, 0.92879506, 0.92110542, 0.91035556 - , 0.90430364, 0.89994694, 0.90073864], - [ 0.95861858, 0.95774543, 0.98254811, 0.98919472, 0.98684824, 0.98882205 - , 0.97662234, 0.95601578, 0.94905385, 0.94934888, 0.97152609, 0.97163004 - , 0.9700702, 0.97158948, 0.95884908], - [ 0.83980439, 0.84726737, 0.85747, 0.85467221, 0.8556751, 0.84818516 - , 0.85265681, 0.84502402, 0.82645665, 0.81743586, 0.83550406, 0.83338919 - , 0.83511679, 0.82136617, 0.80921874], - [ 0.95118156, 0.9466212, 0.94688098, 0.9508583, 0.9512441, 0.95440787 - , 0.96364363, 0.96804412, 0.97136214, 0.97583768, 0.95571724, 0.96895368 - , 0.97001634, 0.97082733, 0.98782366], - [ 1.08910044, 1.08248968, 1.08492895, 1.08656923, 1.09454249, 1.10558188 - , 1.1214086, 1.12292577, 1.13021031, 1.13342735, 1.14686068, 1.14502975 - , 1.14474747, 1.14084037, 1.16142926], - [ 1.06336033, 1.07365823, 1.08691496, 1.09764846, 1.11669863, 1.11856702 - , 1.09764283, 1.08815849, 1.08044313, 1.09278827, 1.07003204, 1.08398066 - , 1.09831768, 1.09298232, 1.09176125], - [ 0.79772065, 0.78829196, 0.78581151, 0.77615922, 0.77035744, 0.77751194 - , 0.79902974, 0.81437881, 0.80788828, 0.79603865, 0.78966436, 0.79949807 - , 0.80172182, 0.82168155, 0.85587911], - [ 1.0052447, 1.00007696, 1.00475899, 1.00613942, 1.00639561, 1.00162979 - , 0.99860739, 1.00814981, 1.00574316, 0.99030032, 0.97682565, 0.97292596 - , 0.96519561, 0.96173403, 0.95890284], - [ 0.95808419, 0.9382568, 0.9654441, 0.95561201, 0.96987289, 0.96608031 - , 0.99727185, 1.00781194, 1.03484236, 1.05333619, 1.0983263, 1.1704974 - , 1.17025154, 1.18730553, 1.14242645]]) + expected = np.array( + [[0.87654416, 0.863147, 0.85637567, 0.84811668, 0.8446154, + 0.83271652, 0.83786314, 0.85012593, 0.85509656, 0.86416612, + 0.87119375, 0.86302631, 0.86148267, 0.86252252, 0.86746356], + [0.9188951, 0.91757931, 0.92333258, 0.92517289, 0.92552388, + 0.90746978, 0.89830489, 0.89431991, 0.88924794, 0.89815176, + 0.91832091, 0.91706054, 0.90139505, 0.87897455, 0.86216858], + [0.82591007, 0.82548596, 0.81989793, 0.81503235, 0.81731522, + 0.78964559, 0.80584442, 0.8084998, 0.82258551, 0.82668196, + 0.82373724, 0.81814804, 0.83675961, 0.83574199, 0.84647177], + [1.09088176, 1.08537689, 1.08456418, 1.08415404, 1.09898841, + 1.14506948, 1.12151133, 1.11160697, 1.10888621, 1.11399806, + 1.12168029, 1.13164797, 1.12958508, 1.11371818, 1.09936775], + [1.10731446, 1.11373944, 1.13283638, 1.14472559, 1.15910025, + 1.16898201, 1.17212488, 1.14752303, 1.11843284, 1.11024964, + 1.11943471, 1.11736468, 1.10863242, 1.09642516, 1.07762337], + [1.42269757, 1.42118434, 1.44273502, 1.43577571, 1.44400684, + 1.44184737, 1.44782832, 1.41978227, 1.39092208, 1.4059372, + 1.40788646, 1.44052766, 1.45241216, 1.43306098, 1.4174431], + [1.13073885, 1.13110513, 1.11074708, 1.13364636, 1.13088149, + 1.10888138, 1.11856629, 1.13062931, 1.11944984, 1.12446239, + 1.11671008, 1.10880034, 1.08401709, 1.06959206, 1.07875225], + [1.04706124, 1.04516831, 1.04253372, 1.03239987, 1.02072545, + 0.99854316, 0.9880258, 0.99669587, 0.99327676, 1.01400905, + 1.03176742, 1.040511, 1.01749645, 0.9936394, 0.98279746], + [0.98996986, 1.00143564, 0.99491, 1.00188408, 1.00455845, + 0.99127006, 0.97925917, 0.9683482, 0.95335147, 0.93694787, + 0.94308213, 0.92232874, 0.91284091, 0.89689833, 0.88928858], + [0.87418391, 0.86416601, 0.84425695, 0.8404494, 0.83903044, + 0.8578708, 0.86036185, 0.86107306, 0.8500772, 0.86981998, + 0.86837929, 0.87204141, 0.86633032, 0.84946077, 0.83287146], + [1.14196118, 1.14660262, 1.14892712, 1.14909594, 1.14436624, + 1.14450183, 1.12349752, 1.12596664, 1.12213996, 1.1119989, + 1.10257792, 1.10491258, 1.11059842, 1.10509795, 1.10020097], + [0.97282463, 0.96700147, 0.96252588, 0.9653878, 0.96057687, + 0.95831051, 0.94480909, 0.94804195, 0.95430286, 0.94103989, + 0.92122519, 0.91010201, 0.89280392, 0.89298243, 0.89165385], + [0.94325468, 0.96436902, 0.96455242, 0.95243009, 0.94117647, + 0.9480927, 0.93539182, 0.95388718, 0.94597005, 0.96918424, + 0.94781281, 0.93466815, 0.94281559, 0.96520315, 0.96715441], + [0.97478408, 0.98169225, 0.98712809, 0.98474769, 0.98559897, + 0.98687073, 0.99237486, 0.98209969, 0.9877653, 0.97399471, + 0.96910087, 0.98416665, 0.98423613, 0.99823861, 0.99545704], + [0.85570269, 0.85575915, 0.85986132, 0.85693406, 0.8538012, + 0.86191535, 0.84981451, 0.85472102, 0.84564835, 0.83998883, + 0.83478547, 0.82803648, 0.8198736, 0.82265395, 0.8399404], + [0.87022047, 0.85996258, 0.85961813, 0.85689572, 0.83947136, + 0.82785597, 0.86008789, 0.86776298, 0.86720209, 0.8676334, + 0.89179317, 0.94202108, 0.9422231, 0.93902708, 0.94479184], + [0.90134907, 0.90407738, 0.90403991, 0.90201769, 0.90399238, + 0.90906632, 0.92693339, 0.93695966, 0.94242697, 0.94338265, + 0.91981796, 0.91108804, 0.90543476, 0.91737138, 0.94793657], + [1.1977611, 1.18222564, 1.18439158, 1.18267865, 1.19286723, + 1.20172869, 1.21328691, 1.22624778, 1.22397075, 1.23857042, + 1.24419893, 1.23929384, 1.23418676, 1.23626739, 1.26754398], + [1.24919678, 1.25754773, 1.26991161, 1.28020651, 1.30625667, + 1.34790023, 1.34399863, 1.32575181, 1.30795492, 1.30544841, + 1.30303302, 1.32107766, 1.32936244, 1.33001241, 1.33288462], + [1.06768004, 1.03799276, 1.03637303, 1.02768449, 1.03296093, + 1.05059016, 1.03405057, 1.02747623, 1.03162734, 0.9961416, + 0.97356208, 0.94241549, 0.92754547, 0.92549227, 0.92138102], + [1.09475614, 1.11526796, 1.11654299, 1.13103948, 1.13143264, + 1.13889622, 1.12442212, 1.13367018, 1.13982256, 1.14029944, + 1.11979401, 1.10905389, 1.10577769, 1.11166825, 1.09985155], + [0.76530058, 0.76612841, 0.76542451, 0.76722683, 0.76014284, + 0.74480073, 0.76098396, 0.76156903, 0.76651952, 0.76533288, + 0.78205934, 0.76842416, 0.77487118, 0.77768683, 0.78801192], + [0.98391336, 0.98075816, 0.98295341, 0.97386015, 0.96913803, + 0.97370819, 0.96419154, 0.97209861, 0.97441313, 0.96356162, + 0.94745352, 0.93965462, 0.93069645, 0.94020973, 0.94358232], + [0.83561828, 0.82298088, 0.81738502, 0.81748588, 0.80904801, + 0.80071489, 0.83358256, 0.83451613, 0.85175032, 0.85954307, + 0.86790024, 0.87170334, 0.87863799, 0.87497981, 0.87888675], + [0.98845573, 1.02092428, 0.99665283, 0.99141823, 0.99386619, + 0.98733195, 0.99644997, 0.99669587, 1.02559097, 1.01116651, + 0.99988024, 0.97906749, 0.99323123, 1.00204939, 0.99602148], + [1.14930913, 1.15241949, 1.14300962, 1.14265542, 1.13984683, + 1.08312397, 1.05192626, 1.04230892, 1.05577278, 1.08569751, + 1.12443486, 1.08891079, 1.08603695, 1.05997314, 1.02160943], + [1.11368269, 1.1057147, 1.11893431, 1.13778669, 1.1432272, + 1.18257029, 1.16226243, 1.16009196, 1.14467789, 1.14820235, + 1.12386598, 1.12680236, 1.12357937, 1.1159258, 1.12570828], + [1.30379431, 1.30752186, 1.31206366, 1.31532267, 1.30625667, + 1.31210239, 1.29989156, 1.29203193, 1.27183516, 1.26830786, + 1.2617743, 1.28656675, 1.29734097, 1.29390205, 1.29345446], + [0.83953719, 0.82701448, 0.82006005, 0.81188876, 0.80294864, + 0.78772975, 0.82848011, 0.8259679, 0.82435705, 0.83108634, + 0.84373784, 0.83891093, 0.84349247, 0.85637272, 0.86539395], + [1.23450087, 1.2426022, 1.23537935, 1.23581293, 1.24522626, + 1.2256767, 1.21126648, 1.19377804, 1.18355337, 1.19674434, + 1.21536573, 1.23653297, 1.27962009, 1.27968392, 1.25907738], + [0.9769662, 0.97400719, 0.98035944, 0.97581531, 0.95543282, + 0.96480308, 0.94686376, 0.93679073, 0.92540049, 0.92988835, + 0.93442917, 0.92100464, 0.91475304, 0.90249622, 0.9021363], + [0.84986886, 0.8986851, 0.84295997, 0.87280534, 0.85659368, + 0.88937573, 0.894401, 0.90448993, 0.95495898, 0.92698333, + 0.94745352, 0.92562488, 0.96635366, 1.02520312, 1.0394296], + [1.01922808, 1.00258203, 1.00974428, 1.00303417, 0.99765073, + 1.00759019, 0.99192968, 0.99747298, 0.99550759, 0.97583768, + 0.9610168, 0.94779638, 0.93759089, 0.93353431, 0.94121705], + [0.86367411, 0.85558932, 0.85544346, 0.85103025, 0.84336613, + 0.83434854, 0.85813595, 0.84667961, 0.84374558, 0.85951183, + 0.87194227, 0.89455097, 0.88283929, 0.90349491, 0.90600675], + [1.00947534, 1.00411055, 1.00698819, 0.99513687, 0.99291086, + 1.00581626, 0.98850522, 0.99291168, 0.98983209, 0.97511924, + 0.96134615, 0.96382634, 0.95011401, 0.9434686, 0.94637765], + [1.05712571, 1.05459419, 1.05753012, 1.04880786, 1.05103857, + 1.04800023, 1.03024941, 1.04200483, 1.0402554, 1.03296979, + 1.02191682, 1.02476275, 1.02347523, 1.02517684, 1.04359571], + [1.07084189, 1.06669497, 1.07937623, 1.07387988, 1.0794043, + 1.0531801, 1.07452771, 1.09383478, 1.1052447, 1.10322136, + 1.09167939, 1.08772756, 1.08859544, 1.09177338, 1.1096083], + [0.86719222, 0.86628896, 0.86675156, 0.86425632, 0.86511809, + 0.86287327, 0.85169796, 0.85411285, 0.84886336, 0.84517414, + 0.84843858, 0.84488343, 0.83374329, 0.82812044, 0.82878599], + [0.88389211, 0.92288667, 0.90282398, 0.91229186, 0.92023286, + 0.92652175, 0.94278865, 0.93682452, 0.98655146, 0.992237, + 0.9798497, 0.93869677, 0.96947771, 1.00362626, 0.98102351], + [0.97082064, 0.95320233, 0.94534081, 0.94215593, 0.93967, + 0.93092109, 0.92662519, 0.93412152, 0.93501274, 0.92879506, + 0.92110542, 0.91035556, 0.90430364, 0.89994694, 0.90073864], + [0.95861858, 0.95774543, 0.98254811, 0.98919472, 0.98684824, + 0.98882205, 0.97662234, 0.95601578, 0.94905385, 0.94934888, + 0.97152609, 0.97163004, 0.9700702, 0.97158948, 0.95884908], + [0.83980439, 0.84726737, 0.85747, 0.85467221, 0.8556751, + 0.84818516, 0.85265681, 0.84502402, 0.82645665, 0.81743586, + 0.83550406, 0.83338919, 0.83511679, 0.82136617, 0.80921874], + [0.95118156, 0.9466212, 0.94688098, 0.9508583, 0.9512441, + 0.95440787, 0.96364363, 0.96804412, 0.97136214, 0.97583768, + 0.95571724, 0.96895368, 0.97001634, 0.97082733, 0.98782366], + [1.08910044, 1.08248968, 1.08492895, 1.08656923, 1.09454249, + 1.10558188, 1.1214086, 1.12292577, 1.13021031, 1.13342735, + 1.14686068, 1.14502975, 1.14474747, 1.14084037, 1.16142926], + [1.06336033, 1.07365823, 1.08691496, 1.09764846, 1.11669863, + 1.11856702, 1.09764283, 1.08815849, 1.08044313, 1.09278827, + 1.07003204, 1.08398066, 1.09831768, 1.09298232, 1.09176125], + [0.79772065, 0.78829196, 0.78581151, 0.77615922, 0.77035744, + 0.77751194, 0.79902974, 0.81437881, 0.80788828, 0.79603865, + 0.78966436, 0.79949807, 0.80172182, 0.82168155, 0.85587911], + [1.0052447, 1.00007696, 1.00475899, 1.00613942, 1.00639561, + 1.00162979, 0.99860739, 1.00814981, 1.00574316, 0.99030032, + 0.97682565, 0.97292596, 0.96519561, 0.96173403, 0.95890284], + [0.95808419, 0.9382568, 0.9654441, 0.95561201, 0.96987289, + 0.96608031, 0.99727185, 1.00781194, 1.03484236, 1.05333619, + 1.0983263, 1.1704974, 1.17025154, 1.18730553, 1.14242645]]) self.assertTrue(np.allclose(result, expected)) self.assertTrue(type(result) == type(expected)) @@ -268,32 +292,35 @@ class SpaceTimeTests(unittest.TestCase): def test_rebin_data(self): """Test rebin_data""" - ## sample in double the time (even case since 10 % 2 = 0): - ## (0+1)/2, (2+3)/2, (4+5)/2, (6+7)/2, (8+9)/2 - ## = 0.5, 2.5, 4.5, 6.5, 8.5 + # sample in double the time (even case since 10 % 2 = 0): + # (0+1)/2, (2+3)/2, (4+5)/2, (6+7)/2, (8+9)/2 + # = 0.5, 2.5, 4.5, 6.5, 8.5 ans_even = np.array([(i + 0.5) * np.ones(10, dtype=float) for i in range(0, 10, 2)]).T - self.assertTrue(np.array_equal(std.rebin_data(self.time_data, 2), ans_even)) + self.assertTrue( + np.array_equal(std.rebin_data(self.time_data, 2), ans_even)) - ## sample in triple the time (uneven since 10 % 3 = 1): - ## (0+1+2)/3, (3+4+5)/3, (6+7+8)/3, (9)/1 - ## = 1, 4, 7, 9 - ans_odd = np.array([i * np.ones(10, dtype=float) - for i in (1, 4, 7, 9)]).T - self.assertTrue(np.array_equal(std.rebin_data(self.time_data, 3), ans_odd)) + # sample in triple the time (uneven since 10 % 3 = 1): + # (0+1+2)/3, (3+4+5)/3, (6+7+8)/3, (9)/1 + # = 1, 4, 7, 9 + ans_odd = np.array([i * np.ones(10, dtype=float) + for i in (1, 4, 7, 9)]).T + self.assertTrue( + np.array_equal(std.rebin_data(self.time_data, 3), ans_odd)) def test_get_prob_dist(self): """Test get_prob_dist""" lag_indices = np.array([1, 2, 3, 4]) unit_indices = np.array([1, 3, 2, 4]) answer = np.array([ - [ 0.0754717 , 0.88207547, 0.04245283, 0. , 0. ], - [ 0. , 0. , 0.09411765, 0.87058824, 0.03529412], - [ 0.0049505 , 0.09405941, 0.77722772, 0.11881188, 0.0049505 ], - [ 0. , 0. , 0. , 0.02352941, 0.97647059] + [0.0754717, 0.88207547, 0.04245283, 0., 0.], + [0., 0., 0.09411765, 0.87058824, 0.03529412], + [0.0049505, 0.09405941, 0.77722772, 0.11881188, 0.0049505], + [0., 0., 0., 0.02352941, 0.97647059] ]) - result = std.get_prob_dist(self.transition_matrix, lag_indices, unit_indices) + result = std.get_prob_dist(self.transition_matrix, + lag_indices, unit_indices) self.assertTrue(np.array_equal(result, answer)) @@ -301,16 +328,20 @@ class SpaceTimeTests(unittest.TestCase): """Test get_prob_stats""" probs = np.array([ - [ 0.0754717 , 0.88207547, 0.04245283, 0. , 0. ], - [ 0. , 0. , 0.09411765, 0.87058824, 0.03529412], - [ 0.0049505 , 0.09405941, 0.77722772, 0.11881188, 0.0049505 ], - [ 0. , 0. , 0. , 0.02352941, 0.97647059] + [0.0754717, 0.88207547, 0.04245283, 0., 0.], + [0., 0., 0.09411765, 0.87058824, 0.03529412], + [0.0049505, 0.09405941, 0.77722772, 0.11881188, 0.0049505], + [0., 0., 0., 0.02352941, 0.97647059] ]) unit_indices = np.array([1, 3, 2, 4]) answer_up = np.array([0.04245283, 0.03529412, 0.12376238, 0.]) answer_down = np.array([0.0754717, 0.09411765, 0.0990099, 0.02352941]) - answer_trend = np.array([-0.03301887 / 0.88207547, -0.05882353 / 0.87058824, 0.02475248 / 0.77722772, -0.02352941 / 0.97647059]) - answer_volatility = np.array([ 0.34221495, 0.33705421, 0.29226542, 0.38834223]) + answer_trend = np.array([-0.03301887 / 0.88207547, + -0.05882353 / 0.87058824, + 0.02475248 / 0.77722772, + -0.02352941 / 0.97647059]) + answer_volatility = np.array([0.34221495, 0.33705421, + 0.29226542, 0.38834223]) result = std.get_prob_stats(probs, unit_indices) result_up = result[0]