115 lines
3.8 KiB
TypeScript
115 lines
3.8 KiB
TypeScript
import type { DetectionResult } from './types';
|
|
import { VALIDATION_RULES } from './model-config';
|
|
|
|
/**
|
|
* A temporal filter to smooth detections and reduce flickering.
|
|
*/
|
|
class TemporalFilter {
|
|
private history: (DetectionResult | null)[] = [];
|
|
private frameCount = 0;
|
|
|
|
constructor(private consistencyFrames: number) {
|
|
this.history = new Array(consistencyFrames).fill(null);
|
|
}
|
|
|
|
add(detection: DetectionResult | null): DetectionResult | null {
|
|
this.history.shift();
|
|
this.history.push(detection);
|
|
|
|
const recentDetections = this.history.filter(d => d !== null);
|
|
if (recentDetections.length >= this.consistencyFrames) {
|
|
// Return the most confident detection from the recent history
|
|
return recentDetections.reduce((prev, current) => (prev!.confidence > current!.confidence ? prev : current));
|
|
}
|
|
|
|
return null;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* The InferencePipeline class handles post-processing of model outputs,
|
|
* including filtering, validation, and temporal smoothing to prevent false positives.
|
|
*/
|
|
export class InferencePipeline {
|
|
private temporalFilter: TemporalFilter;
|
|
|
|
constructor() {
|
|
this.temporalFilter = new TemporalFilter(VALIDATION_RULES.temporalConsistencyFrames);
|
|
}
|
|
|
|
/**
|
|
* Processes the raw output from the TensorFlow.js model.
|
|
* @param boxes Raw bounding boxes.
|
|
* @param scores Raw confidence scores.
|
|
* @param classes Raw class indices.
|
|
* @param confidenceThreshold The current confidence threshold.
|
|
* @returns A single, validated DetectionResult or null.
|
|
*/
|
|
process(boxes: number[], scores: number[], classes: number[], confidenceThreshold: number): DetectionResult | null {
|
|
const detections: DetectionResult[] = [];
|
|
|
|
// Process up to 5 detections like the working implementation
|
|
for (let i = 0; i < Math.min(5, scores.length); i++) {
|
|
const score = scores[i];
|
|
|
|
// Convert to percentage and check threshold like working implementation
|
|
const scorePercent = score * 100;
|
|
if (scorePercent < (confidenceThreshold * 100)) continue;
|
|
|
|
// Extract bounding box [y_min, x_min, y_max, x_max] like working implementation
|
|
const yMin = boxes[i * 4];
|
|
const xMin = boxes[i * 4 + 1];
|
|
const yMax = boxes[i * 4 + 2];
|
|
const xMax = boxes[i * 4 + 3];
|
|
|
|
// Convert to [x, y, width, height] format
|
|
const bbox: [number, number, number, number] = [xMin, yMin, xMax - xMin, yMax - yMin];
|
|
|
|
const detection: DetectionResult = {
|
|
bbox,
|
|
confidence: score,
|
|
class: 'shoe', // Assume all detections are shoes
|
|
timestamp: Date.now()
|
|
};
|
|
|
|
if (this.isValid(detection)) {
|
|
detections.push(detection);
|
|
}
|
|
}
|
|
|
|
if (detections.length === 0) {
|
|
return this.temporalFilter.add(null);
|
|
}
|
|
|
|
// Get the single best detection
|
|
const bestDetection = detections.reduce((prev, current) => (prev.confidence > current.confidence ? prev : current));
|
|
|
|
return this.temporalFilter.add(bestDetection);
|
|
}
|
|
|
|
/**
|
|
* Validates a detection against a set of rules.
|
|
* @param detection The detection to validate.
|
|
* @returns True if the detection is valid, false otherwise.
|
|
*/
|
|
private isValid(detection: DetectionResult): boolean {
|
|
const { bbox } = detection;
|
|
const [, , width, height] = bbox;
|
|
|
|
// Bounding box size validation (relative to a 320x320 input)
|
|
const boxPixelWidth = width * 320;
|
|
const boxPixelHeight = height * 320;
|
|
if (boxPixelWidth < VALIDATION_RULES.minBoundingBoxSize || boxPixelHeight < VALIDATION_RULES.minBoundingBoxSize) {
|
|
return false;
|
|
}
|
|
|
|
// Aspect ratio validation
|
|
const aspectRatio = boxPixelWidth / boxPixelHeight;
|
|
if (aspectRatio < VALIDATION_RULES.aspectRatioRange[0] || aspectRatio > VALIDATION_RULES.aspectRatioRange[1]) {
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
}
|