TensorFlow.js in React: Building ML-Powered Components

9 min read1784 words

Our product team wanted image moderation without sending user photos to external APIs. Privacy concerns and latency requirements ruled out server-side solutions. I built client-side content classification using TensorFlow.js in React—processing happens entirely in the browser. Here's how to add machine learning capabilities to React components.

Why Client-Side ML

Server-side ML means network latency, API costs, and privacy implications. TensorFlow.js runs models directly in the browser:

  • Privacy: User data never leaves the device
  • Latency: No network round-trip for predictions
  • Offline capability: Works without internet once models load
  • Cost: No per-prediction API charges

The tradeoffs are model size (downloads can be large) and computational limits (browser JavaScript vs. GPU clusters). For many use cases—image classification, pose detection, text toxicity—pre-trained models work well within these constraints.

Setting Up TensorFlow.js

Install the core library and any model packages you need:

npm install @tensorflow/tfjs @tensorflow-models/mobilenet

TensorFlow.js automatically uses WebGL for GPU acceleration when available, falling back to CPU. No configuration needed for basic usage.

Loading Models in React

Models should load once and persist across renders. Use a ref or state to store the loaded model:

import { useEffect, useState, useCallback } from 'react';
import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';
 
interface Prediction {
  className: string;
  probability: number;
}
 
export function useImageClassifier() {
  const [model, setModel] = useState<mobilenet.MobileNet | null>(null);
  const [isLoading, setIsLoading] = useState(true);
  const [error, setError] = useState<Error | null>(null);
 
  useEffect(() => {
    let mounted = true;
 
    async function loadModel() {
      try {
        // Ensure WebGL backend is ready
        await tf.ready();
 
        const loadedModel = await mobilenet.load({
          version: 2,
          alpha: 1.0, // Model size: 0.25, 0.5, 0.75, or 1.0
        });
 
        if (mounted) {
          setModel(loadedModel);
          setIsLoading(false);
        }
      } catch (err) {
        if (mounted) {
          setError(err as Error);
          setIsLoading(false);
        }
      }
    }
 
    loadModel();
 
    return () => {
      mounted = false;
    };
  }, []);
 
  const classify = useCallback(
    async (imageElement: HTMLImageElement): Promise<Prediction[]> => {
      if (!model) {
        throw new Error('Model not loaded');
      }
 
      const predictions = await model.classify(imageElement);
      return predictions;
    },
    [model]
  );
 
  return { classify, isLoading, error, isReady: !!model };
}

The alpha parameter controls model size. Lower values (0.25) mean faster loading and inference but less accuracy. Start with 1.0 and reduce if performance is an issue.

Image Classification Component

Here's a complete component for classifying uploaded images:

import { useState, useRef } from 'react';
import { useImageClassifier } from './useImageClassifier';
 
export function ImageClassifier() {
  const { classify, isLoading, isReady, error } = useImageClassifier();
  const [predictions, setPredictions] = useState<Prediction[]>([]);
  const [isClassifying, setIsClassifying] = useState(false);
  const [imageUrl, setImageUrl] = useState<string | null>(null);
  const imageRef = useRef<HTMLImageElement>(null);
 
  const handleFileChange = (event: React.ChangeEvent<HTMLInputElement>) => {
    const file = event.target.files?.[0];
    if (file) {
      const url = URL.createObjectURL(file);
      setImageUrl(url);
      setPredictions([]);
    }
  };
 
  const handleClassify = async () => {
    if (!imageRef.current || !isReady) return;
 
    setIsClassifying(true);
    try {
      const results = await classify(imageRef.current);
      setPredictions(results);
    } catch (err) {
      console.error('Classification failed:', err);
    } finally {
      setIsClassifying(false);
    }
  };
 
  if (error) {
    return <div className="text-red-600">Failed to load model: {error.message}</div>;
  }
 
  return (
    <div className="max-w-md mx-auto p-6">
      <div className="mb-4">
        <input
          type="file"
          accept="image/*"
          onChange={handleFileChange}
          disabled={isLoading}
          className="block w-full text-sm text-gray-500 file:mr-4 file:py-2 file:px-4 file:rounded file:border-0 file:bg-blue-50 file:text-blue-700"
        />
      </div>
 
      {imageUrl && (
        <div className="mb-4">
          <img
            ref={imageRef}
            src={imageUrl}
            alt="Upload preview"
            className="w-full rounded-lg"
            crossOrigin="anonymous"
          />
        </div>
      )}
 
      <button
        onClick={handleClassify}
        disabled={!isReady || !imageUrl || isClassifying}
        className="w-full py-2 px-4 bg-blue-600 text-white rounded-lg disabled:opacity-50"
      >
        {isLoading ? 'Loading model...' : isClassifying ? 'Classifying...' : 'Classify Image'}
      </button>
 
      {predictions.length > 0 && (
        <div className="mt-4 space-y-2">
          {predictions.map((pred, index) => (
            <div key={index} className="flex justify-between items-center">
              <span className="font-medium">{pred.className}</span>
              <span className="text-gray-600">
                {(pred.probability * 100).toFixed(1)}%
              </span>
            </div>
          ))}
        </div>
      )}
    </div>
  );
}

