269 lines
7.2 KiB
TypeScript
269 lines
7.2 KiB
TypeScript
import type { ModelInfo } from './types';
|
|
|
|
const DB_NAME = 'ShoeDetectionModels';
|
|
const DB_VERSION = 1;
|
|
const STORE_NAME = 'models';
|
|
|
|
export interface CachedModel {
|
|
id: string;
|
|
variant: 'quantized' | 'standard' | 'full';
|
|
data: ArrayBuffer;
|
|
metadata: ModelInfo;
|
|
timestamp: number;
|
|
version: string;
|
|
}
|
|
|
|
/**
|
|
* IndexedDB-based model cache for TensorFlow.js models
|
|
*/
|
|
export class ModelCache {
|
|
private db: IDBDatabase | null = null;
|
|
private initPromise: Promise<void> | null = null;
|
|
|
|
constructor() {
|
|
this.initPromise = this.init();
|
|
}
|
|
|
|
/**
|
|
* Initialize IndexedDB
|
|
*/
|
|
private async init(): Promise<void> {
|
|
return new Promise((resolve, reject) => {
|
|
const request = indexedDB.open(DB_NAME, DB_VERSION);
|
|
|
|
request.onerror = () => {
|
|
console.error('Failed to open IndexedDB:', request.error);
|
|
reject(request.error);
|
|
};
|
|
|
|
request.onsuccess = () => {
|
|
this.db = request.result;
|
|
resolve();
|
|
};
|
|
|
|
request.onupgradeneeded = (event) => {
|
|
const db = (event.target as IDBOpenDBRequest).result;
|
|
|
|
// Create models store
|
|
if (!db.objectStoreNames.contains(STORE_NAME)) {
|
|
const store = db.createObjectStore(STORE_NAME, { keyPath: 'id' });
|
|
store.createIndex('variant', 'variant', { unique: false });
|
|
store.createIndex('timestamp', 'timestamp', { unique: false });
|
|
}
|
|
};
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Ensure database is ready
|
|
*/
|
|
private async ensureReady(): Promise<void> {
|
|
if (this.initPromise) {
|
|
await this.initPromise;
|
|
}
|
|
if (!this.db) {
|
|
throw new Error('Database not initialized');
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Cache a model in IndexedDB
|
|
*/
|
|
async cacheModel(variant: 'quantized' | 'standard' | 'full', modelData: ArrayBuffer, metadata: ModelInfo): Promise<void> {
|
|
await this.ensureReady();
|
|
|
|
return new Promise((resolve, reject) => {
|
|
const transaction = this.db!.transaction([STORE_NAME], 'readwrite');
|
|
const store = transaction.objectStore(STORE_NAME);
|
|
|
|
const cachedModel: CachedModel = {
|
|
id: `shoe-detection-${variant}`,
|
|
variant,
|
|
data: modelData,
|
|
metadata,
|
|
timestamp: Date.now(),
|
|
version: '1.0.0'
|
|
};
|
|
|
|
const request = store.put(cachedModel);
|
|
|
|
request.onsuccess = () => {
|
|
console.log(`Model ${variant} cached successfully`);
|
|
resolve();
|
|
};
|
|
|
|
request.onerror = () => {
|
|
console.error(`Failed to cache model ${variant}:`, request.error);
|
|
reject(request.error);
|
|
};
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Retrieve a cached model
|
|
*/
|
|
async getCachedModel(variant: 'quantized' | 'standard' | 'full'): Promise<CachedModel | null> {
|
|
await this.ensureReady();
|
|
|
|
return new Promise((resolve, reject) => {
|
|
const transaction = this.db!.transaction([STORE_NAME], 'readonly');
|
|
const store = transaction.objectStore(STORE_NAME);
|
|
const request = store.get(`shoe-detection-${variant}`);
|
|
|
|
request.onsuccess = () => {
|
|
resolve(request.result || null);
|
|
};
|
|
|
|
request.onerror = () => {
|
|
reject(request.error);
|
|
};
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Check if a model is cached and up to date
|
|
*/
|
|
async isModelCached(variant: 'quantized' | 'standard' | 'full', requiredVersion: string): Promise<boolean> {
|
|
try {
|
|
const cached = await this.getCachedModel(variant);
|
|
return cached !== null && cached.version === requiredVersion;
|
|
} catch (error) {
|
|
console.error('Error checking cached model:', error);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Download and cache a model
|
|
*/
|
|
async downloadAndCacheModel(variant: 'quantized' | 'standard' | 'full', modelInfo: ModelInfo, onProgress?: (progress: number) => void): Promise<ArrayBuffer> {
|
|
console.log(`Downloading model ${variant} from ${modelInfo.url}`);
|
|
|
|
const response = await fetch(modelInfo.url);
|
|
|
|
if (!response.ok) {
|
|
throw new Error(`Failed to download model: ${response.statusText}`);
|
|
}
|
|
|
|
const contentLength = response.headers.get('content-length');
|
|
const total = contentLength ? parseInt(contentLength, 10) : 0;
|
|
let loaded = 0;
|
|
|
|
const reader = response.body?.getReader();
|
|
const chunks: Uint8Array[] = [];
|
|
|
|
if (!reader) {
|
|
throw new Error('Failed to get response reader');
|
|
}
|
|
|
|
while (true) {
|
|
const { done, value } = await reader.read();
|
|
|
|
if (done) break;
|
|
|
|
chunks.push(value);
|
|
loaded += value.length;
|
|
|
|
if (onProgress && total > 0) {
|
|
onProgress((loaded / total) * 100);
|
|
}
|
|
}
|
|
|
|
// Combine chunks into single ArrayBuffer
|
|
const totalLength = chunks.reduce((acc, chunk) => acc + chunk.length, 0);
|
|
const result = new Uint8Array(totalLength);
|
|
let offset = 0;
|
|
|
|
for (const chunk of chunks) {
|
|
result.set(chunk, offset);
|
|
offset += chunk.length;
|
|
}
|
|
|
|
const modelData = result.buffer;
|
|
|
|
// Cache the model
|
|
await this.cacheModel(variant, modelData, modelInfo);
|
|
|
|
return modelData;
|
|
}
|
|
|
|
/**
|
|
* Get or download a model
|
|
*/
|
|
async getModel(variant: 'quantized' | 'standard' | 'full', modelInfo: ModelInfo, onProgress?: (progress: number) => void): Promise<ArrayBuffer> {
|
|
// Check if model is already cached
|
|
const isCache = await this.isModelCached(variant, '1.0.0');
|
|
|
|
if (isCache) {
|
|
console.log(`Using cached model ${variant}`);
|
|
const cached = await this.getCachedModel(variant);
|
|
return cached!.data;
|
|
}
|
|
|
|
// Download and cache the model
|
|
return await this.downloadAndCacheModel(variant, modelInfo, onProgress);
|
|
}
|
|
|
|
/**
|
|
* Clear old cached models
|
|
*/
|
|
async clearOldModels(maxAge: number = 7 * 24 * 60 * 60 * 1000): Promise<void> {
|
|
await this.ensureReady();
|
|
|
|
const cutoffTime = Date.now() - maxAge;
|
|
|
|
return new Promise((resolve, reject) => {
|
|
const transaction = this.db!.transaction([STORE_NAME], 'readwrite');
|
|
const store = transaction.objectStore(STORE_NAME);
|
|
const index = store.index('timestamp');
|
|
const range = IDBKeyRange.upperBound(cutoffTime);
|
|
|
|
const request = index.openCursor(range);
|
|
|
|
request.onsuccess = (event) => {
|
|
const cursor = (event.target as IDBRequest).result;
|
|
|
|
if (cursor) {
|
|
cursor.delete();
|
|
cursor.continue();
|
|
} else {
|
|
console.log('Old models cleared');
|
|
resolve();
|
|
}
|
|
};
|
|
|
|
request.onerror = () => {
|
|
reject(request.error);
|
|
};
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Get cache storage usage
|
|
*/
|
|
async getCacheStats(): Promise<{ totalSize: number; modelCount: number; models: string[] }> {
|
|
await this.ensureReady();
|
|
|
|
return new Promise((resolve, reject) => {
|
|
const transaction = this.db!.transaction([STORE_NAME], 'readonly');
|
|
const store = transaction.objectStore(STORE_NAME);
|
|
const request = store.getAll();
|
|
|
|
request.onsuccess = () => {
|
|
const models = request.result as CachedModel[];
|
|
const totalSize = models.reduce((sum, model) => sum + model.data.byteLength, 0);
|
|
const modelNames = models.map(m => m.variant);
|
|
|
|
resolve({
|
|
totalSize,
|
|
modelCount: models.length,
|
|
models: modelNames
|
|
});
|
|
};
|
|
|
|
request.onerror = () => {
|
|
reject(request.error);
|
|
};
|
|
});
|
|
}
|
|
} |