import type { DetectionResult } from './types'; import { CLASS_LABELS, VALIDATION_RULES } from './model-config'; /** * A temporal filter to smooth detections and reduce flickering. */ class TemporalFilter { private history: (DetectionResult | null)[] = []; private frameCount = 0; constructor(private consistencyFrames: number) { this.history = new Array(consistencyFrames).fill(null); } add(detection: DetectionResult | null): DetectionResult | null { this.history.shift(); this.history.push(detection); const recentDetections = this.history.filter(d => d !== null); if (recentDetections.length >= this.consistencyFrames) { // Return the most confident detection from the recent history return recentDetections.reduce((prev, current) => (prev!.confidence > current!.confidence ? prev : current)); } return null; } } /** * The InferencePipeline class handles post-processing of model outputs, * including filtering, validation, and temporal smoothing to prevent false positives. */ export class InferencePipeline { private temporalFilter: TemporalFilter; constructor() { this.temporalFilter = new TemporalFilter(VALIDATION_RULES.temporalConsistencyFrames); } /** * Processes the raw output from the TensorFlow.js model. * @param boxes Raw bounding boxes. * @param scores Raw confidence scores. * @param classes Raw class indices. * @param confidenceThreshold The current confidence threshold. * @returns A single, validated DetectionResult or null. */ process(boxes: number[], scores: number[], classes: number[], confidenceThreshold: number): DetectionResult | null { const detections: DetectionResult[] = []; // Process up to 5 detections like the working implementation for (let i = 0; i < Math.min(5, scores.length); i++) { const score = scores[i]; // Convert to percentage and check threshold like working implementation const scorePercent = score * 100; if (scorePercent < (confidenceThreshold * 100)) continue; // Extract bounding box [y_min, x_min, y_max, x_max] like working implementation const yMin = boxes[i * 4]; const xMin = boxes[i * 4 + 1]; const yMax = boxes[i * 4 + 2]; const xMax = boxes[i * 4 + 3]; // Convert to [x, y, width, height] format const bbox: [number, number, number, number] = [xMin, yMin, xMax - xMin, yMax - yMin]; const detection: DetectionResult = { bbox, confidence: score, class: 'shoe', // Assume all detections are shoes timestamp: Date.now() }; if (this.isValid(detection)) { detections.push(detection); } } if (detections.length === 0) { return this.temporalFilter.add(null); } // Get the single best detection const bestDetection = detections.reduce((prev, current) => (prev.confidence > current.confidence ? prev : current)); return this.temporalFilter.add(bestDetection); } /** * Validates a detection against a set of rules. * @param detection The detection to validate. * @returns True if the detection is valid, false otherwise. */ private isValid(detection: DetectionResult): boolean { const { bbox } = detection; const [, , width, height] = bbox; // Bounding box size validation (relative to a 320x320 input) const boxPixelWidth = width * 320; const boxPixelHeight = height * 320; if (boxPixelWidth < VALIDATION_RULES.minBoundingBoxSize || boxPixelHeight < VALIDATION_RULES.minBoundingBoxSize) { return false; } // Aspect ratio validation const aspectRatio = boxPixelWidth / boxPixelHeight; if (aspectRatio < VALIDATION_RULES.aspectRatioRange[0] || aspectRatio > VALIDATION_RULES.aspectRatioRange[1]) { return false; } return true; } }