Source code for pyiecwind.parsing

"""Input parsing and validation helpers for pyIECWind."""

from __future__ import annotations

import re
import warnings
from collections.abc import Sequence
from pathlib import Path
from typing import Protocol

from .models import (
    CASE_PREFIXES,
    CASE_TYPE_ORDER,
    DEFAULT_INPUT_FILENAME,
    FALSE_TOKENS,
    NONE_TOKENS,
    TRUE_TOKENS,
    IECParameters,
    IECWindWarning,
)

# `parse_input_file` is the only public name here; everything else is internal
# (see docs/architecture.md for the module-surface convention).
__all__ = ["parse_input_file"]

FIELD_ALIASES = {
    "si_unit": "si_unit",
    "si_units": "si_unit",
    "units": "si_unit",
    "t1": "t1",
    "transient_start_time": "t1",
    "transient_start": "t1",
    "wtc": "wtc",
    "wind_turbine_class": "wtc",
    "turbine_class": "wtc",
    "catg": "catg",
    "turbulence_category": "catg",
    "category": "catg",
    "slope": "slope_deg",
    "slope_deg": "slope_deg",
    "inflow_angle": "slope_deg",
    "inflow_inclination_deg": "slope_deg",
    "iec_edition": "iec_edition",
    "edition": "iec_edition",
    "hh": "hh",
    "hub_height": "hh",
    "hub_height_m_or_ft": "hh",
    "dia": "dia",
    "rotor_diameter": "dia",
    "diameter": "dia",
    "vin": "vin",
    "cut_in": "vin",
    "cut_in_speed": "vin",
    "vrated": "vrated",
    "rated_speed": "vrated",
    "vout": "vout",
    "cut_out": "vout",
    "cut_out_speed": "vout",
    "condition": "condition",
    "conditions": "conditions",
}


def _normalize_key(raw: str) -> str:
    return raw.strip().lower().replace("-", "_").replace(" ", "_")


def _assign_scalar_field(
    fields: dict[str, str],
    field_lines: dict[str, int],
    key: str,
    value: str,
    *,
    lineno: int,
) -> None:
    """Store a scalar field, rejecting a duplicate definition.

    Each scalar field may appear at most once. Re-defining it -- directly or via a
    different alias for the same field -- is an error rather than a silent
    last-wins overwrite, which previously masked typos such as two conflicting
    ``wtc`` lines.
    """

    if key in fields:
        raise ValueError(
            f"Duplicate field '{key}' on line {lineno}; it was already set on line "
            f"{field_lines[key]}. Each scalar field may appear only once."
        )
    fields[key] = value
    field_lines[key] = lineno


_SI_TRUE_TOKENS = {"T", "TRUE", ".TRUE.", "YES", "Y", "SI", "METRIC", "1"}
_SI_FALSE_TOKENS = {"F", "FALSE", ".FALSE.", "NO", "N", "ENGLISH", "IMPERIAL", "US", "0"}


def _parse_si_unit(raw: str, *, lineno: int | None = None) -> bool:
    """Parse the unit-system flag strictly, rejecting unrecognised tokens.

    Unknown values are an error rather than being silently treated as English
    units, which previously hid typos like ``si_unit = maybe``.
    """

    token = raw.strip().upper()
    if token in _SI_TRUE_TOKENS:
        return True
    if token in _SI_FALSE_TOKENS:
        return False
    location = f" on line {lineno}" if lineno is not None else ""
    raise ValueError(
        f"Cannot interpret si_unit value {raw!r}{location}. Use a boolean such as True/False (SI vs. English units)."
    )


def _normalize_case_options_text(text: str) -> str:
    text = text.strip()
    if text.startswith("[") and text.endswith("]"):
        return text[1:-1].strip()
    return text


def _split_case_options(text: str) -> list[str]:
    text = _normalize_case_options_text(text)
    if not text or text.upper() in NONE_TOKENS:
        return []
    return [item.strip() for item in text.split(",") if item.strip()]


