import vtkActor from '@kitware/vtk.js/Rendering/Core/Actor';
import vtkRenderer from '@kitware/vtk.js/Rendering/Core/Renderer';
import {mat4, type quat, vec3} from 'gl-matrix';
import {type RefObject} from 'react';

import {type DigitalTwinData, type Vector3D} from '@/library/digital-twins';
import {getViewports} from '@/state/viewports';

export function decomposeXYZRotationMatrix(
	matrix: mat4,
	rangerot2 = 0,
): Vector3D {
	let rot1 = 0;
	let rot3 = 0;

	if (rangerot2 === 0) {
		rot1 = Math.atan2(-matrix[6], matrix[10]); // Z rotation
		rot3 = Math.atan2(-matrix[1], matrix[0]); // Y rotation
	}

	const rot2 = Math.asin(matrix[2]);
	const RAD2DEG = 180 / Math.PI;

	return {
		x: rot1 * RAD2DEG,
		y: rot2 * RAD2DEG,
		z: rot3 * RAD2DEG,
	};
}

export function getAxisVector(
	axis: 'x' | 'y' | 'z',
	selectedActorData: DigitalTwinData,
): Vector3D {
	const {matrix} = selectedActorData;

	switch (axis) {
		case 'x': {
			return {x: matrix[0], y: matrix[4], z: matrix[8]};
		}

		case 'y': {
			return {x: matrix[1], y: matrix[5], z: matrix[9]};
		}

		case 'z': {
			return {x: matrix[2], y: matrix[6], z: matrix[10]};
		}

		default: {
			throw new Error(`Invalid axis ${axis}`);
		}
	}
}

export function getPosition(actor: vtkActor): Vector3D {
	return {
		x: actor.getMatrix()[3],
		y: actor.getMatrix()[7],
		z: actor.getMatrix()[11],
	};
}

export function setPosition(actor: vtkActor, position: Vector3D) {
	actor.getUserMatrix()[12] = position.x;
	actor.getUserMatrix()[13] = position.y;
	actor.getUserMatrix()[14] = position.z;
	actor.modified();
}

export function getRotation(actor: vtkActor): Vector3D {
	return decomposeXYZRotationMatrix(actor.getMatrix());
}

export function setRotation(actor: vtkActor, quaternion: quat) {
	// Get the position, so that we can set the matrix
	const pos = getPosition(actor);
	const mat = mat4.fromQuat(mat4.create(), quaternion);
	mat[12] = pos.x;
	mat[13] = pos.y;
	mat[14] = pos.z;
	setTransformationMatrix(actor, mat);
}

export const getScreenEventPositionFor = (
	container: RefObject<HTMLElement>,
	source: MouseEvent,
) => {
	const {
		volume: {openGLRenderWindow},
	} = getViewports();

	const containerElement = container.current;
	if (!containerElement || !openGLRenderWindow) {
		return {x: 0, y: 0};
	}

	const containerRect = containerElement.getBoundingClientRect();
	const [canvasWidth, canvasHeight] = openGLRenderWindow.getSize();

	// Calculate the scaling factors
	const scaleX = canvasWidth / containerRect.width;
	const scaleY = canvasHeight / containerRect.height;

	// Calculate the position relative to the container
	const relativeX = source.clientX - containerRect.left;
	const relativeY = source.clientY - containerRect.top;

	// Apply scaling and invert Y-coordinate
	return {
		x: Math.round(relativeX * scaleX),
		y: Math.round((containerRect.height - relativeY) * scaleY),
	};
};

export function setTransformationMatrix(actor: vtkActor, matrix: mat4) {
	actor.setPosition(0, 0, 0);
	actor.setOrientation(0, 0, 0);
	actor.setScale(1, 1, 1);
	actor.setUserMatrix(matrix);
	actor.modified();
}

export function rayPlaneIntersection(
	cameraPosition: vec3,
	ray: Vector3D,
	planeOrigin: Vector3D,
	planeNormal: Vector3D,
): Vector3D {
	const denominator = vec3.dot(
		[ray.x, ray.y, ray.z],
		[planeNormal.x, planeNormal.y, planeNormal.z],
	);
	// Check if the ray is parallel to the plane
	const rayLength =
		vec3.dot(
			vec3.sub(
				vec3.create(),
				[planeOrigin.x, planeOrigin.y, planeOrigin.z],
				cameraPosition,
			),
			[planeNormal.x, planeNormal.y, planeNormal.z],
		) / denominator;
	const result = vec3.scaleAndAdd(
		vec3.create(),
		cameraPosition,
		[ray.x, ray.y, ray.z],
		rayLength,
	);
	return {x: result[0], y: result[1], z: result[2]};
}

export function getScreenPositionRay(
	screenX: number,
	screenY: number,
	ren: vtkRenderer,
): Vector3D {
	// First, we are getting the camera location and focal point (in world coords)
	const camera = ren.getActiveCamera();
	const cameraPos = camera.getPosition();
	const cameraFocalPoint = camera.getFocalPoint();

	// Now, we need to get some information about the screen view for defining normalized display coords
	const view = ren.getRenderWindow()?.getViews()[0];
	const dimensions = view.getViewportSize(ren);
	const aspect = dimensions[0] / dimensions[1];

	// Now, we are determining the z location of the image plane
	let displayCoords = [];
	displayCoords = ren.worldToNormalizedDisplay(
		cameraFocalPoint[0],
		cameraFocalPoint[1],
		cameraFocalPoint[2],
		aspect,
	);
	displayCoords = view.normalizedDisplayToDisplay(
		displayCoords[0],
		displayCoords[1],
		displayCoords[2],
	);

	const normalizedDisplay = view.displayToNormalizedDisplay(
		screenX,
		screenY,
		displayCoords[2],
		aspect,
	);
	const worldCoords = ren.normalizedDisplayToWorld(
		normalizedDisplay[0],
		normalizedDisplay[1],
		normalizedDisplay[2],
		aspect,
	);
	return {
		x: worldCoords[0] - cameraPos[0],
		y: worldCoords[1] - cameraPos[1],
		z: worldCoords[2] - cameraPos[2],
	};
}
