import * as tf from "@tensorflow/tfjs";
import { loadGraphModel } from "@tensorflow/tfjs-converter";
import CustomShaderMaterial from "three-custom-shader-material/vanilla";

document.depthMaterial = new CustomShaderMaterial({
  baseMaterial: THREE.MeshStandardMaterial,
  vertexShader: `
    uniform vec4 uvClip;

    varying vec2 vUV;
    varying vec2 vUVClip;
#if METRIC_DEPTH
    varying vec3 camSpacePosition;
#endif

    void main() {
      vUV = uv;
      vUVClip = uv * (uvClip.zw - uvClip.xy) + uvClip.xy;
#if METRIC_DEPTH
      vec4 camSpacePosition4 = modelViewMatrix * vec4(position, 1.0);
      camSpacePosition = camSpacePosition4.xyz;
#endif
    }
  `,
  fragmentShader: `
    uniform sampler2D albedoMap;
    uniform sampler2D depthMap;
    uniform vec4 albedoColor;
    uniform vec2 screenDimInv;
    uniform float zoomInv;

#if METRIC_DEPTH
    uniform mat4 textureMatrix;
    uniform float fragToMeters;
    uniform float depthStrength;

    varying vec3 camSpacePosition;
#else
    uniform sampler2D maskMap;
    uniform sampler2D lightingMap;
#endif

    varying vec2 vUV;
    varying vec2 vUVClip;

  #if METRIC_DEPTH
    float getDepthMeters(vec2 uv) {
      vec2 packedDepthAndVisibility = texture2D(depthMap, uv).rg;
      return dot(packedDepthAndVisibility, vec2(255.0, 256.0 * 255.0)) * fragToMeters;
    }
  #endif

    void main() {
      vec2 screenUV = screenDimInv * gl_FragCoord.xy;
#if METRIC_DEPTH
      vec4 depthUV = textureMatrix * vec4(screenUV.x, 1.0 - screenUV.y, 0, 1);
#else
      ivec2 textureDims = textureSize(maskMap, 0);
      float textureAspect = float(textureDims.x) / float(textureDims.y);
      float windowAspect = screenDimInv.y / screenDimInv.x;
      vec2 depthUV = screenUV;
      depthUV -= 0.5;
      if (textureAspect > windowAspect) {
        depthUV.x *= windowAspect / textureAspect;
      } else {
        depthUV.y *= textureAspect / windowAspect;
      }
      depthUV *= zoomInv;
      depthUV += 0.5;
#endif
#if METRIC_DEPTH
      float depth = getDepthMeters(depthUV.xy);
      float alpha = mix(1.0 - clamp(3.0 * (-camSpacePosition.z - depth), 0.0, 1.0), 1.0, clamp(depth - 2.5, 0.0, 1.0));
      alpha = mix(1.0, alpha, depthStrength);
#else
      float alpha = 1.0 - texture2D(maskMap, depthUV).a;
      alpha *= 1.0 - float(abs(depthUV.x - 0.5) > 0.5 || abs(depthUV.y - 0.5) > 0.5);
#endif

    vec4 albedo = texture2D(albedoMap, vUVClip);
    albedo.rgb = mix(albedo.rgb, albedoColor.rgb, albedoColor.a);

#if METRIC_DEPTH
      csm_DiffuseColor = alpha * albedo;
#else
      vec4 lighting = texture2D(lightingMap, depthUV);
      csm_DiffuseColor = alpha * vec4(mix(albedo.rgb, lighting.rgb, lighting.a), albedo.a);
#endif
    }
  `,
  uniforms: {
    albedoMap: { value: undefined },
    normalMap: { value: undefined },
    depthMap: { value: undefined },
    albedoColor: { value: new THREE.Vector4(0.0, 0.0, 0.0, 0.0) },
    screenDimInv: { value: new THREE.Vector2(1.0 / window.innerWidth, 1.0 / window.innerHeight) },
    uvClip: { value: new THREE.Vector4(0.0, 0.0, 1.0, 1.0) },
    zoomInv: { value: 1.0 },
    // METRIC_DEPTH:
    textureMatrix: { value: new THREE.Matrix4() },
    fragToMeters: { value: 0.0 },
    depthStrength: { value: 0.0 },
    // not METRIC_DEPTH:
    maskMap: { value: undefined },
    lightingMap: { value: undefined },
  },
  silent: true,
  defines: {
    METRIC_DEPTH: false
  },
  transparent: true
});
document.glowMaterial = new CustomShaderMaterial({
  baseMaterial: THREE.MeshStandardMaterial,
  vertexShader: `
    varying vec3 camSpacePosition;
    varying vec3 camSpaceNormal;

    void main() {
      vec4 camSpacePosition4 = modelViewMatrix * vec4(position, 1.0);
      camSpacePosition = camSpacePosition4.xyz;
      camSpaceNormal = normalMatrix * normal;
    }
  `,
  fragmentShader: `
    uniform sampler2D depthMap;
    uniform vec2 screenDimInv;

    uniform mat4 textureMatrix;
    uniform float fragToMeters;
    uniform float depthStrength;

    varying vec3 camSpacePosition;
    varying vec3 camSpaceNormal;

    float getDepthMeters(vec2 uv) {
      vec2 packedDepthAndVisibility = texture2D(depthMap, uv).rg;
      return dot(packedDepthAndVisibility, vec2(255.0, 256.0 * 255.0)) * fragToMeters;
    }

    void main() {
      vec2 screenUV = screenDimInv * gl_FragCoord.xy;
      vec4 depthUV = textureMatrix * vec4(screenUV.x, 1.0 - screenUV.y, 0, 1);

      float depth = getDepthMeters(depthUV.xy);
      float alpha = mix(1.0 - clamp(3.0 * (-camSpacePosition.z - depth), 0.0, 1.0), 1.0, clamp(depth - 2.5, 0.0, 1.0));
      alpha = mix(1.0, alpha, depthStrength);

      float angle = 1.0 + dot(normalize(camSpacePosition), normalize(camSpaceNormal));
      vec4 albedo = vec4(vec3(1, 0, 0), alpha * angle);
      csm_FragColor = albedo;
    }
  `,
  uniforms: {
    depthMap: { value: undefined },
    screenDimInv: { value: new THREE.Vector2(1.0 / window.innerWidth, 1.0 / window.innerHeight) },
    textureMatrix: { value: new THREE.Matrix4() },
    fragToMeters: { value: 0.0 },
    depthStrength: { value: 0.0 },
  },
  silent: true,
  transparent: true
});
// For midas debugging:
// document.depthMaterial = new THREE.ShaderMaterial({
//   vertexShader: `
//     uniform sampler2D depthMap;
//     uniform float depthNear;
//     uniform float depthFar;

