Source code for equilibria.sam_tools.mip_to_sam_transforms

"""MIP to SAM structural transformations."""

from __future__ import annotations

from typing import Any

import numpy as np
import pandas as pd

from equilibria.sam_tools.models import Sam
from equilibria.sam_tools.sam_transforms import ensure_key
from equilibria.sam_tools.selectors import norm_text, norm_text_lower


DEFAULT_VA_FACTOR_SHARES = {"L": 0.65, "K": 0.35}

DEFAULT_FACTOR_TO_HOUSEHOLD_SHARES = {
    "L": {"hh": 0.95, "gvt": 0.05},
    "K": {"hh": 0.50, "firm": 0.45, "gvt": 0.05},
}

DEFAULT_TAX_RATES = {
    "production_tax": 0.10,
    "import_tariff": 0.05,
    "direct_tax": 0.15,
}


def _get_va_aggregate_row(sam: Sam) -> tuple[str, str] | None:
    """Find the VA aggregate row in the SAM (raw or normalized form)."""
    for cat, elem in sam.row_keys:
        cat_lower = norm_text_lower(cat)
        elem_lower = norm_text_lower(elem)
        if cat_lower == "va" and "aggregate" in elem_lower:
            return (cat, elem)
        if "va" in elem_lower and "aggregate" in elem_lower:
            return (cat, elem)
    return None


def _get_fd_columns(sam: Sam) -> list[tuple[str, str]]:
    """Get final demand columns (those with FD category or specific labels)."""
    fd_cols: list[tuple[str, str]] = []
    for cat, elem in sam.col_keys:
        cat_lower = norm_text_lower(cat)
        elem_lower = norm_text_lower(elem)
        if cat_lower == "fd" or elem_lower in {"hh", "gov", "gvt", "inv", "exp", "export"}:
            fd_cols.append((cat, elem))
    return fd_cols


def _get_sector_columns(sam: Sam) -> list[tuple[str, str]]:
    """Get sector (J) columns from SAM."""
    return [(cat, elem) for cat, elem in sam.col_keys if norm_text_lower(cat) == "j"]


def _get_commodity_keys(sam: Sam) -> list[tuple[str, str]]:
    """Get commodity (I) keys from SAM."""
    return [(cat, elem) for cat, elem in sam.row_keys if norm_text_lower(cat) == "i"]


_FD_TOKENS = ("hh", "hogar", "gov", "gob", "inv", "inversion", "exp", "export")


def _is_fd_label(elem_lower: str) -> bool:
    return any(tok in elem_lower for tok in _FD_TOKENS)


def _is_special_row_label(elem_lower: str) -> bool:
    return ("va" in elem_lower and "aggregate" in elem_lower) or (
        "imp" in elem_lower and "total" in elem_lower
    )


