#pragma once

#include <cstdint>
#include <vector>
#include <tuple>
#include <chrono>
#include <iostream>

using std::tuple;
using std::vector;

using T = float;
using T3 = tuple<T, T, T>;
using TSphere = tuple<T3, T>;
using TCapsule = tuple<T3, T3, T>;
using TSegment = tuple<T3, T3>;
using TTriangle = tuple<T3, T3, T3>;

constexpr T ZERO = (T)0.0;
constexpr T ONE = (T)1.0;
constexpr T EPSILON = (T)1e-6;

#define INLINE_S __forceinline
#define INLINE_L __forceinline

namespace by_copy {

	struct Vector3 {
		T x;
		T y;
		T z;

		INLINE_S Vector3() : x(ZERO), y(ZERO), z(ZERO) { }
		INLINE_S Vector3(T _x, T _y, T _z) : x(_x), y(_y), z(_z) {}
		INLINE_S Vector3(Vector3 const&) = default;
		INLINE_S Vector3& operator=(Vector3 const&) = default;
	};

	INLINE_S Vector3 operator+(Vector3 a, Vector3 b) {
		return Vector3(a.x + b.x, a.y + b.y, a.z + b.z);
	}

	INLINE_S Vector3 operator-(Vector3 a, Vector3 b) {
		return Vector3(a.x - b.x, a.y - b.y, a.z - b.z);
	}

	INLINE_S Vector3 operator*(Vector3 a, T v) {
		return Vector3(a.x * v, a.y * v, a.z * v);
	}

	INLINE_S T dot(Vector3 a, Vector3 b) {
		return a.x * b.x + a.y * b.y + a.z * b.z;
	}

	INLINE_S Vector3 cross(Vector3 a, Vector3 b) {
		return Vector3(
			a.y * b.z - a.z * b.y,
			-(a.x * b.z - a.z * b.x),
			a.x * b.y - a.y * b.x
		);
	}

	INLINE_S T clamp(T value, T min, T max) {
		if (value <= min) {
			return min;
		}
		else if (value >= max) {
			return max;
		}
		else {
			return value;
		}
	}

	INLINE_S T clamp01(T value) {
		return clamp(value, ZERO, ONE);
	}

	INLINE_L T sq_dist_point_segment(Vector3 a, Vector3 b, Vector3 c) {
		auto ab = b - a;
		auto ac = c - a;
		auto bc = c - b;

		auto e = dot(ac, ab);
		if (e <= ZERO) {
			return dot(ac, ac);
		}

		auto f = dot(ab, ab);
		if (e >= f) {
			return dot(bc, bc);
		}

		return dot(ac, ac) - e * e / f;
	}

	INLINE_S bool test_sphere_capsule(Vector3 sphere_center, T sphere_rad, Vector3 capsule_a, Vector3 capsule_b, T capsule_rad) {
		auto dist_sq = sq_dist_point_segment(capsule_a, capsule_b, sphere_center);
		auto rad = sphere_rad + capsule_rad;
		auto result = dist_sq < rad * rad;
		return result;
	}

	INLINE_S bool test_sphere_sphere(Vector3 sphere_center_a, T sphere_rad_a, Vector3 sphere_center_b, T sphere_rad_b) {
		auto diff = sphere_center_b - sphere_center_a;
		auto dist2 = dot(diff, diff);
		auto rad = sphere_rad_a + sphere_rad_b;
		return dist2 <= rad * rad;
	}

	INLINE_L std::tuple<T, T, T, Vector3, Vector3> closest_pt_segment_segment(
		Vector3 p1, Vector3 q1, Vector3 p2, Vector3 q2)
	{
		T s = ZERO;
		T t = ZERO;

		auto d1 = q1 - p1;
		auto d2 = q2 - p2;
		auto r = p1 - p2;
		auto a = dot(d1, d1);
		auto e = dot(d2, d2);
		auto f = dot(d2, r);

		if (a <= EPSILON && e <= EPSILON) {
			auto result = dot(p1 - p2, p1 - p2);
			return std::make_tuple(result, s, t, p1, p2);
		}

		if (a <= EPSILON) {
			t = clamp01(f / e);
		}
		else {
			auto c = dot(d1, r);
			if (e <= EPSILON) {
				s = clamp01(-c / a);
			}
			else {
				auto b = dot(d1, d2);
				auto denom = a * e - b * b;

				if (denom != ZERO) {
					s = clamp01((b * f - c * e) / denom);
				}
				else {
					t = (b * s + f) / e;
					if (t < ZERO) {
						t = ZERO;
						s = clamp01(-c / a);
					}
					else if (t > ONE) {
						t = ONE;
						s = clamp01((b - c) / a);
					}
				}
			}
		}

		auto c1 = p1 + d1 * s;
		auto c2 = p2 + d2 * t;
		auto result = dot(c1 - c2, c1 - c2);
		return std::make_tuple(result, s, t, c1, c2);
	}

	INLINE_S bool test_capsule_capsule(Vector3 capsule0_a, Vector3 capsule0_b, T capsule0_rad,
		Vector3 capsule1_a, Vector3 capsule1_b, T capsule1_rad)
	{
		auto result = closest_pt_segment_segment(capsule0_a, capsule0_b, capsule1_a, capsule1_b);
		auto dist_sq = std::get<0>(result);
		auto rad = capsule0_rad + capsule1_rad;
		auto overlaps = dist_sq <= rad * rad;
		return overlaps;
	}

