"""
data_protection_public.py
HouseClick.net — AI Accountant Agent

PURPOSE
-------
This file is the complete, publicly auditable implementation of the
user data encryption layer for aiaccountant.houseclick.net.

WHAT THIS FILE GUARANTEES
--------------------------
1. Sensitive financial fields are encrypted before database storage.
2. Service administrators cannot decrypt user data.
3. Decryption is only allowed for approved output actions.
4. Decryption attempts are written to an immutable audit trail.
5. Aggregations can run on encrypted values via Paillier PHE.
"""

from __future__ import annotations

import base64
import hashlib
import hmac
import json
import secrets
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from decimal import Decimal, ROUND_HALF_UP
from typing import Any, Union

from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from phe import paillier

from core import data_protection_secure as _dp_secure
from core.data_protection_audit import append_event


class IntegrityError(Exception):
    """Raised when ciphertext integrity verification fails."""


@dataclass(frozen=True)
class EncryptedField:
    """Serialized envelope-encrypted field payload."""

    ciphertext_b64: str
    wrapped_dek: str
    nonce_b64: str
    aad_b64: str
    key_version: int
    alg: str
    created_at: str
    user_id: str


@dataclass(frozen=True)
class DeletionCertificate:
    """Result object returned after cryptographic erasure."""

    user_id: str
    timestamp: str
    authorised_by: str
    reason: str
    certificate_id: str


_TOKEN_VAULT: dict[str, dict[str, str]] = {}
_SENSITIVE_KEYS = {
    "email",
    "phone",
    "tax_id",
    "amount",
    "vat_amount",
    "employee_name",
    "salary",
    "form_content",
    "extracted_json",
    "vendor_name",
    "gross_salary",
    "net_salary",
    "employer_tax",
}


def _utc_now() -> str:
    return datetime.now(timezone.utc).isoformat()


def _make_aad(user_id: str, field_name: str, record_id: str) -> bytes:
    return f"tenant:{user_id}|field:{field_name}|record:{record_id}".encode("utf-8")


def _pack_value(value: Any) -> bytes:
    payload = {"value": value}
    return json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode("utf-8")


def _unpack_value(raw: bytes) -> Any:
    payload = json.loads(raw.decode("utf-8"))
    return payload["value"]


def _normalize_for_hash(value: str) -> str:
    return value.strip().lower()


def _serialize_encrypted_field(field: EncryptedField) -> str:
    return json.dumps(
        {
            "ciphertext_b64": field.ciphertext_b64,
            "wrapped_dek": field.wrapped_dek,
            "nonce_b64": field.nonce_b64,
            "aad_b64": field.aad_b64,
            "key_version": field.key_version,
            "alg": field.alg,
            "created_at": field.created_at,
            "user_id": field.user_id,
        },
        separators=(",", ":"),
    )


def _deserialize_encrypted_field(raw: Union[str, dict[str, Any], EncryptedField]) -> EncryptedField:
    if isinstance(raw, EncryptedField):
        return raw
    payload = json.loads(raw) if isinstance(raw, str) else raw
    return EncryptedField(
        ciphertext_b64=str(payload["ciphertext_b64"]),
        wrapped_dek=str(payload["wrapped_dek"]),
        nonce_b64=str(payload["nonce_b64"]),
        aad_b64=str(payload["aad_b64"]),
        key_version=int(payload["key_version"]),
        alg=str(payload["alg"]),
        created_at=str(payload["created_at"]),
        user_id=str(payload["user_id"]),
    )


def _is_phe_payload(raw: str) -> bool:
    try:
        payload = json.loads(raw)
    except Exception:
        return False
    return isinstance(payload, dict) and "ciphertext" in payload and "exponent" in payload


