#version 430 core

precision highp float;
precision highp int;

const float INF = 1.0/0.0;
const float NINF = -1.0/0.0;
const float PI = 3.14159265359;
const int GridSize = 16;

layout (local_size_x = 8, local_size_y = 8, local_size_z = 1) in;

layout(rgba32f, binding = 0) uniform image2D imgOutput;
layout(rgba8, binding = 1) uniform image2D Skybox;
layout(rgba8, binding = 2) uniform image2D BluNoise;
layout(rgba32f, binding = 3) uniform image2D HitData;

struct Voxel
{
    int state;
    vec3 Lighting;
    int Samples;
};

layout(std430, binding = 2) buffer VoxData
{
    int data[];
}; 

ivec2 id;

uniform mat4 _CameraToWorld;
uniform mat4 _CameraInverseProjection;
uniform float SkyIntesity;
uniform vec3 Sunlight;
uniform int NoiseShift;
uniform int NoiseType;
uniform int Samples;
vec2 _Pixel;
uniform float _InSeed;
float _Seed;

ivec2 SkyboxDim;
ivec2 OutputDim;
ivec2 BluDim;
vec3 AxisShift;

//! MetaData
float Dist = 0;
vec3 Normal = vec3(0);
bool FirstHit = true;

struct Ray
{
    vec3 origin;
    vec3 direction;
    vec3 energy;
};

Ray CreateRay(vec3 origin, vec3 direction)
{
    Ray ray;
    ray.origin = origin;
    ray.direction = direction;
    ray.energy = vec3(1.0);
    return ray;
}

struct RayHit
{
    vec3 position;
    float distance;
    vec3 normal;
    vec3 albedo;
    vec3 specular;
    vec3 emission;
};

RayHit CreateRayHit()
{
    RayHit hit;
    hit.position = vec3(0.0f, 0.0f, 0.0f);
    hit.distance = INF;
    hit.normal = vec3(0.0f, 0.0f, 0.0f);
    hit.albedo = vec3(0.0f, 0.0f, 0.0f);
    hit.specular = vec3(0.0f, 0.0f, 0.0f);
    hit.emission = vec3(0);
    return hit;
}

Ray CreateCameraRay(vec2 uv)
{
    vec3 origin = (_CameraToWorld*vec4(0.0f,0.0f,0.0f,1.0f)).xyz;
    vec3 direction = (_CameraInverseProjection*vec4(uv,0.0f,1.0f)).xyz;
    direction = (_CameraToWorld*vec4(direction,0.0f)).xyz;
    direction = normalize(direction);
    return CreateRay(origin,direction);
}

vec4 GetSky(Ray ray)
{
    float theta = acos(ray.direction.y); // Ranges from 0 to PI
    float phi = atan(ray.direction.x, -ray.direction.z); // Ranges from -PI to PI

    if (phi < 0.0) {
        phi += 2.0 * PI; // Make phi range from 0 to 2 * PI
    }

    // Normalize theta and phi to [0, 1] range
    float normalizedTheta = theta / PI;
    float normalizedPhi = phi / (2.0 * PI);

    int pixelX = int(normalizedPhi * float(SkyboxDim.x));
    int pixelY = int(normalizedTheta * float(SkyboxDim.y));

    pixelY = clamp(pixelY, 0, SkyboxDim.y - 1);
    pixelX = clamp(pixelX, 0, SkyboxDim.x - 1);

    return imageLoad(Skybox,ivec2(pixelX,pixelY))*SkyIntesity;
}

//! Voxel Traversal
void CalculateAxisShift(vec3 Dir)
{
    AxisShift = vec3((sign(Dir.x)+1)/2,(sign(Dir.y)+1)/2,(sign(Dir.z)+1)/2);
}

void Time(float O,float D,float Shift, inout float t)
{
    float T = INF;
    if(Shift == 1)
    T = -(O-Shift)/D;
    else
    T = -(O-Shift+0.00001)/D;
    if(T > 0 && T < t)
    {
        t=T;
    }
}

float CalcTime(Ray ray,vec3 CurrentPos)
{
    vec3 unitO = fract(CurrentPos);
    float t = INF;
    Time(unitO.x,ray.direction.x,AxisShift.x,t);
    Time(unitO.y,ray.direction.y,AxisShift.y,t);
    Time(unitO.z,ray.direction.z,AxisShift.z,t);

    return t;
}

