dlib/tools/python/test/test_vector.py
2018-01-16 21:26:56 -05:00

171 lines
3.5 KiB
Python

from dlib import vector, vectors, vectorss, dot
try:
import cPickle as pickle # Use cPickle on Python 2.7
except ImportError:
import pickle
from pytest import raises
def test_vector_empty_init():
v = vector()
assert len(v) == 0
assert v.shape == (0, 1)
assert str(v) == ""
assert repr(v) == "dlib.vector([])"
def test_vector_init_with_number():
v = vector(3)
assert len(v) == 3
assert v.shape == (3, 1)
assert str(v) == "0\n0\n0"
assert repr(v) == "dlib.vector([0, 0, 0])"
def test_vector_set_size():
v = vector(3)
v.set_size(0)
assert len(v) == 0
assert v.shape == (0, 1)
v.resize(10)
assert len(v) == 10
assert v.shape == (10, 1)
for i in range(10):
assert v[i] == 0
def test_vector_init_with_list():
v = vector([1, 2, 3])
assert len(v) == 3
assert v.shape == (3, 1)
assert str(v) == "1\n2\n3"
assert repr(v) == "dlib.vector([1, 2, 3])"
def test_vector_getitem():
v = vector([1, 2, 3])
assert v[0] == 1
assert v[-1] == 3
assert v[1] == v[-2]
def test_vector_slice():
v = vector([1, 2, 3, 4, 5])
v_slice = v[1:4]
assert len(v_slice) == 3
for idx, val in enumerate([2, 3, 4]):
assert v_slice[idx] == val
v_slice = v[-3:-1]
assert len(v_slice) == 2
for idx, val in enumerate([3, 4]):
assert v_slice[idx] == val
v_slice = v[1:-2]
assert len(v_slice) == 2
for idx, val in enumerate([2, 3]):
assert v_slice[idx] == val
def test_vector_invalid_getitem():
v = vector([1, 2, 3])
with raises(IndexError):
v[-4]
with raises(IndexError):
v[3]
def test_vector_init_with_negative_number():
with raises(Exception):
vector(-3)
def test_dot():
v1 = vector([1, 0])
v2 = vector([0, 1])
v3 = vector([-1, 0])
assert dot(v1, v1) == 1
assert dot(v1, v2) == 0
assert dot(v1, v3) == -1
def test_vector_serialization():
v = vector([1, 2, 3])
ser = pickle.dumps(v, 2)
deser = pickle.loads(ser)
assert str(v) == str(deser)
def generate_test_vectors():
vs = vectors()
vs.append(vector([0, 1, 2]))
vs.append(vector([3, 4, 5]))
vs.append(vector([6, 7, 8]))
assert len(vs) == 3
return vs
def generate_test_vectorss():
vss = vectorss()
vss.append(generate_test_vectors())
vss.append(generate_test_vectors())
vss.append(generate_test_vectors())
assert len(vss) == 3
return vss
def test_vectors_serialization():
vs = generate_test_vectors()
ser = pickle.dumps(vs, 2)
deser = pickle.loads(ser)
assert vs == deser
def test_vectors_clear():
vs = generate_test_vectors()
vs.clear()
assert len(vs) == 0
def test_vectors_resize():
vs = vectors()
vs.resize(100)
assert len(vs) == 100
for i in range(100):
assert len(vs[i]) == 0
def test_vectors_extend():
vs = vectors()
vs.extend([vector([1, 2, 3]), vector([4, 5, 6])])
assert len(vs) == 2
def test_vectorss_serialization():
vss = generate_test_vectorss()
ser = pickle.dumps(vss, 2)
deser = pickle.loads(ser)
assert vss == deser
def test_vectorss_clear():
vss = generate_test_vectorss()
vss.clear()
assert len(vss) == 0
def test_vectorss_resize():
vss = vectorss()
vss.resize(100)
assert len(vss) == 100
for i in range(100):
assert len(vss[i]) == 0
def test_vectorss_extend():
vss = vectorss()
vss.extend([generate_test_vectors(), generate_test_vectors()])
assert len(vss) == 2