def encrypt_field(
    value: str | int | float | dict,
    field_name: str,
    record_id: str,
    user_id: str,
) -> EncryptedField:
    """
    Encrypts a single field before database storage.
    Algorithm: AES-256-GCM
    Key model: Per-record DEK, wrapped by per-user KEK
    """
    dek = secrets.token_bytes(32)
    aes = AESGCM(dek)
    nonce = secrets.token_bytes(12)
    aad = _make_aad(user_id, field_name, record_id)
    plaintext = _pack_value(value)
    ciphertext = aes.encrypt(nonce, plaintext, aad)
    key_version = _dp_secure.get_active_key_version()
    wrapped_dek = _dp_secure.kms_wrap_dek(dek, user_id=user_id, key_version=key_version)
    return EncryptedField(
        ciphertext_b64=base64.b64encode(ciphertext).decode("ascii"),
        wrapped_dek=wrapped_dek,
        nonce_b64=base64.b64encode(nonce).decode("ascii"),
        aad_b64=base64.b64encode(aad).decode("ascii"),
        key_version=key_version,
        alg="AES-256-GCM",
        created_at=_utc_now(),
        user_id=user_id,
    )


def decrypt_for_output(
    encrypted_field: Union[EncryptedField, str],
    field_name: str,
    record_id: str,
    actor_context: _dp_secure.ActorContext,
) -> str | int | float | dict:
    """
    Decrypts a field only for authorised output actions.
    Every call (allowed or blocked) is added to immutable audit logs.
    """
    action = actor_context.action
    try:
        _dp_secure.assert_actor_allowed(actor_context, action)
    except PermissionError as exc:
        key_version = (
            encrypted_field.key_version if isinstance(encrypted_field, EncryptedField) else 0
        )
        append_event(
            actor_id=actor_context.actor_id,
            actor_role=actor_context.role,
            action=action,
            field_name=field_name,
            record_id=record_id,
            key_version=key_version,
            result="BLOCKED",
            reason=str(exc),
        )
        raise

    if isinstance(encrypted_field, str) and _is_phe_payload(encrypted_field):
        value = _decrypt_phe_json(encrypted_field, actor_context.actor_id)
        append_event(
            actor_id=actor_context.actor_id,
            actor_role=actor_context.role,
            action=action,
            field_name=field_name,
            record_id=record_id,
            key_version=_dp_secure.get_active_key_version(),
            result="ALLOWED",
            reason="policy_check_passed_phe",
        )
        return value

    encrypted_field_obj = _deserialize_encrypted_field(encrypted_field)
    if actor_context.role == "user" and str(actor_context.actor_id) != str(encrypted_field_obj.user_id):
        append_event(
            actor_id=actor_context.actor_id,
            actor_role=actor_context.role,
            action=action,
            field_name=field_name,
            record_id=record_id,
            key_version=encrypted_field_obj.key_version,
            result="BLOCKED",
            reason="User cannot decrypt another user's data",
        )
        raise PermissionError("User cannot decrypt another user's data")

    try:
        expected_aad = _make_aad(encrypted_field_obj.user_id, field_name, record_id)
        actual_aad = base64.b64decode(encrypted_field_obj.aad_b64)
        if actual_aad != expected_aad:
            raise IntegrityError("AAD mismatch detected")
        verify_integrity(encrypted_field_obj)
        dek = _dp_secure.kms_unwrap_dek(
            encrypted_field_obj.wrapped_dek,
            user_id=encrypted_field_obj.user_id,
            key_version=encrypted_field_obj.key_version,
        )
        aes = AESGCM(dek)
        plaintext = aes.decrypt(
            base64.b64decode(encrypted_field_obj.nonce_b64),
            base64.b64decode(encrypted_field_obj.ciphertext_b64),
            actual_aad,
        )
        value = _unpack_value(plaintext)
        append_event(
            actor_id=actor_context.actor_id,
            actor_role=actor_context.role,
            action=action,
            field_name=field_name,
            record_id=record_id,
            key_version=encrypted_field_obj.key_version,
            result="ALLOWED",
            reason="policy_check_passed",
        )
        return value
    except IntegrityError:
        append_event(
            actor_id=actor_context.actor_id,
            actor_role=actor_context.role,
            action=action,
            field_name=field_name,
            record_id=record_id,
            key_version=encrypted_field_obj.key_version,
            result="INTEGRITY_FAIL",
            reason="ciphertext_tamper_or_mismatch",
        )
        raise


