Files
temp_SSA_SCAN/lib/ml/inference-pipeline.ts

115 lines
3.8 KiB
TypeScript

import type { DetectionResult } from './types';
import { CLASS_LABELS, VALIDATION_RULES } from './model-config';
/**
* A temporal filter to smooth detections and reduce flickering.
*/
class TemporalFilter {
private history: (DetectionResult | null)[] = [];
private frameCount = 0;
constructor(private consistencyFrames: number) {
this.history = new Array(consistencyFrames).fill(null);
}
add(detection: DetectionResult | null): DetectionResult | null {
this.history.shift();
this.history.push(detection);
const recentDetections = this.history.filter(d => d !== null);
if (recentDetections.length >= this.consistencyFrames) {
// Return the most confident detection from the recent history
return recentDetections.reduce((prev, current) => (prev!.confidence > current!.confidence ? prev : current));
}
return null;
}
}
/**
* The InferencePipeline class handles post-processing of model outputs,
* including filtering, validation, and temporal smoothing to prevent false positives.
*/
export class InferencePipeline {
private temporalFilter: TemporalFilter;
constructor() {
this.temporalFilter = new TemporalFilter(VALIDATION_RULES.temporalConsistencyFrames);
}
/**
* Processes the raw output from the TensorFlow.js model.
* @param boxes Raw bounding boxes.
* @param scores Raw confidence scores.
* @param classes Raw class indices.
* @param confidenceThreshold The current confidence threshold.
* @returns A single, validated DetectionResult or null.
*/
process(boxes: number[], scores: number[], classes: number[], confidenceThreshold: number): DetectionResult | null {
const detections: DetectionResult[] = [];
// Process up to 5 detections like the working implementation
for (let i = 0; i < Math.min(5, scores.length); i++) {
const score = scores[i];
// Convert to percentage and check threshold like working implementation
const scorePercent = score * 100;
if (scorePercent < (confidenceThreshold * 100)) continue;
// Extract bounding box [y_min, x_min, y_max, x_max] like working implementation
const yMin = boxes[i * 4];
const xMin = boxes[i * 4 + 1];
const yMax = boxes[i * 4 + 2];
const xMax = boxes[i * 4 + 3];
// Convert to [x, y, width, height] format
const bbox: [number, number, number, number] = [xMin, yMin, xMax - xMin, yMax - yMin];
const detection: DetectionResult = {
bbox,
confidence: score,
class: 'shoe', // Assume all detections are shoes
timestamp: Date.now()
};
if (this.isValid(detection)) {
detections.push(detection);
}
}
if (detections.length === 0) {
return this.temporalFilter.add(null);
}
// Get the single best detection
const bestDetection = detections.reduce((prev, current) => (prev.confidence > current.confidence ? prev : current));
return this.temporalFilter.add(bestDetection);
}
/**
* Validates a detection against a set of rules.
* @param detection The detection to validate.
* @returns True if the detection is valid, false otherwise.
*/
private isValid(detection: DetectionResult): boolean {
const { bbox } = detection;
const [, , width, height] = bbox;
// Bounding box size validation (relative to a 320x320 input)
const boxPixelWidth = width * 320;
const boxPixelHeight = height * 320;
if (boxPixelWidth < VALIDATION_RULES.minBoundingBoxSize || boxPixelHeight < VALIDATION_RULES.minBoundingBoxSize) {
return false;
}
// Aspect ratio validation
const aspectRatio = boxPixelWidth / boxPixelHeight;
if (aspectRatio < VALIDATION_RULES.aspectRatioRange[0] || aspectRatio > VALIDATION_RULES.aspectRatioRange[1]) {
return false;
}
return true;
}
}