int GetIndex(vec3 WorldCoord)
{
    float Index = WorldCoord.z + WorldCoord.y * GridSize + WorldCoord.x * GridSize * GridSize;
    return int(Index);
}

//! Main Stocastic Rendering Functions
float energy(vec3 color)
{
    return dot(color, vec3(1.0f / 3.0f));
}

float rand()
{
    float result = 0;
    if(NoiseType == 0)
    {
        result = fract(sin(_Seed / 100.0f * dot(_Pixel, vec2(12.9898f, 78.233f))) * 43758.5453f);
    }
    else
    {
        ivec2 pixel = ivec2(id+_Seed+NoiseShift);
        result = imageLoad(BluNoise,ivec2(mod(pixel,BluDim))).x;
    } 



    _Seed += 1.0f;
    return result;
}

float RandomValueNormDistb()
{
    float theta = 2 * PI * rand();
    float rho = sqrt(-2 * log(rand()));
    return rho * cos(theta);
}

vec3 RandomDirection()
{
    float X = RandomValueNormDistb();
    float Y = RandomValueNormDistb();
    float Z = RandomValueNormDistb();
    return normalize(vec3(X,Y,Z));
}

vec3 SampHemisphere(vec3 normal)
{
    vec3 dir = RandomDirection();
    return dir * sign(dot(normal,dir));
}
float sdot(vec3 x, vec3 y)
{
    float f = 1.0f;
    return clamp(dot(x, y) * f, 0.0, 1.0);
}

void IntersectGroundPlane(Ray ray, inout RayHit bestHit, vec3 Spec, vec3 Alb,vec3 Emis,vec2 Axis,float Shift)
{
    // Calculate distance along the ray where the ground plane is intersected
    float t = -(Axis.x-Shift) / Axis.y;
    if (t > 0 && t < bestHit.distance)
    {
        bestHit.distance = t;
        bestHit.position = ray.origin + t * ray.direction;
        bestHit.normal = vec3(1.0f, 0.0f, 0.0f);
        bestHit.specular = Spec;
        bestHit.albedo = Alb;
        bestHit.emission = Emis;
    }
}

void IntersectSphere(Ray ray, inout RayHit bestHit, vec4 sphere, vec3 Spec, vec3 Alb,vec3 Emis)
{
    // Calculate distance along the ray where the sphere is intersected
    vec3 d = ray.origin - sphere.xyz;
    float p1 = -dot(ray.direction, d);
    float p2sqr = p1 * p1 - dot(d, d) + sphere.w * sphere.w;
    if (p2sqr < 0)
        return;
    float p2 = sqrt(p2sqr);
    float t = p1 - p2 > 0 ? p1 - p2 : p1 + p2;
    if (t > 0 && t < bestHit.distance)
    {
        bestHit.distance = t;
        bestHit.position = ray.origin + t * ray.direction;
        bestHit.normal = normalize(bestHit.position - sphere.xyz);
        bestHit.specular = Spec;
        bestHit.albedo = Alb;
        bestHit.emission = Emis;
    }
}

void IntersectCube(Ray ray, inout RayHit idealHit,vec3 Spec,vec3 Alb,vec4 position,vec3 Emis)
{
    vec3 cubeMin = position.xyz - position.w; // Calculate minimum corner of the cube
    vec3 cubeMax = position.xyz + position.w; // Calculate maximum corner of the cube

    vec3 tMin = (cubeMin - ray.origin) / ray.direction;
    vec3 tMax = (cubeMax - ray.origin) / ray.direction;

    vec3 t1 = min(tMin, tMax);
    vec3 t2 = max(tMin, tMax);

    float tNear = max(max(t1.x, t1.y), t1.z);
    float tFar = min(min(t2.x, t2.y), t2.z);

    if (tNear <= tFar && tFar > 0 && tNear < idealHit.distance)
    {
        idealHit.distance = tNear;
        idealHit.position = ray.origin + tNear * ray.direction;

        // Determine which face of the cube was hit
        vec3 hitNormal;
        if (t1.x > t1.y && t1.x > t1.z)
            hitNormal = vec3(sign(ray.direction.x), 0, 0);
        else if (t1.y > t1.z)
            hitNormal = vec3(0, sign(ray.direction.y), 0);
        else
            hitNormal = vec3(0, 0, sign(ray.direction.z));

        idealHit.normal = -hitNormal;
        idealHit.specular = Spec;
        idealHit.albedo = Alb;
        idealHit.emission = Emis;
    }
}