def hash_for_lookup(value: str, field_name: str) -> str:
    """
    Creates a blind index for searchable sensitive fields.
    Algorithm: HMAC-SHA-256.
    """
    normalized = _normalize_for_hash(value)
    pepper = _dp_secure.get_pepper(field_name)
    data = f"{field_name}:{normalized}".encode("utf-8")
    return hmac.new(pepper, data, hashlib.sha256).hexdigest()


def encrypt_for_arithmetic(amount: Decimal, user_id: str) -> str:
    """
    Encrypts a numeric amount using Paillier Partial Homomorphic Encryption.

    Returns:
        JSON string with base64-encoded Paillier ciphertext and exponent.
    """
    if not isinstance(amount, Decimal):
        amount = Decimal(str(amount))
    amount_int = int((amount * 100).quantize(Decimal("1"), rounding=ROUND_HALF_UP))
    pub_key = _dp_secure.get_phe_public_key(user_id)
    ciphertext = pub_key.encrypt(amount_int)
    serialized = {
        "ciphertext": base64.b64encode(str(ciphertext.ciphertext()).encode("utf-8")).decode(
            "ascii"
        ),
        "exponent": ciphertext.exponent,
    }
    return json.dumps(serialized, separators=(",", ":"))


def phe_add(ct_a: str, ct_b: str, user_id: str) -> str:
    """
    Adds two PHE-encrypted values without decryption.
    """
    data_a = json.loads(ct_a)
    data_b = json.loads(ct_b)
    pub_key = _dp_secure.get_phe_public_key(user_id)
    ct_a_obj = paillier.EncryptedNumber(
        pub_key,
        int(base64.b64decode(data_a["ciphertext"]).decode("utf-8")),
        exponent=data_a["exponent"],
    )
    ct_b_obj = paillier.EncryptedNumber(
        pub_key,
        int(base64.b64decode(data_b["ciphertext"]).decode("utf-8")),
        exponent=data_b["exponent"],
    )
    ct_sum = ct_a_obj + ct_b_obj
    result = {
        "ciphertext": base64.b64encode(str(ct_sum.ciphertext()).encode("utf-8")).decode(
            "ascii"
        ),
        "exponent": ct_sum.exponent,
    }
    return json.dumps(result, separators=(",", ":"))


def tokenize_identifier(value: str, token_type: str) -> str:
    """
    Replaces a sensitive identifier with an opaque UUID token.
    """
    token = str(uuid.uuid4())
    _TOKEN_VAULT[token] = {
        "token_type": token_type,
        "value": value,
        "created_at": _utc_now(),
    }
    return token


def redact_for_logs(payload: dict | str) -> dict | str:
    """
    Removes or masks PII/financial data before logging.
    """
    if isinstance(payload, str):
        return "[REDACTED]" if any(k in payload.lower() for k in _SENSITIVE_KEYS) else payload

    redacted: dict[str, Any] = {}
    for key, value in payload.items():
        lowered = str(key).lower()
        if lowered in _SENSITIVE_KEYS:
            redacted[key] = 0.0 if isinstance(value, (int, float, Decimal)) else "[REDACTED]"
        elif isinstance(value, dict):
            redacted[key] = redact_for_logs(value)
        else:
            redacted[key] = value
    return redacted


def verify_integrity(encrypted_field: EncryptedField) -> bool:
    """
    Verifies the AES-GCM authentication tag on stored ciphertext.
    Raises IntegrityError if modified.
    """
    try:
        dek = _dp_secure.kms_unwrap_dek(
            encrypted_field.wrapped_dek,
            user_id=encrypted_field.user_id,
            key_version=encrypted_field.key_version,
        )
        aes = AESGCM(dek)
        aes.decrypt(
            base64.b64decode(encrypted_field.nonce_b64),
            base64.b64decode(encrypted_field.ciphertext_b64),
            base64.b64decode(encrypted_field.aad_b64),
        )
        return True
    except Exception as exc:
        raise IntegrityError("Ciphertext integrity check failed") from exc


