import vtkDataArray from '@kitware/vtk.js/Common/Core/DataArray';
import vtkPoints from '@kitware/vtk.js/Common/Core/Points';
import vtkPolyData from '@kitware/vtk.js/Common/DataModel/PolyData';
import {TypedArray} from '@kitware/vtk.js/types';
import {mat4, vec3} from 'gl-matrix';

import ActorEntry from '@/library/actors';
import {Bone, bones} from '@/library/bones';
import {
	type DigitalTwinData,
	type DigitalTwinPolygonType,
	type ReamData,
	type Vector3D,
} from '@/library/digital-twins';
import {getAxisVector, getPosition} from '@/library/vtk/math';
import {PointCloudColorScheme} from '@/library/vtk/point-cloud-color-schemes';
import {getViewports} from '@/state/viewports';
import {useDensityMapStore} from '@/state/density-map';
import {useDigitalTwinsStore} from '@/state/digital-twins';
import {useResectionPlanesStore} from '@/state/resection-planes';
import {useVisibilityStore} from '@/state/visibility';

import {type Coordinate, type ResectionPlaneKey, type Vector3} from '@/types';

/* eslint-disable-next-line import/no-cycle --
 * TODO: Break this circular dependency
 */
import {calculateGlobalRange, makeLut} from '@/utils/draco-point-cloud-reader';

const getPointCloudActorEntryAndBackup = ({
	bone,
}: {
	bone: Bone;
}): {actorEntry: ActorEntry; backupActorEntry: ActorEntry} => {
	const {
		femoralActor,
		femoralActorBackup,
		fibularActor,
		fibularActorBackup,
		patellarActor,
		patellarActorBackup,
		tibialActor,
		tibialActorBackup,
	} = getPointCloudAndBackupActors();

	const boneToActorEntry: Record<Bone, ActorEntry> = {
		femur: femoralActor,
		fibula: fibularActor,
		patella: patellarActor,
		tibia: tibialActor,
	};

	const boneToActorEntryBackup: Record<Bone, ActorEntry> = {
		femur: femoralActorBackup,
		fibula: fibularActorBackup,
		patella: patellarActorBackup,
		tibia: tibialActorBackup,
	};

	return {
		actorEntry: boneToActorEntry[bone],
		backupActorEntry: boneToActorEntryBackup[bone],
	};
};

const getPointCloudAndBackupActors = () => {
	const {pointClouds, pointCloudsBackup} = useDensityMapStore.getState();

	const femoralActor = pointClouds.find(
		(pointcloud) => pointcloud.id === 'femur_pointcloud',
	);

	const femoralActorBackup = pointCloudsBackup.find(
		(pointcloud) => pointcloud.id === 'femur_pointcloud',
	);

	const tibialActor = pointClouds.find(
		(pointcloud) => pointcloud.id === 'tibia_pointcloud',
	);

	const tibialActorBackup = pointCloudsBackup.find(
		(pointcloud) => pointcloud.id === 'tibia_pointcloud',
	);

	const patellarActor = pointClouds.find(
		(pointcloud) => pointcloud.id === 'patella_pointcloud',
	);

	const patellarActorBackup = pointCloudsBackup.find(
		(pointcloud) => pointcloud.id === 'patella_pointcloud',
	);

	const fibularActor = pointClouds.find(
		(pointcloud) => pointcloud.id === 'fibula_pointcloud',
	);

	const fibularActorBackup = pointCloudsBackup.find(
		(pointcloud) => pointcloud.id === 'fibula_pointcloud',
	);

	if (!femoralActor) {
		throw new Error('Femoral actor not found');
	}

	if (!femoralActorBackup) {
		throw new Error('Femoral actor backup not found');
	}

	if (!tibialActor) {
		throw new Error('Tibial actor not found');
	}

	if (!tibialActorBackup) {
		throw new Error('Tibial actor backup not found');
	}

	if (!patellarActor) {
		throw new Error('Patellar actor not found');
	}

	if (!patellarActorBackup) {
		throw new Error('Patellar actor backup not found');
	}

	if (!fibularActor) {
		throw new Error('Fibular actor not found');
	}

	if (!fibularActorBackup) {
		throw new Error('Fibular actor backup not found');
	}

	return {
		femoralActor,
		femoralActorBackup,
		fibularActor,
		fibularActorBackup,
		patellarActor,
		patellarActorBackup,
		tibialActor,
		tibialActorBackup,
	};
};

function restoreActorFromBackup({
	actorEntry,
	backupActorEntry,
}: {
	actorEntry: ActorEntry;
	backupActorEntry: ActorEntry;
}) {
	if (!actorEntry.actor?.getMapper()?.getInputData()) {
		throw new Error(`Missing properties on actorEntry`);
	}

	if (!backupActorEntry.actor?.getMapper()?.getInputData()) {
		throw new Error(`Missing properties on backupActorEntry`);
	}

	const originalPolyData = backupActorEntry.actor.getMapper()?.getInputData();
	const clonedPolyData = vtkPolyData.newInstance();
	clonedPolyData.shallowCopy(originalPolyData);

	actorEntry.actor.getMapper()?.setInputData(clonedPolyData);
}

