Skip to content

Commit

Permalink
feat: evaluator module (#192)
Browse files Browse the repository at this point in the history
* feat: evaluator module
- Move segment rule evaluators to flag_engine.segments.evaluator module
- Add missing runtime type checks
  • Loading branch information
khvn26 authored Sep 18, 2023
1 parent 32f01ef commit 2064e9e
Show file tree
Hide file tree
Showing 4 changed files with 359 additions and 235 deletions.
121 changes: 115 additions & 6 deletions flag_engine/segments/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
import operator
import re
import typing
from contextlib import suppress
from functools import wraps

import semver

from flag_engine.environments.models import EnvironmentModel
from flag_engine.identities.models import IdentityModel
from flag_engine.identities.traits.models import TraitModel
from flag_engine.identities.traits.types import TraitValue
from flag_engine.segments import constants
from flag_engine.segments.models import (
SegmentConditionModel,
SegmentModel,
SegmentRuleModel,
)
from flag_engine.segments.types import ConditionOperator
from flag_engine.utils.hashing import get_hashed_percentage_for_object_ids

from ..environments.models import EnvironmentModel
from ..identities.traits.models import TraitModel
from . import constants
from .models import SegmentConditionModel, SegmentModel, SegmentRuleModel
from flag_engine.utils.semver import is_semver
from flag_engine.utils.types import get_casting_function


def get_identity_segments(
Expand Down Expand Up @@ -79,6 +92,7 @@ def _traits_match_segment_condition(
identity_id: typing.Union[int, str],
) -> bool:
if condition.operator == constants.PERCENTAGE_SPLIT:
assert condition.value
float_value = float(condition.value)
return (
get_hashed_percentage_for_object_ids([segment_id, identity_id])
Expand All @@ -95,4 +109,99 @@ def _traits_match_segment_condition(
if condition.operator == constants.IS_SET:
return trait is not None

return condition.matches_trait_value(trait.trait_value) if trait else False
return _matches_trait_value(condition, trait.trait_value) if trait else False


def _matches_trait_value(
condition: SegmentConditionModel,
trait_value: TraitValue,
) -> bool:
if match_func := MATCH_FUNCS_BY_OPERATOR.get(condition.operator):
return match_func(condition.value, trait_value)

return False


def _evaluate_not_contains(
segment_value: typing.Optional[str],
trait_value: TraitValue,
) -> bool:
return isinstance(trait_value, str) and str(segment_value) not in trait_value


def _evaluate_regex(
segment_value: typing.Optional[str],
trait_value: TraitValue,
) -> bool:
return (
trait_value is not None
and re.compile(str(segment_value)).match(str(trait_value)) is not None
)


def _evaluate_modulo(
segment_value: typing.Optional[str],
trait_value: TraitValue,
) -> bool:
if not isinstance(trait_value, (int, float)):
return False

if segment_value is None:
return False

try:
divisor_part, remainder_part = segment_value.split("|")
divisor = float(divisor_part)
remainder = float(remainder_part)
except ValueError:
return False

return trait_value % divisor == remainder


def _evaluate_in(segment_value: typing.Optional[str], trait_value: TraitValue) -> bool:
if segment_value:
if isinstance(trait_value, str):
return trait_value in segment_value.split(",")
if isinstance(trait_value, int) and not any(
trait_value is x for x in (False, True)
):
return str(trait_value) in segment_value.split(",")
return False


def _trait_value_typed(
func: typing.Callable[..., bool],
) -> typing.Callable[[typing.Optional[str], TraitValue], bool]:
@wraps(func)
def inner(
segment_value: typing.Optional[str],
trait_value: TraitValue,
) -> bool:
with suppress(TypeError, ValueError):
if isinstance(trait_value, str) and is_semver(segment_value):
trait_value = semver.VersionInfo.parse(
trait_value,
)
match_value = get_casting_function(trait_value)(segment_value)
return func(trait_value, match_value)
return False

return inner


MATCH_FUNCS_BY_OPERATOR: typing.Dict[
ConditionOperator, typing.Callable[[typing.Optional[str], TraitValue], bool]
] = {
constants.NOT_CONTAINS: _evaluate_not_contains,
constants.REGEX: _evaluate_regex,
constants.MODULO: _evaluate_modulo,
constants.IN: _evaluate_in,
constants.EQUAL: _trait_value_typed(operator.eq),
constants.GREATER_THAN: _trait_value_typed(operator.gt),
constants.GREATER_THAN_INCLUSIVE: _trait_value_typed(operator.ge),
constants.LESS_THAN: _trait_value_typed(operator.lt),
constants.LESS_THAN_INCLUSIVE: _trait_value_typed(operator.le),
constants.NOT_EQUAL: _trait_value_typed(operator.ne),
constants.CONTAINS: _trait_value_typed(operator.contains),
}
66 changes: 0 additions & 66 deletions flag_engine/segments/models.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,17 @@
import re
import typing
from contextlib import suppress

import semver
from pydantic import BaseModel, Field

from flag_engine.features.models import FeatureStateModel
from flag_engine.segments import constants
from flag_engine.segments.types import ConditionOperator, RuleType
from flag_engine.utils.semver import is_semver
from flag_engine.utils.types import get_casting_function


class SegmentConditionModel(BaseModel):
_EXCEPTION_OPERATOR_METHODS = {
constants.NOT_CONTAINS: "evaluate_not_contains",
constants.REGEX: "evaluate_regex",
constants.MODULO: "evaluate_modulo",
constants.IN: "evaluate_in",
}

operator: ConditionOperator
value: typing.Optional[str] = None
property_: typing.Optional[str] = None

def matches_trait_value(self, trait_value: typing.Any) -> bool:
# TODO: move this logic to the evaluator module
with suppress(ValueError):
if type(self.value) is str and is_semver(self.value):
trait_value = semver.VersionInfo.parse(trait_value)
if self.operator in self._EXCEPTION_OPERATOR_METHODS:
evaluator_function = getattr(
self, self._EXCEPTION_OPERATOR_METHODS.get(self.operator)
)
return evaluator_function(trait_value)

matching_function_name = {
constants.EQUAL: "__eq__",
constants.GREATER_THAN: "__gt__",
constants.GREATER_THAN_INCLUSIVE: "__ge__",
constants.LESS_THAN: "__lt__",
constants.LESS_THAN_INCLUSIVE: "__le__",
constants.NOT_EQUAL: "__ne__",
constants.CONTAINS: "__contains__",
}.get(self.operator)
matching_function = getattr(
trait_value, matching_function_name, lambda v: False
)
to_same_type_as_trait_value = get_casting_function(trait_value)
return matching_function(to_same_type_as_trait_value(self.value))

return False

def evaluate_not_contains(self, trait_value: typing.Iterable) -> bool:
return self.value not in trait_value

def evaluate_regex(self, trait_value: str) -> bool:
return (
trait_value is not None
and re.compile(str(self.value)).match(str(trait_value)) is not None
)

def evaluate_modulo(self, trait_value: typing.Union[str, int, float, bool]) -> bool:
if type(trait_value) not in (int, float):
return False
try:
divisor, remainder = self.value.split("|")
divisor = float(divisor)
remainder = float(remainder)
except ValueError:
return False
return trait_value % divisor == remainder

def evaluate_in(self, trait_value) -> bool:
try:
return str(trait_value) in self.value.split(",")
except AttributeError:
return False


class SegmentRuleModel(BaseModel):
type: RuleType
Expand Down
Loading

0 comments on commit 2064e9e

Please sign in to comment.