"""Data models for SAM workflow execution."""
from __future__ import annotations
from pathlib import Path
from typing import Any, Iterable, Sequence
import numpy as np
import pandas as pd
from pydantic import BaseModel, ConfigDict, Field, model_validator
from equilibria.sam_tools.aggregation import aggregate_dataframe, load_mapping
from equilibria.sam_tools.balancing import RASBalanceResult, RASBalancer
from equilibria.sam_tools.enums import SAMFormat
[docs]
class Sam(BaseModel):
"""Contenedor base para una SAM, asegura matriz cuadrada con cuentas consistentes."""
dataframe: pd.DataFrame
model_config = ConfigDict(arbitrary_types_allowed=True)
@staticmethod
def _normalize_accounts(keys: Iterable[Any]) -> list[tuple[str, ...]]:
normalized: list[tuple[str, ...]] = []
for entry in keys:
if isinstance(entry, tuple):
parts = tuple(str(part).strip() for part in entry)
else:
parts = (str(entry).strip(),)
normalized.append(parts)
return normalized
@staticmethod
def _ensure_dataframe(frame: pd.DataFrame) -> pd.DataFrame:
df = frame.copy()
if df.ndim != 2:
raise ValueError("El SAM debe ser una matriz bidimensional")
rows, cols = df.shape
if rows != cols:
raise ValueError("El SAM debe ser cuadrado: filas y columnas iguales")
df.index = pd.MultiIndex.from_tuples(Sam._normalize_accounts(df.index))
df.columns = pd.MultiIndex.from_tuples(Sam._normalize_accounts(df.columns))
if set(df.index.tolist()) != set(df.columns.tolist()):
raise ValueError("Las mismas cuentas deben aparecer en filas y columnas")
return df.astype(float)
@model_validator(mode="after")
def _validate(self) -> Sam:
normalized = self._ensure_dataframe(self.dataframe)
object.__setattr__(self, "dataframe", normalized)
return self
@staticmethod
def _build_square_dataframe(
matrix: np.ndarray,
row_keys: Sequence[tuple[str, str]],
col_keys: Sequence[tuple[str, str]],
) -> pd.DataFrame:
normalized_rows = [tuple(str(part).strip() for part in key) for key in row_keys]
normalized_cols = [tuple(str(part).strip() for part in key) for key in col_keys]
combined: list[tuple[str, str]] = []
seen: set[tuple[str, str]] = set()
for key in normalized_rows + normalized_cols:
if key not in seen:
combined.append(key)
seen.add(key)
size = len(combined)
square = np.zeros((size, size), dtype=float)
index_map = {key: idx for idx, key in enumerate(combined)}
for i, row_key in enumerate(normalized_rows):
for j, col_key in enumerate(normalized_cols):
square[index_map[row_key], index_map[col_key]] = matrix[i, j]
multi_index = pd.MultiIndex.from_tuples(combined)
return pd.DataFrame(square, index=multi_index, columns=multi_index)
[docs]
@classmethod
def from_matrix(
cls,
matrix: np.ndarray,
row_keys: Sequence[tuple[str, str]],
col_keys: Sequence[tuple[str, str]],
) -> Sam:
if matrix.shape != (len(row_keys), len(col_keys)):
raise ValueError("La matriz no coincide con los índices provistos")
df = cls._build_square_dataframe(matrix, row_keys, col_keys)
return cls(dataframe=df)
@property
def row_keys(self) -> list[tuple[str, ...]]:
return [tuple(key) for key in self.dataframe.index.tolist()]
@property
def col_keys(self) -> list[tuple[str, ...]]:
return [tuple(key) for key in self.dataframe.columns.tolist()]
@property
def matrix(self) -> np.ndarray:
matrix = self.dataframe.to_numpy(dtype=float, copy=False)
matrix.setflags(write=True)
return matrix
[docs]
def update_matrix(self, matrix: np.ndarray) -> None:
new_df = pd.DataFrame(
matrix,
index=self.dataframe.index,
columns=self.dataframe.columns,
)
self.replace_dataframe(new_df)
[docs]
def to_dataframe(self) -> pd.DataFrame:
return self.dataframe.copy()
[docs]
def replace_dataframe(self, frame: pd.DataFrame) -> None:
normalized = self._ensure_dataframe(frame)
object.__setattr__(self, "dataframe", normalized)
[docs]
def aggregate(self, mapping_path: Path) -> Sam:
mapping, ordered = load_mapping(mapping_path)
df = self.to_dataframe()
aggregated = aggregate_dataframe(df, mapping, ordered)
category = self.row_keys[0][0] if self.row_keys else "RAW"
multi_index = pd.MultiIndex.from_tuples([(category, label) for label in aggregated.index])
new_df = pd.DataFrame(aggregated.to_numpy(dtype=float), index=multi_index, columns=multi_index)
self.replace_dataframe(new_df)
return self
[docs]
def balance_ras(
self,
*,
ras_type: str = "arithmetic",
tolerance: float = 1e-9,
max_iterations: int = 200,
) -> RASBalanceResult:
result = RASBalancer().balance_dataframe(
self.to_dataframe(),
ras_type=ras_type,
tolerance=tolerance,
max_iterations=max_iterations,
)
self.replace_dataframe(result.matrix)
return result
[docs]
def balance_status(self, *, tolerance: float = 1e-9) -> dict[str, Any]:
"""Return balance diagnostics for row vs column totals."""
matrix = self.matrix
row_sums = matrix.sum(axis=1)
col_sums = matrix.sum(axis=0)
diff = row_sums - col_sums
if diff.size == 0:
max_abs = 0.0
worst_index = 0
else:
max_abs = float(np.max(np.abs(diff)))
worst_index = int(np.argmax(np.abs(diff)))
worst_key = self.row_keys[worst_index] if self.row_keys else ("", "")
worst_diff = float(diff[worst_index]) if diff.size else 0.0
return {
"is_balanced": bool(max_abs <= float(tolerance)),
"tolerance": float(tolerance),
"max_row_col_abs_diff": max_abs,
"worst_account": worst_key,
"worst_account_diff": worst_diff,
"total": float(matrix.sum()),
}
[docs]
class SamTable(BaseModel):
"""Table-level SAM object with source metadata and editable matrix access."""
sam: Sam
source_path: Path
source_format: str
raw_df: pd.DataFrame | None = None
data_start_row: int | None = None
data_start_col: int | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)
@property
def row_keys(self) -> list[tuple[str, ...]]:
return self.sam.row_keys
@property
def col_keys(self) -> list[tuple[str, ...]]:
return self.sam.col_keys
@property
def matrix(self) -> np.ndarray:
return self.sam.matrix
@matrix.setter
def matrix(self, value: np.ndarray) -> None:
self.sam.update_matrix(value)
[docs]
def to_dataframe(self) -> pd.DataFrame:
return self.sam.to_dataframe()
[docs]
class SAMWorkflowConfig(BaseModel):
"""Resolved workflow config from YAML."""
name: str
country: str | None
input_path: Path
input_format: SAMFormat
output_path: Path
output_format: SAMFormat
input_options: dict[str, Any] = Field(default_factory=dict)
transforms: list[dict[str, Any]] = Field(default_factory=list)
report_path: Path | None
output_symbol: str
model_config = ConfigDict(arbitrary_types_allowed=True)