From 81f04dc9260fc9f5a05678efd0fc95cab0343b0f Mon Sep 17 00:00:00 2001 From: David Liu Date: Thu, 10 Aug 2023 22:26:23 -0400 Subject: [PATCH] Fix check_contracts handling of forward references in type annotations (#941) --- CHANGELOG.md | 1 + python_ta/contracts/__init__.py | 7 ++- .../test_class_contracts.py | 0 .../test_class_forward_reference.py | 53 +++++++++++++++++++ tests/{ => test_contracts}/test_contracts.py | 0 .../test_contracts_attr_value_restoration.py | 0 ...t_contracts_type_alias_abstract_network.py | 0 ...test_contracts_type_alias_abstract_ring.py | 4 +- 8 files changed, 63 insertions(+), 2 deletions(-) rename tests/{ => test_contracts}/test_class_contracts.py (100%) create mode 100644 tests/test_contracts/test_class_forward_reference.py rename tests/{ => test_contracts}/test_contracts.py (100%) rename tests/{ => test_contracts}/test_contracts_attr_value_restoration.py (100%) rename tests/{ => test_contracts}/test_contracts_type_alias_abstract_network.py (100%) rename tests/{ => test_contracts}/test_contracts_type_alias_abstract_ring.py (74%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3544b87af..b1e799ed3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - Make `graphviz` an optional dependency, and clarify the installation requirements for visualizing control flow graphs. +- Fix `check_contrats` handling of forward references in class type annotations when using `check_contracts` decorator. ## [2.6.0] - 2023-08-06 diff --git a/python_ta/contracts/__init__.py b/python_ta/contracts/__init__.py index 6bd3db719..65570282d 100644 --- a/python_ta/contracts/__init__.py +++ b/python_ta/contracts/__init__.py @@ -140,7 +140,7 @@ def add_class_invariants(klass: type) -> None: _set_invariants(klass) klass_mod = _get_module(klass) - cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__) + cls_annotations = None # This is a cached value set the first time new_setattr is called def new_setattr(self: klass, name: str, value: Any) -> None: """Set the value of the given attribute on self to the given value. @@ -150,6 +150,11 @@ def new_setattr(self: klass, name: str, value: Any) -> None: if not ENABLE_CONTRACT_CHECKING: super(klass, self).__setattr__(name, value) return + + nonlocal cls_annotations + if cls_annotations is None: + cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__) + if name in cls_annotations: try: _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance") diff --git a/tests/test_class_contracts.py b/tests/test_contracts/test_class_contracts.py similarity index 100% rename from tests/test_class_contracts.py rename to tests/test_contracts/test_class_contracts.py diff --git a/tests/test_contracts/test_class_forward_reference.py b/tests/test_contracts/test_class_forward_reference.py new file mode 100644 index 000000000..38bd92ae9 --- /dev/null +++ b/tests/test_contracts/test_class_forward_reference.py @@ -0,0 +1,53 @@ +"""Tests for forward references in instance attribute type annotations.""" +from __future__ import annotations + +from typing import Any, Optional + +import pytest + +from python_ta.contracts import check_contracts + + +@check_contracts +class Node: + """A node in a linked list.""" + + item: Any + next: Optional[Node] + + def __init__(self, item: Any) -> None: + self.item = item + self.next = None + + def bad_method(self) -> None: + """Set self.next to an invalid value, violating the type annotation.""" + self.next = 1 + + +def test_node_no_error() -> None: + """Test that a Node can be initialized without error.""" + Node(1) + + +def test_node_setattr_no_error() -> None: + """Test that assigning a valid attribute to a Node does not raise an error.""" + node = Node(1) + node.next = Node(2) + + assert node.next.item == 2 + + +def test_node_setattr_error() -> None: + """Test that assigning an invalid attribute to a Node raises an AssertionError.""" + node = Node(1) + + with pytest.raises(AssertionError): + node.next = 2 + + +def test_node_method_error() -> None: + """Test that violating a Node type annotation with a method call raises an AssertionError.""" + node = Node(1) + + with pytest.raises(AssertionError): + node.bad_method() diff --git a/tests/test_contracts.py b/tests/test_contracts/test_contracts.py similarity index 100% rename from tests/test_contracts.py rename to tests/test_contracts/test_contracts.py diff --git a/tests/test_contracts_attr_value_restoration.py b/tests/test_contracts/test_contracts_attr_value_restoration.py similarity index 100% rename from tests/test_contracts_attr_value_restoration.py rename to tests/test_contracts/test_contracts_attr_value_restoration.py diff --git a/tests/test_contracts_type_alias_abstract_network.py b/tests/test_contracts/test_contracts_type_alias_abstract_network.py similarity index 100% rename from tests/test_contracts_type_alias_abstract_network.py rename to tests/test_contracts/test_contracts_type_alias_abstract_network.py diff --git a/tests/test_contracts_type_alias_abstract_ring.py b/tests/test_contracts/test_contracts_type_alias_abstract_ring.py similarity index 74% rename from tests/test_contracts_type_alias_abstract_ring.py rename to tests/test_contracts/test_contracts_type_alias_abstract_ring.py index 4ac822c4d..dbb20460e 100644 --- a/tests/test_contracts_type_alias_abstract_ring.py +++ b/tests/test_contracts/test_contracts_type_alias_abstract_ring.py @@ -1,5 +1,7 @@ from python_ta.contracts import check_contracts -from tests.test_contracts_type_alias_abstract_network import AbstractNetwork +from tests.test_contracts.test_contracts_type_alias_abstract_network import ( + AbstractNetwork, +) def test_type_alias_as_type_annotation_for_class_attribute_no_error() -> None: