"""
CONSIDER A SPHERICAL HUMAN:
Packing Density Bounds for Coffins, Clown Cars, and Coach Class

In the grand tradition of physics, we approximate the human body as a sphere
and investigate how many fit inside real-world enclosures.

This is a serious realization of a joke idea.
"""

import numpy as np
import json
import os

import pyvista as pv
import matplotlib.pyplot as plt
import matplotlib as mpl

from containers import (
    VENUES, get_volume, get_bounding_box,
    pack_spheres_fcc_in_venue, pack_spheres_cubic_in_venue,
    make_container_mesh, CONTAINMENT_CHECKS,
)

OUTPUT_DIR = "results_spherical"
os.makedirs(OUTPUT_DIR, exist_ok=True)

pv.OFF_SCREEN = True

plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 11,
    'axes.titlesize': 14,
    'axes.titleweight': 'bold',
    'axes.spines.top': False,
    'axes.spines.right': False,
    'figure.facecolor': '#FAFAFA',
    'axes.facecolor': '#FAFAFA',
    'axes.edgecolor': '#333333',
    'xtick.color': '#333333',
    'ytick.color': '#333333',
    'axes.labelcolor': '#333333',
    'text.color': '#333333',
    'grid.alpha': 0.4,
    'grid.linestyle': '--',
    'savefig.facecolor': '#FAFAFA',
})


# ============================================================
# Spherical Human Models
# ============================================================

# The Meatball: volume-equivalent sphere
# Average human body volume = 65.22 L (direct measurement via underwater weighing)
# Source: Nagao et al. 1995, BioNumbers BNID 109718
HUMAN_VOLUME_GLOBAL = 0.06522  # m^3

# The Freedom Sphere: American adult
# Average American adult mass = 83.6 kg (CDC NHANES 2015-2018, combined sexes)
# Tissue density ~985 kg/m^3
AMERICAN_MASS = 83.6     # kg
TISSUE_DENSITY = 985.0   # kg/m^3
AMERICAN_VOLUME = AMERICAN_MASS / TISSUE_DENSITY  # 0.0849 m^3

# The Hamster Ball: height-bounding sphere
# Global average human height ~1.70 m (NCD Risk Factor Collaboration, 2016)
HUMAN_HEIGHT = 1.70  # m

# The Junior: child (age 8, 50th percentile)
# Mass = 25.6 kg (CDC growth charts), height = 128 cm
# Child tissue density ~1010 kg/m^3 (less adipose tissue)
CHILD_MASS = 25.6        # kg
CHILD_DENSITY = 1010.0   # kg/m^3
CHILD_VOLUME = CHILD_MASS / CHILD_DENSITY  # 0.02535 m^3
CHILD_HEIGHT = 1.28      # m

SPHERE_MODELS = {
    "The Meatball": {
        "radius": (3 * HUMAN_VOLUME_GLOBAL / (4 * np.pi)) ** (1/3),
        "description": "Volume-equivalent adult sphere (r={r:.3f}m). A dense beach ball.",
        "color": "#E74C3C",
        "derivation": "V_body = 65.22 L, r = (3V/4pi)^(1/3)",
    },
    "The Hamster Ball": {
        "radius": HUMAN_HEIGHT / 2,
        "description": "Height-bounding sphere (r={r:.3f}m). Full-body hamster ball.",
        "color": "#F39C12",
        "derivation": "d = human height = 1.70m",
    },
    "The Freedom Sphere": {
        "radius": (3 * AMERICAN_VOLUME / (4 * np.pi)) ** (1/3),
        "description": "American adult volume sphere (r={r:.3f}m). Supersized.",
        "color": "#3498DB",
        "derivation": "m = 83.6 kg (CDC NHANES), rho = 985 kg/m^3",
    },
    "The Junior": {
        "radius": (3 * CHILD_VOLUME / (4 * np.pi)) ** (1/3),
        "description": "Child (age 8) volume sphere (r={r:.3f}m). Compact.",
        "color": "#27AE60",
        "derivation": "m = 25.6 kg (CDC 50th percentile), rho = 1010 kg/m^3",
    },
}

