195 lines
7.1 KiB
TypeScript
195 lines
7.1 KiB
TypeScript
import type { DetectionConfig, DetectionResult, WorkerMessage, WorkerResponse } from '../lib/ml/types';
|
|
import { InferencePipeline } from '../lib/ml/inference-pipeline';
|
|
|
|
declare const self: DedicatedWorkerGlobalScope;
|
|
|
|
let tfGlobal: any = null;
|
|
let model: any = null;
|
|
let config: DetectionConfig | null = null;
|
|
let pipeline: InferencePipeline | null = null;
|
|
|
|
async function initialize(id: string) {
|
|
console.log('Initializing worker...');
|
|
|
|
tfGlobal = await import('@tensorflow/tfjs');
|
|
await import('@tensorflow/tfjs-backend-webgl');
|
|
|
|
await tfGlobal.setBackend('webgl');
|
|
await tfGlobal.ready();
|
|
console.log('TensorFlow.js backend set to:', tfGlobal.getBackend());
|
|
|
|
pipeline = new InferencePipeline();
|
|
|
|
self.postMessage({ type: 'INITIALIZED', id });
|
|
}
|
|
|
|
async function loadModelWorker(variant: 'quantized' | 'standard' | 'full', modelData: ArrayBuffer, id: string) {
|
|
console.log(`Worker: Loading model ${variant}...`);
|
|
try {
|
|
if (!tfGlobal) {
|
|
throw new Error('TensorFlow.js not initialized');
|
|
}
|
|
|
|
// Use local model files from public folder with full URL for worker context
|
|
const baseUrl = self.location.origin;
|
|
const modelUrls = {
|
|
'quantized': `${baseUrl}/models/model.json`,
|
|
'standard': `${baseUrl}/models/model.json`,
|
|
'full': `${baseUrl}/models/model.json`
|
|
};
|
|
|
|
console.log(`Worker: Loading REAL model from ${modelUrls[variant]}`);
|
|
|
|
// Load the real model like in the working GitHub implementation
|
|
model = await tfGlobal.loadGraphModel(modelUrls[variant]);
|
|
console.log('Worker: Real model loaded successfully', model);
|
|
|
|
// Warm up the model like the working implementation
|
|
if (model && config) {
|
|
console.log('Worker: Warming up model with input size:', config.inputSize);
|
|
const dummyFloat = tfGlobal.zeros([1, ...config.inputSize, 3]);
|
|
const dummyInput = tfGlobal.cast(dummyFloat, 'int32');
|
|
dummyFloat.dispose();
|
|
|
|
const result = await model.executeAsync(
|
|
{ image_tensor: dummyInput },
|
|
['detection_boxes', 'num_detections', 'detection_classes', 'detection_scores']
|
|
);
|
|
console.log('Worker: Warmup result:', result);
|
|
dummyInput.dispose();
|
|
if (Array.isArray(result)) {
|
|
result.forEach(t => t.dispose());
|
|
} else if (result) {
|
|
result.dispose();
|
|
}
|
|
console.log('Worker: Model warmed up successfully.');
|
|
}
|
|
self.postMessage({ type: 'LOADED_MODEL', id });
|
|
} catch (error) {
|
|
console.error(`Worker: Failed to load model ${variant}:`, error);
|
|
self.postMessage({ type: 'ERROR', error: error instanceof Error ? error.message : 'Unknown error during model loading', id });
|
|
}
|
|
}
|
|
|
|
async function configureWorker(newConfig: DetectionConfig, id: string) {
|
|
console.log('Worker: Configuring...');
|
|
config = newConfig;
|
|
self.postMessage({ type: 'CONFIGURED', id });
|
|
}
|
|
|
|
async function detect(imageData: ImageData, id: string) {
|
|
console.log('Worker: detect function called.');
|
|
if (!model || !config || !pipeline) {
|
|
self.postMessage({ type: 'ERROR', error: 'Worker not initialized or configured.', id });
|
|
return;
|
|
}
|
|
|
|
const tensor = tfGlobal.tidy(() => {
|
|
// Convert ImageData to tensor in Web Worker context
|
|
const { data, width, height } = imageData;
|
|
|
|
// In Web Worker, we need to create tensor manually from the pixel data
|
|
// Convert RGBA to RGB by dropping every 4th value (alpha channel)
|
|
const rgbData = new Uint8Array(width * height * 3);
|
|
for (let i = 0; i < width * height; i++) {
|
|
rgbData[i * 3] = data[i * 4]; // R
|
|
rgbData[i * 3 + 1] = data[i * 4 + 1]; // G
|
|
rgbData[i * 3 + 2] = data[i * 4 + 2]; // B
|
|
// Skip alpha channel (data[i * 4 + 3])
|
|
}
|
|
|
|
// Create tensor from RGB data
|
|
const img = tfGlobal.tensor3d(rgbData, [height, width, 3]);
|
|
|
|
// Resize to model input size (300x300) - this returns float32
|
|
const resized = tfGlobal.image.resizeBilinear(img, config!.inputSize);
|
|
|
|
// Cast to int32 as required by the model
|
|
const int32Tensor = tfGlobal.cast(resized, 'int32');
|
|
|
|
return int32Tensor.expandDims(0); // Now properly int32
|
|
});
|
|
|
|
try {
|
|
console.log('Worker: About to execute model with tensor shape:', tensor.shape, 'dtype:', tensor.dtype);
|
|
// Use the same input format as the working implementation
|
|
const result = await model.executeAsync(
|
|
{ image_tensor: tensor },
|
|
['detection_boxes', 'num_detections', 'detection_classes', 'detection_scores']
|
|
);
|
|
tensor.dispose();
|
|
|
|
// Reduced logging for performance
|
|
if (process.env.NODE_ENV === 'development') {
|
|
console.log('Worker: Model execution completed, processing results...');
|
|
}
|
|
|
|
if (!result || !Array.isArray(result) || result.length < 4) {
|
|
console.error('Worker: Invalid model output:', result);
|
|
self.postMessage({ type: 'DETECTION_RESULT', result: null, id });
|
|
return;
|
|
}
|
|
|
|
// Match the working implementation: [boxes, num_detections, classes, scores]
|
|
const [boxes, numDetections, classes, scores] = result;
|
|
console.log('Worker: Extracting data from tensors...');
|
|
|
|
const boxesData = await boxes.data();
|
|
const scoresData = await scores.data();
|
|
const classesData = await classes.data();
|
|
|
|
// Only log detailed outputs when debugging specific issues
|
|
const maxScore = Math.max(...Array.from(scoresData.slice(0, 10)));
|
|
const scoresAbove30 = Array.from(scoresData.slice(0, 10)).filter(s => s > 0.3).length;
|
|
|
|
if (process.env.NODE_ENV === 'development' && (maxScore > 0.3 || scoresAbove30 > 0)) {
|
|
console.log('Worker: Potential detection found:', { maxScore, scoresAbove30 });
|
|
}
|
|
|
|
result.forEach(t => t.dispose());
|
|
|
|
const detectionResult = pipeline.process(
|
|
boxesData as number[],
|
|
scoresData as number[],
|
|
classesData as number[],
|
|
config.confidenceThreshold
|
|
);
|
|
|
|
console.log('Worker detectionResult:', detectionResult);
|
|
|
|
self.postMessage({ type: 'DETECTION_RESULT', result: detectionResult, id });
|
|
} catch (error) {
|
|
tensor.dispose();
|
|
console.error('Worker: Detection execution failed:', error);
|
|
console.error('Worker: Error stack:', error.stack);
|
|
self.postMessage({ type: 'ERROR', error: error instanceof Error ? error.message : 'Detection execution failed', id });
|
|
}
|
|
}
|
|
|
|
(async () => {
|
|
self.onmessage = async (event: MessageEvent<WorkerMessage>) => {
|
|
try {
|
|
switch (event.data.type) {
|
|
case 'INITIALIZE':
|
|
await initialize(event.data.id);
|
|
break;
|
|
case 'LOAD_MODEL':
|
|
await loadModelWorker(event.data.variant, event.data.modelData, event.data.id);
|
|
break;
|
|
case 'CONFIGURE':
|
|
await configureWorker(event.data.config, event.data.id);
|
|
break;
|
|
case 'DETECT':
|
|
await detect(event.data.imageData, event.data.id);
|
|
break;
|
|
case 'UPDATE_CONFIG':
|
|
await configureWorker(event.data.config, event.data.id);
|
|
break;
|
|
default:
|
|
throw new Error(`Unknown message type: ${(event.data as any).type}`);
|
|
}
|
|
} catch (error) {
|
|
self.postMessage({ type: 'ERROR', error: error instanceof Error ? error.message : 'Unknown error in worker', id: event.data.id });
|
|
}
|
|
};
|
|
})(); |