RayHit Trace(Ray ray)
{
    RayHit bestHit = CreateRayHit();

    vec3 VoxelRayOrigin = ray.origin;

    float tSetDist = 0;

    for (int i = 0; i < 30; i++)
    {
        vec3 _CurrentVoxel = floor(VoxelRayOrigin);

        if(any(lessThanEqual(VoxelRayOrigin,vec3(0,0,0))) || any(greaterThanEqual(VoxelRayOrigin,vec3(16,16,16)))){break;}

        if(data[GetIndex(_CurrentVoxel)] != 0)
        {
            //IntersectCube(ray, Hit, vec4(_CurrentVoxel+(1/2),1));
            IntersectCube(ray,bestHit,vec3(0),vec3(1,1,0),vec4(_CurrentVoxel+.5,.5),vec3(0,0,0));
            break;
        }
        else
        {
            //tSetDist += GetT(_CurrentVoxel,VoxelRayOrigin,ray.direction);
            tSetDist = CalcTime(ray,VoxelRayOrigin);
            VoxelRayOrigin = VoxelRayOrigin + ray.direction * tSetDist;
        }
        
    }

    // IntersectGroundPlane(ray, bestHit,vec3(0),vec3(1,.5,0),vec3(0),vec2(ray.origin.x,ray.direction.x),0);
    // IntersectGroundPlane(ray, bestHit,vec3(0),vec3(1,.5,0),vec3(0),vec2(ray.origin.y,ray.direction.y),-10);
    // IntersectGroundPlane(ray, bestHit,vec3(0),vec3(1,.5,0),vec3(0),vec2(ray.origin.z,ray.direction.z),0);
    IntersectSphere(ray,bestHit,vec4(0,7,0,1),vec3(0),vec3(1,1,0),vec3(0));
    IntersectCube(ray,bestHit,vec3(0),vec3(0),vec4(7,7,7,2),vec3(1));
    return bestHit;
}

vec3 Shade(inout Ray ray, RayHit hit)
{
    if (hit.distance < INF)
    {
        if(FirstHit == true)
        {
            Dist = hit.distance;
            Normal = hit.normal;
            FirstHit = false;
        }

        hit.albedo = min(vec3(1.0) - hit.specular, hit.albedo);
        float specChance = energy(hit.specular);
        float diffChance = energy(hit.albedo);
        float sum = specChance + diffChance;
        specChance /= sum;
        diffChance /= sum;

        float roulette = rand();
        if(roulette < specChance)
        {
            ray.origin = hit.position + hit.normal * 0.001f;
            ray.direction = reflect(ray.direction, hit.normal);
            ray.energy *= (1.0f / specChance) * hit.specular; //* sdot(hit.normal, ray.direction);
        }
        else
        {
            ray.origin = hit.position + hit.normal * 0.001f;
            ray.direction = SampHemisphere(hit.normal);
            ray.energy *= hit.albedo;
        }
        return hit.emission;
    }
    else
    {
        ray.energy = vec3(0.0);
        return GetSky(ray).xyz;
    }
}

void main() 
{
    id = ivec2(gl_GlobalInvocationID);
    _Pixel = id.xy;
    _Seed = _InSeed;
    SkyboxDim = imageSize(Skybox);
    OutputDim = imageSize(imgOutput);
    BluDim = imageSize(BluNoise);


    vec4 pixel =vec4(1.0);

    vec2 uv = vec2((id.xy+vec2(.5f,.5f))/vec2(OutputDim.x,OutputDim.y)*2.0f-1.0f);

    Ray ray = CreateCameraRay(uv);



    vec3 Result = vec3(0.0);
    

    for(int i = 0; i<8; i++)
    {
        CalculateAxisShift(ray.direction);
        RayHit Hit = Trace(ray);
        vec3 TempEnrg = ray.energy;
            Result += TempEnrg * Shade(ray,Hit);

        if(!any(bvec3(ray.energy)))
        break;
    }


    //imageStore(imgOutput,id,imageLoad(BluNoise,id));
    //pixel = GetSky(ray);

    vec4 FResult = mix(imageLoad(imgOutput,id), vec4(Result,1), 1.0/Samples);

    imageStore(HitData,id,vec4(Normal,Dist));

    imageStore(imgOutput,id,FResult);
}