Skip to content

Commit

Permalink
Fix check_contracts handling of forward references in type annotations (
Browse files Browse the repository at this point in the history
  • Loading branch information
david-yz-liu authored Aug 11, 2023
1 parent df754d2 commit 81f04dc
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

- Make `graphviz` an optional dependency, and clarify the installation requirements for visualizing
control flow graphs.
- Fix `check_contrats` handling of forward references in class type annotations when using `check_contracts` decorator.

## [2.6.0] - 2023-08-06

Expand Down
7 changes: 6 additions & 1 deletion python_ta/contracts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def add_class_invariants(klass: type) -> None:
_set_invariants(klass)

klass_mod = _get_module(klass)
cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
cls_annotations = None # This is a cached value set the first time new_setattr is called

def new_setattr(self: klass, name: str, value: Any) -> None:
"""Set the value of the given attribute on self to the given value.
Expand All @@ -150,6 +150,11 @@ def new_setattr(self: klass, name: str, value: Any) -> None:
if not ENABLE_CONTRACT_CHECKING:
super(klass, self).__setattr__(name, value)
return

nonlocal cls_annotations
if cls_annotations is None:
cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)

if name in cls_annotations:
try:
_debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
Expand Down
File renamed without changes.
53 changes: 53 additions & 0 deletions tests/test_contracts/test_class_forward_reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Tests for forward references in instance attribute type annotations."""
from __future__ import annotations

from typing import Any, Optional

import pytest

from python_ta.contracts import check_contracts


@check_contracts
class Node:
"""A node in a linked list."""

item: Any
next: Optional[Node]

def __init__(self, item: Any) -> None:
self.item = item
self.next = None

def bad_method(self) -> None:
"""Set self.next to an invalid value, violating the type annotation."""
self.next = 1


def test_node_no_error() -> None:
"""Test that a Node can be initialized without error."""
Node(1)


def test_node_setattr_no_error() -> None:
"""Test that assigning a valid attribute to a Node does not raise an error."""
node = Node(1)
node.next = Node(2)

assert node.next.item == 2


def test_node_setattr_error() -> None:
"""Test that assigning an invalid attribute to a Node raises an AssertionError."""
node = Node(1)

with pytest.raises(AssertionError):
node.next = 2


def test_node_method_error() -> None:
"""Test that violating a Node type annotation with a method call raises an AssertionError."""
node = Node(1)

with pytest.raises(AssertionError):
node.bad_method()
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from python_ta.contracts import check_contracts
from tests.test_contracts_type_alias_abstract_network import AbstractNetwork
from tests.test_contracts.test_contracts_type_alias_abstract_network import (
AbstractNetwork,
)


def test_type_alias_as_type_annotation_for_class_attribute_no_error() -> None:
Expand Down

0 comments on commit 81f04dc

Please sign in to comment.