# Fill in formatted descriptions
for name, model in SPHERE_MODELS.items():
    model["description"] = model["description"].format(r=model["radius"])

# Kepler limit
KEPLER_DENSITY = np.pi / (3 * np.sqrt(2))  # ~0.7405


# ============================================================
# Rendering
# ============================================================

def render_sphere_packing(centers, radius, venue, filename, title="",
                           window_size=(1400, 900), max_render=5000,
                           export_3d=True):
    """Render spheres packed in a real container geometry.
    Also exports a .ply 3D model file for manual inspection.
    """
    pl = pv.Plotter(off_screen=True, window_size=window_size)
    pl.set_background('#F0EDE8')

    n_total = len(centers)
    if n_total <= max_render:
        indices = range(n_total)
    else:
        indices = np.linspace(0, n_total - 1, max_render, dtype=int)

    # Lower sphere mesh resolution for large counts to keep rendering fast
    if n_total > 2000:
        theta_res, phi_res = 6, 6
    elif n_total > 500:
        theta_res, phi_res = 8, 8
    else:
        theta_res, phi_res = 16, 16

    palette = [
        '#D4956A', '#E8C4A8', '#C68642', '#FFDBAC', '#8D5524',
        '#F1C27D', '#A0785A', '#E8B89D', '#C4867A', '#D2A679',
    ]

    sphere_meshes = []
    for i, idx in enumerate(indices):
        sphere = pv.Sphere(radius=radius, center=centers[idx],
                           theta_resolution=theta_res, phi_resolution=phi_res)
        c = palette[i % len(palette)]
        pl.add_mesh(sphere, color=c, smooth_shading=True, specular=0.5, ambient=0.2)
        sphere_meshes.append(sphere)

    container_mesh = make_container_mesh(venue)
    pl.add_mesh(container_mesh, color=venue.get("color", "#333333"),
                style='wireframe', line_width=2.5, opacity=0.7)

    if title:
        pl.add_title(title, font_size=12, color='#333333')

    pl.reset_camera()
    pl.camera.azimuth = 35
    pl.camera.elevation = 25
    # Zoom in more for elongated venues (747, King's Chamber) so spheres are visible
    bb = get_bounding_box(venue)
    aspect = max(bb) / min(bb) if min(bb) > 0 else 1
    if aspect > 5:
        pl.camera.zoom(1.8)
    elif aspect > 3:
        pl.camera.zoom(1.2)
    else:
        pl.camera.zoom(0.85)

    bb = get_bounding_box(venue)
    Lx, Ly, Lz = bb
    pl.add_light(pv.Light(position=(Lx*2, Ly*2, Lz*3), intensity=0.7))
    pl.add_light(pv.Light(position=(-Lx, -Ly, Lz*2), intensity=0.3))

    pl.screenshot(filename)
    pl.close()

    # Export 3D model (PLY) for manual inspection (skip if too many spheres)
    if export_3d and sphere_meshes and len(sphere_meshes) <= 500:
        ply_path = filename.rsplit('.', 1)[0] + '.ply'
        try:
            combined = sphere_meshes[0]
            for m in sphere_meshes[1:]:
                combined = combined.merge(m)
            combined = combined.merge(container_mesh)
            combined.save(ply_path)
        except Exception as e:
            print(f"    (3D export failed: {e})")


# ============================================================
# Experiments
# ============================================================

def experiment_1_venue_specs():
    """Show all venue specs with real dimensions and shapes."""
    print("=" * 70)
    print("EXPERIMENT 1: Real Venue Geometries")
    print("=" * 70)

    for name, venue in VENUES.items():
        vol = get_volume(venue)
        bb = get_bounding_box(venue)
        rated = venue.get("rated_capacity", "?")
        print(f"\n  {name}")
        print(f"    Type: {venue['type']}")
        print(f"    {venue['description']}")
        print(f"    Bounding box: {bb[0]:.2f} x {bb[1]:.2f} x {bb[2]:.2f}m")
        print(f"    Actual volume: {vol:.2f} m3")
        print(f"    Rated capacity: {rated} humans")
        print(f"    Source: {venue.get('source', 'N/A')}")

        box_vol = bb[0] * bb[1] * bb[2]
        fill = vol / box_vol if box_vol > 0 else 0
        print(f"    Box volume: {box_vol:.2f} m3  (shape fill: {fill:.0%})")


