#!/usr/bin/env python
"""Direct local-barrier rules for KMC event rates.
``LocalBarrierModel`` is the lightweight alternative to fitting a local cluster
expansion. It is useful when the migration barrier can be written directly as a
small set of ordered rules:
* use a constant fallback barrier for every hop;
* count selected occupation states in an event local environment;
* count chemical species after mapping occupation states to species labels;
* match a short wildcard occupation pattern; or
* match an exact event/local-occupation entry.
The model works with the same compact KMC state used by the simulator. State
values are nonnegative integer indices. For binary models, ``occupied`` and
``template`` are aliases for state ``0`` and ``vacant`` and ``mismatch`` are
aliases for state ``1``. Multicomponent rules can use explicit state indices
such as ``2`` or ``3``. Rule order is significant: the first matching rule
supplies the selected property, usually ``barrier`` in meV. If no rule matches,
``default_properties`` are used when present. ``compute_probability`` returns an
event rate in Hz using temperature in K and attempt frequency in Hz.
Minimal setup::
from kmcpy.models import LocalBarrierModel
model = LocalBarrierModel.constant_barrier(300.0)
model.to("model.json")
Rule-based setup::
model = LocalBarrierModel(default_barrier=300.0)
model.add_state_count_rule(
name="crowded",
sites="local_env",
state="occupied",
min_count=3,
barrier=450.0,
)
The saved ``model.json`` can be referenced by ``model_file`` in a simulation
configuration. ``BaseModel.from_config`` dispatches to this class when the model
file carries this class' Monty metadata.
"""
from __future__ import annotations
import logging
from typing import Any, Optional, TYPE_CHECKING
import numpy as np
from kmcpy.event import Event, event_direction
from kmcpy.models.base import BaseModel, require_model_type
from kmcpy.simulator.state import State
from kmcpy.units import BOLTZMANN_CONSTANT_MEV_PER_K as K_B_MEV_PER_K
if TYPE_CHECKING:
from kmcpy.simulator.config import RuntimeConfig
logger = logging.getLogger(__name__)
STATE_VALUE_ALIASES = {
"occupied": 0,
"match": 0,
"template": 0,
"vacant": 1,
"vacancy": 1,
"mismatch": 1,
"other": 1,
}
SITE_SELECTORS = {
"canonical",
"local_env",
"mobile_ion",
"from",
"to",
"all",
}
RULE_TYPES = {
"constant",
"exact",
"pattern",
"state_count",
"species_count",
}
COUNT_KEYS = ("count", "min_count", "max_count")
def _normalize_index_sequence(
values: Any, field_name: str, allow_empty: bool = False
) -> tuple[int, ...]:
if not isinstance(values, (list, tuple)):
raise TypeError(f"'{field_name}' must be a list or tuple of integers")
if not values and not allow_empty:
raise ValueError(f"'{field_name}' must be non-empty")
normalized: list[int] = []
for value in values:
if isinstance(value, bool) or not isinstance(value, int):
raise TypeError(f"'{field_name}' must contain integers only")
normalized.append(int(value))
return tuple(normalized)
def _canonical_site_indices(
mobile_ion_indices: tuple[int, ...], local_env_indices: tuple[int, ...]
) -> tuple[int, ...]:
canonical: list[int] = []
seen: set[int] = set()
for site_index in mobile_ion_indices + local_env_indices:
if site_index in seen:
continue
seen.add(site_index)
canonical.append(site_index)
return tuple(canonical)
def _event_indices(event: Event) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
mobile_ion_indices = _normalize_index_sequence(
event.mobile_ion_indices,
"event.mobile_ion_indices",
)
local_env_indices = _normalize_index_sequence(
event.local_env_indices,
"event.local_env_indices",
allow_empty=True,
)
canonical_sites = _canonical_site_indices(mobile_ion_indices, local_env_indices)
return mobile_ion_indices, local_env_indices, canonical_sites
def _normalize_state_value(value: Any, field_name: str = "state") -> int:
message = (
f"'{field_name}' must be a nonnegative integer state index or one of "
f"{sorted(STATE_VALUE_ALIASES)}"
)
if isinstance(value, bool):
raise TypeError(message)
if isinstance(value, int):
if value < 0:
raise ValueError(message)
return int(value)
if isinstance(value, str):
token = value.strip().lower()
if token in STATE_VALUE_ALIASES:
return STATE_VALUE_ALIASES[token]
try:
state_value = int(token)
except ValueError:
raise ValueError(message) from None
if state_value < 0:
raise ValueError(message)
return state_value
raise ValueError(message)
def _normalize_occupations(values: Any, field_name: str = "occupations") -> tuple[int, ...]:
if not isinstance(values, (list, tuple)):
raise TypeError(f"'{field_name}' must be a list or tuple")
if not values:
raise ValueError(f"'{field_name}' must be non-empty")
return tuple(_normalize_state_value(value, field_name) for value in values)
def _normalize_pattern(values: Any) -> tuple[int | str, ...]:
if not isinstance(values, (list, tuple)):
raise TypeError("'pattern' must be a list or tuple")
if not values:
raise ValueError("'pattern' must be non-empty")
pattern: list[int | str] = []
for value in values:
if isinstance(value, str) and value.strip() == "*":
pattern.append("*")
else:
pattern.append(_normalize_state_value(value, "pattern"))
return tuple(pattern)
def _normalize_properties(properties: Any) -> dict[str, float]:
if not isinstance(properties, dict) or not properties:
raise ValueError("'properties' must be a non-empty object")
normalized: dict[str, float] = {}
for key, value in properties.items():
if not isinstance(key, str) or not key.strip():
raise ValueError("Property names must be non-empty strings")
if isinstance(value, bool) or not isinstance(value, (int, float)):
raise TypeError(f"Property '{key}' must be a numeric value")
normalized[key] = float(value)
return normalized
def _normalize_optional_properties(properties: Any) -> dict[str, float] | None:
if properties is None:
return None
return _normalize_properties(properties)
def _properties_from_rule(rule: dict[str, Any]) -> dict[str, float]:
properties = dict(rule.get("properties", {}))
if "barrier" in rule:
barrier = rule["barrier"]
if isinstance(barrier, bool) or not isinstance(barrier, (int, float)):
raise TypeError("'barrier' must be numeric")
properties.setdefault("barrier", float(barrier))
return _normalize_properties(properties)
def _normalize_sites_spec(value: Any, default: str) -> str | tuple[int, ...]:
if value is None:
return default
if isinstance(value, str):
token = value.strip()
if token not in SITE_SELECTORS:
raise ValueError(
f"Unsupported sites selector '{value}'. "
f"Supported selectors: {sorted(SITE_SELECTORS)}"
)
return token
return _normalize_index_sequence(value, "sites")
def _normalize_count_constraints(rule: dict[str, Any]) -> dict[str, int]:
constraints: dict[str, int] = {}
for key in COUNT_KEYS:
if key not in rule:
continue
value = rule[key]
if isinstance(value, bool) or not isinstance(value, int):
raise TypeError(f"'{key}' must be an integer")
if value < 0:
raise ValueError(f"'{key}' must be non-negative")
constraints[key] = int(value)
if not constraints:
raise ValueError(
"Count rules must provide at least one of "
"'count', 'min_count', or 'max_count'"
)
if "count" in constraints and (
"min_count" in constraints or "max_count" in constraints
):
raise ValueError("'count' cannot be combined with min_count or max_count")
return constraints
def _count_matches(count: int, constraints: dict[str, int]) -> bool:
if "count" in constraints and count != constraints["count"]:
return False
if "min_count" in constraints and count < constraints["min_count"]:
return False
if "max_count" in constraints and count > constraints["max_count"]:
return False
return True
def _normalize_site_species(site_species: Any) -> dict[int, dict[int, str]]:
if site_species is None:
return {}
if not isinstance(site_species, dict):
raise TypeError("'site_species' must be a mapping")
normalized: dict[int, dict[int, str]] = {}
for site_key, state_mapping in site_species.items():
try:
site_index = int(site_key)
except (TypeError, ValueError) as exc:
raise TypeError("site_species keys must be site indices") from exc
if not isinstance(state_mapping, dict):
raise TypeError(
"site_species values must map occupation states to species strings"
)
normalized_state_mapping: dict[int, str] = {}
for state_key, species in state_mapping.items():
state_value = _normalize_state_value(state_key, "site_species state")
if not isinstance(species, str) or not species.strip():
raise ValueError("site_species species values must be non-empty strings")
normalized_state_mapping[state_value] = species
normalized[site_index] = normalized_state_mapping
return normalized
def _site_species_as_dict(site_species: dict[int, dict[int, str]]) -> dict[str, dict[str, str]]:
return {
str(site_index): {
str(state_value): species
for state_value, species in state_mapping.items()
}
for site_index, state_mapping in site_species.items()
}
def _coerce_rule_type(rule: dict[str, Any]) -> str:
rule_type = rule.get("type")
if rule_type is None:
if "occupations" in rule and "mobile_ion_indices" in rule:
return "exact"
if "pattern" in rule:
return "pattern"
if "species" in rule:
return "species_count"
if "state" in rule or "occupation" in rule:
return "state_count"
return "constant"
if not isinstance(rule_type, str) or not rule_type.strip():
raise ValueError("Rule 'type' must be a non-empty string")
rule_type = rule_type.strip()
if rule_type not in RULE_TYPES:
raise ValueError(
f"Unsupported local barrier rule type '{rule_type}'. "
f"Supported types: {sorted(RULE_TYPES)}"
)
return rule_type
def _normalize_rule(rule: dict[str, Any], default_name: str) -> dict[str, Any]:
if not isinstance(rule, dict):
raise TypeError("Each local barrier rule must be a dictionary")
rule_type = _coerce_rule_type(rule)
normalized: dict[str, Any] = {
"name": str(rule.get("name") or default_name),
"type": rule_type,
"properties": _properties_from_rule(rule),
}
if "mobile_ion_indices" in rule:
normalized["mobile_ion_indices"] = _normalize_index_sequence(
rule["mobile_ion_indices"], "mobile_ion_indices"
)
if "local_env_indices" in rule:
normalized["local_env_indices"] = _normalize_index_sequence(
rule["local_env_indices"], "local_env_indices", allow_empty=True
)
if rule_type == "exact":
mobile_ion_indices = _normalize_index_sequence(
rule.get("mobile_ion_indices"), "mobile_ion_indices"
)
local_env_indices = _normalize_index_sequence(
rule.get("local_env_indices"),
"local_env_indices",
allow_empty=True,
)
canonical_sites = _canonical_site_indices(mobile_ion_indices, local_env_indices)
occupations = _normalize_occupations(rule.get("occupations"))
if len(occupations) != len(canonical_sites):
raise ValueError(
"Exact rule occupation length must match canonical site count "
f"({len(canonical_sites)}), got {len(occupations)}"
)
normalized["mobile_ion_indices"] = mobile_ion_indices
normalized["local_env_indices"] = local_env_indices
normalized["occupations"] = occupations
normalized["canonical_site_indices"] = canonical_sites
elif rule_type == "pattern":
normalized["sites"] = _normalize_sites_spec(rule.get("sites"), "canonical")
normalized["pattern"] = _normalize_pattern(rule.get("pattern"))
elif rule_type == "state_count":
normalized["sites"] = _normalize_sites_spec(rule.get("sites"), "local_env")
if "occupation" in rule:
normalized["state"] = _normalize_state_value(rule["occupation"], "occupation")
else:
normalized["state"] = _normalize_state_value(rule.get("state"), "state")
normalized.update(_normalize_count_constraints(rule))
elif rule_type == "species_count":
normalized["sites"] = _normalize_sites_spec(rule.get("sites"), "local_env")
species = rule.get("species")
if isinstance(species, str):
normalized["species"] = (species,)
elif isinstance(species, (list, tuple)) and species:
normalized_species = []
for item in species:
if not isinstance(item, str) or not item.strip():
raise ValueError("'species' entries must be non-empty strings")
normalized_species.append(item)
normalized["species"] = tuple(normalized_species)
else:
raise ValueError("'species' must be a string or non-empty list")
normalized.update(_normalize_count_constraints(rule))
return normalized
def _rule_as_dict(rule: dict[str, Any]) -> dict[str, Any]:
payload: dict[str, Any] = {
"name": rule["name"],
"type": rule["type"],
"properties": dict(rule["properties"]),
}
for key in ("mobile_ion_indices", "local_env_indices", "occupations", "pattern"):
if key in rule:
payload[key] = list(rule[key])
if "sites" in rule:
sites = rule["sites"]
payload["sites"] = list(sites) if isinstance(sites, tuple) else sites
if "state" in rule:
payload["state"] = int(rule["state"])
if "species" in rule:
species = rule["species"]
payload["species"] = species[0] if len(species) == 1 else list(species)
for key in COUNT_KEYS:
if key in rule:
payload[key] = int(rule[key])
return payload
[docs]
class LocalBarrierModel(BaseModel):
"""
Choose migration barriers from ordered local-environment rules.
``LocalBarrierModel`` stores a list of simple rule dictionaries and evaluates
them against a ``State`` and ``Event``. Each rule returns a dictionary of
numeric properties; by default ``compute`` returns the ``barrier`` property.
``compute_probability`` then evaluates the Arrhenius rate when the current
endpoint states match a mobile-vacancy hop.
Parameters:
rules: Ordered rule dictionaries. The first matching rule is used.
name: Human-readable model name.
default_properties: Property dictionary used when no rule matches.
default_barrier: Shortcut for ``default_properties={"barrier": value}``.
default_property: Property returned by ``compute`` when
``property_name`` is not supplied.
probability_mode: Probability calculation mode. Currently only
``"barrier_arrhenius"`` is supported.
probability_property: Property used as the barrier in
``compute_probability``.
site_species: Mapping used by ``species_count`` rules. The shape is
``{site_index: {occupation_state: species}}``. For example,
``{10: {0: "P", 1: "Si", 2: "Al"}}`` means site 10 is counted as
P, Si, or Al depending on its current state index.
Supported rule types:
``constant``
Always matches and returns its properties. In most cases,
``default_barrier`` is clearer than an explicit constant rule.
``state_count``
Counts how many selected sites have the requested state index.
Binary aliases such as ``occupied``/``0`` and ``vacant``/``1`` are
accepted for convenience.
``species_count``
Counts species labels after applying ``site_species``.
``pattern``
Matches selected occupations against a pattern containing
nonnegative state indices, binary state aliases, or ``"*"``
wildcards.
``exact``
Matches a specific event and exact occupation vector. This is the
direct replacement for catalog-style local-environment tables.
Site selectors:
Rules can use ``sites="local_env"``, ``"mobile_ion"``,
``"canonical"``, ``"from"``, ``"to"``, ``"all"``, or an explicit list
of active-site indices. ``canonical`` means ``event.mobile_ion_indices``
followed by ``event.local_env_indices`` with duplicates removed.
Examples:
Constant barrier::
model = LocalBarrierModel.constant_barrier(300.0)
At least three occupied local-environment sites::
model = LocalBarrierModel(default_barrier=300.0)
model.add_state_count_rule(
name="crowded",
sites="local_env",
state="occupied",
min_count=3,
barrier=450.0,
)
More than three Si sites in the local environment::
model = LocalBarrierModel(
default_barrier=300.0,
site_species={
1: {0: "P", 1: "Si"},
2: {0: "Si", 1: "P"},
3: {0: "Si", 1: "P"},
4: {0: "Al", 1: "Si"},
},
)
model.add_species_count_rule(
name="si_rich",
sites="local_env",
species="Si",
min_count=4,
barrier=420.0,
)
Exact event/local-environment match::
model = LocalBarrierModel.from_exact_entries([
{
"mobile_ion_indices": [0, 1],
"local_env_indices": [1, 2, 3],
"occupations": [1, 0, 1, 0],
"properties": {"barrier": 250.0},
}
])
"""
MODEL_TYPE = "local_barrier"
PAYLOAD_KEY = "local_barrier"
SUPPORTED_PROBABILITY_MODE = "barrier_arrhenius"
BOLTZMANN_CONSTANT_MEV_PER_K = K_B_MEV_PER_K
def __init__(
self,
rules: Optional[list[dict[str, Any]]] = None,
name: str = "LocalBarrierModel",
default_properties: Optional[dict[str, float]] = None,
default_barrier: Optional[float] = None,
default_property: str = "barrier",
probability_mode: str = SUPPORTED_PROBABILITY_MODE,
probability_property: str = "barrier",
site_species: Optional[dict[Any, Any]] = None,
) -> None:
super().__init__(name=name)
self.name = name
self.default_property = default_property
self.probability_mode = probability_mode
self.probability_property = probability_property
self.site_species = _normalize_site_species(site_species)
self.default_properties = _normalize_optional_properties(default_properties)
if default_barrier is not None:
if isinstance(default_barrier, bool) or not isinstance(
default_barrier, (int, float)
):
raise TypeError("'default_barrier' must be numeric")
if self.default_properties is None:
self.default_properties = {}
self.default_properties.setdefault("barrier", float(default_barrier))
self.rules: list[dict[str, Any]] = []
self._exact_rule_keys: set[
tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]
] = set()
if rules is not None:
self.build(rules=rules)
[docs]
def fit(self, *args, **kwargs):
"""Local barrier rules are defined explicitly, not fitted."""
raise NotImplementedError(
"LocalBarrierModel does not support fit(). Provide rules or default_properties instead."
)
[docs]
@classmethod
def constant_barrier(
cls,
barrier: float,
name: str = "ConstantBarrierModel",
**kwargs,
) -> "LocalBarrierModel":
"""Construct a model that returns the same barrier for every event.
This is the simplest setup for smoke tests, toy simulations, or models
where all event rates share one activation barrier. The returned model
has no rules; the barrier is stored in ``default_properties``.
"""
return cls(name=name, default_barrier=barrier, **kwargs)
[docs]
@classmethod
def entry_from_event_state(
cls,
event: Event,
state: State,
properties: dict[str, float],
name: Optional[str] = None,
) -> dict[str, Any]:
"""Build an exact-match rule from a runtime event and state snapshot.
The occupation vector is sampled in canonical event-site order:
``mobile_ion_indices`` first, then ``local_env_indices`` with duplicate
site indices removed. Use this helper when turning a known event/state
snapshot into an exact rule without manually constructing the
occupation list.
"""
mobile_ion_indices, local_env_indices, canonical_sites = _event_indices(event)
try:
occupations = [
int(state.occupations[site_index]) for site_index in canonical_sites
]
except IndexError as exc:
raise IndexError(
"Event site index is out of range for provided simulation occupations"
) from exc
entry = {
"type": "exact",
"mobile_ion_indices": list(mobile_ion_indices),
"local_env_indices": list(local_env_indices),
"occupations": occupations,
"properties": dict(properties),
}
if name is not None:
entry["name"] = name
return entry
[docs]
@classmethod
def from_exact_entries(
cls,
entries: list[dict[str, Any] | Any],
name: str = "LocalBarrierModel",
default_properties: Optional[dict[str, float]] = None,
default_barrier: Optional[float] = None,
default_property: str = "barrier",
probability_mode: str = SUPPORTED_PROBABILITY_MODE,
probability_property: str = "barrier",
site_species: Optional[dict[Any, Any]] = None,
) -> "LocalBarrierModel":
"""Construct from exact event/local-occupation entries.
Each entry must contain ``mobile_ion_indices``, ``local_env_indices``,
``occupations``, and ``properties``. The ``occupations`` list is in
canonical site order: mobile-ion sites first, then local-environment
sites with duplicates removed. Duplicate exact entries are rejected.
"""
rules: list[dict[str, Any]] = []
for index, entry in enumerate(entries):
payload = entry.as_dict() if hasattr(entry, "as_dict") else dict(entry)
payload["type"] = "exact"
payload.setdefault("name", f"exact_{index}")
rules.append(payload)
return cls(
rules=rules,
name=name,
default_properties=default_properties,
default_barrier=default_barrier,
default_property=default_property,
probability_mode=probability_mode,
probability_property=probability_property,
site_species=site_species,
)
def _validate_probability_mode(self) -> None:
if self.probability_mode != self.SUPPORTED_PROBABILITY_MODE:
raise ValueError(
f"Unsupported probability mode '{self.probability_mode}' for LocalBarrierModel"
)
def _exact_key_for_rule(
self, rule: dict[str, Any]
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
return (
rule["mobile_ion_indices"],
rule["canonical_site_indices"],
rule["occupations"],
)
[docs]
def add_rule(self, rule: dict[str, Any]) -> None:
"""Add one normalized local barrier rule to the ordered rule list."""
normalized = _normalize_rule(rule, default_name=f"rule_{len(self.rules)}")
if normalized["type"] == "exact":
exact_key = self._exact_key_for_rule(normalized)
if exact_key in self._exact_rule_keys:
raise ValueError(
"Duplicate exact local-barrier rule detected: "
f"mobile_ion_indices={normalized['mobile_ion_indices']}, "
f"canonical_sites={normalized['canonical_site_indices']}, "
f"occupations={normalized['occupations']}"
)
self._exact_rule_keys.add(exact_key)
self.rules.append(normalized)
[docs]
def add_exact_rule(
self,
mobile_ion_indices: list[int] | tuple[int, ...],
local_env_indices: list[int] | tuple[int, ...],
occupations: list[int] | tuple[int, ...],
barrier: Optional[float] = None,
properties: Optional[dict[str, float]] = None,
name: Optional[str] = None,
) -> str:
"""Add an event-specific exact occupation rule.
``occupations`` must follow canonical site order for the supplied
``mobile_ion_indices`` and ``local_env_indices``. Use this rule type
when the barrier is known only for one exact event/environment pattern.
"""
rule = {
"type": "exact",
"mobile_ion_indices": list(mobile_ion_indices),
"local_env_indices": list(local_env_indices),
"occupations": list(occupations),
"properties": dict(properties or {}),
}
if barrier is not None:
rule["barrier"] = barrier
if name is not None:
rule["name"] = name
self.add_rule(rule)
return self.rules[-1]["name"]
[docs]
def add_state_count_rule(
self,
state: str | int,
barrier: Optional[float] = None,
properties: Optional[dict[str, float]] = None,
name: Optional[str] = None,
sites: str | list[int] | tuple[int, ...] = "local_env",
count: Optional[int] = None,
min_count: Optional[int] = None,
max_count: Optional[int] = None,
) -> str:
"""Add a rule based on the number of sites in an occupation state.
``state`` accepts any nonnegative integer state index. The binary
aliases ``"occupied"``/``0`` and ``"vacant"``/``1`` are also accepted.
Supply exactly one of ``count`` or a ``min_count``/``max_count`` range.
"""
rule = {
"type": "state_count",
"sites": sites,
"state": state,
"properties": dict(properties or {}),
}
if barrier is not None:
rule["barrier"] = barrier
if name is not None:
rule["name"] = name
if count is not None:
rule["count"] = count
if min_count is not None:
rule["min_count"] = min_count
if max_count is not None:
rule["max_count"] = max_count
self.add_rule(rule)
return self.rules[-1]["name"]
[docs]
def add_species_count_rule(
self,
species: str | list[str] | tuple[str, ...],
barrier: Optional[float] = None,
properties: Optional[dict[str, float]] = None,
name: Optional[str] = None,
sites: str | list[int] | tuple[int, ...] = "local_env",
count: Optional[int] = None,
min_count: Optional[int] = None,
max_count: Optional[int] = None,
) -> str:
"""Add a rule based on the number of sites currently carrying a species.
Species labels are looked up from ``site_species`` using each selected
site index and current occupation value. This is appropriate for rules
such as "use a higher barrier when at least four selected sites are Si".
"""
rule = {
"type": "species_count",
"sites": sites,
"species": species,
"properties": dict(properties or {}),
}
if barrier is not None:
rule["barrier"] = barrier
if name is not None:
rule["name"] = name
if count is not None:
rule["count"] = count
if min_count is not None:
rule["min_count"] = min_count
if max_count is not None:
rule["max_count"] = max_count
self.add_rule(rule)
return self.rules[-1]["name"]
[docs]
def add_pattern_rule(
self,
pattern: list[int | str] | tuple[int | str, ...],
barrier: Optional[float] = None,
properties: Optional[dict[str, float]] = None,
name: Optional[str] = None,
sites: str | list[int] | tuple[int, ...] = "canonical",
) -> str:
"""Add a wildcard occupation pattern rule.
Patterns can contain nonnegative integer state indices, binary state
aliases, or ``"*"`` wildcards. The pattern length must match the number
of selected sites.
"""
rule = {
"type": "pattern",
"sites": sites,
"pattern": list(pattern),
"properties": dict(properties or {}),
}
if barrier is not None:
rule["barrier"] = barrier
if name is not None:
rule["name"] = name
self.add_rule(rule)
return self.rules[-1]["name"]
[docs]
def build(
self,
rules: Optional[list[dict[str, Any]]] = None,
default_properties: Optional[dict[str, float]] = None,
default_barrier: Optional[float] = None,
default_property: Optional[str] = None,
probability_mode: Optional[str] = None,
probability_property: Optional[str] = None,
site_species: Optional[dict[Any, Any]] = None,
) -> None:
"""Replace this model's rule table."""
if default_properties is not None:
self.default_properties = _normalize_properties(default_properties)
if default_barrier is not None:
if self.default_properties is None:
self.default_properties = {}
self.default_properties.setdefault("barrier", float(default_barrier))
if default_property is not None:
self.default_property = default_property
if probability_mode is not None:
self.probability_mode = probability_mode
if probability_property is not None:
self.probability_property = probability_property
if site_species is not None:
self.site_species = _normalize_site_species(site_species)
self._validate_probability_mode()
if rules is None:
rules = []
if not isinstance(rules, list):
raise TypeError("'rules' must be a list")
self.rules = []
self._exact_rule_keys = set()
for rule in rules:
self.add_rule(rule)
def _selected_sites(
self,
sites: str | tuple[int, ...],
event: Event,
occupations: list[int],
) -> tuple[int, ...]:
mobile_ion_indices, local_env_indices, canonical_sites = _event_indices(event)
if isinstance(sites, tuple):
return sites
if sites == "canonical":
return canonical_sites
if sites == "local_env":
return local_env_indices
if sites == "mobile_ion":
return mobile_ion_indices
if sites == "from":
return (mobile_ion_indices[0],)
if sites == "to":
return (mobile_ion_indices[1],)
if sites == "all":
return tuple(range(len(occupations)))
raise ValueError(f"Unsupported sites selector '{sites}'")
def _occupation_pattern(
self, sites: tuple[int, ...], occupations: list[int]
) -> tuple[int, ...]:
try:
return tuple(int(occupations[site_index]) for site_index in sites)
except IndexError as exc:
raise IndexError(
"Rule site index is out of range for provided simulation occupations"
) from exc
def _event_constraints_match(self, rule: dict[str, Any], event: Event) -> bool:
mobile_ion_indices, local_env_indices, _ = _event_indices(event)
if (
"mobile_ion_indices" in rule
and rule["mobile_ion_indices"] != mobile_ion_indices
):
return False
if (
"local_env_indices" in rule
and rule["local_env_indices"] != local_env_indices
):
return False
return True
def _species_for_site(self, site_index: int, occupation: int) -> str:
if site_index not in self.site_species:
raise ValueError(
"species_count rules require site_species for every counted site; "
f"missing site {site_index}"
)
state_mapping = self.site_species[site_index]
if occupation not in state_mapping:
raise ValueError(
"species_count rules require site_species entries for all "
f"occupation states; missing site {site_index}, state {occupation}"
)
return state_mapping[occupation]
def _rule_matches(self, rule: dict[str, Any], simulation_state: State, event: Event) -> bool:
if not self._event_constraints_match(rule, event):
return False
occupations = simulation_state.occupations
rule_type = rule["type"]
if rule_type == "constant":
return True
if rule_type == "exact":
_, _, canonical_sites = _event_indices(event)
if canonical_sites != rule["canonical_site_indices"]:
return False
return (
self._occupation_pattern(canonical_sites, occupations)
== rule["occupations"]
)
if rule_type == "pattern":
sites = self._selected_sites(rule["sites"], event, occupations)
current_pattern = self._occupation_pattern(sites, occupations)
expected_pattern = rule["pattern"]
if len(current_pattern) != len(expected_pattern):
raise ValueError(
f"Pattern rule '{rule['name']}' has length {len(expected_pattern)} "
f"but selected {len(current_pattern)} sites"
)
return all(
expected == "*" or expected == actual
for expected, actual in zip(expected_pattern, current_pattern)
)
if rule_type == "state_count":
sites = self._selected_sites(rule["sites"], event, occupations)
current_pattern = self._occupation_pattern(sites, occupations)
count = sum(1 for value in current_pattern if value == rule["state"])
return _count_matches(count, rule)
if rule_type == "species_count":
species_to_count = set(rule["species"])
sites = self._selected_sites(rule["sites"], event, occupations)
current_pattern = self._occupation_pattern(sites, occupations)
count = sum(
1
for site_index, occupation in zip(sites, current_pattern)
if self._species_for_site(site_index, occupation) in species_to_count
)
return _count_matches(count, rule)
raise ValueError(f"Unsupported rule type '{rule_type}'")
def _matched_properties(self, simulation_state: State, event: Event) -> dict[str, float]:
if simulation_state is None:
raise ValueError("simulation_state is required")
if event is None:
raise ValueError("event is required")
for rule in self.rules:
if self._rule_matches(rule, simulation_state, event):
return rule["properties"]
if self.default_properties is not None:
return self.default_properties
_, _, canonical_sites = _event_indices(event)
occupation_pattern = self._occupation_pattern(
canonical_sites, simulation_state.occupations
)
raise KeyError(
"No local barrier rule matched and no default_properties were provided: "
f"mobile_ion_indices={tuple(event.mobile_ion_indices)}, "
f"local_env_indices={tuple(event.local_env_indices)}, "
f"canonical_sites={canonical_sites}, "
f"occupations={occupation_pattern}"
)
[docs]
def compute(
self,
simulation_state: State,
event: Event,
property_name: Optional[str] = None,
) -> float:
"""Compute a barrier/property value by local rule matching."""
properties = self._matched_properties(
simulation_state=simulation_state, event=event
)
selected_property = property_name or self.default_property
if selected_property not in properties:
raise KeyError(
f"Property '{selected_property}' not found in matched local barrier rule"
)
return float(properties[selected_property])
[docs]
def compute_probability(
self,
event: Event,
runtime_config: "RuntimeConfig",
simulation_state: State,
) -> float:
"""Compute event rate in Hz from a selected meV barrier."""
self._validate_probability_mode()
barrier = self.compute(
simulation_state=simulation_state,
event=event,
property_name=self.probability_property,
)
occupations = simulation_state.occupations
hop_factor = 1.0 if event_direction(occupations, event) != 0 else 0.0
temperature = runtime_config.temperature
attempt_frequency = runtime_config.attempt_frequency
probability = hop_factor * attempt_frequency * np.exp(
-barrier / (self.BOLTZMANN_CONSTANT_MEV_PER_K * temperature)
)
return float(probability)
def __str__(self) -> str:
return (
f"LocalBarrierModel(name={self.name}, rules={len(self.rules)}, "
f"default_property={self.default_property}, "
f"probability_property={self.probability_property})"
)
def __repr__(self) -> str:
return (
"LocalBarrierModel("
f"name={self.name!r}, rules={len(self.rules)}, "
f"default_property={self.default_property!r}, "
f"probability_mode={self.probability_mode!r}, "
f"probability_property={self.probability_property!r})"
)
[docs]
def as_dict(self) -> dict[str, Any]:
"""Serialize model payload."""
return {
"@module": self.__class__.__module__,
"@class": self.__class__.__name__,
"name": self.name,
"default_property": self.default_property,
"probability_mode": self.probability_mode,
"probability_property": self.probability_property,
"default_properties": (
dict(self.default_properties)
if self.default_properties is not None
else None
),
"site_species": _site_species_as_dict(self.site_species),
"rules": [_rule_as_dict(rule) for rule in self.rules],
}
[docs]
def to(self, filename: str, indent: int = 2) -> None:
"""Write this local barrier model."""
from monty.serialization import dumpfn
dumpfn(self.as_dict(), filename, indent=indent)
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "LocalBarrierModel":
"""Deserialize from in-memory payload."""
if not isinstance(data, dict):
raise ValueError("LocalBarrierModel payload must be a JSON object")
if data.get("model_type") == cls.MODEL_TYPE and cls.PAYLOAD_KEY in data:
data = data[cls.PAYLOAD_KEY]
return cls(
rules=data.get("rules", []),
name=data.get("name", "LocalBarrierModel"),
default_properties=data.get("default_properties"),
default_barrier=data.get("default_barrier"),
default_property=data.get("default_property", "barrier"),
probability_mode=data.get(
"probability_mode", cls.SUPPORTED_PROBABILITY_MODE
),
probability_property=data.get("probability_property", "barrier"),
site_species=data.get("site_species"),
)
[docs]
@classmethod
def from_file(cls, filename: str) -> "LocalBarrierModel":
"""Load from a model file or direct model payload."""
from monty.serialization import loadfn
payload = loadfn(filename, cls=None)
if isinstance(payload, dict) and "filetype" in payload:
payload = require_model_type(payload, cls.MODEL_TYPE).get(cls.PAYLOAD_KEY)
if not isinstance(payload, dict):
raise ValueError(
"Local barrier model file is missing object key "
f"'{cls.PAYLOAD_KEY}'"
)
return cls.from_dict(payload)