"""
Basis functions and occupation management for lattice models in kMCpy.
This module provides different basis functions for converting between species and numerical values
used in cluster expansion and other lattice-based calculations, as well as an Occupation class
for managing site occupation states with support for custom basis functions via registry.
"""
import numpy as np
from abc import ABC, abstractmethod
from typing import List, Union, Tuple, Iterator, Dict, Type, Any
from monty.json import MSONable
# Basis function registry for extensibility
BASIS_REGISTRY: Dict[str, Type['BasisFunction']] = {}
[docs]
class BasisFunction(MSONable, ABC):
"""
Abstract base class for all basis functions.
This defines the interface that all basis functions must implement,
allowing users to create custom basis functions that work with the
Occupation class.
"""
def __init__(self):
self.name = self.__class__.__name__.lower().replace('basis', '')
@property
def uses_state_indices(self) -> bool:
"""Whether occupations store discrete species-state indices."""
return False
@property
@abstractmethod
def match_value(self) -> Union[int, float]:
"""Value representing specie matches template structure."""
pass
@property
@abstractmethod
def mismatch_value(self) -> Union[int, float]:
"""Value representing specie doesn't match template structure."""
pass
@property
@abstractmethod
def valid_values(self) -> set:
"""Set of all valid values in this basis."""
pass
@property
@abstractmethod
def basis_function(self) -> List[Union[int, float]]:
"""List defining the basis function values."""
pass
[docs]
def num_site_basis_functions(self, n_states: int) -> int:
"""Return number of non-constant site basis functions."""
return 1
[docs]
def state_value(self, state_index: int, n_states: int | None = None) -> Union[int, float]:
"""Return the occupation value used for a species-state index."""
return self.match_value if int(state_index) == 0 else self.mismatch_value
[docs]
def site_basis_values(self, n_states: int) -> np.ndarray:
"""Return basis values with shape ``(n_states, n_basis_functions)``."""
return np.array(
[
[self.state_value(state_index, n_states)]
for state_index in range(int(n_states))
],
dtype=float,
)
[docs]
def as_dict(self) -> dict:
"""Serialize the basis definition."""
return {
"@module": self.__class__.__module__,
"@class": self.__class__.__name__,
"name": self.name,
}
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "BasisFunction":
"""Create a basis function from a Monty-style payload."""
if not isinstance(data, dict):
raise ValueError("BasisFunction.from_dict expects a dictionary")
target_cls: Type[BasisFunction] = cls
if cls is BasisFunction:
class_name = data.get("@class")
for registered_cls in BASIS_REGISTRY.values():
if registered_cls.__name__ == class_name:
target_cls = registered_cls
break
else:
basis_name = data.get("name")
if basis_name in BASIS_REGISTRY:
target_cls = BASIS_REGISTRY[basis_name]
else:
raise ValueError(f"Unknown basis payload: {data}")
kwargs = {
key: value
for key, value in data.items()
if not key.startswith("@") and key != "name"
}
return target_cls(**kwargs)
[docs]
def flip(self, value: Union[int, float]) -> Union[int, float]:
"""Flip between match and mismatch values."""
return self.mismatch_value if value == self.match_value else self.match_value
[docs]
def convert_to(self, value: Union[int, float], target_basis: 'BasisFunction') -> Union[int, float]:
"""Convert a value to another basis function."""
if value == self.mismatch_value:
return target_basis.mismatch_value
else:
return target_basis.match_value
def __str__(self) -> str:
return f"{self.__class__.__name__}(match={self.match_value}, mismatch={self.mismatch_value})"
[docs]
def register_basis(name: str):
"""
Decorator to register a basis function class.
Args:
name: Name to register the basis function under
"""
def decorator(cls: Type[BasisFunction]):
BASIS_REGISTRY[name] = cls
return cls
return decorator
[docs]
@register_basis('occupation')
class OccupationBasis(BasisFunction):
"""
Occupation basis function that maps between binary occupation states.
Uses [0, 1] representation for site occupancy.
In occupation basis:
- 0 = vacant (no atom present)
- 1 = occupied (atom present)
This is the most intuitive binary representation.
"""
@property
def match_value(self) -> int:
"""Value when site matches template (has atom = occupied)."""
return 1
@property
def mismatch_value(self) -> int:
"""Value when site doesn't match template (no atom = vacant)."""
return 0
# Physical interpretation properties
@property
def vacant_value(self) -> int:
"""Value representing vacant site (no atom)."""
return 0
@property
def occupied_value(self) -> int:
"""Value representing occupied site (atom present)."""
return 1
@property
def valid_values(self) -> set:
return {0, 1}
@property
def basis_function(self) -> List[int]:
return [0, 1]
[docs]
def flip(self, value: int) -> int:
"""Flip between match and mismatch."""
return 1 - value
[docs]
def is_occupied(self, value: int) -> bool:
"""Check if value represents occupied state."""
return value == self.occupied_value
[docs]
def is_vacant(self, value: int) -> bool:
"""Check if value represents vacant state."""
return value == self.vacant_value
[docs]
def to_chebyshev(self, value: int) -> int:
"""Convert occupation value to the Chebyshev species-state basis."""
return 1 if value == self.vacant_value else 0
[docs]
def from_chebyshev(self, value: int) -> int:
"""Convert the Chebyshev species-state value to occupation basis."""
return self.occupied_value if value == 0 else self.vacant_value
[docs]
@register_basis('chebyshev')
class ChebyshevBasis(BasisFunction):
"""
Discrete site-state Chebyshev basis used for cluster expansion.
kMCpy stores occupations as species-state indices. A site with ``q``
allowed species has states ``0`` through ``q - 1`` and ``q - 1``
non-constant Chebyshev basis functions.
"""
def __init__(self, max_states: int = 2):
super().__init__()
if int(max_states) < 2:
raise ValueError("ChebyshevBasis requires at least two states")
self.max_states = int(max_states)
@property
def uses_state_indices(self) -> bool:
"""Chebyshev occupations are stored as species-state indices."""
return True
@property
def match_value(self) -> int:
"""Value when site matches the first mapped species."""
return 0
@property
def mismatch_value(self) -> int:
"""Value when site is missing or matches another allowed species."""
return 1
@property
def occupied_value(self) -> int:
"""Alias for the first mapped species state."""
return self.match_value
@property
def vacant_value(self) -> int:
"""Alias for the missing/second-species state."""
return self.mismatch_value
@property
def valid_values(self) -> set:
return set(range(self.max_states))
@property
def basis_function(self) -> List[Union[int, float, list]]:
return self.site_basis_values(self.max_states).tolist()
[docs]
def flip(self, value: int) -> int:
"""Flip between match and mismatch."""
return self.mismatch_value if value == self.match_value else self.match_value
[docs]
def is_occupied(self, value: int) -> bool:
"""Check if value represents the first mapped species state."""
return value == self.occupied_value
[docs]
def is_vacant(self, value: int) -> bool:
"""Check if value represents the missing/second-species state."""
return value == self.vacant_value
[docs]
def to_occupation(self, value: int) -> int:
"""
Convert Chebyshev value to occupation basis.
Chebyshev: 0 (first species), 1 (missing/second species)
Occupation: 0 (vacant), 1 (occupied)
"""
return 1 if value == self.occupied_value else 0
[docs]
def from_occupation(self, value: int) -> int:
"""
Convert occupation value to Chebyshev basis.
Occupation: 0 (vacant), 1 (occupied)
Chebyshev: 0 (first species), 1 (missing/second species)
"""
return self.occupied_value if value == 1 else self.vacant_value
[docs]
def num_site_basis_functions(self, n_states: int) -> int:
"""Return ``q - 1`` non-constant basis functions for ``q`` states."""
n_states = int(n_states)
if n_states < 2:
return 1
return n_states - 1
[docs]
def state_value(self, state_index: int, n_states: int | None = None) -> int:
"""Return the occupation value for a species-state index."""
return int(state_index)
[docs]
def site_basis_values(self, n_states: int) -> np.ndarray:
"""
Return discrete Chebyshev values for a site's allowed species.
The species states are encoded as equally spaced points in ``[-1, 1]``.
For ``q`` species, the returned columns are ``T_1`` through
``T_{q-1}`` evaluated at those points.
"""
n_states = int(n_states)
if n_states < 2:
return np.ones((n_states, 1), dtype=float)
encoded = np.linspace(-1.0, 1.0, n_states)
return np.polynomial.chebyshev.chebvander(encoded, n_states - 1)[:, 1:]
[docs]
def as_dict(self) -> dict:
payload = super().as_dict()
payload["max_states"] = self.max_states
return payload
[docs]
def get_basis(name: str, **kwargs) -> BasisFunction:
"""
Get a basis function instance by name.
Args:
name: Name of the basis function
Returns:
Instance of the requested basis function
Raises:
ValueError: If basis name is not registered
"""
if name not in BASIS_REGISTRY:
raise ValueError(f"Unknown basis '{name}'. Available: {list(BASIS_REGISTRY.keys())}")
return BASIS_REGISTRY[name](**kwargs)
[docs]
class Occupation:
"""
Encapsulates site occupation data with basis conversion and validation.
Provides a clean interface for managing occupation arrays with different
basis representations. Supports any registered basis function, allowing
users to define custom basis functions.
Features:
- Automatic validation of occupation values using basis classes
- Basis conversion between any registered basis types
- Common operations like counting match/mismatch sites
- Immutable and mutable variants
- Efficient numpy operations under the hood
- Extensible via basis function registry
"""
def __init__(self, data: Union[List[int], Tuple[int], np.ndarray],
basis: Union[str, BasisFunction] = 'chebyshev', validate: bool = True):
"""
Initialize occupation array.
Args:
data: Occupation values as list, tuple, or numpy array
basis: Basis function - either a string name or BasisFunction instance
validate: Whether to validate occupation values against basis
Raises:
ValueError: If basis is invalid or data doesn't match basis constraints
"""
if isinstance(basis, str):
self._basis_obj = get_basis(basis)
self._basis_name = basis
elif isinstance(basis, BasisFunction):
self._basis_obj = basis
self._basis_name = basis.name
else:
raise ValueError(f"Invalid basis type. Must be string or BasisFunction instance")
self._data = np.array(data, dtype=type(self._basis_obj.match_value))
if validate:
self._validate()
def _validate(self):
"""Validate occupation values against the specified basis."""
invalid = set(self._data) - self._basis_obj.valid_values
if invalid:
raise ValueError(f"Invalid values {invalid} for {self._basis_name} basis. "
f"Must be in {self._basis_obj.valid_values}")
@property
def basis(self) -> str:
"""Get the basis type name."""
return self._basis_name
@property
def basis_obj(self) -> BasisFunction:
"""Get the basis function object."""
return self._basis_obj
@property
def data(self) -> np.ndarray:
"""Get the underlying numpy array (read-only view)."""
return self._data.copy() # Return copy to maintain immutability
@property
def values(self) -> List[int]:
"""Get occupation values as a list."""
return self._data.tolist()
def __len__(self) -> int:
"""Return number of sites."""
return len(self._data)
def __getitem__(self, index: Union[int, slice, np.ndarray, List[int]]) -> Union[int, 'Occupation']:
"""Get occupation value(s) at index."""
if isinstance(index, slice):
return Occupation(self._data[index], basis=self._basis_obj, validate=False)
elif isinstance(index, (list, tuple, np.ndarray)):
# Handle array-like indexing (returns new Occupation object)
return Occupation(self._data[index], basis=self._basis_obj, validate=False)
else:
# Single index (returns scalar)
return self._data[index].item()
def __setitem__(self, index: Union[int, slice], value: Union[int, List[int]]):
"""Set occupation value(s) at index."""
self._data[index] = value
if hasattr(self, '_validate'): # Only validate if not in constructor
self._validate()
def __iter__(self) -> Iterator[int]:
"""Iterate over occupation values."""
return iter(self._data.tolist())
def __eq__(self, other) -> bool:
"""Check equality with another Occupation object."""
if not isinstance(other, Occupation):
return False
return self._basis_name == other._basis_name and np.array_equal(self._data, other._data)
def __ne__(self, other) -> bool:
"""Check inequality with another Occupation object."""
return not self.__eq__(other)
[docs]
def equivalent_to(self, other: 'Occupation') -> bool:
"""
Check if two Occupation objects represent the same occupation pattern,
regardless of basis type.
This is useful for comparing occupations that might be in different bases
but represent the same physical state.
Args:
other: Another Occupation object to compare with
Returns:
True if both objects represent the same occupation pattern
"""
if not isinstance(other, Occupation):
return False
# Convert both to occupation basis for comparison
self_occ = self.to_basis('occupation')
other_occ = other.to_basis('occupation')
return np.array_equal(self_occ._data, other_occ._data)
[docs]
def array_equal(self, array: Union[List, Tuple, np.ndarray]) -> bool:
"""
Check if the underlying data array equals the given array.
This is useful for unit tests where you want to compare the raw data
without creating another Occupation object.
Args:
array: Array-like object to compare with
Returns:
True if the arrays are equal
"""
return np.array_equal(self._data, np.array(array))
def __repr__(self) -> str:
"""String representation."""
return f"Occupation({self.values}, basis='{self._basis_name}')"
def __str__(self) -> str:
"""User-friendly string representation."""
return f"{self._basis_name.title()} occupation: {self.values}"
# Convenience methods for common operations using basis objects
[docs]
def count_mismatch(self) -> int:
"""Count sites with mismatch (specie doesn't match template)."""
return int(np.sum(self._data == self._basis_obj.mismatch_value))
[docs]
def count_match(self) -> int:
"""Count sites with match (specie matches template)."""
return int(np.sum(self._data == self._basis_obj.match_value))
# Physical occupation counts (use basis object's definitions)
[docs]
def count_occupied(self) -> int:
"""Count sites with occupied state."""
return int(np.sum(self._data == self._basis_obj.occupied_value))
[docs]
def count_vacant(self) -> int:
"""Count sites with vacant state."""
return int(np.sum(self._data == self._basis_obj.vacant_value))
[docs]
def get_mismatch_indices(self) -> List[int]:
"""Get indices of sites with mismatch (specie doesn't match template)."""
return np.where(self._data == self._basis_obj.mismatch_value)[0].tolist()
[docs]
def get_match_indices(self) -> List[int]:
"""Get indices of sites with match (specie matches template)."""
return np.where(self._data == self._basis_obj.match_value)[0].tolist()
[docs]
def get_occupied_indices(self) -> List[int]:
"""Get indices of occupied sites."""
return np.where(self._data == self._basis_obj.occupied_value)[0].tolist()
[docs]
def get_vacant_indices(self) -> List[int]:
"""Get indices of vacant sites."""
return np.where(self._data == self._basis_obj.vacant_value)[0].tolist()
[docs]
def flip(self, indices: Union[int, List[int]]) -> 'Occupation':
"""
Return new Occupation with flipped values at specified indices.
Args:
indices: Site index or list of indices to flip
Returns:
New Occupation object with flipped values
"""
new_data = self._data.copy()
if isinstance(indices, int):
indices = [indices]
for idx in indices:
new_data[idx] = self._basis_obj.flip(new_data[idx])
return Occupation(new_data, basis=self._basis_obj, validate=False)
[docs]
def flip_inplace(self, indices: Union[int, List[int]]) -> None:
"""
Flip values at specified indices in-place.
Args:
indices: Site index or list of indices to flip
"""
if isinstance(indices, int):
indices = [indices]
for idx in indices:
self._data[idx] = self._basis_obj.flip(self._data[idx])
[docs]
def to_basis(self, target_basis: Union[str, BasisFunction]) -> 'Occupation':
"""
Convert to different basis representation using basis objects.
Args:
target_basis: Target basis - either string name or BasisFunction instance
Returns:
New Occupation object in target basis
"""
if isinstance(target_basis, str):
target_basis_obj = get_basis(target_basis)
else:
target_basis_obj = target_basis
if target_basis_obj.__class__ == self._basis_obj.__class__:
return Occupation(self._data, basis=target_basis_obj, validate=False)
# General conversion using basis function interface
converted = np.array([self._basis_obj.convert_to(val, target_basis_obj) for val in self._data])
return Occupation(converted, basis=target_basis_obj, validate=False)
[docs]
def copy(self) -> 'Occupation':
"""Create a deep copy of this Occupation."""
return Occupation(self._data.copy(), basis=self._basis_obj, validate=False)
[docs]
@classmethod
def zeros(cls, n_sites: int, basis: Union[str, BasisFunction] = 'chebyshev') -> 'Occupation':
"""
Create Occupation with all sites vacant (no atoms).
Args:
n_sites: Number of sites
basis: Basis function or name
Returns:
Occupation object with all sites vacant
"""
if isinstance(basis, str):
basis_obj = get_basis(basis)
else:
basis_obj = basis
# Use vacant_value from basis object
data = np.full(n_sites, basis_obj.vacant_value, dtype=type(basis_obj.vacant_value))
return cls(data, basis=basis_obj, validate=False)
[docs]
@classmethod
def ones(cls, n_sites: int, basis: Union[str, BasisFunction] = 'chebyshev') -> 'Occupation':
"""
Create Occupation with all sites occupied (atoms present).
Args:
n_sites: Number of sites
basis: Basis function or name
Returns:
Occupation object with all sites occupied
"""
if isinstance(basis, str):
basis_obj = get_basis(basis)
else:
basis_obj = basis
# Use occupied_value from basis object
data = np.full(n_sites, basis_obj.occupied_value, dtype=type(basis_obj.occupied_value))
return cls(data, basis=basis_obj, validate=False)
[docs]
@classmethod
def random(cls, n_sites: int, basis: Union[str, BasisFunction] = 'chebyshev',
fill_fraction: float = 0.5, seed: int = None) -> 'Occupation':
"""
Create random Occupation using appropriate basis values.
Args:
n_sites: Number of sites
basis: Basis function or name
fill_fraction: Fraction of sites to be occupied (atoms present) (0.0 to 1.0)
seed: Random seed for reproducibility
Returns:
Random Occupation object
"""
if seed is not None:
np.random.seed(seed)
# Generate random boolean array for occupied sites
occupied = np.random.random(n_sites) < fill_fraction
if isinstance(basis, str):
basis_obj = get_basis(basis)
else:
basis_obj = basis
# Use occupied_value and vacant_value from basis object
data = np.where(occupied, basis_obj.occupied_value, basis_obj.vacant_value)
return cls(data, basis=basis_obj, validate=False)