export function hidePointsInPointClouds() {
	for (const bone of bones) {
		hidePointsInPointCloud(bone);
	}
}

export function hidePointsInPointCloud(bone: Bone) {
	const {digitalTwins: digitalTwinsVisibility} = useVisibilityStore.getState();
	const {digitalTwins, mode} = useDigitalTwinsStore.getState();
	const {
		resectionPlanes,
		selectedPairKey,
		visibility: resectionPlanesVisibility,
	} = useResectionPlanesStore.getState();
	const {actorEntry, backupActorEntry} = getPointCloudActorEntryAndBackup({
		bone,
	});
	restoreActorFromBackup({actorEntry, backupActorEntry});

	const boneLabelToResectionPlaneMap: Record<
		'femur' | 'tibia',
		ResectionPlaneKey
	> = {
		femur: 'femoral',
		tibia: 'tibial',
	};

	if (resectionPlanesVisibility) {
		for (const resectionPlanePair of Object.values(resectionPlanes)) {
			for (const resectionPlane of Object.values(resectionPlanePair).filter(
				(resectionPlane) =>
					resectionPlane.pair === selectedPairKey &&
					resectionPlane.plane ===
						boneLabelToResectionPlaneMap[actorEntry.label as 'femur' | 'tibia'],
			)) {
				const {origin, normal} = resectionPlane;
				const backupPolyData = actorEntry.actor.getMapper()?.getInputData();
				const backupPoints = backupPolyData.getPoints();
				const backupScalars = backupPolyData
					.getPointData()
					.getScalars()
					.getData();

				const {points, scalars} = findPointsAndScalarsRelativeToPlane({
					backupPoints,
					backupScalars,
					planeOrigin: origin,
					planeNormal: normal,
				});

				showOnlySelectedPoints({
					actorEntry,
					points,
					scalars,
				});
			}
		}
	}

	if (digitalTwinsVisibility) {
		if (mode === 'remaining') {
			for (const digitalTwin of digitalTwins) {
				if (digitalTwin.bone === bone) {
					const startingPolyData = actorEntry.actor.getMapper()?.getInputData();
					const startingPoints = startingPolyData.getPoints();
					const startingScalars = startingPolyData
						.getPointData()
						.getScalars()
						.getData();
					const position = getPosition(digitalTwin.actor);
					const {points, scalars} = findPointsAndScalarsRelativeToShape({
						digitalTwin,
						shape: digitalTwin.type,
						center: [position.x, position.y, position.z],
						radius: (digitalTwin as ReamData).radius,
						startingPoints,
						startingScalars,
						filterPerspective: 'outside',
						...(digitalTwin.type === 'drill' && {
							height: digitalTwin.height,
							direction: digitalTwin.directions.z,
						}),
						...(digitalTwin.type === 'resect' && {
							depth: digitalTwin.depth,
							height: digitalTwin.height,
							width: digitalTwin.width,
							rotation: digitalTwin.rotation,
						}),
					});

					showOnlySelectedPoints({
						actorEntry,
						points,
						scalars,
					});
				}
			}
		} else if (mode === 'removed') {
			const uniquePointIdentifiers = new Set<string>();
			const accumulatedPoints: number[] = [];
			const accumulatedScalars: number[] = [];

			const startingPolyData = actorEntry.actor.getMapper()?.getInputData();
			const startingPoints = startingPolyData.getPoints();
			const startingScalars = startingPolyData
				.getPointData()
				.getScalars()
				.getData();

			for (const digitalTwin of digitalTwins) {
				if (digitalTwin.bone === bone) {
					const position = getPosition(digitalTwin.actor);
					const {points, scalars} = findPointsAndScalarsRelativeToShape({
						digitalTwin,
						shape: digitalTwin.type,
						center: [position.x, position.y, position.z],
						radius: (digitalTwin as ReamData).radius,
						startingPoints,
						startingScalars,
						filterPerspective: 'inside',
						...(digitalTwin.type === 'drill' && {
							height: digitalTwin.height,
							direction: digitalTwin.directions.z,
						}),
						...(digitalTwin.type === 'resect' && {
							depth: digitalTwin.depth,
							height: digitalTwin.height,
							width: digitalTwin.width,
							rotation: digitalTwin.rotation,
						}),
					});

					for (let index = 0; index < points.length; index += 3) {
						const pointString = `${points[index]},${points[index + 1]},${
							points[index + 2]
						}`;

						if (!uniquePointIdentifiers.has(pointString)) {
							uniquePointIdentifiers.add(pointString);
							accumulatedPoints.push(
								points[index],
								points[index + 1],
								points[index + 2],
							);
							accumulatedScalars.push(scalars[index / 3]);
						}
					}
				}
			}

			showOnlySelectedPoints({
				actorEntry,
				points: new Float32Array(accumulatedPoints),
				scalars: new Float32Array(accumulatedScalars),
			});
		}
	}
}