//     varying vec4 worldPos;
//     varying vec2 vUV;

//     void main() {
//       vUV = uv;
//       worldPos = vec4((mix(depthFar, depthNear, texture2D(depthMap, uv).r) * normalize((inverse(projectionMatrix) * vec4(2.0 * uv - 1.0, 0.5, 1)).xyz)), 1);
//       gl_Position = projectionMatrix * (viewMatrix * worldPos);
//     }
//   `,
//   fragmentShader: `
//     uniform sampler2D depthMap;

//     varying vec4 worldPos;
//     varying vec2 vUV;

//     void main() {
//       // gl_FragColor = vec4(texture2D(depthMap, vUV).rgb, 1.0);
//       gl_FragColor = vec4(fract(abs(worldPos.xyz)), 1.0);
//     }
//   `,
//   uniforms: {
//     depthMap: { value: undefined },
//     depthNear: { value: 0.0 },
//     depthFar: { value: 4.0 },
//   },
//   transparent: true
// });

// From https://github.com/timmh/monocular_depth_estimation_demo
const serializeTensor = (tensor) => ({
  data: tensor.dataSync(),
  shape: tensor.shape,
});
const deserializeTensor = ({ data, shape }) => tf.tensor(data, shape);

let model;
const modelLoaded = async () => {
  return loadGraphModel("models/midas_u8/model.json")
    .then((loaded_model) => {
      model = loaded_model;
    })
    .catch((err) => {
      alert("Failed to load model");
    });
};

const infer = async (input) => {
  input = deserializeTensor(input);
  await modelLoaded();
  input = tf.div(input, 255);
  input = tf.transpose(input, [2, 0, 1]);
  input = tf.expandDims(input);

  let output = await model.executeAsync(input);

  output = tf.transpose(output, [1, 2, 0]);
  output = tf.div(
    tf.sub(output, tf.min(output)),
    tf.sub(tf.max(output), tf.min(output))
  );
  return serializeTensor(output);
};

document.estimateDepth = async (file, outputCanvas) => {
  return new Promise((resolve, reject) => {
    let url = URL.createObjectURL(file);
    let img = new Image();

    img.onload = async () => {
      URL.revokeObjectURL(img.src);
      let canvas = document.createElement("canvas");
      canvas.width = img.width;
      canvas.height = img.height;
      let ctx = canvas.getContext("2d");
      ctx.drawImage(img, 0, 0);
      let imageData = ctx.getImageData(0, 0, img.width, img.height);
      let input = tf.browser.fromPixels(imageData);
      const originalSize = [input.shape[0], input.shape[1]];
      input = tf.image.resizeBilinear(input, [256, 256]);
      let output = deserializeTensor(
        await infer(serializeTensor(input))
      );
      output = tf.image.resizeBilinear(output, originalSize);
      resolve(tf.browser.toPixels(output, outputCanvas));
    };
    img.src = url;
  });
};