def experiment_2_packing_primary():
    """Pack Meatball spherical humans into all venues with real geometry."""
    print("\n" + "=" * 70)
    print("EXPERIMENT 2: Meatball Packing in Real Geometries")
    print("=" * 70)

    r = SPHERE_MODELS["The Meatball"]["radius"]
    sphere_vol = (4/3) * np.pi * r**3

    all_results = {}

    print(f"\n  The Meatball: r={r:.4f}m  (volume = {sphere_vol*1000:.1f} L)")
    print(f"  {'Venue':<28} {'FCC':>6} {'Cubic':>6} {'Best':>6} {'Vol':>8} {'Density':>8} {'Rated':>6} {'Ratio':>6}")
    print("  " + "-" * 82)

    for name, venue in VENUES.items():
        vol = get_volume(venue)
        count_fcc, centers_fcc = pack_spheres_fcc_in_venue(r, venue)
        count_cubic, centers_cubic = pack_spheres_cubic_in_venue(r, venue)
        best = max(count_fcc, count_cubic)
        density = best * sphere_vol / vol if vol > 0 else 0
        kepler_pct = density / KEPLER_DENSITY
        rated = venue.get("rated_capacity", 0)
        ratio = best / rated if rated > 0 else float('inf')

        all_results[name] = {
            "fcc": count_fcc,
            "cubic": count_cubic,
            "best": best,
            "volume_m3": round(vol, 2),
            "density": round(density, 4),
            "kepler_pct": round(kepler_pct, 4),
            "rated_capacity": rated,
            "ratio_to_rated": round(ratio, 1),
        }

        print(f"  {name:<28} {count_fcc:>6} {count_cubic:>6} {best:>6} {vol:>7.2f}m3 {density:>7.1%} {rated:>6} {ratio:>5.1f}x")

    # Bar chart: packing counts
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))

    names = list(all_results.keys())
    short_names = [n.split("(")[0].strip() for n in names]
    fcc_counts = [all_results[n]["fcc"] for n in names]
    cubic_counts = [all_results[n]["cubic"] for n in names]

    x = np.arange(len(names))
    w = 0.35
    ax1.barh(x - w/2, fcc_counts, w, label='FCC Lattice', color='#3498DB', edgecolor='white')
    ax1.barh(x + w/2, cubic_counts, w, label='Simple Cubic', color='#E74C3C', edgecolor='white')
    ax1.set_yticks(x)
    ax1.set_yticklabels(short_names, fontsize=9)
    ax1.set_xlabel("Spherical Humans Packed")
    ax1.set_xscale('symlog', linthresh=1)
    ax1.set_title(f"Meatball Packing (r={r:.3f}m)\nReal Container Geometries",
                  fontsize=14, fontweight='bold')
    ax1.legend()
    for i, n in enumerate(names):
        best = all_results[n]["best"]
        ax1.text(best + 2, i, str(best), va='center', fontsize=8, color='#555')

    # Density comparison
    densities = [all_results[n]["density"] * 100 for n in names]
    colors = plt.cm.RdYlGn([d / (KEPLER_DENSITY * 100) for d in densities])
    ax2.barh(x, densities, color=colors, edgecolor='white', height=0.6)
    ax2.axvline(x=KEPLER_DENSITY * 100, color='#E74C3C', linestyle='--', linewidth=2,
                label=f'Kepler limit ({KEPLER_DENSITY:.1%})')
    ax2.axvline(x=52.36, color='#3498DB', linestyle=':', linewidth=1.5,
                label='Simple cubic (52.4%)')
    ax2.set_yticks(x)
    ax2.set_yticklabels(short_names, fontsize=9)
    ax2.set_xlabel("Packing Density (%)")
    ax2.set_title("Achieved Packing Density\n(real geometry vs theoretical limits)",
                  fontsize=14, fontweight='bold')
    ax2.legend(fontsize=8)

    plt.tight_layout(w_pad=3)
    plt.savefig(f"{OUTPUT_DIR}/packing_real_geometry.png", dpi=200, bbox_inches='tight')
    plt.close()
    print("\n  Saved packing_real_geometry.png")

    with open(f"{OUTPUT_DIR}/packing_results.json", "w") as f:
        json.dump(all_results, f, indent=2)

    return all_results