def _decrypt_phe_json(ct_json: str, user_id: str) -> Decimal:
    """Decrypts a Paillier ciphertext JSON payload to Decimal(2dp)."""
    data = json.loads(ct_json)
    priv_key = _dp_secure.get_phe_private_key(user_id)
    ct_obj = paillier.EncryptedNumber(
        priv_key.public_key,
        int(base64.b64decode(data["ciphertext"]).decode("utf-8")),
        exponent=int(data["exponent"]),
    )
    decrypted_int = priv_key.decrypt(ct_obj)
    return (Decimal(decrypted_int) / Decimal("100")).quantize(Decimal("0.01"))


def phe_scalar_multiply(ct_json: str, scalar: float, user_id: str) -> str:
    """Public wrapper for homomorphic scalar multiplication helper."""
    return _dp_secure.phe_scalar_multiply(ct_json, scalar, user_id)


def destroy_subject_data(user_id: str, reason: str, authorised_by: str) -> DeletionCertificate:
    """
    Permanently destroys user data via cryptographic erasure.
    """
    _dp_secure.destroy_user_keys(user_id)
    tokens_to_delete = [
        token for token, meta in _TOKEN_VAULT.items() if meta.get("value") == user_id
    ]
    for token in tokens_to_delete:
        _TOKEN_VAULT.pop(token, None)

    cert = DeletionCertificate(
        user_id=user_id,
        timestamp=_utc_now(),
        authorised_by=authorised_by,
        reason=reason,
        certificate_id=str(uuid.uuid4()),
    )
    append_event(
        actor_id=authorised_by,
        actor_role="system",
        action="gdpr_destroy_subject_data",
        field_name="*",
        record_id=user_id,
        key_version=_dp_secure.get_active_key_version(),
        result="ALLOWED",
        reason=f"erasure_completed:{reason}",
    )
    return cert


# ============================================================
# User-Specific Encryption Helpers
# ============================================================

def hash_email_for_lookup(email: str) -> str:
    """Convenience wrapper for login lookup blind index."""
    return hash_for_lookup(email, "email")


def encrypt_user_record(user_data: dict, user_id: str) -> dict:
    """
    Encrypts sensitive user fields before database write.
    Returns ciphertext blobs and blind indexes ready for persistence.
    """
    result: dict[str, Any] = {}

    email = str(user_data.get("email") or "").strip().lower()
    if email:
        result["email_ciphertext"] = _serialize_encrypted_field(
            encrypt_field(email, "email", f"user_{user_id}", user_id)
        )
        result["email_hash"] = hash_for_lookup(email, "email")

    full_name = str(user_data.get("full_name") or user_data.get("name") or "").strip()
    if full_name:
        result["full_name_ciphertext"] = _serialize_encrypted_field(
            encrypt_field(full_name, "full_name", f"user_{user_id}", user_id)
        )

    phone = str(user_data.get("phone") or "").strip()
    if phone:
        result["phone_ciphertext"] = _serialize_encrypted_field(
            encrypt_field(phone, "phone", f"user_{user_id}", user_id)
        )

    tax_id = str(user_data.get("tax_id") or "").strip()
    if tax_id:
        result["tax_id_ciphertext"] = _serialize_encrypted_field(
            encrypt_field(tax_id, "tax_id", f"user_{user_id}", user_id)
        )
        result["tax_id_hash"] = hash_for_lookup(tax_id, "tax_id")

    return result