function findPointsAndScalarsRelativeToPlane({
	backupPoints,
	backupScalars,
	planeOrigin,
	planeNormal,
}: {
	backupPoints: vtkPoints;
	backupScalars: number[] | TypedArray;
	planeOrigin: Vector3;
	planeNormal: Vector3;
}): {
	points: Float32Array;
	scalars: Float32Array;
} {
	const [x1, y1, z1] = planeOrigin;
	const [a, b, c] = planeNormal;
	const numberOfPoints = backupPoints.getNumberOfValues() / 3;

	const filteredPointsArray = new Float32Array(numberOfPoints * 3);
	const filteredScalarsArray = new Float32Array(numberOfPoints);
	let count = 0;

	for (let index = 0; index < numberOfPoints; index++) {
		const [x, y, z] = backupPoints.getPoint(index);
		const dotProduct = (x - x1) * a + (y - y1) * b + (z - z1) * c;
		const isVisible = dotProduct > 0;

		if (isVisible) {
			filteredPointsArray.set([x, y, z], count * 3);
			filteredScalarsArray[count] = backupScalars[index];
			count++;
		}
	}

	return {
		points: filteredPointsArray.subarray(0, count * 3),
		scalars: filteredScalarsArray.subarray(0, count),
	};
}

function transformPointRelativeToRotation(
	point: Vector3,
	center: Vector3,
	rotation: Vector3D,
): Vector3 {
	// Create a rotation matrix
	const rotationMatrix = mat4.create();

	// Translate to the center of rotation (resect's center)
	mat4.translate(rotationMatrix, rotationMatrix, center as vec3);

	// Apply rotations with X-axis rotation negated
	mat4.rotateZ(rotationMatrix, rotationMatrix, -rotation.z * (Math.PI / 180));
	mat4.rotateY(rotationMatrix, rotationMatrix, -rotation.y * (Math.PI / 180));
	mat4.rotateX(rotationMatrix, rotationMatrix, -rotation.x * (Math.PI / 180));

	// Translate back
	mat4.translate(
		rotationMatrix,
		rotationMatrix,
		vec3.negate(vec3.create(), center as vec3),
	);

	// Transform the point
	const transformedPoint = vec3.create();
	vec3.transformMat4(transformedPoint, point as vec3, rotationMatrix);

	return [...transformedPoint] as Vector3;
}