[docs] def normalize_mip_accounts(sam: Sam, op: dict[str, Any]) -> dict[str, Any]: """ Convert RAW MIP labels to initial PEP structure (J, I, VA, FD). The MIP raw matrix is square-padded so FD labels (HH/GOV/INV/EXP) appear on the row axis with zero data. We classify columns first using the FD closed vocabulary; everything else in the columns is a sector (J), and rows that match an FD label are dropped as spurious zero-rows. """ df = sam.to_dataframe() raw_keys = list(df.index) # Build new keys with proper categories new_keys: list[tuple[str, str]] = [] old_to_new_row: dict[tuple[str, str], tuple[str, str] | None] = {} old_to_new_col: dict[tuple[str, str], tuple[str, str] | None] = {} # Process columns first to get the canonical sector vocabulary. # VA/IMP columns are square-padding artifacts and get dropped. sector_labels: set[str] = set() for cat, elem in df.columns: elem_lower = norm_text_lower(elem) if _is_special_row_label(elem_lower): old_to_new_col[(cat, elem)] = None continue if _is_fd_label(elem_lower): new_key = ("FD", norm_text(elem)) else: new_key = ("J", norm_text(elem)) sector_labels.add(elem_lower) old_to_new_col[(cat, elem)] = new_key if new_key not in new_keys: new_keys.append(new_key) # Process rows: VA / IMP go to their own categories, sector labels become # commodities, FD labels are dropped (square-padding artifact). for cat, elem in raw_keys: elem_lower = norm_text_lower(elem) if "va" in elem_lower and "aggregate" in elem_lower: new_key = ("VA", "aggregate") elif "imp" in elem_lower and "total" in elem_lower: new_key = ("IMP", "total") elif _is_fd_label(elem_lower) and elem_lower not in sector_labels: old_to_new_row[(cat, elem)] = None continue else: new_key = ("I", norm_text(elem)) # Commodities old_to_new_row[(cat, elem)] = new_key if new_key not in new_keys: new_keys.append(new_key) # Build new matrix - just copy values with new labels n = len(new_keys) new_matrix = np.zeros((n, n), dtype=float) key_index = {key: idx for idx, key in enumerate(new_keys)} for i, row_key in enumerate(raw_keys): for j, col_key in enumerate(df.columns): value = float(df.iloc[i, j]) if abs(value) <= 1e-14: continue new_row = old_to_new_row.get(row_key) new_col = old_to_new_col.get(col_key) if new_row and new_col and new_row in key_index and new_col in key_index: new_matrix[key_index[new_row], key_index[new_col]] += value # Create new dataframe multi_index = pd.MultiIndex.from_tuples(new_keys) new_df = pd.DataFrame(new_matrix, index=multi_index, columns=multi_index) sam.replace_dataframe(new_df) return { "raw_accounts": len(raw_keys), "normalized_accounts": len(new_keys), "commodities": len([k for k in new_keys if k[0] == "I"]), "sectors": len([k for k in new_keys if k[0] == "J"]), "final_demand": len([k for k in new_keys if k[0] == "FD"]), }
[docs] def create_make_matrix(sam: Sam, op: dict[str, Any]) -> dict[str, Any]: """ Add diagonal make-matrix entries so each sector J has income matching its cost. A standard MIP records sector costs in J columns (intermediate + VA + IMP) and commodity supply in I rows, but it does not record sector *output* per commodity. SAM closure requires J row income to equal J column cost. We close the loop with a diagonal make matrix: each sector j sells exactly its column total to commodity j, i.e. df.loc[("J", j), ("I", j)] = sum of ("J", j) column. This assumes one-to-one sector-commodity correspondence, which holds for MIPs where sectors and commodities share the same labels. """ df = sam.to_dataframe() total_make = 0.0 pairs: list[tuple[str, str]] = [] for col_key in sam.col_keys: if norm_text_lower(col_key[0]) != "j": continue sector_label = col_key[1] commodity_key = ("I", sector_label) if commodity_key not in sam.row_keys: continue cost = float(df.loc[:, col_key].sum()) if cost <= 1e-14: continue sector_row_key = ("J", sector_label) if sector_row_key not in sam.row_keys: ensure_key(sam, sector_row_key) df = sam.to_dataframe() df.loc[sector_row_key, commodity_key] += cost total_make += cost pairs.append((sector_label, sector_label)) sam.replace_dataframe(df) return {"sectors_made": len(pairs), "total_make": total_make}
[docs] def disaggregate_va_to_factors(sam: Sam, op: dict[str, Any]) -> dict[str, Any]: """ Split VA aggregate row into L (labor) and K (capital) using shares. Before: ("VA", "aggregate") → ("J", "agr") = 100 After: ("L", "labor") → ("J", "agr") = 65 # 100 * 0.65 ("K", "capital") → ("J", "agr") = 35 # 100 * 0.35 """ shares = op.get("va_factor_shares", DEFAULT_VA_FACTOR_SHARES) # Validate shares sum to 1.0 total_share = sum(shares.values()) if abs(total_share - 1.0) > 1e-6: raise ValueError(f"va_factor_shares must sum to 1.0, got {total_share}") df = sam.to_dataframe() # Find VA row va_key = _get_va_aggregate_row(sam) if va_key is None: return {"factors_created": [], "total_va": 0.0, "error": "VA aggregate row not found"} if va_key not in df.index: return {"factors_created": [], "total_va": 0.0, "error": f"VA key {va_key} not in index"} # Create L and K rows l_key = ("L", "labor") k_key = ("K", "capital") ensure_key(sam, l_key) ensure_key(sam, k_key) df = sam.to_dataframe() total_va = 0.0 # Distribute VA to L and K for each sector for col_key in sam.col_keys: if norm_text_lower(col_key[0]) != "j": continue va_value = float(df.loc[va_key, col_key]) if abs(va_value) <= 1e-14: continue # Distribute to factors df.loc[l_key, col_key] = va_value * shares.get("L", 0.0) df.loc[k_key, col_key] = va_value * shares.get("K", 0.0) # Clear original VA df.loc[va_key, col_key] = 0.0 total_va += va_value sam.replace_dataframe(df) return { "factors_created": ["L", "K"], "total_va": total_va, "shares": shares, }
[docs] def create_factor_income_distribution(sam: Sam, op: dict[str, Any]) -> dict[str, Any]: """ Route factor income to institutions (households, firms, government). Before: ("L", "labor") → ("J", "agr") = 65 # Only production cost After: ("L", "labor") → ("J", "agr") = 65 # Keep production cost ("AG", "hh") → ("L", "labor") = 61.75 # 65 * 0.95 to households ("AG", "gvt") → ("L", "labor") = 3.25 # 65 * 0.05 taxes """ factor_shares = op.get("factor_to_household_shares", DEFAULT_FACTOR_TO_HOUSEHOLD_SHARES) df = sam.to_dataframe() total_distributed = 0.0 distribution_by_factor: dict[str, float] = {} # Get factor keys (L, K) factor_keys = [(cat, elem) for cat, elem in sam.row_keys if norm_text_lower(cat) in {"l", "k"}] for factor_key in factor_keys: factor_cat = factor_key[0] factor_total = 0.0 # Calculate total factor income from all sectors for col_key in sam.col_keys: if norm_text_lower(col_key[0]) == "j": factor_total += float(df.loc[factor_key, col_key]) if abs(factor_total) <= 1e-14: continue # Get distribution for this factor shares = factor_shares.get(factor_cat, {}) # Create institution accounts and distribute for institution, share in shares.items(): inst_key = ("AG", norm_text_lower(institution)) ensure_key(sam, inst_key) df = sam.to_dataframe() income = factor_total * share # Add income from factor to institution df.loc[inst_key, factor_key] += income sam.replace_dataframe(df) total_distributed += income distribution_by_factor[factor_cat] = factor_total return { "total_distributed": total_distributed, "distribution_by_factor": distribution_by_factor, "institutions_created": list({inst for shares in factor_shares.values() for inst in shares.keys()}), }
[docs] def create_household_expenditure(sam: Sam, op: dict[str, Any]) -> dict[str, Any]: """ Convert final demand "HH" from MIP to AG.hh → I flows. Before: ("I", "agr") → ("FD", "HH") = 20 # Final demand temporary After: ("AG", "hh") → ("I", "agr") = 20 # Household consumption """ df = sam.to_dataframe() # Find HH final demand column hh_fd_key = None for cat, elem in sam.col_keys: elem_lower = norm_text_lower(elem) if "hh" in elem_lower or "hogar" in elem_lower: hh_fd_key = (cat, elem) break if hh_fd_key is None: return {"converted": 0, "total_expenditure": 0.0, "error": "HH final demand not found"} # Ensure AG.hh exists hh_key = ("AG", "hh") ensure_key(sam, hh_key) df = sam.to_dataframe() total_expenditure = 0.0 # Transfer I → FD.HH flows to I → AG.hh: households pay for the commodity, # i.e. df.loc[I_row, hh_col] (hh column = hh expenditure on commodity I). for row_key in sam.row_keys: if norm_text_lower(row_key[0]) != "i": continue value = float(df.loc[row_key, hh_fd_key]) if abs(value) <= 1e-14: continue df.loc[row_key, hh_fd_key] = 0.0 df.loc[row_key, hh_key] += value total_expenditure += value sam.replace_dataframe(df) return { "converted": 1, "total_expenditure": total_expenditure, }
[docs] def create_government_flows(sam: Sam, op: dict[str, Any]) -> dict[str, Any]: """ Consolidate government flows (fiscal revenues + expenditure). Creates: - ("AG", "ti") → ("AG", "gvt"): Indirect taxes - ("AG", "tm") → ("AG", "gvt"): Import tariffs - ("AG", "gvt") → ("I", *): Government consumption """ tax_rates = op.get("tax_rates", DEFAULT_TAX_RATES) df = sam.to_dataframe() # Ensure accounts exist gvt_key = ("AG", "gvt") ti_key = ("AG", "ti") tm_key = ("AG", "tm") for key in [gvt_key, ti_key, tm_key]: ensure_key(sam, key) df = sam.to_dataframe() # Calculate indirect taxes on production production_tax_rate = tax_rates.get("production_tax", 0.10) total_ti = 0.0 for col_key in sam.col_keys: if norm_text_lower(col_key[0]) == "j": # Tax on sector output (sum of commodity inputs) sector_output = 0.0 for row_key in sam.row_keys: if norm_text_lower(row_key[0]) == "i": sector_output += float(df.loc[row_key, col_key]) if sector_output > 1e-14: tax = sector_output * production_tax_rate df.loc[ti_key, col_key] += tax total_ti += tax # Calculate import tariffs import_tariff_rate = tax_rates.get("import_tariff", 0.05) total_tm = 0.0 imp_key = None for cat, elem in sam.row_keys: if norm_text_lower(cat) == "imp": imp_key = (cat, elem) break if imp_key is not None and imp_key in df.index: for col_key in sam.col_keys: if norm_text_lower(col_key[0]) == "i": import_value = float(df.loc[imp_key, col_key]) if import_value > 1e-14: tariff = import_value * import_tariff_rate df.loc[tm_key, col_key] += tariff total_tm += tariff # Route tax revenues to government (gvt receives from ti/tm). # SAM convention: df.loc[receiver_row, payer_col]. if total_ti > 1e-14: df.loc[gvt_key, ti_key] += total_ti if total_tm > 1e-14: df.loc[gvt_key, tm_key] += total_tm # Transfer government consumption from FD gov_fd_key = None for cat, elem in sam.col_keys: elem_lower = norm_text_lower(elem) if "gov" in elem_lower or "gob" in elem_lower or "gobierno" in elem_lower: gov_fd_key = (cat, elem) break # Move I → FD.GOV flows to I → AG.gvt: government pays for the commodity. # SAM convention: df.loc[I_row, gvt_col] (gvt column = gvt expenditure on I). total_gov_consumption = 0.0 if gov_fd_key is not None: for row_key in sam.row_keys: if norm_text_lower(row_key[0]) != "i": continue value = float(df.loc[row_key, gov_fd_key]) if abs(value) <= 1e-14: continue df.loc[row_key, gov_fd_key] = 0.0 df.loc[row_key, gvt_key] += value total_gov_consumption += value sam.replace_dataframe(df) return { "total_ti": total_ti, "total_tm": total_tm, "total_gov_consumption": total_gov_consumption, "tax_rates": tax_rates, }
[docs] def create_row_account(sam: Sam, op: dict[str, Any]) -> dict[str, Any]: """ Create rest-of-world (ROW) account for foreign trade. Before: ("IMP", "total") → ("I", "agr") = 15 # Imports ("I", "agr") → ("FD", "EXP") = 10 # Exports After: ("AG", "row") → ("I", "agr") = 15 # Import supply ("I", "agr") → ("AG", "row") = 10 # Export demand (moved to X later) """ df = sam.to_dataframe() # Ensure ROW account exists row_acct_key = ("AG", "row") ensure_key(sam, row_acct_key) df = sam.to_dataframe() total_imports = 0.0 total_exports = 0.0 # Handle imports: IMP → I becomes AG.row → I imp_key = None for cat, elem in sam.row_keys: if norm_text_lower(cat) == "imp": imp_key = (cat, elem) break if imp_key is not None and imp_key in df.index: # In the normalized SAM, imports sit in IMP row × J (sector) columns: # they are intermediate import use by sectors. We route them to the # ROW account (AG.row) so the import flow becomes AG.row → J. for col_key in sam.col_keys: if norm_text_lower(col_key[0]) != "j": continue value = float(df.loc[imp_key, col_key]) if abs(value) <= 1e-14: continue df.loc[imp_key, col_key] = 0.0 df.loc[row_acct_key, col_key] += value total_imports += value # Handle exports: I → FD.EXP becomes I → AG.row exp_fd_key = None for cat, elem in sam.col_keys: elem_lower = norm_text_lower(elem) if "exp" in elem_lower or "export" in elem_lower: exp_fd_key = (cat, elem) break if exp_fd_key is not None: for row_key in sam.row_keys: if norm_text_lower(row_key[0]) != "i": continue value = float(df.loc[row_key, exp_fd_key]) if abs(value) <= 1e-14: continue df.loc[row_key, exp_fd_key] = 0.0 df.loc[row_key, row_acct_key] += value # Picked up by convert_exports_to_x_on_sam total_exports += value sam.replace_dataframe(df) return { "total_imports": total_imports, "total_exports": total_exports, }
[docs] def create_investment_account(sam: Sam, op: dict[str, Any]) -> dict[str, Any]: """ Create savings-investment closure. Creates: - ("AG", *) → ("OTH", "inv"): Savings from institutions - ("I", *) → ("OTH", "inv"): Investment demand """ df = sam.to_dataframe() # Ensure investment account exists inv_key = ("OTH", "inv") ensure_key(sam, inv_key) df = sam.to_dataframe() # Transfer investment demand from FD inv_fd_key = None for cat, elem in sam.col_keys: elem_lower = norm_text_lower(elem) if "inv" in elem_lower or "inversion" in elem_lower: inv_fd_key = (cat, elem) break total_investment = 0.0 if inv_fd_key is not None: for row_key in sam.row_keys: if norm_text_lower(row_key[0]) != "i": continue value = float(df.loc[row_key, inv_fd_key]) if abs(value) <= 1e-14: continue df.loc[row_key, inv_fd_key] = 0.0 df.loc[row_key, inv_key] += value total_investment += value # Calculate institutional savings (residual balancing). # SAM convention: df.loc[row, col] means row receives from col, i.e. col pays row. # If income > expenditure the institution has a surplus and pays into investment: # OTH.inv (row, receiver) ← AG.x (col, payer). So df.loc[inv_key, ag_key] += savings. # If income < expenditure the institution dissaves: investment pays into the # institution, df.loc[ag_key, inv_key] += -savings. total_savings = 0.0 for ag_key in sam.row_keys: if norm_text_lower(ag_key[0]) != "ag": continue # Skip tax accounts elem_lower = norm_text_lower(ag_key[1]) if elem_lower in {"ti", "tm", "tx", "td"}: continue # Income = sum of row income = float(df.loc[ag_key, :].sum()) # Expenditure = sum of column if ag_key in df.columns: expenditure = float(df.loc[:, ag_key].sum()) else: expenditure = 0.0 savings = income - expenditure if savings > 1e-14: df.loc[inv_key, ag_key] += savings total_savings += savings elif savings < -1e-14: df.loc[ag_key, inv_key] += -savings total_savings += savings sam.replace_dataframe(df) return { "total_investment": total_investment, "total_savings": total_savings, "balance_diff": abs(total_investment - total_savings), }