def experiment_3_all_models():
    """Compare all sphere models across all venues — the big heatmap."""
    print("\n" + "=" * 70)
    print("EXPERIMENT 3: All Sphere Models x All Venues")
    print("=" * 70)

    model_names = list(SPHERE_MODELS.keys())
    venue_names = list(VENUES.keys())

    # Print sphere model summary
    print("\n  Sphere Models:")
    for mname, model in SPHERE_MODELS.items():
        r = model["radius"]
        v = (4/3) * np.pi * r**3
        print(f"    {mname:<22} r={r:.4f}m  V={v*1000:.1f}L  ({model['derivation']})")
    print()

    results = {}
    for venue_name in venue_names:
        venue = VENUES[venue_name]
        results[venue_name] = {}
        for model_name in model_names:
            r = SPHERE_MODELS[model_name]["radius"]
            count_fcc, _ = pack_spheres_fcc_in_venue(r, venue)
            count_cubic, _ = pack_spheres_cubic_in_venue(r, venue)
            best = max(count_fcc, count_cubic)
            results[venue_name][model_name] = best

    # Print table
    header = f"  {'Venue':<28}" + "".join(f"{m:>16}" for m in model_names)
    print(header)
    print("  " + "-" * (28 + 16 * len(model_names)))
    for vn in venue_names:
        row = f"  {vn:<28}"
        for mn in model_names:
            row += f"{results[vn][mn]:>16}"
        print(row)

    # Heatmap
    matrix = np.array([[results[v][m] for m in model_names] for v in venue_names], dtype=float)
    log_matrix = np.log10(matrix + 1)

    fig, ax = plt.subplots(figsize=(12, 8))
    im = ax.imshow(log_matrix, cmap='YlOrRd', aspect='auto')

    for i in range(len(venue_names)):
        for j in range(len(model_names)):
            val = int(matrix[i, j])
            color = 'white' if log_matrix[i, j] > log_matrix.max() * 0.65 else '#333'
            ax.text(j, i, str(val), ha='center', va='center', fontsize=10,
                    fontweight='bold', color=color)

    ax.set_xticks(range(len(model_names)))
    ax.set_xticklabels(model_names, rotation=25, ha='right', fontsize=11)
    ax.set_yticks(range(len(venue_names)))
    short = [n.split("(")[0].strip() for n in venue_names]
    ax.set_yticklabels(short, fontsize=10)
    ax.set_title("Spherical Humans Packed: All Models x All Venues\n(real container geometries, log color scale)",
                 fontsize=15, fontweight='bold', pad=15)

    cbar = plt.colorbar(im, ax=ax, shrink=0.8)
    cbar.set_label("log10(count + 1)", fontsize=10)

    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/heatmap_all_models.png", dpi=200, bbox_inches='tight')
    plt.close()
    print("\n  Saved heatmap_all_models.png")

    # Save full results
    with open(f"{OUTPUT_DIR}/all_models_results.json", "w") as f:
        json.dump(results, f, indent=2)

    return results


