#!/usr/bin/env python
"""Tracker for monitoring mobile ion trajectories during kMC simulations."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable, Optional
import numpy as np
from pymatgen.core import Structure
from kmcpy.simulator.property import (
BUILTIN_PROPERTY_FIELDS,
BUILTIN_PROPERTY_UNITS,
PropertyPlan,
PropertyRecord,
PropertySpec,
append_record,
compute_transport_properties,
describe_property_calculations,
make_property_spec,
set_property_enabled_flag,
should_trigger,
validate_schedule,
)
from kmcpy.simulator.results import format_tracker_summary, write_tracker_results
from kmcpy.event import INVALID_STATE, event_direction
if TYPE_CHECKING:
from kmcpy.simulator.config import Configuration
from kmcpy.simulator.state import State
logger = logging.getLogger(__name__)
RESULT_FIELDS = (
"time",
"jump_diffusivity",
"tracer_diffusivity",
"conductivity",
"correlation_factor",
"havens_ratio",
"msd",
)
RESULT_UNITS = {
"time": "s",
**BUILTIN_PROPERTY_UNITS,
}
[docs]
class CallbackExecutionError(RuntimeError):
"""Raised when an attached property callback fails and cannot be recovered."""
def _create_result_store() -> dict[str, list[float]]:
"""Allocate empty storage lists for built-in summary fields."""
return {field: [] for field in RESULT_FIELDS}
def _append_result_row(store: dict[str, list[float]], sim_time: float, metrics: dict[str, float]) -> None:
"""Append one built-in summary sample to the result table."""
store["time"].append(sim_time)
store["jump_diffusivity"].append(metrics["jump_diffusivity"])
store["tracer_diffusivity"].append(metrics["tracer_diffusivity"])
store["conductivity"].append(metrics["conductivity"])
store["correlation_factor"].append(metrics["correlation_factor"])
store["havens_ratio"].append(metrics["havens_ratio"])
store["msd"].append(metrics["msd"])
[docs]
class Tracker:
"""Track trajectories and evaluate attached properties for each sampling point.
Built-in result units are available through ``Tracker.result_units`` and are
also written next to the result CSV by ``write_results``.
"""
def __init__(
self,
config: "Configuration",
structure: Structure,
initial_state: Optional["State"] = None,
property_plan: Optional[PropertyPlan] = None,
default_property_interval: Optional[int] = None,
hop_state_lookup: Any = None,
) -> None:
"""Initialize tracker state, trajectory arrays, and built-in sampling."""
logger.info("Initializing Tracker ...")
if initial_state is None:
raise ValueError("State must be provided to Tracker")
self.config = config
self.structure = structure
self.state = initial_state
self.hop_state_lookup = hop_state_lookup
self._initialize_mobile_ion_tracking(initial_state.occupations)
self.results = _create_result_store()
self.current_pass = 0
self._global_interval: Optional[int] = 1
self._global_time_interval: Optional[float] = None
self._enabled_builtin_properties = {
name: True for name in BUILTIN_PROPERTY_FIELDS
}
self._properties: dict[str, PropertySpec] = {}
self._property_records: dict[str, list[PropertyRecord]] = {}
self._last_summary_trigger_time: Optional[float] = None
if property_plan is not None:
self.apply_property_plan(
property_plan,
default_interval=default_property_interval,
)
logger.info("number of mobile ion specie = %d", self.n_mobile_ion_specie)
logger.info(
"Center of mass (%s): %s",
self.mobile_ion_specie,
np.mean(self.r0, axis=0),
)
def _initialize_mobile_ion_tracking(self, initial_occ: list[int]) -> None:
"""Initialize mobile ion tracking arrays."""
if self.hop_state_lookup is None:
self.n_mobile_ion_specie_site = len(
[el.symbol for el in self.structure.species if self.mobile_ion_specie in el.symbol]
)
initial_active_occ = np.array(initial_occ[0 : self.n_mobile_ion_specie_site])
mobile_state_mask = initial_active_occ == 0
else:
mobile_states = self.hop_state_lookup.mobile_state_by_site
self.n_mobile_ion_specie_site = int(np.sum(mobile_states != INVALID_STATE))
initial_active_occ = np.array(initial_occ[0 : len(mobile_states)])
mobile_state_mask = (
(mobile_states != INVALID_STATE)
& (initial_active_occ == mobile_states)
)
self.mobile_ion_specie_locations = np.where(mobile_state_mask)[0]
self.n_mobile_ion_specie = len(self.mobile_ion_specie_locations)
logger.debug("Initial mobile ion locations = %s", self.mobile_ion_specie_locations)
self.displacement = np.zeros((self.n_mobile_ion_specie, 3))
self.hop_counter = np.zeros(self.n_mobile_ion_specie, dtype=np.int64)
self.r0 = self.frac_coords[self.mobile_ion_specie_locations] @ self.latt.matrix
@property
def occ_initial(self) -> list:
"""Return current occupations from the shared simulation state."""
return self.state.occupations
@property
def frac_coords(self) -> np.ndarray:
"""Return structure fractional coordinates."""
return self.structure.frac_coords
@property
def latt(self):
"""Return structure lattice."""
return self.structure.lattice
@property
def volume(self) -> float:
"""Return structure volume."""
return self.structure.volume
@property
def dimension(self) -> int:
"""Return simulation dimensionality."""
return self.config.dimension
@property
def q(self) -> float:
"""Return mobile ion charge."""
return self.config.mobile_ion_charge
@property
def elem_hop_distance(self) -> float:
"""Return elementary hop distance."""
return self.config.elementary_hop_distance
@property
def temperature(self) -> float:
"""Return simulation temperature."""
return self.config.temperature
@property
def v(self) -> float:
"""Return attempt frequency."""
return self.config.attempt_frequency
@property
def time(self) -> float:
"""Return current simulation time."""
return self.state.time
@property
def mobile_ion_specie(self) -> str:
"""Return tracked mobile ion species label."""
return self.config.mobile_ion_specie
@property
def result_units(self) -> dict[str, str]:
"""Return units for built-in result fields."""
return dict(RESULT_UNITS)
[docs]
def set_global_property_frequency(
self,
interval: Optional[int] = None,
time_interval: Optional[float] = None,
) -> None:
"""Set global sampling defaults for all attached properties."""
validate_schedule(interval=interval, time_interval=time_interval)
self._global_interval = interval
self._global_time_interval = time_interval
[docs]
def apply_property_plan(
self,
property_plan: PropertyPlan,
default_interval: Optional[int] = None,
) -> None:
"""Apply a property sampling recipe to this tracker."""
interval = property_plan.global_interval
time_interval = property_plan.global_time_interval
if interval is None and time_interval is None and default_interval is not None:
interval = default_interval
self.set_global_property_frequency(
interval=interval,
time_interval=time_interval,
)
for property_name, enabled in property_plan.builtin_enabled.items():
self.set_property_enabled(property_name, enabled)
for spec in property_plan.fresh_attachment_specs():
self.attach_spec(spec)
[docs]
def attach(
self,
func: Callable[["State", int, float], Any],
interval: Optional[int] = None,
time_interval: Optional[float] = None,
name: Optional[str] = None,
store: bool = True,
max_records: Optional[int] = None,
on_error: Optional[Callable[[Exception, "State", int, float], bool]] = None,
enabled: bool = True,
) -> str:
"""Attach one property callback to this tracker."""
spec = make_property_spec(
func,
interval=interval,
time_interval=time_interval,
name=name,
store=store,
max_records=max_records,
on_error=on_error,
enabled=enabled,
existing_names=set(self._properties),
)
self._properties[spec.name] = spec
self._property_records[spec.name] = []
return spec.name
[docs]
def attach_spec(self, spec: PropertySpec) -> str:
"""Attach a prevalidated property specification to this tracker."""
return self.attach(
spec.callback,
interval=spec.interval,
time_interval=spec.time_interval,
name=spec.name,
store=spec.store,
max_records=spec.max_records,
on_error=spec.on_error,
enabled=spec.enabled,
)
[docs]
def detach(self, name: str) -> None:
"""Detach a previously attached property callback."""
if name not in self._properties:
raise ValueError(f"Property '{name}' is not attached")
del self._properties[name]
self._property_records.pop(name, None)
[docs]
def clear_attachments(self) -> None:
"""Remove all user-attached property callbacks."""
self._properties.clear()
self._property_records.clear()
[docs]
def list_attachments(self) -> list[str]:
"""Return names of user-attached properties."""
return list(self._properties)
[docs]
def list_property_calculations(self) -> dict[str, list[str]]:
"""Return enabled/disabled built-ins and currently attached callbacks."""
return describe_property_calculations(
builtin_enabled=self._enabled_builtin_properties,
attached_properties=self._properties,
)
[docs]
def set_property_enabled(self, name: str, enabled: bool) -> None:
"""Enable or disable a built-in summary field or an attached callback."""
set_property_enabled_flag(
builtin_enabled=self._enabled_builtin_properties,
attached_properties=self._properties,
name=name,
enabled=enabled,
)
def _compute_transport_summary(self) -> dict[str, float]:
"""Compute the built-in transport summary from current tracker state."""
return compute_transport_properties(
self.displacement,
self.hop_counter,
sim_time=float(self.time),
dimension=self.dimension,
n_mobile_ion_specie=self.n_mobile_ion_specie,
elementary_hop_distance=self.elem_hop_distance,
volume=self.volume,
mobile_ion_charge=self.q,
temperature=self.temperature,
enabled=self._enabled_builtin_properties,
)
def _sample_transport_summary(self, step: int, sim_time: float) -> None:
"""Sample built-in transport metrics when the global schedule is due."""
if not should_trigger(
step=step,
sim_time=sim_time,
interval=self._global_interval,
time_interval=self._global_time_interval,
last_trigger_time=self._last_summary_trigger_time,
):
return
metrics = self._compute_transport_summary()
_append_result_row(self.results, sim_time, metrics)
self._last_summary_trigger_time = sim_time
def _handle_callback_error(
self,
spec: PropertySpec,
exc: Exception,
step: int,
sim_time: float,
) -> None:
"""Handle callback failures using optional on_error policy."""
if spec.on_error is None:
raise CallbackExecutionError(
f"Property callback '{spec.name}' failed at step={step}, time={sim_time}"
) from exc
try:
keep_running = bool(spec.on_error(exc, self.state, step, sim_time))
except Exception as handler_exc:
raise CallbackExecutionError(
f"Error handler for callback '{spec.name}' failed"
) from handler_exc
if not keep_running:
raise CallbackExecutionError(
f"Property callback '{spec.name}' failed and requested termination"
) from exc
def _latest_property_value(self, name: str) -> Any:
"""Return latest sampled value for a property name."""
records = self._property_records.get(name, [])
if not records:
return float("nan")
return records[-1].value
[docs]
def sample_properties(self, step: int, sim_time: float) -> None:
"""Evaluate schedules and execute all attached property callbacks."""
self._sample_transport_summary(step=step, sim_time=sim_time)
for spec in list(self._properties.values()):
if not spec.enabled:
continue
interval = spec.interval if spec.interval is not None else self._global_interval
time_interval = (
spec.time_interval
if spec.time_interval is not None
else self._global_time_interval
)
if not should_trigger(
step=step,
sim_time=sim_time,
interval=interval,
time_interval=time_interval,
last_trigger_time=spec.last_trigger_time,
):
continue
try:
value = spec.callback(self.state, step, sim_time)
except Exception as exc:
self._handle_callback_error(spec=spec, exc=exc, step=step, sim_time=sim_time)
spec.last_trigger_step = step
spec.last_trigger_time = sim_time
continue
if spec.store:
append_record(
records=self._property_records[spec.name],
spec=spec,
step=step,
sim_time=sim_time,
value=value,
)
spec.last_trigger_step = step
spec.last_trigger_time = sim_time
[docs]
def get_property_records(
self, name: Optional[str] = None
) -> dict[str, list[dict[str, Any]]] | list[dict[str, Any]]:
"""Retrieve stored callback records."""
if name is not None:
if name not in self._property_records:
raise ValueError(f"Property '{name}' has no stored records")
return [record.__dict__.copy() for record in self._property_records[name]]
return {
key: [record.__dict__.copy() for record in records]
for key, records in self._property_records.items()
}
[docs]
def update(self, event, dt) -> None:
"""Update trajectory observables using the current pre-event State."""
_ = dt
occupations = self.state.occupations
direction = event_direction(occupations, event)
if direction == 0:
logger.error("Proposed event does not match current endpoint occupations")
return
logger.debug(
"Tracker update: event=%s direction=%d time=%.6f",
event.mobile_ion_indices,
direction,
self.time,
)
mobile_ion_index = self._record_mobile_ion_hop(event, direction)
displacement = self._wrapped_hop_displacement(event, direction)
self.displacement[mobile_ion_index] += displacement
self.hop_counter[mobile_ion_index] += 1
def _record_mobile_ion_hop(self, event, direction: int) -> int:
"""Move the tracked mobile ion identity and return its row index."""
from_site, to_site = event.mobile_ion_indices
source_site = from_site if direction == 1 else to_site
destination_site = to_site if direction == 1 else from_site
matches = np.where(self.mobile_ion_specie_locations == source_site)[0]
if len(matches) == 0:
raise RuntimeError(
"Tracker mobile-ion locations are inconsistent with the "
f"accepted event source site {source_site}."
)
mobile_ion_index = int(matches[0])
self.mobile_ion_specie_locations[mobile_ion_index] = destination_site
return mobile_ion_index
def _wrapped_hop_displacement(self, event, direction: int) -> np.ndarray:
"""Return the minimum-image Cartesian displacement for one accepted hop."""
from_site, to_site = event.mobile_ion_indices
displacement_frac = direction * (
self.frac_coords[to_site] - self.frac_coords[from_site]
)
displacement_frac -= np.round(displacement_frac).astype(int)
return np.array(self.latt.get_cartesian_coords(displacement_frac))
[docs]
def update_current_pass(self, current_pass: int) -> None:
"""Update current pass index used in logging/output."""
self.current_pass = current_pass
[docs]
def show_current_info(self) -> None:
"""Log current simulation information and latest sampled summary."""
if not self.results["time"]:
logger.info("Pass %d has no sampled properties yet.", self.current_pass)
return
attached_values = {
name: self._latest_property_value(name)
for name in self.list_attachments()
}
logger.info(
"Tracker Summary:%s",
format_tracker_summary(
current_pass=self.current_pass,
results=self.results,
result_units=self.result_units,
attached_values=attached_values,
),
)
[docs]
def return_current_info(self) -> tuple[float, float, float, float, float, float, float]:
"""Return latest sampled summary values for testing/reporting.
Units follow ``Tracker.result_units`` and tuple order is:
``time``, ``msd``, ``jump_diffusivity``, ``tracer_diffusivity``,
``conductivity``, ``havens_ratio``, ``correlation_factor``.
"""
if not self.results["time"]:
raise ValueError("No property samples are available. Increase sampling frequency.")
return (
self.results["time"][-1],
self.results["msd"][-1],
self.results["jump_diffusivity"][-1],
self.results["tracer_diffusivity"][-1],
self.results["conductivity"][-1],
self.results["havens_ratio"][-1],
self.results["correlation_factor"][-1],
)
[docs]
def write_results(self, label: str | None = None) -> None:
"""Write trajectory arrays, built-in summaries, and custom-property records."""
write_tracker_results(
label=label,
current_pass=self.current_pass,
displacement=self.displacement,
hop_counter=self.hop_counter,
occupations=self.state.occupations,
results=self.results,
result_units=self.result_units,
property_records=self._property_records,
)