JAX-sklearn is a drop-in replacement for scikit-learn that provides automatic JAX acceleration for machine learning algorithms while maintaining 100% API compatibility.
- Drop-in Replacement: Use
import xlearn as sklearn— no code changes needed - JAX Acceleration: 4-20x speedup on CPU, 100x+ on GPU/TPU
- Multi-Hardware: CPU, NVIDIA GPU (CUDA), Apple Silicon (Metal), Google TPU
- Auto Fallback: Graceful degradation when JAX is unavailable
- Prerequisite for Secret-Learn: Privacy-preserving federated ML with SecretFlow
# Apple Silicon (recommended)
uv pip install 'jax-sklearn[jax-metal]' --system
# NVIDIA GPU
uv pip install 'jax-sklearn[jax-gpu]' --system
# CPU only
uv pip install 'jax-sklearn[jax-cpu]' --systemOr with pip:
pip install jax-sklearnBuild prerequisites: When installing from source (no wheel for your platform), you need C/C++ tooling and Python headers. See Troubleshooting below.
import xlearn as sklearn
from xlearn.linear_model import LinearRegression
model = LinearRegression()
model.fit(X, y)
predictions = model.predict(X_test)
# JAX acceleration applied automatically when beneficialJAX acceleration can be configured via xlearn._jax:
import xlearn._jax as jax_config
# Default: always enable JAX (best for GPU/TPU)
jax_config.set_config(enable_jax=True)
# Threshold mode: only use JAX for large datasets (CPU users)
jax_config.set_config(enable_jax=True, jax_auto_threshold=True)
# Disable JAX (pure sklearn)
jax_config.set_config(enable_jax=False)| Data Size | XLearn | sklearn | Speedup |
|---|---|---|---|
| 10K × 100 | 0.0097s | 0.0113s | 1.16x |
| 10K × 1K | 0.0384s | 0.1590s | 4.14x |
| 10K × 10K | 2.82s | 55.96s | 19.86x |
| 50K × 2K | 0.54s | 1.96s | 3.60x |
| 100K × 1K | 0.40s | 1.23s | 3.04x |
JAX has ~0.2s JIT compilation overhead on first run. Crossover point is ~10K × 100 on CPU.
| Hardware | Small Data | Medium Data | Large Data |
|---|---|---|---|
| CPU | ~1x | 0.2-0.5x | 4-20x |
| Metal (M1-M4) | ~1x | 1.5-2x | 2-3x |
| CUDA GPU | 1-2x | 5-10x | 50-100x |
| TPU | 2-5x | 10-20x | 100x+ |
| Operation | Size | JAX | NumPy/SciPy | Speedup |
|---|---|---|---|---|
| Matrix Multiply | 5000×5000 | 0.012s | 0.122s | 9.9x 🚀 |
| RBF Kernel (pairwise distance) | 5K×100 | 0.0001s | 0.239s | 3571x 🚀 |
JAX acceleration is most impactful on GPU/CUDA — the benchmarks above are on CPU only.
- Linear Models: LinearRegression, Ridge, Lasso, ElasticNet
- Clustering: KMeans
- Decomposition: PCA, TruncatedSVD
- Preprocessing: StandardScaler, MinMaxScaler
- Gaussian Process: GaussianProcessRegressor (JIT kernel + Cholesky)
- SVM: SVC, SVR, LinearSVC — predict/decision_function (JIT kernel eval)
- Neural Network: MLPClassifier, MLPRegressor — predict (JIT forward pass)
- Decision Trees / Random Forest: predict/predict_proba (JIT tree walk)
All other scikit-learn algorithms are fully available via automatic fallback to the original implementation.
- Python: 3.10+
- JAX: 0.4.20+
- NumPy: 1.22.0+, SciPy: 1.8.0+
Hardware-specific: NVIDIA GPU (CUDA 11.1+), Apple Silicon (macOS 12+), or Google TPU.
# Install Python headers (Linux)
sudo apt-get install python3-dev # Debian/Ubuntu
sudo dnf install python3-devel # RHEL/Fedora
# macOS
xcode-select --install
# Disable build isolation if needed
pip install --no-build-isolation jax-sklearnimport jax
print("Devices:", jax.devices())
print("Backend:", jax.default_backend())git clone https://github.com/chenxingqiang/jax-sklearn.git
cd jax-sklearn
pip install -e ".[tests]"
pytest xlearn/tests/ -vJAX-sklearn is BSD 3-Clause licensed.
- Secret-Learn: Privacy-preserving federated ML with SecretFlow (348 algorithm implementations)