	INLINE_L bool test_segment_triangle(Vector3 const& p, Vector3 const& q, Vector3 const& a, Vector3 const& b, Vector3 const& c) {
		T u, v, w, t;

		auto ab = b - a;
		auto ac = c - a;
		auto qp = p - q;

		auto n = cross(ab, ac);
		auto d = dot(qp, n);

		if (d <= ZERO) {
			return false;
		}

		auto ap = p - a;
		t = dot(ap, n);
		if (t <= ZERO || t > d) {
			return false;
		}

		auto e = cross(qp, ap);
		v = dot(ac, e);
		if (v < ZERO || v > d) {
			return false;
		}

		w = -dot(ab, e);
		if (w < ZERO || v + w > d) {
			return false;
		}

		auto inv_d = ONE / d;
		t *= inv_d;
		v *= inv_d;
		w *= inv_d;
		u = ONE - v - w;
		return true;
	}

	INLINE_S Vector3 ToVec3(T3 v) {
		return Vector3(std::get<0>(v), std::get<1>(v), std::get<2>(v));
	}

	struct Sphere {
		Vector3 center;
		T radius;
		Sphere(Vector3 c, T r) : center(c), radius(r) {}
	};

	struct Capsule {
		Vector3 a, b;
		T radius;
		Capsule(Vector3 _a, Vector3 _b, T r) : a(_a), b(_b), radius(r) {}
	};

	struct Segment {
		Vector3 a, b;
		Segment(Vector3 _a, Vector3 _b) : a(_a), b(_b) {}
	};

	struct Triangle {
		Vector3 a, b, c;
		Triangle(Vector3 _a, Vector3 _b, Vector3 _c) : a(_a), b(_b), c(_c) {}
	};

	std::tuple<int64_t, int64_t> run_test(
		vector<TSphere> const& _spheres,
		vector<TCapsule> const& _capsules,
		vector<TSegment> const& _segments,
		vector<TTriangle> const& _triangles) 
	{
		int64_t num_overlaps = 0;
		int64_t milliseconds = 0;

		vector<Sphere> spheres;
		spheres.reserve(_spheres.size());
		for (auto& sphere : _spheres) {
			spheres.emplace_back(ToVec3(std::get<0>(sphere)), std::get<1>(sphere));
		}

		vector<Capsule> capsules;
		capsules.reserve(_capsules.size());
		for (auto& capsule : _capsules) {
			capsules.emplace_back(
				ToVec3(std::get<0>(capsule)),
				ToVec3(std::get<1>(capsule)),
				std::get<2>(capsule));
		}

		vector<Segment> segments;
		segments.reserve(_segments.size());
		for (auto& segment : _segments) {
			segments.emplace_back(
				ToVec3(std::get<0>(segment)),
				ToVec3(std::get<1>(segment)));
		}

		vector<Triangle> triangles;
		triangles.reserve(_triangles.size());
		for (auto const& triangle : _triangles) {
			triangles.emplace_back(
				ToVec3(std::get<0>(triangle)),
				ToVec3(std::get<1>(triangle)),
				ToVec3(std::get<2>(triangle)));
		}

		auto start = std::chrono::high_resolution_clock::now();

		// All sphere-sphere intersections
		for (int i = 0; i < spheres.size(); ++i) {
			auto a = spheres[i];
			for (int j = i + 1; j < spheres.size(); ++j) {
				auto const& b = spheres[j];
				bool overlap = test_sphere_sphere(a.center, a.radius, b.center, b.radius);
				if (overlap) {
					num_overlaps += 1;
				}
			}
		}

		// All sphere-capsule
		for (int i = 0; i < spheres.size(); ++i) {
			auto s = spheres[i];
			for (int j = 0; j < capsules.size(); ++j) {
				auto const& c = capsules[j];
				bool overlap = test_sphere_capsule(s.center, s.radius, c.a, c.b, c.radius);
				if (overlap) {
					num_overlaps += 1;
				}
			}
		}

		// All capsule-capsule
		for (int i = 0; i < capsules.size(); ++i) {
			auto c1 = capsules[i];
			for (int j = i + 1; j < capsules.size(); ++j) {
				auto const& c2 = capsules[j];
				bool overlap = test_capsule_capsule(c1.a, c1.b, c1.radius, c2.a, c2.b, c2.radius);
				if (overlap) {
					num_overlaps += 1;
				}
			}
		}

		// All segment-triangle
		for (int i = 0; i < segments.size(); ++i) {
			auto s = segments[i];
			for (int j = 0; j < triangles.size(); ++j) {
				auto const& t = triangles[j];
				bool overlap = test_segment_triangle(s.a, s.b, t.a, t.b, t.c);
				if (overlap) {
					num_overlaps += 1;
				}
			}
		}

		auto end = std::chrono::high_resolution_clock::now();
		auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
		milliseconds += elapsed_ms.count();

		return std::make_tuple(num_overlaps, milliseconds);
	}

} // namespace by_copy