def decrypt_user_for_output(
    user_record: dict,
    actor_context: _dp_secure.ActorContext,
    requesting_user_id: str,
) -> dict:
    """
    Decrypts user fields only for the authorised subject or allowed system action.
    """
    if actor_context.role == "user" and str(actor_context.actor_id) != str(requesting_user_id):
        raise PermissionError(
            f"User {actor_context.actor_id} cannot access data of {requesting_user_id}"
        )

    result: dict[str, Any] = {}
    user_record_id = f"user_{requesting_user_id}"

    email_ct = user_record.get("email_ciphertext") or user_record.get("email_enc")
    if email_ct:
        result["email"] = decrypt_for_output(email_ct, "email", user_record_id, actor_context)
    elif user_record.get("email"):
        result["email"] = user_record["email"]

    name_ct = user_record.get("full_name_ciphertext") or user_record.get("full_name_enc")
    if name_ct:
        result["full_name"] = decrypt_for_output(name_ct, "full_name", user_record_id, actor_context)
    elif user_record.get("name"):
        result["full_name"] = user_record["name"]

    phone_ct = user_record.get("phone_ciphertext") or user_record.get("phone_enc")
    if phone_ct:
        result["phone"] = decrypt_for_output(phone_ct, "phone", user_record_id, actor_context)
    elif user_record.get("phone"):
        result["phone"] = user_record["phone"]

    tax_ct = user_record.get("tax_id_ciphertext") or user_record.get("tax_id_enc")
    if tax_ct:
        result["tax_id"] = decrypt_for_output(tax_ct, "tax_id", user_record_id, actor_context)
    elif user_record.get("tax_id"):
        result["tax_id"] = user_record["tax_id"]

    return result


def find_user_by_email(email: str, db_session) -> dict | None:
    """
    Finds a user by blind index lookup (email_hash), never by plaintext email.
    """
    from core.database import User

    email_hash = hash_for_lookup(email, "email")
    user = db_session.query(User).filter(User.email_hash == email_hash).first()
    if not user:
        return None

    return {
        "id": user.id,
        "email": user.email,
        "password_hash": user.password_hash,
        "name": user.name,
        "phone": user.phone,
        "tax_id": user.tax_id,
        "email_hash": getattr(user, "email_hash", None),
        "tax_id_hash": getattr(user, "tax_id_hash", None),
        "email_ciphertext": getattr(user, "email_enc", None),
        "full_name_ciphertext": getattr(user, "full_name_enc", None),
        "phone_ciphertext": getattr(user, "phone_enc", None),
        "tax_id_ciphertext": getattr(user, "tax_id_enc", None),
    }


# ============================================================
# Document-Specific Encryption Helpers
# ============================================================

def encrypt_document_record(document_data: dict, document_id: str, user_id: str) -> dict:
    """
    Encrypts document extracted_json content and creates integrity hash.
    """
    result: dict[str, Any] = {}
    extracted_json = document_data.get("extracted_json")
    if extracted_json is None:
        extracted_json = document_data.get("extracted_data")

    if extracted_json:
        if isinstance(extracted_json, dict):
            extracted_payload = json.dumps(extracted_json, ensure_ascii=False)
        else:
            extracted_payload = str(extracted_json)

        result["extracted_json_ciphertext"] = _serialize_encrypted_field(
            encrypt_field(extracted_payload, "extracted_json", str(document_id), str(user_id))
        )
        result["document_hash"] = hashlib.sha256(extracted_payload.encode("utf-8")).hexdigest()
    return result


def decrypt_document_for_output(
    document_record: dict,
    actor_context: _dp_secure.ActorContext,
    user_id: str,
) -> dict:
    """
    Decrypts document extracted_json for authorised output only.
    """
    if actor_context.role == "user" and str(actor_context.actor_id) != str(user_id):
        raise PermissionError(f"User {actor_context.actor_id} cannot access document of {user_id}")

    result: dict[str, Any] = {}
    cipher = (
        document_record.get("extracted_json_ciphertext")
        or document_record.get("extracted_data")
        or document_record.get("extracted_json")
    )
    if cipher:
        plaintext_json = decrypt_for_output(
            cipher,
            "extracted_json",
            str(document_record.get("id", "unknown")),
            actor_context,
        )
        try:
            result["extracted_json"] = json.loads(plaintext_json)
        except Exception:
            result["extracted_json"] = plaintext_json
    return result


# ============================================================
# Transaction-Specific Encryption Helpers
# ============================================================