Real-Time Webcam Detection

For live video analysis, use requestAnimationFrame to continuously process frames:

import { useEffect, useRef, useState, useCallback } from 'react';
import * as cocoSsd from '@tensorflow-models/coco-ssd';
 
interface Detection {
  bbox: [number, number, number, number];
  class: string;
  score: number;
}
 
export function ObjectDetector() {
  const videoRef = useRef<HTMLVideoElement>(null);
  const canvasRef = useRef<HTMLCanvasElement>(null);
  const [model, setModel] = useState<cocoSsd.ObjectDetection | null>(null);
  const [isRunning, setIsRunning] = useState(false);
  const animationRef = useRef<number>();
 
  useEffect(() => {
    cocoSsd.load().then(setModel);
  }, []);
 
  const startWebcam = async () => {
    try {
      const stream = await navigator.mediaDevices.getUserMedia({
        video: { facingMode: 'environment' }
      });
      if (videoRef.current) {
        videoRef.current.srcObject = stream;
        await videoRef.current.play();
        setIsRunning(true);
      }
    } catch (err) {
      console.error('Webcam access denied:', err);
    }
  };
 
  const stopWebcam = () => {
    const stream = videoRef.current?.srcObject as MediaStream;
    stream?.getTracks().forEach(track => track.stop());
    setIsRunning(false);
    if (animationRef.current) {
      cancelAnimationFrame(animationRef.current);
    }
  };
 
  const detectFrame = useCallback(async () => {
    if (!model || !videoRef.current || !canvasRef.current || !isRunning) return;
 
    const video = videoRef.current;
    const canvas = canvasRef.current;
    const ctx = canvas.getContext('2d');
 
    if (!ctx) return;
 
    // Match canvas to video dimensions
    canvas.width = video.videoWidth;
    canvas.height = video.videoHeight;
 
    // Draw current frame
    ctx.drawImage(video, 0, 0);
 
    // Run detection
    const predictions = await model.detect(video);
 
    // Draw bounding boxes
    predictions.forEach((prediction) => {
      const [x, y, width, height] = prediction.bbox;
 
      ctx.strokeStyle = '#00ff00';
      ctx.lineWidth = 2;
      ctx.strokeRect(x, y, width, height);
 
      ctx.fillStyle = '#00ff00';
      ctx.font = '16px Arial';
      ctx.fillText(
        `${prediction.class} (${Math.round(prediction.score * 100)}%)`,
        x,
        y - 5
      );
    });
 
    animationRef.current = requestAnimationFrame(detectFrame);
  }, [model, isRunning]);
 
  useEffect(() => {
    if (isRunning && model) {
      detectFrame();
    }
    return () => {
      if (animationRef.current) {
        cancelAnimationFrame(animationRef.current);
      }
    };
  }, [isRunning, model, detectFrame]);
 
  return (
    <div className="relative">
      <video
        ref={videoRef}
        className="hidden"
        playsInline
        muted
      />
      <canvas
        ref={canvasRef}
        className="w-full rounded-lg"
      />
      <div className="mt-4 flex gap-2">
        <button
          onClick={startWebcam}
          disabled={!model || isRunning}
          className="px-4 py-2 bg-green-600 text-white rounded"
        >
          Start Detection
        </button>
        <button
          onClick={stopWebcam}
          disabled={!isRunning}
          className="px-4 py-2 bg-red-600 text-white rounded"
        >
          Stop
        </button>
      </div>
    </div>
  );
}

