109 lines
3.5 KiB
TypeScript
109 lines
3.5 KiB
TypeScript
import type { DetectionResult } from './types';
|
|
import { CLASS_LABELS, VALIDATION_RULES } from './model-config';
|
|
|
|
/**
|
|
* A temporal filter to smooth detections and reduce flickering.
|
|
*/
|
|
class TemporalFilter {
|
|
private history: (DetectionResult | null)[] = [];
|
|
private frameCount = 0;
|
|
|
|
constructor(private consistencyFrames: number) {
|
|
this.history = new Array(consistencyFrames).fill(null);
|
|
}
|
|
|
|
add(detection: DetectionResult | null): DetectionResult | null {
|
|
this.history.shift();
|
|
this.history.push(detection);
|
|
|
|
const recentDetections = this.history.filter(d => d !== null);
|
|
if (recentDetections.length >= this.consistencyFrames) {
|
|
// Return the most confident detection from the recent history
|
|
return recentDetections.reduce((prev, current) => (prev!.confidence > current!.confidence ? prev : current));
|
|
}
|
|
|
|
return null;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* The InferencePipeline class handles post-processing of model outputs,
|
|
* including filtering, validation, and temporal smoothing to prevent false positives.
|
|
*/
|
|
export class InferencePipeline {
|
|
private temporalFilter: TemporalFilter;
|
|
|
|
constructor() {
|
|
this.temporalFilter = new TemporalFilter(VALIDATION_RULES.temporalConsistencyFrames);
|
|
}
|
|
|
|
/**
|
|
* Processes the raw output from the TensorFlow.js model.
|
|
* @param boxes Raw bounding boxes.
|
|
* @param scores Raw confidence scores.
|
|
* @param classes Raw class indices.
|
|
* @param confidenceThreshold The current confidence threshold.
|
|
* @returns A single, validated DetectionResult or null.
|
|
*/
|
|
process(boxes: number[], scores: number[], classes: number[], confidenceThreshold: number): DetectionResult | null {
|
|
const detections: DetectionResult[] = [];
|
|
|
|
for (let i = 0; i < scores.length; i++) {
|
|
const score = scores[i];
|
|
if (score < confidenceThreshold) continue;
|
|
|
|
const classIndex = classes[i];
|
|
const className = CLASS_LABELS[classIndex];
|
|
if (className !== 'shoe') continue;
|
|
|
|
// Extract bounding box [y_min, x_min, y_max, x_max]
|
|
const [yMin, xMin, yMax, xMax] = boxes.slice(i * 4, (i + 1) * 4);
|
|
const bbox: [number, number, number, number] = [xMin, yMin, xMax - xMin, yMax - yMin];
|
|
|
|
const detection: DetectionResult = {
|
|
bbox,
|
|
confidence: score,
|
|
class: className,
|
|
};
|
|
|
|
if (this.isValid(detection)) {
|
|
detections.push(detection);
|
|
}
|
|
}
|
|
|
|
if (detections.length === 0) {
|
|
return this.temporalFilter.add(null);
|
|
}
|
|
|
|
// Get the single best detection
|
|
const bestDetection = detections.reduce((prev, current) => (prev.confidence > current.confidence ? prev : current));
|
|
|
|
return this.temporalFilter.add(bestDetection);
|
|
}
|
|
|
|
/**
|
|
* Validates a detection against a set of rules.
|
|
* @param detection The detection to validate.
|
|
* @returns True if the detection is valid, false otherwise.
|
|
*/
|
|
private isValid(detection: DetectionResult): boolean {
|
|
const { bbox } = detection;
|
|
const [, , width, height] = bbox;
|
|
|
|
// Bounding box size validation (relative to a 320x320 input)
|
|
const boxPixelWidth = width * 320;
|
|
const boxPixelHeight = height * 320;
|
|
if (boxPixelWidth < VALIDATION_RULES.minBoundingBoxSize || boxPixelHeight < VALIDATION_RULES.minBoundingBoxSize) {
|
|
return false;
|
|
}
|
|
|
|
// Aspect ratio validation
|
|
const aspectRatio = boxPixelWidth / boxPixelHeight;
|
|
if (aspectRatio < VALIDATION_RULES.aspectRatioRange[0] || aspectRatio > VALIDATION_RULES.aspectRatioRange[1]) {
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
}
|