// eslint-disable-next-line complexity
function findPointsAndScalarsRelativeToShape({
	digitalTwin,
	center,
	depth,
	direction,
	filterPerspective = 'inside',
	height,
	startingPoints,
	startingScalars,
	radius,
	rotation,
	shape,
	width,
}: {
	digitalTwin: DigitalTwinData;
	center: Vector3;
	depth?: number;
	direction?: Coordinate;
	filterPerspective: 'inside' | 'outside';
	height?: number;
	startingPoints: vtkPoints;
	startingScalars: number[] | TypedArray;
	radius: number;
	rotation?: Vector3D;
	shape: DigitalTwinPolygonType;
	width?: number;
}): {
	points: Float32Array;
	scalars: Float32Array;
} {
	const numberOfPoints = startingPoints.getNumberOfValues() / 3;

	const filteredPointsArray = new Float32Array(numberOfPoints * 3);
	const filteredScalarsArray = new Float32Array(numberOfPoints);
	let count = 0;

	if (shape === 'drill' && height && direction) {
		const normalizedAxis = getAxisVector('z', digitalTwin);
		const halfHeight = height / 2;

		for (let index = 0; index < numberOfPoints; index++) {
			const point = startingPoints.getPoint(index) as Vector3;
			const vectorToPoint = vec3.fromValues(
				point[0] - center[0],
				point[1] - center[1],
				point[2] - center[2],
			);
			const projectionDistance = Math.abs(
				vec3.dot(vectorToPoint, [
					normalizedAxis.x,
					normalizedAxis.y,
					normalizedAxis.z,
				]),
			);
			const crossProduct = vec3.cross(vec3.create(), vectorToPoint, [
				normalizedAxis.x,
				normalizedAxis.y,
				normalizedAxis.z,
			]);
			const perpendicularDistance = vec3.length(crossProduct);

			const isInsideDrill =
				projectionDistance <= halfHeight && perpendicularDistance <= radius;
			if (
				(filterPerspective === 'inside' && isInsideDrill) ||
				(filterPerspective === 'outside' && !isInsideDrill)
			) {
				filteredPointsArray.set(point, count * 3);
				filteredScalarsArray[count] = startingScalars[index];
				count++;
			}
		}
	} else if (shape === 'ream') {
		const radiusSquared = radius * radius;
		for (let index = 0; index < numberOfPoints; index++) {
			const point = startingPoints.getPoint(index) as Vector3;
			const distanceSquared =
				(point[0] - center[0]) ** 2 +
				(point[1] - center[1]) ** 2 +
				(point[2] - center[2]) ** 2;
			const isInside = distanceSquared <= radiusSquared;
			if (
				(filterPerspective === 'inside' && isInside) ||
				(filterPerspective === 'outside' && !isInside)
			) {
				filteredPointsArray.set(point, count * 3);
				filteredScalarsArray[count] = startingScalars[index];
				count++;
			}
		}
	} else if (shape === 'resect' && width && height && depth && rotation) {
		const halfWidth = width / 2; // X-axis
		const halfDepth = depth / 2; // Y-axis
		const halfHeight = height / 2; // Z-axis

		for (let index = 0; index < numberOfPoints; index++) {
			const point = startingPoints.getPoint(index) as Vector3;
			const transformedPoint = transformPointRelativeToRotation(
				point,
				center,
				rotation,
			);

			const isInsideResect =
				Math.abs(transformedPoint[0] - center[0]) <= halfWidth &&
				Math.abs(transformedPoint[1] - center[1]) <= halfDepth && // Height check on Y-axis
				Math.abs(transformedPoint[2] - center[2]) <= halfHeight; // Depth check on Z-axis

			if (
				(filterPerspective === 'inside' && isInsideResect) ||
				(filterPerspective === 'outside' && !isInsideResect)
			) {
				filteredPointsArray.set(point, count * 3);
				filteredScalarsArray[count] = startingScalars[index];
				count++;
			}
		}
	}

	return {
		points: filteredPointsArray.subarray(0, count * 3),
		scalars: filteredScalarsArray.subarray(0, count),
	};
}

function showOnlySelectedPoints({
	actorEntry,
	points,
	scalars,
}: {
	actorEntry: ActorEntry;
	points: Float32Array;
	scalars: Float32Array;
}): void {
	const polyData = actorEntry.actor.getMapper()?.getInputData();
	const newPoints = vtkPoints.newInstance();

	newPoints.setData(points, 3);
	polyData.setPoints(newPoints);

	polyData.getPointData().setScalars(
		vtkDataArray.newInstance({
			numberOfComponents: 1,
			values: scalars,
			name: 'Scalars',
		}),
	);
}

export function setPointCloudColorScheme({
	colorScheme,
	threshold,
}: {
	colorScheme: PointCloudColorScheme;
	threshold: {
		lower: number;
		// eslint-disable-next-line @typescript-eslint/no-explicit-any
		lut?: any;
		max: number;
		min: number;
		upper: number;
	};
}): {
	globalLut: ReturnType<typeof makeLut>;
	globalRange: ReturnType<typeof calculateGlobalRange>;
} {
	const {pointClouds} = useDensityMapStore.getState();
	const actors = pointClouds.map((actorEntry) => actorEntry.actor);

	const globalRange = calculateGlobalRange(actors);
	const globalLut = makeLut(globalRange, colorScheme);

	for (const {actor} of pointClouds) {
		actor
			.getMapper()
			?.getLookupTable()
			.setRange([globalRange.min, globalRange.max]);
		actor.getMapper()?.setLookupTable(globalLut);
	}

	globalLut.setThresholdRange(threshold.lower, threshold.upper);

	if (colorScheme.mapColorsToThreshold) {
		globalLut.setMappingRange(threshold.lower, threshold.upper);
	}

	const {volume: volumeViewport} = getViewports();
	volumeViewport.render();

	return {globalLut, globalRange};
}

export function setPointCloudOpacity(opacity: number) {
	const {pointClouds} = useDensityMapStore.getState();

	for (const {actor} of pointClouds) {
		actor.getProperty().setOpacity(opacity / 100);
	}

	const {volume: volumeViewport} = getViewports();
	volumeViewport.render();
}

export function setPointCloudThreshold({
	colorScheme,
	lower,
	lut,
	upper,
}: {
	colorScheme: PointCloudColorScheme;
	lower: number;
	// eslint-disable-next-line @typescript-eslint/no-explicit-any
	lut: any;
	upper: number;
}) {
	lut.setThresholdRange(lower, upper);

	if (colorScheme.mapColorsToThreshold) {
		lut.setMappingRange(lower, upper);
	}

	const {volume: volumeViewport} = getViewports();
	volumeViewport.render();
}

export const pointCloudsMaxOpacity = 1;

export const pointCloudsPointSize = 1;