def experiment_4_visualizations():
    """Render 3D sphere packing in real container shapes."""
    print("\n" + "=" * 70)
    print("EXPERIMENT 4: 3D Visualizations (Real Geometry)")
    print("=" * 70)

    r = SPHERE_MODELS["The Meatball"]["radius"]

    for venue_name, venue in VENUES.items():
        count_fcc, centers_fcc = pack_spheres_fcc_in_venue(r, venue)
        count_cubic, centers_cubic = pack_spheres_cubic_in_venue(r, venue)

        if count_fcc >= count_cubic:
            count, centers = count_fcc, centers_fcc
            method = "FCC"
        else:
            count, centers = count_cubic, centers_cubic
            method = "Cubic"

        if count == 0:
            print(f"  {venue_name}: 0 spheres fit (container too small!)")
            continue

        safe = venue_name.split("(")[0].strip().replace(" ", "_")
        render_sphere_packing(
            centers, r, venue,
            f"{OUTPUT_DIR}/scene_{safe}.png",
            title=f"{venue_name}: {count} spherical humans ({method})",
        )
        print(f"  {venue_name}: {count} ({method})")

    # Hamster ball in the Boeing and King's Chamber for comedy
    r_big = SPHERE_MODELS["The Hamster Ball"]["radius"]
    for venue_name in ["Boeing 747-400", "King's Chamber (Great Pyramid)", "ISS Destiny Lab"]:
        venue = VENUES[venue_name]
        count_fcc, centers_fcc = pack_spheres_fcc_in_venue(r_big, venue)
        count_cubic, centers_cubic = pack_spheres_cubic_in_venue(r_big, venue)
        count = max(count_fcc, count_cubic)
        centers = centers_fcc if count_fcc >= count_cubic else centers_cubic
        if count == 0:
            print(f"  {venue_name} (hamster ball): 0 fit")
            continue
        safe = venue_name.split("(")[0].strip().replace(" ", "_") + "_hamsterball"
        render_sphere_packing(
            centers, r_big, venue,
            f"{OUTPUT_DIR}/scene_{safe}.png",
            title=f"{venue_name}: {count} hamster-ball humans (r={r_big:.2f}m)",
        )
        print(f"  {venue_name} (hamster ball): {count}")

    # The Junior (child) in VW Beetle and Coffin
    r_child = SPHERE_MODELS["The Junior"]["radius"]
    for venue_name in ["Volkswagen Beetle (Classic)", "Standard Coffin"]:
        venue = VENUES[venue_name]
        count_fcc, centers_fcc = pack_spheres_fcc_in_venue(r_child, venue)
        count_cubic, centers_cubic = pack_spheres_cubic_in_venue(r_child, venue)
        count = max(count_fcc, count_cubic)
        centers = centers_fcc if count_fcc >= count_cubic else centers_cubic
        if count == 0:
            print(f"  {venue_name} (The Junior): 0 fit")
            continue
        safe = venue_name.split("(")[0].strip().replace(" ", "_") + "_funsize"
        render_sphere_packing(
            centers, r_child, venue,
            f"{OUTPUT_DIR}/scene_{safe}.png",
            title=f"{venue_name}: {count} The Junior spherical children (r={r_child:.3f}m)",
        )
        print(f"  {venue_name} (The Junior): {count}")


