import { LocalFaceRecognitionModel } from "../interfaces/LocalFaceRecognitionModel"
import { BlazeFaceModel, NormalizedFace } from "@tensorflow-models/blazeface"

const backgroundFillStyle : string = 'rgba(255, 0, 0, 0.2)'
const primaryFillStyle : string = 'rgba(0, 0, 255, 0.2)'

export default class LocalBlazeFaceModel implements LocalFaceRecognitionModel {
    model : BlazeFaceModel | null
    predictions : NormalizedFace[]
    primaryPrediction : NormalizedFace | null

    async load(): Promise<void> {
        this.model = await require('@tensorflow-models/blazeface').load()
    }

    async updatePredictions(canvas: HTMLCanvasElement, context: CanvasRenderingContext2D, drawPredictionsToCanvas: boolean): Promise<void> {
        if (!this.model) {
            return
        }

        if (drawPredictionsToCanvas) {
            this.drawPredictions(context)
        }

        const returnTensors = false
        const rawPredictions : any = await this.model.estimateFaces(canvas, returnTensors)
        let maxArea = 0

        this.predictions = []
        this.primaryPrediction = null
        for (let i = 0; i < rawPredictions.length; i++) {
            const prediction = rawPredictions[i]

            if (prediction.probability[0] && prediction.probability[0] > .95) {
                const start : any = prediction.topLeft
                const end : any = prediction.bottomRight
                const area = Math.abs(end[0] - start[0] * end[1] - start[1])

                this.predictions.push(prediction)
                if (area > maxArea) {
                    maxArea = area
                    this.primaryPrediction = prediction;
                }
            }
        }
    }

    drawPredictions(context: CanvasRenderingContext2D) {
        if (this.predictions.length > 0) {
            for (let i = 0; i < this.predictions.length; i++) {
                const prediction = this.predictions[i]
                if (prediction !== this.primaryPrediction) {
                    this.drawPrediction(this.predictions[i], context, backgroundFillStyle)
                }
            }
            if (this.primaryPrediction) {
                this.drawPrediction(this.primaryPrediction, context, primaryFillStyle)
            }
        }
    }

    drawPrediction(prediction: any, context: CanvasRenderingContext2D, fillStyle: string) {
        const start: any = prediction.topLeft
        const end: any = prediction.bottomRight
        const size = [end[0] - start[0], end[1] - start[1]]

        // Render a rectangle over each detected face.
        context.fillStyle = fillStyle
        context.fillRect(start[0], start[1], size[0], size[1])
        context.fillStyle = 'rgba(255, 255, 255, 0.5)'
        for (let j = 0; j < prediction.landmarks.length; j++) {
            const landmark = prediction.landmarks[j]
            context.beginPath()
            context.arc(landmark[0], landmark[1], 2, 0, 2 * Math.PI, true)
            context.fill()
        }
    }

    hasValidFace() : boolean {
        return !!this.primaryPrediction
    }

    constructor() {
        this.model = null
        this.predictions = []
        this.primaryPrediction = null
    }
}