def _expand_case_row(case_type: str, options: list[str], *, lineno: int) -> list[str]:
    if case_type not in CASE_PREFIXES:
        raise ValueError(f"Unknown case type on line {lineno}: {case_type}")
    prefix = CASE_PREFIXES[case_type]
    if case_type == "NWP":
        return [f"{prefix}{option}" for option in options]
    return [f"{prefix}{option.upper()}" for option in options]


def _parse_case_row(line: str, *, lineno: int) -> list[str]:
    parts = re.split(r"\s{2,}", line, maxsplit=3)
    if len(parts) < 3:
        raise ValueError(
            f"Cannot parse case row on line {lineno}: {line!r}. "
            "Expected '<case_type><spaces><True/False/None><spaces><options>'."
        )

    case_type = parts[0].strip().upper()
    enabled = parts[1].strip().upper()
    options_text = parts[2].strip()

    if case_type not in CASE_TYPE_ORDER:
        raise ValueError(f"Unknown case type on line {lineno}: {case_type}")
    if enabled in NONE_TOKENS or enabled in FALSE_TOKENS:
        return []
    if enabled not in TRUE_TOKENS:
        raise ValueError(f"Case enable flag on line {lineno} must be True, False, or None. Got: {parts[1]!r}")

    options = _split_case_options(options_text)
    if not options:
        return []
    return _expand_case_row(case_type, options, lineno=lineno)


def _group_conditions_by_type(conditions: Sequence[str]) -> dict[str, list[str]]:
    grouped: dict[str, list[str]] = {case_type: [] for case_type in CASE_TYPE_ORDER}
    for code in conditions:
        prefix = code[:3].upper()
        if prefix in grouped:
            grouped[prefix].append(code[3:])
    return grouped


def _parse_condition_value(value: str, *, lineno: int) -> str | None:
    tokens = value.split()
    if not tokens:
        raise ValueError(f"Missing condition code on line {lineno}.")

    first = tokens[0].upper()
    if first in TRUE_TOKENS | FALSE_TOKENS:
        if len(tokens) < 2:
            raise ValueError(f"Condition toggle on line {lineno} must be followed by a condition code.")
        return " ".join(tokens[1:]).upper() if first in TRUE_TOKENS else None
    if first in NONE_TOKENS:
        return None
    return value.upper()


def _append_condition_value(conditions: list[str], value: str, *, lineno: int) -> None:
    parsed = _parse_condition_value(value, lineno=lineno)
    if parsed is not None:
        conditions.append(parsed)