def experiment_5_reality_check():
    """Compare spherical packing predictions to rated/world-record capacities."""
    print("\n" + "=" * 70)
    print("EXPERIMENT 5: Reality Check")
    print("(Spherical humans vs actual humans — who wins?)")
    print("=" * 70)

    comparisons = []

    for name, venue in VENUES.items():
        rated = venue.get("rated_capacity", 0)
        world_record = venue.get("world_record", None)

        row = {"venue": name, "rated": rated, "world_record": world_record}

        for model_name in SPHERE_MODELS:
            r = SPHERE_MODELS[model_name]["radius"]
            count_fcc, _ = pack_spheres_fcc_in_venue(r, venue)
            count_cubic, _ = pack_spheres_cubic_in_venue(r, venue)
            best = max(count_fcc, count_cubic)
            row[model_name] = best

        comparisons.append(row)

    # Print table
    print(f"\n  {'Venue':<28} {'Rated':>6} {'WR':>6} {'Meatball':>10} {'Hamster':>10} {'Freedom':>10} {'Child':>10}")
    print("  " + "-" * 90)
    for row in comparisons:
        wr = str(row["world_record"]) if row["world_record"] else "-"
        print(f"  {row['venue']:<28} {row['rated']:>6} {wr:>6} "
              f"{row['The Meatball']:>10} {row['The Hamster Ball']:>10} "
              f"{row['The Freedom Sphere']:>10} {row['The Junior']:>10}")

    # Chart: Meatball vs Rated
    fig, ax = plt.subplots(figsize=(14, 8))

    venue_labels = [r["venue"].split("(")[0].strip() for r in comparisons]
    rated_vals = [r["rated"] for r in comparisons]
    meatball_vals = [r["The Meatball"] for r in comparisons]
    child_vals = [r["The Junior"] for r in comparisons]

    x = np.arange(len(venue_labels))
    w = 0.25

    ax.bar(x - w, rated_vals, w, label='Rated Capacity', color='#2ECC71', edgecolor='white')
    ax.bar(x, meatball_vals, w, label='Meatball (adult)', color='#E74C3C', edgecolor='white')
    ax.bar(x + w, child_vals, w, label='The Junior (child)', color='#3498DB', edgecolor='white')

    # Mark world records
    for i, row in enumerate(comparisons):
        if row["world_record"]:
            ax.plot(i - w, row["world_record"], 'k*', markersize=15, zorder=5)
            ax.annotate(f'WR: {row["world_record"]}',
                       (i - w, row["world_record"]),
                       textcoords="offset points", xytext=(10, 10),
                       fontsize=8, fontweight='bold')

    ax.set_xticks(x)
    ax.set_xticklabels(venue_labels, rotation=35, ha='right', fontsize=9)
    ax.set_ylabel("Humans Packed")
    ax.set_yscale('symlog', linthresh=1)
    ax.set_title("Reality Check: Spherical vs Actual Humans\n"
                 "(rated capacity, world records, and spherical predictions)",
                 fontsize=14, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/reality_check.png", dpi=200, bbox_inches='tight')
    plt.close()
    print("\n  Saved reality_check.png")

    # Child Advantage Factor chart
    fig, ax = plt.subplots(figsize=(12, 7))
    caf = [c / m if m > 0 else float('inf') for c, m in zip(child_vals, meatball_vals)]
    caf_plot = [v if v != float('inf') else 0 for v in caf]  # can't plot inf
    colors = ['#333333' if v == float('inf') else '#E74C3C' if v > 2.5 else '#F39C12' if v > 1.5 else '#3498DB' for v in caf]
    bars = ax.barh(venue_labels, caf_plot, color=colors, edgecolor='white', height=0.6)
    ax.axvline(x=1, color='#333', linewidth=1.5, linestyle='-')
    ax.set_xlabel("Child Advantage Factor (The Junior count / Meatball count)")
    ax.set_title("The Child Advantage Factor (CAF)\n"
                 "How many more spherical children fit than adults?",
                 fontsize=14, fontweight='bold')
    for bar, val in zip(bars, caf):
        if val == float('inf'):
            ax.text(0.15, bar.get_y() + bar.get_height()/2,
                    '\u221e (5 children, 0 adults)', va='center', fontsize=9,
                    fontweight='bold', color='white')
        elif val > 0:
            ax.text(bar.get_width() + 0.05, bar.get_y() + bar.get_height()/2,
                    f'{val:.1f}x', va='center', fontsize=9)

    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/child_advantage_factor.png", dpi=200, bbox_inches='tight')
    plt.close()
    print("  Saved child_advantage_factor.png")

    # Save results
    with open(f"{OUTPUT_DIR}/reality_check.json", "w") as f:
        json.dump(comparisons, f, indent=2, default=str)

    return comparisons


# ============================================================
# MAIN
# ============================================================

if __name__ == "__main__":
    print("=" * 70)
    print("  CONSIDER A SPHERICAL HUMAN")
    print("  Packing Density Bounds for Coffins, Clown Cars, and Coach Class")
    print("=" * 70)
    print()

    # Print sphere models
    print("Sphere Models:")
    for name, model in SPHERE_MODELS.items():
        r = model["radius"]
        v = (4/3) * np.pi * r**3
        d = 2 * r
        print(f"  {name:<22}  r = {r:.4f}m  d = {d:.3f}m  V = {v*1000:.1f}L")
    print()

    experiment_1_venue_specs()
    results = experiment_2_packing_primary()
    all_models = experiment_3_all_models()
    experiment_4_visualizations()
    reality = experiment_5_reality_check()

    print("\n" + "=" * 70)
    print("All experiments complete.")
    print(f"Results saved to ./{OUTPUT_DIR}/")
    print("=" * 70)
