diff --git a/src/py/crankshaft/test/helper.py b/src/py/crankshaft/test/helper.py index 7d28b94..b273354 100644 --- a/src/py/crankshaft/test/helper.py +++ b/src/py/crankshaft/test/helper.py @@ -2,6 +2,7 @@ import unittest from mock_plpy import MockPlPy plpy = MockPlPy() +from mock_plpy import MockDBResponse import sys sys.modules['plpy'] = plpy diff --git a/src/py/crankshaft/test/mock_plpy.py b/src/py/crankshaft/test/mock_plpy.py index a982ebe..05d0f21 100644 --- a/src/py/crankshaft/test/mock_plpy.py +++ b/src/py/crankshaft/test/mock_plpy.py @@ -1,12 +1,13 @@ import re + class MockCursor: def __init__(self, data): self.cursor_pos = 0 self.data = data def fetch(self, batch_size): - batch = self.data[self.cursor_pos : self.cursor_pos + batch_size] + batch = self.data[self.cursor_pos:self.cursor_pos + batch_size] self.cursor_pos += batch_size return batch @@ -45,8 +46,22 @@ class MockPlPy: data = self.execute(query) return MockCursor(data) - def execute(self, query): # TODO: additional arguments - for result in self.results: - if result[0].match(query): - return result[1] - return [] + # TODO: additional arguments + def execute(self, query): + for result in self.results: + if result[0].match(query): + return result[1] + return [] + + +class MockDBResponse: + def __init__(self, data, colnames=None): + self.data = data + if colnames is None: + self.colnames = data[0].keys() + else: + self.colnames = colnames + + + def colnames(self): + return self.colnames