Memory Management

TensorFlow.js tensors consume GPU memory. Without cleanup, you'll hit memory limits. Always dispose tensors after use:

import * as tf from '@tensorflow/tfjs';
 
async function processImage(imageElement: HTMLImageElement) {
  // Create tensor from image
  const tensor = tf.browser.fromPixels(imageElement);
 
  try {
    // Process the tensor
    const resized = tf.image.resizeBilinear(tensor, [224, 224]);
    const normalized = resized.div(255.0);
    const batched = normalized.expandDims(0);
 
    // Run prediction
    const prediction = await model.predict(batched);
 
    // Get data and return
    const data = await prediction.data();
    return Array.from(data);
  } finally {
    // Always dispose tensors
    tensor.dispose();
  }
}
 
// Or use tf.tidy for automatic cleanup
function processWithTidy(imageElement: HTMLImageElement) {
  return tf.tidy(() => {
    const tensor = tf.browser.fromPixels(imageElement);
    const resized = tf.image.resizeBilinear(tensor, [224, 224]);
    const normalized = resized.div(255.0);
    return normalized.expandDims(0);
  });
  // All intermediate tensors are automatically disposed
}

Use tf.memory() during development to monitor tensor count:

useEffect(() => {
  const interval = setInterval(() => {
    console.log('TensorFlow memory:', tf.memory());
  }, 5000);
  return () => clearInterval(interval);
}, []);

Web Worker for Heavy Inference

Long-running predictions block the main thread, freezing the UI. Offload to a Web Worker:

// ml.worker.ts
import * as tf from '@tensorflow/tfjs';
 
let model: tf.LayersModel | null = null;
 
self.onmessage = async (event) => {
  const { type, payload } = event.data;
 
  switch (type) {
    case 'LOAD_MODEL':
      try {
        await tf.ready();
        model = await tf.loadLayersModel(payload.modelUrl);
        self.postMessage({ type: 'MODEL_LOADED' });
      } catch (error) {
        self.postMessage({ type: 'ERROR', error: (error as Error).message });
      }
      break;
 
    case 'PREDICT':
      if (!model) {
        self.postMessage({ type: 'ERROR', error: 'Model not loaded' });
        return;
      }
 
      try {
        const inputTensor = tf.tensor(payload.input);
        const prediction = model.predict(inputTensor) as tf.Tensor;
        const result = await prediction.data();
 
        inputTensor.dispose();
        prediction.dispose();
 
        self.postMessage({ type: 'PREDICTION', result: Array.from(result) });
      } catch (error) {
        self.postMessage({ type: 'ERROR', error: (error as Error).message });
      }
      break;
  }
};
// useMLWorker.ts
import { useEffect, useRef, useState, useCallback } from 'react';
 
export function useMLWorker(modelUrl: string) {
  const workerRef = useRef<Worker | null>(null);
  const [isReady, setIsReady] = useState(false);
 
  useEffect(() => {
    workerRef.current = new Worker(new URL('./ml.worker.ts', import.meta.url));
 
    workerRef.current.onmessage = (event) => {
      if (event.data.type === 'MODEL_LOADED') {
        setIsReady(true);
      }
    };
 
    workerRef.current.postMessage({
      type: 'LOAD_MODEL',
      payload: { modelUrl }
    });
 
    return () => {
      workerRef.current?.terminate();
    };
  }, [modelUrl]);
 
  const predict = useCallback((input: number[]): Promise<number[]> => {
    return new Promise((resolve, reject) => {
      if (!workerRef.current) {
        reject(new Error('Worker not initialized'));
        return;
      }
 
      const handler = (event: MessageEvent) => {
        if (event.data.type === 'PREDICTION') {
          workerRef.current?.removeEventListener('message', handler);
          resolve(event.data.result);
        } else if (event.data.type === 'ERROR') {
          workerRef.current?.removeEventListener('message', handler);
          reject(new Error(event.data.error));
        }
      };
 
      workerRef.current.addEventListener('message', handler);
      workerRef.current.postMessage({ type: 'PREDICT', payload: { input } });
    });
  }, []);
 
  return { predict, isReady };
}

