import React, { useEffect, useState, useRef, useContext } from 'react'
import * as ort from 'onnxruntime-web'
import { LabelerContext } from '../../../../context/LabelerContext'
import { DispatchContext } from '../../../../context/DispatchContext'
import useResizeCategoryCanvas from '../../hooks/useResizeCategoryCanvas'
// eslint-disable-next-line no-unused-vars
import * as _ from 'underscore'
import useZoom from '../../hooks/useZoom'
import SamButtons from './SamButtons'
import { v4 as uuidv4 } from 'uuid'
import { useEffectWR } from '../../hooks/useEffectWR'
import { useClearCanvas } from '../../hooks/useClearCanvas'
import { useCategoryEventListener } from '../../hooks/useCategoryEventListener'
import { getPointsFromMask } from '../../../../../../io/api/ml'

const { Tensor } = ort
export default function Sam({ labeler, panZoom, mouse, model, tensor, modelScale }) {
  const { state } = useContext(LabelerContext)
  const { dispatch } = useContext(DispatchContext)
  const [points, setPoints] = useState([])

  const { segmentationState, labelerState } = state

  const canvasSam = useRef(document.createElement('canvas', null, null))
  const canvasMaskConfirmed = useRef(document.createElement('canvas', null, null))
  const tempCanvas = useRef(document.createElement('canvas', null, null))
  const requestId = useRef(null)

  const [region, setRegion] = useState({
    toolsPosX: 0,
    toolsPosY: 0,
    isOpened: false
  })

  useEffectWR([labelerState.selectedCategory, labelerState.mode], [cancelPoints])
  useCategoryEventListener('mousedown', mouseDown, labeler.current, 'segmentation')
  useCategoryEventListener('mousemove', mouseMove, labeler.current, 'segmentation')
  useCategoryEventListener('mouseleave', mouseLeave, labeler.current, 'segmentation')

  const drawAllMasks = () => {
    const canvas = canvasMaskConfirmed.current
    const ctx = canvas.getContext('2d')
    ctx.setTransform(1, 0, 0, 1, 0, 0)
    ctx.clearRect(0, 0, canvas.width, canvas.height)
    ctx.setTransform(panZoom.scale, 0, 0, panZoom.scale, panZoom.x, panZoom.y)
    ctx.drawImage(tempCanvas.current, 0, 0)
  }

  const drawPoints = () => {
    const canvas = canvasSam.current
    const ctx = canvas.getContext('2d')
    ctx.setTransform(1, 0, 0, 1, 0, 0)
    ctx.clearRect(0, 0, canvas.width, canvas.height)
    ctx.setTransform(panZoom.scale, 0, 0, panZoom.scale, panZoom.x, panZoom.y)
    points.forEach((point) => {
      ctx.beginPath()
      ctx.fillStyle = point.type === 1 ? '#08CACA' : '#f06543'
      ctx.strokeStyle = '#ffffff'
      ctx.arc(point.x, point.y, 5 / panZoom.scale, 0, 2 * Math.PI)
      ctx.stroke()
      ctx.fill()
    })
  }

  useClearCanvas([canvasSam, tempCanvas, canvasMaskConfirmed], '', cancel)
  useEffectWR([state.labelerState.selectedCategory, state.labelerState.mode], [deleteMask])
  useResizeCategoryCanvas([canvasSam, canvasMaskConfirmed, tempCanvas])
  useZoom(panZoom, mouse.wheel, [], () => {
    drawAllMasks()
    drawPoints()
  })

  function cancel() {
    setRegion({
      toolsPosX: 0,
      toolsPosY: 0,
      isOpened: false
    })
    setPoints([])
  }
  function deleteMask() {
    setPoints([])
    const canvas = tempCanvas.current
    const ctx = canvas.getContext('2d')
    ctx.clearRect(0, 0, canvas.width, canvas.height)
    drawAllMasks()
  }

  function mouseMove() {
    const x = (mouse.x - panZoom.x) / panZoom.scale
    const y = (mouse.y - panZoom.y) / panZoom.scale
    if (requestId.current) {
      cancelAnimationFrame(requestId.current)
    }
    requestId.current = requestAnimationFrame(() => {
      if (points.length === 0 && tensor !== null && model !== null) {
        const click = { x, y, clickType: 1 }
        getSuggestionMask([click])
      }
      requestId.current = null
    })
  }

  function mouseDown(e) {
    if (e.button !== 1 && e.srcElement.localName === 'canvas') {
      const x = (mouse.x - panZoom.x) / panZoom.scale
      const y = (mouse.y - panZoom.y) / panZoom.scale
      if (tensor !== null && model !== null)
        setPoints([...points, { x, y, type: e.button === 0 ? 1 : 0 }])
    }
  }

  function mouseLeave() {
    if (points.length === 0) {
      const canvas = tempCanvas.current
      const ctx = canvas.getContext('2d')
      ctx.clearRect(0, 0, canvas.width, canvas.height)
      drawAllMasks()
    }
  }

  useEffect(() => {
    drawPoints()
    if (points.length > 0 && tensor !== null) {
      setRegion({
        toolsPosX:
          canvasMaskConfirmed.current.width / 3 + canvasMaskConfirmed.current.width / 3 / 2 - 50,
        toolsPosY: 100,
        isOpened: true
      })
      getPointMask()
    }
  }, [points])

  const getSuggestionMask = async (clicks) => {
    const feeds = suggestionModelData(clicks, tensor, modelScale)
    const results = await model.run(feeds)
    const output = results[model.outputNames[0]]
    onnxMaskToImage(
      output.data,
      output.dims[2],
      output.dims[3],
      state.segmentationState.selectedSegmentation.color
    )
    drawAllMasks()
  }

  const getPointMask = async () => {
    if (model === null || tensor === null) return
    const feeds = pointModelData(points, tensor, modelScale)
    if (feeds === undefined) return
    const results = await model.run(feeds)
    const output = results[model.outputNames[0]]
    onnxMaskToImage(
      output.data,
      output.dims[2],
      output.dims[3],
      state.segmentationState.selectedSegmentation.color
    )
    drawAllMasks()
  }

  const suggestionModelData = (clicks, tensor, modelScale) => {
    const imageEmbedding = tensor
    let pointCoords
    let pointLabels

    let pointCoordsTensor
    let pointLabelsTensor

    const n = clicks.length
    // eslint-disable-next-line prefer-const
    pointCoords = new Float32Array(2 * (n + 1))
    // eslint-disable-next-line prefer-const
    pointLabels = new Float32Array(n + 1)

    pointCoords[0] = clicks[0].x * modelScale.samScale
    pointCoords[1] = clicks[0].y * modelScale.samScale
    pointLabels[0] = 1

    // eslint-disable-next-line prefer-const
    pointCoordsTensor = new Tensor('float32', pointCoords, [1, n + 1, 2])
    // eslint-disable-next-line prefer-const
    pointLabelsTensor = new Tensor('float32', pointLabels, [1, n + 1])

    const imageSizeTensor = new Tensor('float32', [modelScale.height, modelScale.width])

    if (pointCoordsTensor === undefined || pointLabelsTensor === undefined) return

    const maskInput = new Tensor('float32', new Float32Array(256 * 256), [1, 1, 256, 256])
    const hasMaskInput = new Tensor('float32', [0])

    return {
      image_embeddings: imageEmbedding,
      point_coords: pointCoordsTensor,
      point_labels: pointLabelsTensor,
      orig_im_size: imageSizeTensor,
      mask_input: maskInput,
      has_mask_input: hasMaskInput
    }
  }
  const pointModelData = (points, tensor, modelScale) => {
    const imageEmbedding = tensor
    let pointCoords
    let pointLabels

    let pointCoordsTensor
    let pointLabelsTensor

    // Check there are input click prompts
    if (points) {
      const n = points.length
      pointCoords = new Float32Array(2 * (n + 1))
      pointLabels = new Float32Array(n + 1)
      for (let i = 0; i < n; i++) {
        pointCoords[2 * i] = points[i].x * modelScale.samScale
        pointCoords[2 * i + 1] = points[i].y * modelScale.samScale
        pointLabels[i] = points[i].type
      }

      pointCoordsTensor = new Tensor('float32', pointCoords, [1, n + 1, 2])
      pointLabelsTensor = new Tensor('float32', pointLabels, [1, n + 1])
    }
    const imageSizeTensor = new Tensor('float32', [modelScale.height, modelScale.width])

    if (pointCoordsTensor === undefined || pointLabelsTensor === undefined) return

    // There is no previous mask, so default to an empty tensor
    const maskInput = new Tensor('float32', new Float32Array(256 * 256), [1, 1, 256, 256])
    // There is no previous mask, so default to 0
    const hasMaskInput = new Tensor('float32', [0])

    return {
      image_embeddings: imageEmbedding,
      point_coords: pointCoordsTensor,
      point_labels: pointLabelsTensor,
      orig_im_size: imageSizeTensor,
      mask_input: maskInput,
      has_mask_input: hasMaskInput
    }
  }

  function hex2rgb(hex) {
    return [
      ('0x' + hex[1] + hex[2]) | 0,
      ('0x' + hex[3] + hex[4]) | 0,
      ('0x' + hex[5] + hex[6]) | 0
    ]
  }
  function arrayToImageData(input, width, height, color) {
    const _color = hex2rgb(color)
    const [r, g, b, a] = [_color[0], _color[1], _color[2], 150]
    const arr = new Uint8ClampedArray(4 * width * height).fill(0)
    for (let i = 0; i < input.length; i++) {
      if (input[i] > 0.0) {
        arr[4 * i + 0] = r
        arr[4 * i + 1] = g
        arr[4 * i + 2] = b
        arr[4 * i + 3] = a
      }
    }
    return new ImageData(arr, height, width)
  }

  function transformImageData(originalImageData) {
    const canvas = tempCanvas.current
    canvas.width = originalImageData.width
    canvas.height = originalImageData.height
    const ctx = canvas.getContext('2d')
    ctx.putImageData(originalImageData, 0, 0)
    return canvas
  }

  function onnxMaskToImage(input, width, height, color) {
    return transformImageData(arrayToImageData(input, width, height, color))
  }

  const confirmMask = async () => {
    try {
      dispatch({ type: 'setSegmentationAiLoading', payload: true })
      const blob = await getBlobFromCanvas(tempCanvas.current)
      const formData = new FormData()
      formData.append('file', blob, 'mask.png')
      const points = await getPointsFromMask(formData)

      let index = segmentationState.selectedSegmentationIndex
      const isDeleteMode = labelerState.mode === 'erase'
      index = isDeleteMode ? '_delete' : index
      points.data.forEach((contour) => {
        const id = uuidv4()
        const s = {
          points: contour.pos,
          completed: true,
          deleted: false,
          type: index,
          id,
          hide: false
        }
        dispatch({ type: 'addSegmentationStack', payload: s })
        dispatch({ type: 'addShape', payload: { index, id } })
      })

      dispatch({ type: 'setSaveTags' })
      dispatch({ type: 'redraw' })
      dispatch({ type: 'setSegmentationAiLoading', payload: false })
      cancelPoints()
    } catch (error) {
      console.log(error)
    }
  }

  function getBlobFromCanvas(canvas) {
    return new Promise((resolve) => {
      canvas.toBlob((blob) => {
        resolve(blob)
      })
    })
  }

  function cancelPoints() {
    setRegion({ ...region, isOpened: false })
    setPoints([])
  }

  return (
    <>
      <canvas className="layout" ref={canvasMaskConfirmed}></canvas>
      <canvas
        className="layout"
        onMouseMove={mouseMove}
        onMouseLeave={mouseLeave}
        ref={canvasSam}
      ></canvas>
      {region.isOpened && (
        <SamButtons
          disabled={segmentationState.loadingAI}
          confirmMask={confirmMask}
          cancel={cancelPoints}
          region={region}
        />
      )}
    </>
  )
}