def _build_parameters(
    *,
    si_unit: bool,
    t1: float,
    wtc: int,
    catg: str,
    slope_deg: float,
    iec_edition: int,
    hh_raw: float,
    dia_raw: float,
    vin_raw: float,
    vrated_raw: float,
    vout_raw: float,
    conditions: list[str],
    legacy: bool = False,
) -> IECParameters:
    len_convert = 1.0 if si_unit else 3.2808

    if wtc not in (1, 2, 3):
        raise ValueError(f"Wind turbine class must be 1, 2, or 3. Got: {wtc}")

    catg = catg.upper()
    if catg not in ("A", "B", "C"):
        raise ValueError(f"Turbulence category must be A, B, or C. Got: {catg!r}")

    if abs(slope_deg) > 8.0:
        warnings.warn(
            f"IEC specifies a maximum inclination angle of 8 deg; you specified {slope_deg:.2f} degrees.",
            IECWindWarning,
            stacklevel=2,
        )

    if iec_edition not in (1, 3):
        # Fail closed for a scientific tool: an unsupported edition is an error
        # unless the caller explicitly opts into legacy coercion.
        if not legacy:
            raise ValueError(
                f"Unsupported IEC edition {iec_edition}; only editions 1 and 3 are supported. "
                "Pass legacy=True to coerce unsupported editions to edition 3."
            )
        warnings.warn(
            f"IEC edition should be 1 or 3. Got: {iec_edition}. "
            "Coercing to edition 3 (Alpha=0.14) because legacy mode is enabled.",
            IECWindWarning,
            stacklevel=2,
        )
        iec_edition = 3

    if dia_raw <= 0.0:
        raise ValueError(f"Rotor diameter must be positive. Got: {dia_raw}")
    if hh_raw <= dia_raw / 2.0:
        raise ValueError(
            f"Hub height ({hh_raw}) must be greater than rotor radius ({dia_raw / 2.0:.2f}). Check your input file."
        )

    hh = hh_raw / len_convert
    dia = dia_raw / len_convert
    vin = vin_raw / len_convert
    vrated = vrated_raw / len_convert
    vout = vout_raw / len_convert

    if vrated <= vin:
        raise ValueError(f"Rated speed ({vrated:.2f}) must exceed cut-in ({vin:.2f}).")
    if vout <= vrated:
        raise ValueError(f"Cut-out speed ({vout:.2f}) must exceed rated ({vrated:.2f}).")
    if not conditions:
        raise ValueError("No wind conditions found in input file.")

    return IECParameters(
        si_unit=si_unit,
        t1=t1,
        wtc=wtc,
        catg=catg,
        slope_deg=slope_deg,
        iec_edition=iec_edition,
        hh=hh,
        dia=dia,
        vin=vin,
        vrated=vrated,
        vout=vout,
        conditions=tuple(conditions),
    )


def _finalize_parsed_fields(fields: dict[str, str], conditions: list[str], *, legacy: bool = False) -> IECParameters:
    required = [
        "si_unit",
        "t1",
        "wtc",
        "catg",
        "slope_deg",
        "iec_edition",
        "hh",
        "dia",
        "vin",
        "vrated",
        "vout",
    ]
    missing = [name for name in required if name not in fields]
    if missing:
        raise ValueError(f"Missing required input field(s): {', '.join(missing)}.")

    return _build_parameters(
        si_unit=_parse_si_unit(fields["si_unit"]),
        t1=float(fields["t1"]),
        wtc=int(fields["wtc"]),
        catg=fields["catg"],
        slope_deg=float(fields["slope_deg"]),
        iec_edition=int(fields["iec_edition"]),
        hh_raw=float(fields["hh"]),
        dia_raw=float(fields["dia"]),
        vin_raw=float(fields["vin"]),
        vrated_raw=float(fields["vrated"]),
        vout_raw=float(fields["vout"]),
        conditions=conditions,
        legacy=legacy,
    )


def _parse_legacy_input_file(raw_lines: list[str], *, legacy: bool = False) -> IECParameters:
    while len(raw_lines) < 17:
        raw_lines.append("")

    def first_token(line: str) -> str:
        tokens = line.strip().split()
        if not tokens:
            raise ValueError(f"Expected a value but got an empty line: {line!r}")
        return tokens[0]

    def line_val(idx: int, name: str) -> str:
        try:
            return first_token(raw_lines[idx])
        except (IndexError, ValueError) as exc:
            raise ValueError(f"Premature end of file reading '{name}' at line {idx + 1}.") from exc

    si_unit = _parse_si_unit(line_val(2, "units specifier"), lineno=3)

    conditions: list[str] = []
    for raw in raw_lines[16:]:
        stripped = raw.strip()
        if not stripped:
            break
        conditions.append(stripped.upper())

    return _build_parameters(
        si_unit=si_unit,
        t1=float(line_val(3, "transient start time")),
        wtc=int(line_val(5, "wind turbine class")),
        catg=line_val(6, "turbulence category"),
        slope_deg=float(line_val(7, "wind inflow angle")),
        iec_edition=int(line_val(8, "IEC edition for wind shear exponent")),
        hh_raw=float(line_val(10, "hub height")),
        dia_raw=float(line_val(11, "rotor diameter")),
        vin_raw=float(line_val(12, "cut-in wind speed")),
        vrated_raw=float(line_val(13, "rated wind speed")),
        vout_raw=float(line_val(14, "cut-out wind speed")),
        conditions=conditions,
        legacy=legacy,
    )


