Files
temp_SSA_SCAN/lib/ml/inference-pipeline.ts

109 lines
3.5 KiB
TypeScript

import type { DetectionResult } from './types';
import { CLASS_LABELS, VALIDATION_RULES } from './model-config';
/**
* A temporal filter to smooth detections and reduce flickering.
*/
class TemporalFilter {
private history: (DetectionResult | null)[] = [];
private frameCount = 0;
constructor(private consistencyFrames: number) {
this.history = new Array(consistencyFrames).fill(null);
}
add(detection: DetectionResult | null): DetectionResult | null {
this.history.shift();
this.history.push(detection);
const recentDetections = this.history.filter(d => d !== null);
if (recentDetections.length >= this.consistencyFrames) {
// Return the most confident detection from the recent history
return recentDetections.reduce((prev, current) => (prev!.confidence > current!.confidence ? prev : current));
}
return null;
}
}
/**
* The InferencePipeline class handles post-processing of model outputs,
* including filtering, validation, and temporal smoothing to prevent false positives.
*/
export class InferencePipeline {
private temporalFilter: TemporalFilter;
constructor() {
this.temporalFilter = new TemporalFilter(VALIDATION_RULES.temporalConsistencyFrames);
}
/**
* Processes the raw output from the TensorFlow.js model.
* @param boxes Raw bounding boxes.
* @param scores Raw confidence scores.
* @param classes Raw class indices.
* @param confidenceThreshold The current confidence threshold.
* @returns A single, validated DetectionResult or null.
*/
process(boxes: number[], scores: number[], classes: number[], confidenceThreshold: number): DetectionResult | null {
const detections: DetectionResult[] = [];
for (let i = 0; i < scores.length; i++) {
const score = scores[i];
if (score < confidenceThreshold) continue;
const classIndex = classes[i];
const className = CLASS_LABELS[classIndex];
if (className !== 'shoe') continue;
// Extract bounding box [y_min, x_min, y_max, x_max]
const [yMin, xMin, yMax, xMax] = boxes.slice(i * 4, (i + 1) * 4);
const bbox: [number, number, number, number] = [xMin, yMin, xMax - xMin, yMax - yMin];
const detection: DetectionResult = {
bbox,
confidence: score,
class: className,
};
if (this.isValid(detection)) {
detections.push(detection);
}
}
if (detections.length === 0) {
return this.temporalFilter.add(null);
}
// Get the single best detection
const bestDetection = detections.reduce((prev, current) => (prev.confidence > current.confidence ? prev : current));
return this.temporalFilter.add(bestDetection);
}
/**
* Validates a detection against a set of rules.
* @param detection The detection to validate.
* @returns True if the detection is valid, false otherwise.
*/
private isValid(detection: DetectionResult): boolean {
const { bbox } = detection;
const [, , width, height] = bbox;
// Bounding box size validation (relative to a 320x320 input)
const boxPixelWidth = width * 320;
const boxPixelHeight = height * 320;
if (boxPixelWidth < VALIDATION_RULES.minBoundingBoxSize || boxPixelHeight < VALIDATION_RULES.minBoundingBoxSize) {
return false;
}
// Aspect ratio validation
const aspectRatio = boxPixelWidth / boxPixelHeight;
if (aspectRatio < VALIDATION_RULES.aspectRatioRange[0] || aspectRatio > VALIDATION_RULES.aspectRatioRange[1]) {
return false;
}
return true;
}
}