Skip to content

Commit c3f73b1

Browse files
committed
completely refactored defuzzification and put in new defuzz.py
1 parent 931d571 commit c3f73b1

File tree

2 files changed

+128
-41
lines changed

2 files changed

+128
-41
lines changed

src/fuzzylogic/classes.py

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import numpy as np
2121

22+
from fuzzylogic import defuzz
23+
2224
from .combinators import MAX, MIN, bounded_sum, product, simple_disjoint_sum
2325
from .functions import Membership, inv, normalize
2426

@@ -481,50 +483,60 @@ def __eq__(self, other: object) -> bool:
481483
def __getitem__(self, key: Iterable[Set]) -> Set:
482484
return self.conditions[frozenset(key)]
483485

484-
def __call__(self, values: dict[Domain, float | int], method: str = "cog") -> float | None:
485-
"""Calculate the infered value based on different methods.
486-
Default is center of gravity (cog).
486+
def __call__(self, values: dict[Domain, float], method=defuzz.cog) -> float | None:
487+
"""
488+
Calculate the inferred crisp value based on the fuzzy rules.
489+
The 'method' parameter should be one of the static methods from the DefuzzMethod class.
487490
"""
488-
assert isinstance(values, dict), "Please make sure to pass a dict[Domain, float|int] as values."
489-
assert len(self.conditions) > 0, "No point in having a rule with no conditions, is there?"
491+
assert isinstance(values, dict), "Please pass a dict[Domain, float|int] as values."
492+
assert values, "No condition rules defined!"
493+
494+
# Extract common target domain and build list of (then_set, firing_strength)
495+
sample_then_set = next(iter(self.conditions.values()))
496+
target_domain = getattr(sample_then_set, "domain", None)
497+
assert target_domain, "Target domain must be defined."
498+
499+
target_weights: list[tuple[Set, float]] = []
500+
for if_sets, then_set in self.conditions.items():
501+
assert then_set.domain == target_domain, "All target sets must be in the same Domain."
502+
degrees = []
503+
for s in if_sets:
504+
assert s.domain is not None, "Domain must be defined for all fuzzy sets."
505+
degrees.append(s(values[s.domain]))
506+
firing_strength = min(degrees, default=0)
507+
if firing_strength > 0:
508+
target_weights.append((then_set, firing_strength))
509+
if not target_weights:
510+
return None
511+
512+
# For center-of-gravity / centroid:
513+
if method == defuzz.cog:
514+
return defuzz.cog(target_weights)
515+
516+
# For methods that rely on an aggregated membership function:
517+
points = list(target_domain.range)
518+
n = len(points)
519+
step = (
520+
(target_domain._high - target_domain._low) / (n - 1)
521+
if n > 1
522+
else (target_domain._high - target_domain._low)
523+
)
524+
525+
def aggregated_membership(x: float) -> float:
526+
# For each rule, limit its inferred output by its firing strength and then take the max
527+
return max(min(weight, then_set(x)) for then_set, weight in target_weights)
528+
490529
match method:
491-
case "cog":
492-
# iterate over the conditions and calculate the actual values and weights contributing to cog
493-
target_weights: list[tuple[Set, float]] = []
494-
target_domain = list(self.conditions.values())[0].domain
495-
assert target_domain is not None, "Target domain must be defined."
496-
for if_sets, then_set in self.conditions.items():
497-
actual_values: list[float] = []
498-
assert then_set.domain == target_domain, "All target sets must be in the same Domain."
499-
for s in if_sets:
500-
assert s.domain is not None, "Domains must be defined."
501-
actual_values.append(s(values[s.domain]))
502-
x = min(actual_values, default=0)
503-
if x > 0:
504-
target_weights.append((then_set, x))
505-
if not target_weights:
506-
return None
507-
sum_weights = 0
508-
sum_weighted_cogs: float = 0
509-
for then_set, weight in target_weights:
510-
sum_weighted_cogs += then_set.center_of_gravity() * weight
511-
sum_weights += weight
512-
index = sum_weighted_cogs / sum_weights
513-
return (target_domain._high - target_domain._low) / len( # type: ignore
514-
target_domain.range
515-
) * index + target_domain._low # type: ignore
516-
case "centroid": # centroid == center of mass == center of gravity for simple solids
517-
raise NotImplementedError("actually the same as 'cog' if densities are uniform.")
518-
case "bisector":
519-
raise NotImplementedError("Bisector method not implemented yet.")
520-
case "mom":
521-
raise NotImplementedError("Middle of max method not implemented yet.")
522-
case "som":
523-
raise NotImplementedError("Smallest of max method not implemented yet.")
524-
case "lom":
525-
raise NotImplementedError("Largest of max method not implemented yet.")
530+
case defuzz.bisector:
531+
return defuzz.bisector(aggregated_membership, points, step)
532+
case defuzz.mom:
533+
return defuzz.mom(aggregated_membership, points)
534+
case defuzz.som:
535+
return defuzz.som(aggregated_membership, points)
536+
case defuzz.lom:
537+
return defuzz.lom(aggregated_membership, points)
526538
case _:
527-
raise ValueError("Invalid method.")
539+
raise ValueError("Invalid defuzzification method specified.")
528540

529541

530542
def rule_from_table(table: str, references: dict[str, float]) -> Rule:

src/fuzzylogic/defuzz.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from .classes import Membership, Set
7+
8+
9+
def cog(target_weights: list[tuple[Set, float]]) -> float:
10+
"""
11+
Defuzzify using the center-of-gravity (or centroid) method.
12+
target_weights: list of tuples (then_set, weight)
13+
14+
The COG is defined by the formula:
15+
16+
COG = (∑ μᵢ × xᵢ) / (∑ μᵢ)
17+
18+
where:
19+
• μᵢ is the membership value for the iᵗʰ element,
20+
• xᵢ is the corresponding value for the iᵗʰ element in the output domain.
21+
22+
"""
23+
sum_weights = sum(weight for _, weight in target_weights)
24+
sum_weighted_cogs = sum(then_set.center_of_gravity() * weight for then_set, weight in target_weights)
25+
return sum_weighted_cogs / sum_weights
26+
27+
28+
def bisector(
29+
aggregated_membership: Membership,
30+
points: list[float],
31+
step: float,
32+
) -> float:
33+
"""
34+
Defuzzify via the bisector method.
35+
aggregated_membership: function mapping crisp value x -> membership degree (typically in [0,1])
36+
points: discretized points in the target domain
37+
step: spacing between points
38+
"""
39+
total_area = sum(aggregated_membership(x) * step for x in points)
40+
half_area = total_area / 2.0
41+
cumulative = 0.0
42+
for x in points:
43+
cumulative += aggregated_membership(x) * step
44+
if cumulative >= half_area:
45+
return x
46+
return points[-1]
47+
48+
49+
def mom(aggregated_membership: Membership, points: list[float]) -> float | None:
50+
"""
51+
Mean of Maxima (MOM): average the x-values where the aggregated membership is maximal.
52+
"""
53+
max_points = _get_max_points(aggregated_membership, points)
54+
return sum(max_points) / len(max_points) if max_points else None
55+
56+
57+
def som(aggregated_membership: Membership, points: list[float]) -> float | None:
58+
"""
59+
Smallest of Maxima: return the smallest x-value at which the aggregated membership is maximal.
60+
"""
61+
return min(_get_max_points(aggregated_membership, points), default=None)
62+
63+
64+
def lom(aggregated_membership: Membership, points: list[float]) -> float | None:
65+
"""
66+
Largest of Maxima: return the largest x-value at which the aggregated membership is maximal.
67+
"""
68+
return max(_get_max_points(aggregated_membership, points), default=None)
69+
70+
71+
def _get_max_points(aggregated_membership: Membership, points: list[float]) -> list[float]:
72+
values_points = [(x, aggregated_membership(x)) for x in points]
73+
max_value = max(y for (_, y) in values_points)
74+
tol = 1e-6 # tolerance for floating point comparisons
75+
return [x for (x, y) in values_points if abs(y - max_value) < tol]

0 commit comments

Comments
 (0)