|
19 | 19 |
|
20 | 20 | import numpy as np
|
21 | 21 |
|
| 22 | +from fuzzylogic import defuzz |
| 23 | + |
22 | 24 | from .combinators import MAX, MIN, bounded_sum, product, simple_disjoint_sum
|
23 | 25 | from .functions import Membership, inv, normalize
|
24 | 26 |
|
@@ -481,50 +483,60 @@ def __eq__(self, other: object) -> bool:
|
481 | 483 | def __getitem__(self, key: Iterable[Set]) -> Set:
|
482 | 484 | return self.conditions[frozenset(key)]
|
483 | 485 |
|
484 |
| - def __call__(self, values: dict[Domain, float | int], method: str = "cog") -> float | None: |
485 |
| - """Calculate the infered value based on different methods. |
486 |
| - Default is center of gravity (cog). |
| 486 | + def __call__(self, values: dict[Domain, float], method=defuzz.cog) -> float | None: |
| 487 | + """ |
| 488 | + Calculate the inferred crisp value based on the fuzzy rules. |
| 489 | + The 'method' parameter should be one of the static methods from the DefuzzMethod class. |
487 | 490 | """
|
488 |
| - assert isinstance(values, dict), "Please make sure to pass a dict[Domain, float|int] as values." |
489 |
| - assert len(self.conditions) > 0, "No point in having a rule with no conditions, is there?" |
| 491 | + assert isinstance(values, dict), "Please pass a dict[Domain, float|int] as values." |
| 492 | + assert values, "No condition rules defined!" |
| 493 | + |
| 494 | + # Extract common target domain and build list of (then_set, firing_strength) |
| 495 | + sample_then_set = next(iter(self.conditions.values())) |
| 496 | + target_domain = getattr(sample_then_set, "domain", None) |
| 497 | + assert target_domain, "Target domain must be defined." |
| 498 | + |
| 499 | + target_weights: list[tuple[Set, float]] = [] |
| 500 | + for if_sets, then_set in self.conditions.items(): |
| 501 | + assert then_set.domain == target_domain, "All target sets must be in the same Domain." |
| 502 | + degrees = [] |
| 503 | + for s in if_sets: |
| 504 | + assert s.domain is not None, "Domain must be defined for all fuzzy sets." |
| 505 | + degrees.append(s(values[s.domain])) |
| 506 | + firing_strength = min(degrees, default=0) |
| 507 | + if firing_strength > 0: |
| 508 | + target_weights.append((then_set, firing_strength)) |
| 509 | + if not target_weights: |
| 510 | + return None |
| 511 | + |
| 512 | + # For center-of-gravity / centroid: |
| 513 | + if method == defuzz.cog: |
| 514 | + return defuzz.cog(target_weights) |
| 515 | + |
| 516 | + # For methods that rely on an aggregated membership function: |
| 517 | + points = list(target_domain.range) |
| 518 | + n = len(points) |
| 519 | + step = ( |
| 520 | + (target_domain._high - target_domain._low) / (n - 1) |
| 521 | + if n > 1 |
| 522 | + else (target_domain._high - target_domain._low) |
| 523 | + ) |
| 524 | + |
| 525 | + def aggregated_membership(x: float) -> float: |
| 526 | + # For each rule, limit its inferred output by its firing strength and then take the max |
| 527 | + return max(min(weight, then_set(x)) for then_set, weight in target_weights) |
| 528 | + |
490 | 529 | match method:
|
491 |
| - case "cog": |
492 |
| - # iterate over the conditions and calculate the actual values and weights contributing to cog |
493 |
| - target_weights: list[tuple[Set, float]] = [] |
494 |
| - target_domain = list(self.conditions.values())[0].domain |
495 |
| - assert target_domain is not None, "Target domain must be defined." |
496 |
| - for if_sets, then_set in self.conditions.items(): |
497 |
| - actual_values: list[float] = [] |
498 |
| - assert then_set.domain == target_domain, "All target sets must be in the same Domain." |
499 |
| - for s in if_sets: |
500 |
| - assert s.domain is not None, "Domains must be defined." |
501 |
| - actual_values.append(s(values[s.domain])) |
502 |
| - x = min(actual_values, default=0) |
503 |
| - if x > 0: |
504 |
| - target_weights.append((then_set, x)) |
505 |
| - if not target_weights: |
506 |
| - return None |
507 |
| - sum_weights = 0 |
508 |
| - sum_weighted_cogs: float = 0 |
509 |
| - for then_set, weight in target_weights: |
510 |
| - sum_weighted_cogs += then_set.center_of_gravity() * weight |
511 |
| - sum_weights += weight |
512 |
| - index = sum_weighted_cogs / sum_weights |
513 |
| - return (target_domain._high - target_domain._low) / len( # type: ignore |
514 |
| - target_domain.range |
515 |
| - ) * index + target_domain._low # type: ignore |
516 |
| - case "centroid": # centroid == center of mass == center of gravity for simple solids |
517 |
| - raise NotImplementedError("actually the same as 'cog' if densities are uniform.") |
518 |
| - case "bisector": |
519 |
| - raise NotImplementedError("Bisector method not implemented yet.") |
520 |
| - case "mom": |
521 |
| - raise NotImplementedError("Middle of max method not implemented yet.") |
522 |
| - case "som": |
523 |
| - raise NotImplementedError("Smallest of max method not implemented yet.") |
524 |
| - case "lom": |
525 |
| - raise NotImplementedError("Largest of max method not implemented yet.") |
| 530 | + case defuzz.bisector: |
| 531 | + return defuzz.bisector(aggregated_membership, points, step) |
| 532 | + case defuzz.mom: |
| 533 | + return defuzz.mom(aggregated_membership, points) |
| 534 | + case defuzz.som: |
| 535 | + return defuzz.som(aggregated_membership, points) |
| 536 | + case defuzz.lom: |
| 537 | + return defuzz.lom(aggregated_membership, points) |
526 | 538 | case _:
|
527 |
| - raise ValueError("Invalid method.") |
| 539 | + raise ValueError("Invalid defuzzification method specified.") |
528 | 540 |
|
529 | 541 |
|
530 | 542 | def rule_from_table(table: str, references: dict[str, float]) -> Rule:
|
|
0 commit comments