#!/usr/bin/env python3
import argparse
import os
import socket
import sys
import threading
import time
from datetime import datetime

import faultcore
from faultcore.shm_writer import SHM_SIZE


def ensure_shm_ready() -> str:
    name = os.environ.get("FAULTCORE_CONFIG_SHM", f"/faultcore_{os.getpid()}_config")
    os.environ["FAULTCORE_CONFIG_SHM"] = name
    path = f"/dev/shm/{name.lstrip('/')}"
    fd = os.open(path, os.O_CREAT | os.O_RDWR, 0o600)
    try:
        os.ftruncate(fd, SHM_SIZE)
    finally:
        os.close(fd)
    return name


def start_tcp_push_server(host: str) -> tuple[int, threading.Thread, threading.Event]:
    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server.bind((host, 0))
    server.listen(1)
    port = server.getsockname()[1]
    sent_ready = threading.Event()

    def run() -> None:
        conn = None
        try:
            conn, _ = server.accept()
            conn.sendall(b"pkt00001")
            conn.sendall(b"pkt00002")
            sent_ready.set()
        finally:
            if conn is not None:
                conn.close()
            server.close()

    thread = threading.Thread(target=run, daemon=True)
    thread.start()
    return port, thread, sent_ready


def start_udp_push_sender(host: str, client_port: int) -> threading.Thread:
    def run() -> None:
        sender = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        try:
            time.sleep(0.05)
            sender.sendto(b"dat00001", (host, client_port))
            sender.sendto(b"dat00002", (host, client_port))
        finally:
            sender.close()

    thread = threading.Thread(target=run, daemon=True)
    thread.start()
    return thread


def start_udp_sequence_sender(host: str, client_port: int, count: int) -> threading.Thread:
    def run() -> None:
        sender = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        try:
            time.sleep(0.05)
            for i in range(count):
                sender.sendto(f"{i:08d}".encode(), (host, client_port))
        finally:
            sender.close()

    thread = threading.Thread(target=run, daemon=True)
    thread.start()
    return thread


@faultcore.packet_reorder(prob="100%", max_delay="100ms", window=2)
def recv_two_tcp(sock: socket.socket) -> tuple[bytes, bytes]:
    return sock.recv(8), sock.recv(8)


@faultcore.packet_reorder(prob="100%", max_delay="100ms", window=2)
def recv_two_udp(sock: socket.socket) -> tuple[bytes, bytes]:
    first, _ = sock.recvfrom(64)
    second, _ = sock.recvfrom(64)
    return first, second


@faultcore.packet_reorder(prob="100%", max_delay="100ms", window=2)
def recv_many_udp(sock: socket.socket, count: int) -> list[bytes]:
    out: list[bytes] = []
    for _ in range(count):
        data, _ = sock.recvfrom(64)
        out.append(data)
    return out


def run_tcp_case(host: str) -> None:
    port, _, sent_ready = start_tcp_push_server(host)
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
        sock.settimeout(None)
        sock.connect((host, port))
        if not sent_ready.wait(timeout=1):
            raise RuntimeError("tcp push server did not send packets in time")
        out1, out2 = recv_two_tcp(sock)

    print(f"tcp recv order: {out1!r} then {out2!r}")
    if out1 != b"pkt00002" or out2 != b"pkt00001":
        raise RuntimeError("tcp recv reorder failed: expected swapped order")


def run_udp_case(host: str) -> None:
    with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as client:
        client.bind((host, 0))
        client.settimeout(None)
        client_port = client.getsockname()[1]
        start_udp_push_sender(host, client_port)
        out1, out2 = recv_two_udp(client)

    print(f"udp recvfrom order: {out1!r} then {out2!r}")
    if out1 != b"dat00002" or out2 != b"dat00001":
        raise RuntimeError("udp recvfrom reorder failed: expected swapped order")


def run_udp_stress_case(host: str, count: int) -> None:
    if count < 2 or (count % 2) != 0:
        raise ValueError("count must be an even integer >= 2")
    with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as client:
        client.bind((host, 0))
        client.settimeout(None)
        client_port = client.getsockname()[1]
        start_udp_sequence_sender(host, client_port, count)
        out = recv_many_udp(client, count)

    if len(out) != count:
        raise RuntimeError(f"udp stress recv mismatch: expected {count}, got {len(out)}")

    for i in range(0, count, 2):
        expected_a = f"{i + 1:08d}".encode()
        expected_b = f"{i:08d}".encode()
        if out[i] != expected_a or out[i + 1] != expected_b:
            raise RuntimeError(
                "udp stress reorder failed at pair "
                f"{i}//{i + 1}: got {out[i]!r},{out[i + 1]!r} expected {expected_a!r},{expected_b!r}"
            )
    print(f"udp stress reorder: PASS (count={count})")


def main() -> int:
    parser = argparse.ArgumentParser(description="FaultCore downlink reorder integration probe")
    parser.add_argument("--host", default="127.0.0.1", help="bind host for local test servers")
    parser.add_argument("--port", type=int, default=9000, help="unused, kept for integration runner compatibility")
    parser.add_argument("--mode", choices=["tcp", "udp", "stress", "all"], default="all")
    parser.add_argument("--count", type=int, default=100, help="packet count for stress mode (must be even)")
    args = parser.parse_args()

    print(f"[{datetime.now().isoformat()}] reorder downlink integration mode={args.mode} host={args.host}")
    shm_name = ensure_shm_ready()
    print(f"using shm: {shm_name}")

    try:
        if args.mode in {"tcp", "all"}:
            run_tcp_case(args.host)
        if args.mode in {"udp", "all"}:
            run_udp_case(args.host)
        if args.mode in {"stress", "all"}:
            run_udp_stress_case(args.host, args.count)
    except Exception as exc:  # noqa: BLE001
        print(f"ERROR: {exc}")
        return 1

    print("reorder downlink integration: PASS")
    return 0


if __name__ == "__main__":
    sys.exit(main())
