diff --git a/CHANGELOG.md b/CHANGELOG.md index 64f31e7b1..f7c3d04cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - For the contract checking `new_setattr` function, any variables that depend only on `klass` are now defined in the outer function, efficiency of code was improved, and the attribute value is now restored to the original value if the `_check_invariants` call raises an error. +- Added new function `validate_invariants` which takes in an object and checks that the representation invariants of the object are satisfied. ### Bug Fixes diff --git a/docs/contracts/index.md b/docs/contracts/index.md index 2a85cc3fb..3dfe1f797 100644 --- a/docs/contracts/index.md +++ b/docs/contracts/index.md @@ -50,8 +50,7 @@ AssertionError: divide argument '2' did not match type annotation for parameter The `python_ta.contracts` module offers two functions for enabling contract checking. The first, `check_all_contracts`, enables contract checking for all functions and classes defined within a module or set of modules. -The second, `check_contracts`, is a decorator allowing more fine-grained control over which -functions/classes have contract checking enabled. +The second, `check_contracts`, is a decorator allowing more fine-grained control over which functions/classes have contract checking enabled. ```{eval-rst} .. autofunction:: python_ta.contracts.check_all_contracts @@ -61,6 +60,12 @@ functions/classes have contract checking enabled. .. autofunction:: python_ta.contracts.check_contracts(func_or_class) ``` +You can pass an object into the function `validate_invariants` to manually check the representation invariants of the object. + +```{eval-rst} +.. autofunction:: python_ta.contracts.validate_invariants(object) +``` + You can set the `ENABLE_CONTRACT_CHECKING` constant to `True` to enable all contract checking. ```{eval-rst} diff --git a/python_ta/contracts/__init__.py b/python_ta/contracts/__init__.py index cc1286470..a686f5644 100644 --- a/python_ta/contracts/__init__.py +++ b/python_ta/contracts/__init__.py @@ -137,27 +137,7 @@ def add_class_invariants(klass: type) -> None: # This means the class has already been decorated return - # Update representation invariants from this class' docstring and those of its superclasses. - rep_invariants: List[Tuple[str, CodeType]] = [] - - # Iterate over all inherited classes except builtins - for cls in reversed(klass.__mro__): - if "__representation_invariants__" in cls.__dict__: - rep_invariants.extend(cls.__representation_invariants__) - elif cls.__module__ != "builtins": - assertions = parse_assertions(cls, parse_token="Representation Invariant") - # Try compiling assertions - for assertion in assertions: - try: - compiled = compile(assertion, "", "eval") - except: - _debug( - f"Warning: representation invariant {assertion} could not be parsed as a valid Python expression" - ) - continue - rep_invariants.append((assertion, compiled)) - - setattr(klass, "__representation_invariants__", rep_invariants) + _set_invariants(klass) klass_mod = _get_module(klass) cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__) @@ -603,3 +583,39 @@ def _debug(msg: str) -> None: return print("[PyTA]", msg, file=sys.stderr) + + +def _set_invariants(klass: type) -> None: + """Retrieve and set the representation invariants of this class""" + # Update representation invariants from this class' docstring and those of its superclasses. + rep_invariants: List[Tuple[str, CodeType]] = [] + + # Iterate over all inherited classes except builtins + for cls in reversed(klass.__mro__): + if "__representation_invariants__" in cls.__dict__: + rep_invariants.extend(cls.__representation_invariants__) + elif cls.__module__ != "builtins": + assertions = parse_assertions(cls, parse_token="Representation Invariant") + # Try compiling assertions + for assertion in assertions: + try: + compiled = compile(assertion, "", "eval") + except: + _debug( + f"Warning: representation invariant {assertion} could not be parsed as a valid Python expression" + ) + continue + rep_invariants.append((assertion, compiled)) + + setattr(klass, "__representation_invariants__", rep_invariants) + + +def validate_invariants(obj: object) -> None: + """Check that the representation invariants of obj are satisfied.""" + klass = obj.__class__ + klass_mod = _get_module(klass) + + try: + _check_invariants(obj, klass, klass_mod.__dict__) + except PyTAContractError as e: + raise AssertionError(str(e)) from None diff --git a/tests/test_validate_invariants.py b/tests/test_validate_invariants.py new file mode 100644 index 000000000..4bb0af3ae --- /dev/null +++ b/tests/test_validate_invariants.py @@ -0,0 +1,48 @@ +""" +Test suite for checking the functionality of validate_invariants. +""" + +from typing import List + +import pytest + +from python_ta.contracts import check_contracts, validate_invariants + + +@check_contracts +class Person: + """A custom data type that represents data for a person. + + Representation Invariants: + - self.age >= 0 + - len(self.friends) > 1 + """ + + given_name: str + age: int + friends: List[str] + + def __init__(self, given_name: str, age: int, friends: List[str]) -> None: + """Initialize a new Person object.""" + self.given_name = given_name + self.age = age + self.friends = friends + + +def test_no_errors() -> None: + """Checks that validate_invariants does not raise an error when representation invariants are satisfied.""" + person = Person("Jim", 50, ["Pam", "Dwight"]) + + try: + validate_invariants(person) + except AssertionError: + pytest.fail("validate_invariants has incorrectly raised an AssertionError") + + +def test_raise_error() -> None: + """Checks that validate_invariants raises an error when representation invariants are violated.""" + person = Person("Jim", 50, ["Pam", "Dwight"]) + person.friends.pop() + + with pytest.raises(AssertionError): + validate_invariants(person)