tests for gwr predict
This commit is contained in:
parent
0815db4661
commit
7725cada13
@ -51,13 +51,20 @@ class GWRTest(unittest.TestCase):
|
||||
|
||||
# data for GWR prediction
|
||||
self.data_predict = copy.deepcopy(self.data)
|
||||
self.ids_of_unknowns = [13083, ]
|
||||
self.ids_of_unknowns = [13083, 13009, 13281, 13115, 13247, 13169]
|
||||
self.idx_ids_of_unknowns = [self.data_predict[0]['rowid'].index(idx)
|
||||
for idx in self.ids_of_unknowns]
|
||||
|
||||
for idx in self.idx_ids_of_unknowns:
|
||||
self.data_predict[0]['dep_var'][idx] = None
|
||||
|
||||
self.predicted_knowns = {13009: 10.879,
|
||||
13083: 4.5259,
|
||||
13115: 9.4022,
|
||||
13169: 6.0793,
|
||||
13247: 8.1608,
|
||||
13281: 13.886}
|
||||
|
||||
# params, with ind_vars in same ordering as query above
|
||||
self.params = {'subquery': 'select * from table',
|
||||
'dep_var': 'pctbach',
|
||||
@ -102,8 +109,7 @@ class GWRTest(unittest.TestCase):
|
||||
places=4)
|
||||
|
||||
def test_gwr_predict(self):
|
||||
"""
|
||||
"""
|
||||
"""Testing for GWR_Predict"""
|
||||
gwr = GWR(FakeDataProvider(self.data_predict))
|
||||
gwr_resp = gwr.gwr_predict(self.params['subquery'],
|
||||
self.params['dep_var'],
|
||||
@ -114,17 +120,11 @@ class GWRTest(unittest.TestCase):
|
||||
# unpack response
|
||||
coeffs, stand_errs, t_vals, \
|
||||
r_squareds, predicteds, rowid = zip(*gwr_resp)
|
||||
threshold = 0.05
|
||||
threshold = 0.01
|
||||
|
||||
print("{0}, {1}, {2}, {3}".format('known', 'predicted', 'diff', 'id'))
|
||||
for i, idx in enumerate(self.idx_ids_of_unknowns):
|
||||
|
||||
known_val = self.data[0]['dep_var'][idx]
|
||||
known_val = self.predicted_knowns[rowid[i]]
|
||||
predicted_val = predicteds[i]
|
||||
test_val = abs(known_val - predicted_val) / known_val
|
||||
|
||||
print("{0}, {1}, {2}, {3}".format(self.data[0]['dep_var'][idx],
|
||||
predicteds[i],
|
||||
test_val,
|
||||
rowid[i]))
|
||||
self.assertTrue(test_val < threshold)
|
||||
|
Loading…
Reference in New Issue
Block a user