"""Active-site order utilities for compact KMC occupation vectors."""
from __future__ import annotations
import hashlib
import json
from dataclasses import dataclass, field
from typing import Any, Mapping, Sequence
import numpy as np
from monty.json import MSONable
from pymatgen.core import Structure
from kmcpy.structure.species import (
normalize_species,
species_equivalent,
species_label,
)
ACTIVE_SITE_ORDER_FORMAT = "kmcpy.active_site_order.v1"
ORIGINAL_SITE_PROPERTY = "_kmcpy_original_site_index"
PRIMITIVE_SITE_PROPERTY = "_kmcpy_primitive_site_index"
PRIMITIVE_ACTIVE_SITE_PROPERTY = "_kmcpy_primitive_active_site_index"
ACTIVE_SITE_PROPERTY = "_kmcpy_active_site_index"
[docs]
@dataclass(frozen=True)
class ActiveSiteOrder(MSONable):
"""Map full template sites to compact mutable active-site indices."""
primitive_site_count: int
original_site_count: int
primitive_active_indices: tuple[int, ...]
active_to_original: tuple[int, ...]
active_to_primitive: tuple[int, ...]
fixed_original_indices: tuple[int, ...]
supercell_shape: tuple[int, int, int]
species_by_primitive_site: tuple[str, ...]
allowed_species_by_primitive_site: tuple[tuple[str, ...], ...]
fingerprint: str
template_structure: Structure | None = field(
default=None, repr=False, compare=False
)
[docs]
@classmethod
def from_lattice_structure(
cls,
lattice_structure,
supercell_shape: Sequence[int] | None = None,
) -> "ActiveSiteOrder":
"""Build from a ``LatticeStructure`` instance."""
return cls.from_structure_and_mapping(
lattice_structure.template_structure,
lattice_structure.site_mapping,
supercell_shape=supercell_shape,
)
[docs]
@classmethod
def from_structure_and_mapping(
cls,
template_structure: Structure,
site_mapping: Mapping[Any, Any],
supercell_shape: Sequence[int] | None = None,
) -> "ActiveSiteOrder":
"""Build an active-site order from a full template and site mapping."""
shape = _normalize_supercell_shape(supercell_shape)
allowed_species = _allowed_species_by_site(
template_structure, site_mapping
)
primitive_active_indices = tuple(
index
for index, allowed in enumerate(allowed_species)
if len(allowed) > 1
)
if not primitive_active_indices:
raise ValueError(
"site_mapping does not define any mutable active sites"
)
primitive_active_lookup = {
original_index: active_index
for active_index, original_index in enumerate(primitive_active_indices)
}
full_structure = _make_supercell_with_properties(
template_structure=template_structure,
primitive_active_lookup=primitive_active_lookup,
supercell_shape=shape,
)
active_to_original = []
active_to_primitive = []
fixed_original_indices = []
for full_index, site in enumerate(full_structure):
primitive_index = int(site.properties[PRIMITIVE_SITE_PROPERTY])
if primitive_index in primitive_active_lookup:
active_to_original.append(full_index)
active_to_primitive.append(primitive_index)
else:
fixed_original_indices.append(full_index)
species_by_primitive_site = tuple(
species_label(site.specie) for site in template_structure
)
allowed_species_by_primitive_site = tuple(
tuple(species_label(specie) for specie in allowed)
for allowed in allowed_species
)
fingerprint = _fingerprint(
{
"format": ACTIVE_SITE_ORDER_FORMAT,
"primitive_site_count": len(template_structure),
"original_site_count": len(full_structure),
"primitive_active_indices": primitive_active_indices,
"active_to_original": active_to_original,
"active_to_primitive": active_to_primitive,
"supercell_shape": shape,
"species_by_primitive_site": species_by_primitive_site,
"allowed_species_by_primitive_site": (
allowed_species_by_primitive_site
),
}
)
return cls(
primitive_site_count=len(template_structure),
original_site_count=len(full_structure),
primitive_active_indices=primitive_active_indices,
active_to_original=tuple(active_to_original),
active_to_primitive=tuple(active_to_primitive),
fixed_original_indices=tuple(fixed_original_indices),
supercell_shape=shape,
species_by_primitive_site=species_by_primitive_site,
allowed_species_by_primitive_site=allowed_species_by_primitive_site,
fingerprint=fingerprint,
template_structure=template_structure.copy(),
)
[docs]
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "ActiveSiteOrder":
"""Restore serialized active-site order metadata."""
if data.get("format") != ACTIVE_SITE_ORDER_FORMAT:
raise ValueError(
f"Unsupported active-site order format: {data.get('format')}"
)
return cls(
primitive_site_count=int(data["primitive_site_count"]),
original_site_count=int(data["original_site_count"]),
primitive_active_indices=tuple(
int(index) for index in data["primitive_active_indices"]
),
active_to_original=tuple(
int(index) for index in data["active_to_original"]
),
active_to_primitive=tuple(
int(index) for index in data["active_to_primitive"]
),
fixed_original_indices=tuple(
int(index) for index in data.get("fixed_original_indices", ())
),
supercell_shape=tuple(int(value) for value in data["supercell_shape"]),
species_by_primitive_site=tuple(data["species_by_primitive_site"]),
allowed_species_by_primitive_site=tuple(
tuple(site_species) for site_species in data[
"allowed_species_by_primitive_site"
]
),
fingerprint=str(data["fingerprint"]),
template_structure=None,
)
@property
def active_site_count(self) -> int:
return len(self.active_to_original)
@property
def original_to_active(self) -> dict[int, int]:
return {
int(original_index): active_index
for active_index, original_index in enumerate(self.active_to_original)
}
@property
def primitive_to_active(self) -> dict[int, int]:
return {
int(original_index): active_index
for active_index, original_index in enumerate(self.primitive_active_indices)
}
[docs]
def as_dict(self) -> dict[str, Any]:
"""Serialize active-site order metadata without storing the full structure."""
return {
"format": ACTIVE_SITE_ORDER_FORMAT,
"primitive_site_count": self.primitive_site_count,
"original_site_count": self.original_site_count,
"primitive_active_indices": list(self.primitive_active_indices),
"active_to_original": list(self.active_to_original),
"active_to_primitive": list(self.active_to_primitive),
"fixed_original_indices": list(self.fixed_original_indices),
"supercell_shape": list(self.supercell_shape),
"species_by_primitive_site": list(self.species_by_primitive_site),
"allowed_species_by_primitive_site": [
list(species) for species in self.allowed_species_by_primitive_site
],
"fingerprint": self.fingerprint,
}
[docs]
def assert_same_order(self, other: "ActiveSiteOrder | Mapping[str, Any]") -> None:
"""Raise if another order or metadata payload describes a different site space."""
other_order = (
ActiveSiteOrder.from_dict(other)
if isinstance(other, Mapping)
else other
)
if self.fingerprint != other_order.fingerprint:
raise ValueError(
"Active-site order metadata does not match the current "
"site_mapping and structure."
)
[docs]
def validate_active_indices(
self,
indices: Sequence[int],
field_name: str = "indices",
) -> None:
"""Validate compact active-site indices."""
invalid = [
int(index)
for index in indices
if int(index) < 0 or int(index) >= self.active_site_count
]
if invalid:
raise IndexError(
f"{field_name} contains indices outside the active-site range: {invalid}"
)
[docs]
def select_active_values(self, values: Sequence[Any]) -> list[Any]:
"""Return compact active-site values from active or full-supercell input."""
values = list(values)
if len(values) == self.active_site_count:
return values
if len(values) == self.original_site_count:
return [values[index] for index in self.active_to_original]
raise ValueError(
"Occupation length must match either the active-site count "
f"({self.active_site_count}) or full site count ({self.original_site_count}); "
f"got {len(values)}."
)
[docs]
def full_structure_with_properties(self) -> Structure:
"""Return the full supercell with index-space site properties."""
if self.template_structure is None:
raise ValueError(
"Cannot build structures from serialized active-site order metadata only"
)
primitive_active_lookup = self.primitive_to_active
return _make_supercell_with_properties(
template_structure=self.template_structure,
primitive_active_lookup=primitive_active_lookup,
supercell_shape=self.supercell_shape,
)
[docs]
def active_structure(self) -> Structure:
"""Return active sites only, ordered by compact active index."""
full_structure = self.full_structure_with_properties()
active_sites = [
full_structure[original_index]
for original_index in self.active_to_original
]
active_structure = type(full_structure).from_sites(active_sites)
active_structure.add_site_property(
ACTIVE_SITE_PROPERTY, list(range(self.active_site_count))
)
return active_structure
[docs]
def filter_active_structure(
self,
structure: Structure,
tol: float = 1e-2,
) -> Structure:
"""Filter a possibly full structure down to active sites by index metadata or position."""
if (
len(structure) == self.active_site_count
and ACTIVE_SITE_PROPERTY in structure.site_properties
):
return structure.copy()
if len(structure) == self.original_site_count:
active_sites = [structure[index] for index in self.active_to_original]
active_structure = type(structure).from_sites(active_sites)
active_structure.add_site_property(
ACTIVE_SITE_PROPERTY, list(range(self.active_site_count))
)
return active_structure
full_reference = self.full_structure_with_properties()
original_to_active = self.original_to_active
kept_by_active_index = {}
for site in structure:
distances = full_reference.lattice.get_all_distances(
np.array([site.frac_coords]),
full_reference.frac_coords,
)[0]
original_index = int(np.argmin(distances))
if float(distances[original_index]) > tol:
raise ValueError(
"Input structure contains a site that cannot be mapped to "
"the active-site template."
)
active_index = original_to_active.get(original_index)
if active_index is None:
continue
if active_index in kept_by_active_index:
raise ValueError(
"Input structure maps multiple sites to the same active site"
)
kept_by_active_index[active_index] = site
active_sites = [
kept_by_active_index[index]
for index in sorted(kept_by_active_index)
]
active_structure = type(structure).from_sites(active_sites)
active_structure.add_site_property(
ACTIVE_SITE_PROPERTY, sorted(kept_by_active_index)
)
return active_structure
def _normalize_supercell_shape(
supercell_shape: Sequence[int] | None,
) -> tuple[int, int, int]:
if supercell_shape is None:
return (1, 1, 1)
shape = tuple(int(value) for value in supercell_shape)
if len(shape) != 3:
raise ValueError("supercell_shape must contain three integers")
if any(value <= 0 for value in shape):
raise ValueError("supercell_shape values must be positive")
return shape
def _make_supercell_with_properties(
template_structure: Structure,
primitive_active_lookup: Mapping[int, int],
supercell_shape: tuple[int, int, int],
) -> Structure:
from kmcpy.structure.sites import make_kmc_supercell, structure_from_sites
base = template_structure.copy()
primitive_indices = list(range(len(base)))
base.add_site_property(PRIMITIVE_SITE_PROPERTY, primitive_indices)
base.add_site_property(
PRIMITIVE_ACTIVE_SITE_PROPERTY,
[primitive_active_lookup.get(index, -1) for index in primitive_indices],
)
supercell = make_kmc_supercell(
structure_from_sites(base.sites),
supercell_shape,
)
supercell.add_site_property(
ORIGINAL_SITE_PROPERTY, list(range(len(supercell)))
)
active_indices = []
for site in supercell:
primitive_index = int(site.properties[PRIMITIVE_SITE_PROPERTY])
active_indices.append(primitive_active_lookup.get(primitive_index, -1))
supercell.add_site_property(ACTIVE_SITE_PROPERTY, active_indices)
return supercell
def _allowed_species_by_site(
template_structure: Structure,
site_mapping: Mapping[Any, Any],
) -> list[tuple[Any, ...]]:
entries = [
(normalize_species(key), _normalize_allowed_species(value))
for key, value in site_mapping.items()
]
allowed_species = []
for index, site in enumerate(template_structure):
matches = [
allowed
for key_species, allowed in entries
if species_equivalent(site.specie, key_species)
]
if not matches:
raise ValueError(
"No site_mapping entry found for template site "
f"{index} with species {site.species_string}."
)
allowed_species.append(matches[0])
return allowed_species
def _normalize_allowed_species(value: Any) -> tuple[Any, ...]:
if isinstance(value, (list, tuple)):
return tuple(normalize_species(item) for item in value)
return (normalize_species(value),)
def _fingerprint(payload: Mapping[str, Any]) -> str:
canonical = json.dumps(payload, sort_keys=True, default=str, separators=(",", ":"))
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()