import * as tf from '@tensorflow/tfjs'
import palettes from '@/assets/palettes.json'

const cv = window.self.cv

function getParametersStitch() {
  let paramTypes = new cv.IntVector()
  let paramValues = new cv.FloatVector()
  const paramsArr = [
    {
      id: 581, // Rectify
      value: 1
    },
    {
      id: 305, // Surface projection
      value: 550
    },
    {
      id: 610, // Wave correction
      value: 612
    },
    {
      id: 583, // Bundle adjustement type
      value: 706
    },
    {
      id: 582, // Cam Estimation
      value: 1
    }
  ]

  paramsArr.forEach((param) => {
    paramTypes.push_back(param.id)
    paramValues.push_back(param.value)
  })

  return {
    paramTypes,
    paramValues
  }
}

function imageDataFromMat(mat) {
  if (!(mat instanceof cv.Mat)) {
    throw new Error('not a valid opencv Mat instance')
  }

  if (mat.rows == 0 || mat.cols == 0) {
    return null
  }

  // convert the mat type to cv.CV_8U
  const img = new cv.Mat()
  const depth = mat.type() % 8
  const scale =
    depth <= cv.CV_8S ? 1.0 : depth <= cv.CV_32S ? 1.0 / 256.0 : 255.0
  const shift = depth === cv.CV_8S || depth === cv.CV_16S ? 128.0 : 0.0
  mat.convertTo(img, cv.CV_8U, scale, shift)

  // convert the img type to cv.CV_8UC4
  switch (img.type()) {
    case cv.CV_8UC1:
      cv.cvtColor(img, img, cv.COLOR_GRAY2RGBA)
      break
    case cv.CV_8UC3:
      cv.cvtColor(img, img, cv.COLOR_RGB2RGBA)
      break
    case cv.CV_8UC4:
      break
    default:
      throw new Error(
        'Bad number of channels (Source image must have 1, 3 or 4 channels)'
      )
  }
  const clampedArray = new ImageData(
    new Uint8ClampedArray(img.data),
    img.cols,
    img.rows
  )
  img.delete()
  return clampedArray
}

function applyPalette(temps, palette, params = {}) {
  // Get temperature ranges
  const tempMin = params.tempMin ? tf.tensor(params.tempMin) : temps.min()
  const tempMax = params.tempMax ? tf.tensor(params.tempMax) : temps.max()
  const NUM_STEPS = palette.shape[0] - 1
  //Normalize
  temps = temps.sub(tempMin)
  temps = temps.div(tempMax.sub(tempMin))
  temps = temps.mul(NUM_STEPS)
  temps = temps.clipByValue(0, NUM_STEPS)
  temps = tf.cast(temps, 'int32')

  // Extract RGB colors from palette
  temps = palette.gather(temps)

  // Add transparency channel
  const a_channel = tf.zeros([temps.shape[0], temps.shape[1], 1]).add(255)

  temps = tf.concat([temps, a_channel], 2)

  return temps
}

function getImageDataFromImageArray(image, paletteName, usedTemps) {
  const palette = tf.tensor(palettes[paletteName])
  let stacked = tf.tensor(image)
  const temps = {
    max: parseFloat(`${usedTemps.max}`),
    min: parseFloat(`${usedTemps.min}`)
  }
  let transformed = null

  transformed = applyPalette(stacked, palette, {
    tempMin: temps.min,
    tempMax: temps.max
  })

  return new ImageData(
    Uint8ClampedArray.from(transformed.dataSync()),
    transformed.shape[1],
    transformed.shape[0]
  )
}

export default {
  stitch: function (images = [], paletteName, usedTemps) {
    if (images.length <= 0) return false
    const cv = window.self.cv

    let mMultiStitchImages = new cv.MatVector()
    for (const image of images) {
      let mat = cv.matFromImageData(
        getImageDataFromImageArray(image, paletteName, usedTemps)
      )
      mMultiStitchImages.push_back(mat)
    }

    let mImgStitch = new cv.ImgStitch(mMultiStitchImages)

    for (let i = 0; i < mMultiStitchImages.size(); ++i) {
      mMultiStitchImages.get(i).delete()
    }

    mMultiStitchImages.delete()
    mMultiStitchImages = null

    let fieldsOfView = new cv.FloatVector()
    for (let i = 0; i < images.length; i++) {
      fieldsOfView.push_back(0)
    }
    let stitchedImage = new cv.Mat()

    const paramsStitcher = getParametersStitch()

    mImgStitch.set(paramsStitcher.paramTypes, paramsStitcher.paramValues)

    mImgStitch.stitch(fieldsOfView, stitchedImage)
    return imageDataFromMat(stitchedImage)
  },
  getImageDataFromImageArray
}