def encrypt_transaction_record(transaction_data: dict, transaction_id: str, user_id: str) -> dict:
    """
    Encrypts transaction fields and creates PHE ciphertexts for arithmetic.
    """
    result: dict[str, Any] = {}

    vendor_name = transaction_data.get("vendor_name")
    if vendor_name:
        result["vendor_name_ciphertext"] = _serialize_encrypted_field(
            encrypt_field(str(vendor_name), "vendor_name", str(transaction_id), str(user_id))
        )

    if transaction_data.get("amount") is not None:
        amount_str = str(transaction_data["amount"])
        result["amount_ciphertext"] = _serialize_encrypted_field(
            encrypt_field(amount_str, "amount", str(transaction_id), str(user_id))
        )
        result["amount_phe"] = encrypt_for_arithmetic(Decimal(amount_str), str(user_id))

    if transaction_data.get("vat_amount") is not None:
        vat_str = str(transaction_data["vat_amount"])
        result["vat_amount_ciphertext"] = _serialize_encrypted_field(
            encrypt_field(vat_str, "vat_amount", str(transaction_id), str(user_id))
        )
        result["vat_amount_phe"] = encrypt_for_arithmetic(Decimal(vat_str), str(user_id))

    notes = transaction_data.get("notes")
    if notes:
        result["notes_ciphertext"] = _serialize_encrypted_field(
            encrypt_field(str(notes), "notes", str(transaction_id), str(user_id))
        )

    return result


def decrypt_transaction_for_output(
    transaction_record: dict,
    actor_context: _dp_secure.ActorContext,
    user_id: str,
) -> dict:
    """
    Decrypts transaction fields for authorised output only.
    """
    if actor_context.role == "user" and str(actor_context.actor_id) != str(user_id):
        raise PermissionError(f"User {actor_context.actor_id} cannot access transaction of {user_id}")

    result: dict[str, Any] = {}
    record_id = str(transaction_record.get("id", "unknown"))

    vendor_ct = transaction_record.get("vendor_name_ciphertext") or transaction_record.get("vendor_name_enc")
    if vendor_ct:
        result["vendor_name"] = decrypt_for_output(vendor_ct, "vendor_name", record_id, actor_context)

    amount_ct = transaction_record.get("amount_ciphertext") or transaction_record.get("amount_enc")
    if amount_ct:
        amount_str = decrypt_for_output(amount_ct, "amount", record_id, actor_context)
        result["amount"] = Decimal(str(amount_str))

    vat_ct = transaction_record.get("vat_amount_ciphertext") or transaction_record.get("vat_amount_enc")
    if vat_ct:
        vat_str = decrypt_for_output(vat_ct, "vat_amount", record_id, actor_context)
        result["vat_amount"] = Decimal(str(vat_str))

    notes_ct = transaction_record.get("notes_ciphertext") or transaction_record.get("notes_enc")
    if notes_ct:
        result["notes"] = decrypt_for_output(notes_ct, "notes", record_id, actor_context)

    return result


# ============================================================
# Payroll-Specific Encryption Helpers
# ============================================================

def encrypt_payroll_record(payroll_data: dict, payroll_id: str, user_id: str) -> dict:
    """
    Encrypts payroll entry fields and creates PHE ciphertexts for arithmetic.
    """
    result: dict[str, Any] = {}

    employee_name = payroll_data.get("employee_name")
    if employee_name:
        result["employee_name_ciphertext"] = _serialize_encrypted_field(
            encrypt_field(str(employee_name), "employee_name", str(payroll_id), str(user_id))
        )

    if payroll_data.get("gross_salary") is not None:
        gross_str = str(payroll_data["gross_salary"])
        result["gross_salary_ciphertext"] = _serialize_encrypted_field(
            encrypt_field(gross_str, "gross_salary", str(payroll_id), str(user_id))
        )
        result["gross_salary_phe"] = encrypt_for_arithmetic(Decimal(gross_str), str(user_id))

    if payroll_data.get("net_salary") is not None:
        net_str = str(payroll_data["net_salary"])
        result["net_salary_ciphertext"] = _serialize_encrypted_field(
            encrypt_field(net_str, "net_salary", str(payroll_id), str(user_id))
        )

    if payroll_data.get("employer_tax") is not None:
        tax_str = str(payroll_data["employer_tax"])
        result["employer_tax_ciphertext"] = _serialize_encrypted_field(
            encrypt_field(tax_str, "employer_tax", str(payroll_id), str(user_id))
        )

    return result