def _parse_keyed_input_file(raw_lines: list[str], *, legacy: bool = False) -> IECParameters:
    fields: dict[str, str] = {}
    field_lines: dict[str, int] = {}
    conditions: list[str] = []
    in_conditions = False

    for lineno, raw_line in enumerate(raw_lines, start=1):
        stripped = raw_line.strip()
        if not stripped or stripped.startswith(("!", "#")):
            continue

        if in_conditions:
            if stripped.startswith("-"):
                _append_condition_value(conditions, stripped[1:].strip(), lineno=lineno)
                continue
            if "=" not in stripped and ":" not in stripped:
                _append_condition_value(conditions, stripped, lineno=lineno)
                continue
            in_conditions = False

        if ":" in stripped and stripped.split(":", 1)[0].strip().lower() == "conditions":
            trailing = stripped.split(":", 1)[1].strip()
            in_conditions = True
            if trailing:
                _append_condition_value(conditions, trailing, lineno=lineno)
            continue

        if "=" in stripped:
            raw_key, raw_value = stripped.split("=", 1)
        elif ":" in stripped:
            raw_key, raw_value = stripped.split(":", 1)
        else:
            raise ValueError(f"Cannot parse keyed input line {lineno}: {raw_line!r}")

        key = FIELD_ALIASES.get(_normalize_key(raw_key))
        if key is None:
            raise ValueError(f"Unknown input key on line {lineno}: {raw_key!r}")

        value = raw_value.strip()
        if key in {"condition", "conditions"}:
            _append_condition_value(conditions, value, lineno=lineno)
        else:
            _assign_scalar_field(fields, field_lines, key, value, lineno=lineno)

    return _finalize_parsed_fields(fields, conditions, legacy=legacy)


def _parse_openfast_input_file(raw_lines: list[str], *, legacy: bool = False) -> IECParameters:
    fields: dict[str, str] = {}
    field_lines: dict[str, int] = {}
    conditions: list[str] = []
    in_cases_section = False

    for lineno, raw_line in enumerate(raw_lines, start=1):
        stripped = raw_line.strip()
        if not stripped:
            continue
        if stripped.startswith(("!", "#")):
            if stripped.upper().startswith("! CASES"):
                in_cases_section = True
            continue

        if in_cases_section:
            first_token = re.split(r"\s+", stripped, maxsplit=1)[0].upper()
            if first_token in CASE_TYPE_ORDER:
                conditions.extend(_parse_case_row(stripped, lineno=lineno))
                continue

        parts = re.split(r"\s{2,}", stripped, maxsplit=2)
        if len(parts) < 2:
            raise ValueError(
                f"Cannot parse OpenFAST-style line {lineno}: {raw_line!r}. "
                "Expected '<value><spaces><key><spaces>- comment>'."
            )

        value = parts[0].strip()
        key = FIELD_ALIASES.get(_normalize_key(parts[1]))
        if key is None:
            raise ValueError(f"Unknown input key on line {lineno}: {raw_line!r}")

        if key in {"condition", "conditions"}:
            _append_condition_value(conditions, value, lineno=lineno)
            continue

        if not value:
            raise ValueError(f"Missing value for '{key}' on line {lineno}.")
        _assign_scalar_field(fields, field_lines, key, value, lineno=lineno)

    return _finalize_parsed_fields(fields, conditions, legacy=legacy)


class _LayoutParser(Protocol):
    def __call__(self, raw_lines: list[str], *, legacy: bool = ...) -> IECParameters: ...


