Files
temp_SSA_SCAN/lib/ml/model-cache.ts

269 lines
7.2 KiB
TypeScript

import type { ModelInfo } from './types';
const DB_NAME = 'ShoeDetectionModels';
const DB_VERSION = 1;
const STORE_NAME = 'models';
export interface CachedModel {
id: string;
variant: 'quantized' | 'standard' | 'full';
data: ArrayBuffer;
metadata: ModelInfo;
timestamp: number;
version: string;
}
/**
* IndexedDB-based model cache for TensorFlow.js models
*/
export class ModelCache {
private db: IDBDatabase | null = null;
private initPromise: Promise<void> | null = null;
constructor() {
this.initPromise = this.init();
}
/**
* Initialize IndexedDB
*/
private async init(): Promise<void> {
return new Promise((resolve, reject) => {
const request = indexedDB.open(DB_NAME, DB_VERSION);
request.onerror = () => {
console.error('Failed to open IndexedDB:', request.error);
reject(request.error);
};
request.onsuccess = () => {
this.db = request.result;
resolve();
};
request.onupgradeneeded = (event) => {
const db = (event.target as IDBOpenDBRequest).result;
// Create models store
if (!db.objectStoreNames.contains(STORE_NAME)) {
const store = db.createObjectStore(STORE_NAME, { keyPath: 'id' });
store.createIndex('variant', 'variant', { unique: false });
store.createIndex('timestamp', 'timestamp', { unique: false });
}
};
});
}
/**
* Ensure database is ready
*/
private async ensureReady(): Promise<void> {
if (this.initPromise) {
await this.initPromise;
}
if (!this.db) {
throw new Error('Database not initialized');
}
}
/**
* Cache a model in IndexedDB
*/
async cacheModel(variant: 'quantized' | 'standard' | 'full', modelData: ArrayBuffer, metadata: ModelInfo): Promise<void> {
await this.ensureReady();
return new Promise((resolve, reject) => {
const transaction = this.db!.transaction([STORE_NAME], 'readwrite');
const store = transaction.objectStore(STORE_NAME);
const cachedModel: CachedModel = {
id: `shoe-detection-${variant}`,
variant,
data: modelData,
metadata,
timestamp: Date.now(),
version: '1.0.0'
};
const request = store.put(cachedModel);
request.onsuccess = () => {
console.log(`Model ${variant} cached successfully`);
resolve();
};
request.onerror = () => {
console.error(`Failed to cache model ${variant}:`, request.error);
reject(request.error);
};
});
}
/**
* Retrieve a cached model
*/
async getCachedModel(variant: 'quantized' | 'standard' | 'full'): Promise<CachedModel | null> {
await this.ensureReady();
return new Promise((resolve, reject) => {
const transaction = this.db!.transaction([STORE_NAME], 'readonly');
const store = transaction.objectStore(STORE_NAME);
const request = store.get(`shoe-detection-${variant}`);
request.onsuccess = () => {
resolve(request.result || null);
};
request.onerror = () => {
reject(request.error);
};
});
}
/**
* Check if a model is cached and up to date
*/
async isModelCached(variant: 'quantized' | 'standard' | 'full', requiredVersion: string): Promise<boolean> {
try {
const cached = await this.getCachedModel(variant);
return cached !== null && cached.version === requiredVersion;
} catch (error) {
console.error('Error checking cached model:', error);
return false;
}
}
/**
* Download and cache a model
*/
async downloadAndCacheModel(variant: 'quantized' | 'standard' | 'full', modelInfo: ModelInfo, onProgress?: (progress: number) => void): Promise<ArrayBuffer> {
console.log(`Downloading model ${variant} from ${modelInfo.url}`);
const response = await fetch(modelInfo.url);
if (!response.ok) {
throw new Error(`Failed to download model: ${response.statusText}`);
}
const contentLength = response.headers.get('content-length');
const total = contentLength ? parseInt(contentLength, 10) : 0;
let loaded = 0;
const reader = response.body?.getReader();
const chunks: Uint8Array[] = [];
if (!reader) {
throw new Error('Failed to get response reader');
}
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
loaded += value.length;
if (onProgress && total > 0) {
onProgress((loaded / total) * 100);
}
}
// Combine chunks into single ArrayBuffer
const totalLength = chunks.reduce((acc, chunk) => acc + chunk.length, 0);
const result = new Uint8Array(totalLength);
let offset = 0;
for (const chunk of chunks) {
result.set(chunk, offset);
offset += chunk.length;
}
const modelData = result.buffer;
// Cache the model
await this.cacheModel(variant, modelData, modelInfo);
return modelData;
}
/**
* Get or download a model
*/
async getModel(variant: 'quantized' | 'standard' | 'full', modelInfo: ModelInfo, onProgress?: (progress: number) => void): Promise<ArrayBuffer> {
// Check if model is already cached
const isCache = await this.isModelCached(variant, '1.0.0');
if (isCache) {
console.log(`Using cached model ${variant}`);
const cached = await this.getCachedModel(variant);
return cached!.data;
}
// Download and cache the model
return await this.downloadAndCacheModel(variant, modelInfo, onProgress);
}
/**
* Clear old cached models
*/
async clearOldModels(maxAge: number = 7 * 24 * 60 * 60 * 1000): Promise<void> {
await this.ensureReady();
const cutoffTime = Date.now() - maxAge;
return new Promise((resolve, reject) => {
const transaction = this.db!.transaction([STORE_NAME], 'readwrite');
const store = transaction.objectStore(STORE_NAME);
const index = store.index('timestamp');
const range = IDBKeyRange.upperBound(cutoffTime);
const request = index.openCursor(range);
request.onsuccess = (event) => {
const cursor = (event.target as IDBRequest).result;
if (cursor) {
cursor.delete();
cursor.continue();
} else {
console.log('Old models cleared');
resolve();
}
};
request.onerror = () => {
reject(request.error);
};
});
}
/**
* Get cache storage usage
*/
async getCacheStats(): Promise<{ totalSize: number; modelCount: number; models: string[] }> {
await this.ensureReady();
return new Promise((resolve, reject) => {
const transaction = this.db!.transaction([STORE_NAME], 'readonly');
const store = transaction.objectStore(STORE_NAME);
const request = store.getAll();
request.onsuccess = () => {
const models = request.result as CachedModel[];
const totalSize = models.reduce((sum, model) => sum + model.data.byteLength, 0);
const modelNames = models.map(m => m.variant);
resolve({
totalSize,
modelCount: models.length,
models: modelNames
});
};
request.onerror = () => {
reject(request.error);
};
});
}
}