Text Classification with Toxicity Model

TensorFlow.js includes pre-trained models for text analysis:

import { useEffect, useState, useCallback } from 'react';
import * as toxicity from '@tensorflow-models/toxicity';
 
interface ToxicityResult {
  label: string;
  match: boolean;
  probability: number;
}
 
export function useToxicityClassifier(threshold = 0.9) {
  const [model, setModel] = useState<toxicity.ToxicityClassifier | null>(null);
 
  useEffect(() => {
    toxicity.load(threshold).then(setModel);
  }, [threshold]);
 
  const classify = useCallback(
    async (text: string): Promise<ToxicityResult[]> => {
      if (!model) throw new Error('Model not loaded');
 
      const predictions = await model.classify([text]);
 
      return predictions.map((prediction) => ({
        label: prediction.label,
        match: prediction.results[0].match ?? false,
        probability: prediction.results[0].probabilities[1],
      }));
    },
    [model]
  );
 
  return { classify, isReady: !!model };
}
 
// Usage
function CommentForm() {
  const { classify, isReady } = useToxicityClassifier();
  const [comment, setComment] = useState('');
  const [warning, setWarning] = useState<string | null>(null);
 
  const handleSubmit = async () => {
    if (!isReady) return;
 
    const results = await classify(comment);
    const toxic = results.filter(r => r.match);
 
    if (toxic.length > 0) {
      setWarning(`Your comment may contain: ${toxic.map(t => t.label).join(', ')}`);
    } else {
      // Submit comment
      setWarning(null);
    }
  };
 
  return (
    <div>
      <textarea
        value={comment}
        onChange={(e) => setComment(e.target.value)}
        placeholder="Write a comment..."
      />
      {warning && <p className="text-yellow-600">{warning}</p>}
      <button onClick={handleSubmit} disabled={!isReady}>
        Post Comment
      </button>
    </div>
  );
}

Custom Model Deployment

To use your own trained model:

  1. Train in Python with TensorFlow/Keras
  2. Convert to TensorFlow.js format:
pip install tensorflowjs
tensorflowjs_converter --input_format=keras \
  model.h5 \
  ./tfjs_model
  1. Host the generated files (model.json and weight shards)
  2. Load in React:
const model = await tf.loadLayersModel('/models/custom/model.json');

For Next.js, place models in the public directory.

Performance Tips

Reduce model size: Use quantized models or smaller architectures. MobileNet with alpha: 0.5 loads faster than alpha: 1.0.

Lazy load models: Don't load ML models on initial page load. Load when the user accesses the feature:

const [model, setModel] = useState(null);
 
const initializeML = async () => {
  if (model) return;
  const loaded = await mobilenet.load();
  setModel(loaded);
};

Cache models: TensorFlow.js caches models in IndexedDB by default. Subsequent loads are faster.

Throttle predictions: For real-time video, don't run predictions on every frame:

const lastPredictionTime = useRef(0);
const THROTTLE_MS = 100;
 
const detectFrame = () => {
  const now = Date.now();
  if (now - lastPredictionTime.current < THROTTLE_MS) {
    requestAnimationFrame(detectFrame);
    return;
  }
  lastPredictionTime.current = now;
  // Run prediction
};

Use appropriate backend: WebGL is fastest, but falls back to CPU. Check with:

console.log('Backend:', tf.getBackend());

Client-side ML opens possibilities that server-side APIs can't match—instant feedback, complete privacy, and offline functionality. TensorFlow.js makes these capabilities accessible in React applications without requiring ML expertise.