# A file may pin its layout with a ``! format: <id>`` comment directive (see
# docs/data_sources.rst). When present it overrides auto-detection; aliases map
# the short and versioned spellings onto the same layout parser.
_FORMAT_PARSERS: dict[str, _LayoutParser] = {
    "openfast-table-v1": _parse_openfast_input_file,
    "openfast": _parse_openfast_input_file,
    "keyed-v1": _parse_keyed_input_file,
    "keyed": _parse_keyed_input_file,
    "legacy-v1": _parse_legacy_input_file,
    "legacy": _parse_legacy_input_file,
}

_FORMAT_DIRECTIVE = re.compile(r"^[!#]\s*format(?:_version)?\s*[:=]\s*(\S+)\s*$", re.IGNORECASE)


def _detect_declared_format(raw_lines: list[str]) -> str | None:
    """Return the layout id pinned by a ``! format: <id>`` directive, if any.

    The directive is a comment line, so it is inert to every layout's own parser;
    when it is absent the layout is auto-detected (the compatibility fallback).
    """

    for line in raw_lines:
        match = _FORMAT_DIRECTIVE.match(line.strip())
        if match:
            return match.group(1).strip().lower()
    return None


[docs] def parse_input_file(filepath: str | Path = DEFAULT_INPUT_FILENAME, *, legacy: bool = False) -> IECParameters: """Read an input file and return validated :class:`IECParameters`. The OpenFAST-style table, keyed (``key = value``), and legacy positional layouts are auto-detected. A file may instead pin its layout explicitly with a ``! format: <id>`` comment directive (``openfast-table-v1``, ``keyed-v1``, or ``legacy-v1``); when present it overrides auto-detection. See ``docs/data_sources.rst`` for the full Input Format v1 specification. Parameters ---------- filepath : str or pathlib.Path, optional Path to the input file. Defaults to ``pyiecwind.ipt`` in the current working directory. legacy : bool, default False If ``True``, unsupported IEC editions are coerced to edition 3 (emitting an :class:`~pyiecwind.IECWindWarning`) instead of raising. Provided for backward compatibility with older inputs. Returns ------- IECParameters The validated, immutable parameter object. Raises ------ FileNotFoundError If ``filepath`` does not exist. ValueError For malformed input, an unknown key, a duplicate scalar field, an unrecognised ``si_unit`` token or ``format`` directive, an out-of-range value, or (unless ``legacy=True``) an unsupported edition. Examples -------- >>> from pyiecwind import parse_input_file >>> params = parse_input_file("examples/sample_case.ipt") # doctest: +SKIP >>> params.wtc # doctest: +SKIP 2 """ path = Path(filepath) if not path.exists(): raise FileNotFoundError( f"Cannot find input file '{path}'. Ensure {DEFAULT_INPUT_FILENAME} is in the current working directory." ) # Read with utf-8-sig so a leading byte-order mark (common in files saved by # Windows editors such as Notepad) is stripped instead of corrupting line 1. raw_lines = path.read_text(encoding="utf-8-sig").splitlines() declared = _detect_declared_format(raw_lines) if declared is not None: parser = _FORMAT_PARSERS.get(declared) if parser is None: supported = ", ".join(sorted(_FORMAT_PARSERS)) raise ValueError(f"Unknown format directive '{declared}'. Supported values: {supported}.") return parser(raw_lines, legacy=legacy) keyed_pattern = re.compile(r"^[A-Za-z_][A-Za-z0-9_ -]*\s*(=|:)\s*.*$") openfast_pattern = re.compile(r"^\S+\s{2,}[A-Za-z_][A-Za-z0-9_ -]*\s*(?:\s+-.*)?$") keyed_format = any( stripped and not stripped.startswith(("!", "#")) and keyed_pattern.match(stripped) for stripped in (line.strip() for line in raw_lines) ) openfast_format = any( stripped and not stripped.startswith(("!", "#")) and openfast_pattern.match(stripped) for stripped in (line.strip() for line in raw_lines) ) if openfast_format: return _parse_openfast_input_file(raw_lines, legacy=legacy) if keyed_format: return _parse_keyed_input_file(raw_lines, legacy=legacy) return _parse_legacy_input_file(raw_lines, legacy=legacy)