Introduction: The Double-Edged Sword of Big Data
We live in an era of data deluge. From social media feeds to scientific sensors, we generate and collect information faster than we can process it. For machine learning practitioners, this abundance is both a blessing and a curse. While large datasets can fuel highly accurate models, they also create computational bottlenecks—training models becomes slow, costly, and sometimes infeasible.
A well-known remedy for this is dataset distillation or coreset generation. The idea is intuitive: instead of training on millions of records, we “distill” them into a few thousand representative samples that preserve the essence of the original. These smaller, weighted datasets—coresets—can dramatically reduce training time without sacrificing much performance.
But there’s a deeper problem. Big datasets don’t just carry information; they also reflect societal biases. A loan-approval model trained on historical data, for instance, might perpetuate discriminatory patterns against certain demographic groups. When we distill biased data, do we also distill the bias? Or worse—do we amplify it?
This is where the research paper “Fair Wasserstein Coresets” (FWC) steps in. The authors propose a novel coreset technique that addresses both efficiency and fairness. FWC generates a small, weighted synthetic dataset that closely approximates the original distribution and enforces demographic fairness. It does so using the Wasserstein distance, a powerful metric from optimal transport theory, combined with the fairness criterion known as demographic parity.
In this post, we’ll explore how FWC works—both conceptually and practically. We’ll cover:
- The foundations of coresets, Wasserstein distance, and demographic parity
- How the FWC optimization framework combines these ideas
- The Majorization-Minimization algorithm used to efficiently compute fair coresets
- The surprising link between FWC and k-means clustering
- Experimental findings showing how FWC enhances fairness in standard ML models and even in large language models like GPT-4
Let’s begin.
Background: The Building Blocks of FWC
FWC brings together three building blocks—coresets, Wasserstein distance, and demographic parity. Understanding these concepts helps clarify why the method is both efficient and fair.
What Is a Coreset?
Imagine you have a dataset with millions of points. A coreset is a compact, weighted subset that approximates the original dataset for a specific learning task. Each sample in the coreset carries a weight indicating its relative importance, allowing it to “stand in” for many other points. The goal is simple: train on the coreset and obtain nearly the same performance as if you trained on all the data, but in a fraction of the time.
Dataset distillation approaches based on coresets have become popular for tasks such as clustering, classification, and Bayesian inference. Yet most ignore fairness—an increasingly urgent concern in real-world applications.
Why the Wasserstein Distance?
To assess how faithfully the coreset represents the original dataset, we need a metric that measures similarity between distributions. The Wasserstein distance (often called the “earth mover’s distance”) serves this role perfectly.
Visualize two piles of dirt—each representing a probability distribution. The Wasserstein distance quantifies the minimal effort required to move mass from one arrangement to the other. The total “work” combines how far the dirt must move with how much has to be moved.
The Wasserstein distance compares the geometry of two probability distributions by quantifying how much “mass” must be transported between them.
This metric is ideal because it connects geometry, probability, and learning performance. When the Wasserstein distance between a coreset and the original dataset is small, the coreset can be used to train models with minimal degradation in accuracy. In fact, the FWC paper formally proves that for common models—like neural networks with ReLU activations—the performance difference is bounded by the Wasserstein distance.
Proposition 2.1 in the paper shows that the difference between model outputs on a coreset and on the original data is upper-bounded by the 1-Wasserstein distance. Minimizing this distance thus directly controls downstream error.
What Is Demographic Parity?
Fairness in ML can mean many things, but one of the most common criteria is Demographic Parity (DP). It requires that a model’s outcome distribution be independent of protected attributes such as gender or ethnicity. For example, in a hiring task, the rate of job offers should be similar across demographics.
FWC enforces this principle by constraining the coreset to have subgroup outcome distributions close to a target (typically the overall dataset). This “closeness” is controlled by a tolerance parameter \( \epsilon \): the smaller it is, the stricter the fairness requirement.
The Core Method: Crafting Fair Wasserstein Coresets
At the heart of FWC lies an optimization problem. The goal is to find a small set of synthetic samples \( \{ \hat{Z}_j \}_{j=1}^m \) with weights \( \{ \theta_j \}_{j=1}^m \) that minimize the Wasserstein distance from the original dataset while maintaining fairness.
FWC seeks the coreset that best matches the original distribution while obeying demographic parity constraints.
Formally, this means minimizing the Wasserstein distance subject to a constraint ensuring demographic parity violations stay below \( \epsilon \).
To make this tractable, the authors transform the problem through four systematic steps.
Step 1: Reducing the Search Space
Each data sample \( Z \) consists of features \( X \), protected attributes \( D \), and outcomes \( Y \). Since \( D \) and \( Y \) are categorical and known, we can fix their proportions in the coreset to reflect the original dataset. That way, we only need to optimize over the continuous features \( \hat{X} \) and the weights \( \theta \).
Fixing discrete variables like outcomes and demographics simplifies the search to feature vectors and their weights.
Step 2: Linearizing the Fairness Constraint
The demographic parity condition involves ratios, which are nonlinear and difficult to handle. Ingeniously, the authors express these ratios as pairs of linear inequalities on the coreset weights—one upper and one lower bound for each subgroup and outcome combination.
Reformulating the fairness condition as linear inequalities allows it to be represented compactly as \( A\theta \ge \mathbf{0} \), where \(A\) encodes group relationships.
This conversion is crucial—it turns a nonconvex fairness constraint into a manageable set of linear conditions.
Step 3: Introducing a Transport Plan
To compute the Wasserstein distance efficiently, FWC adopts tools from optimal transport theory. Instead of directly comparing distributions, it defines a transportation plan matrix \( P \), where each element \( P_{ij} \) represents how much probability mass is moved from sample \( Z_i \) to coreset representative \( \hat{Z}_j \).
Each entry \( P_{ij} \) defines how much mass from \( Z_i \) is “transported” to \( \hat{Z}_j \), enabling efficient linear optimization.
Minimizing the total transport cost \( \langle C(\hat{X}), P \rangle \) then becomes a large but structured linear program, where \( C(\hat{X}) \) is a cost matrix (often based on Euclidean distance).
Step 4: Solving the Nested Optimization
The weights \( \theta \) can be recovered from the transport plan, so the problem reduces to optimizing over the coreset features \( \hat{X} \) and the transport plan \( P \). This results in a nested optimization structure:
- The inner problem: for fixed features \( \hat{X} \), find the optimal transport plan \( P \)
- The outer problem: adjust the coreset features to minimize the total cost induced by that plan
The nested structure: an inner linear program defines \( F(C(\hat{X})) \); the outer problem searches for \( \hat{X} \) that minimizes \( F \).
The function \( F(C(\hat{X})) \) gives the optimal transport cost for a given feature configuration.
For a known cost matrix \( C \), the inner program finds the fair transport plan that yields the minimum total distance between distributions.
The challenge is that \( F \) is continuous but nonconvex—enter the Majorization-Minimization algorithm.
The Algorithm: Solving FWC Efficiently
To handle the nonconvexity, FWC employs the Majorization-Minimization (MM) framework—an iterative procedure that substitutes the difficult objective with a simpler surrogate function that upper-bounds it.
Each iteration builds a convex surrogate for the objective, minimizing it to guarantee steady progress.
At each iteration \( k \):
Update the Surrogate Function
Fix the current coreset features \( \hat{X}^k \), and solve the inner linear program to obtain the optimal transport plan \( P_k^* \). This step uses an adaptation of the efficient FairWASP algorithm.
The surrogate becomes a linear function of the cost matrix:The surrogate \( g(\hat{X}; \hat{X}^k) = \langle C(\hat{X}), P_k^* \rangle \) is convex and easy to minimize.
Update the Feature Vectors
Minimize the surrogate to find the next coreset positions \( \hat{X}^{k+1} \). This decomposes into \( m \) independent subproblems—each corresponding to finding a weighted centroid of the original points assigned to that coreset representative.Updating coreset points involves computing weighted centroids, which can be done in parallel.
These steps repeat until convergence, progressively refining both the mapping and locations to produce fair, representative coresets.
The “Aha!” Moment: FWC as a Generalized Clustering Algorithm
Here’s the elegant twist: when fairness constraints are dropped, FWC collapses to Lloyd’s algorithm for k-means clustering.
Without fairness, each original point is simply assigned to its nearest coreset representative, and centroids are recomputed—exactly the steps of k-means. Thus, FWC generalizes k-means and k-medians to include fairness constraints via optimal transport.
This connection provides a powerful intuition: FWC performs fair clustering in distributional space, balancing representativeness and parity.
Experiments: Putting FWC to the Test
Scalability and Runtime
FWC scales efficiently. The paper’s synthetic experiments show that runtime increases roughly linearly with dataset size. Even as datasets grow to millions of samples, FWC remains computationally feasible.
Top left: Runtime analysis confirming near-linear scalability. Others: Fairness–utility tradeoffs across datasets.
Fairness–Utility Tradeoff
Using standard fairness datasets (Adult, Credit, Crime, and Drug), the authors compared FWC to baseline coreset and fair clustering methods. A downstream neural network was trained on the resulting coresets.
Each model’s accuracy (AUC) and fairness (demographic disparity) were plotted. The Pareto frontier—the dashed red line—marks the best achievable tradeoffs. Across all datasets, FWC sits on or near the frontier, offering competitive utility while substantially reducing disparity, even without any extra fairness preprocessing.
Reducing Bias in Large Language Models
In a particularly creative experiment, the team tested FWC in the few-shot prompting of GPT-3.5 and GPT-4. Using the Adult dataset, they supplied 16 FWC-selected examples to guide the LLMs’ predictions. The results were striking: demographic disparities dropped by 75% for GPT-3.5 and 35% for GPT-4, with minimal accuracy loss.
Few-shot prompts based on FWC samples help GPT models produce fairer predictions without major performance tradeoffs.
This suggests that well-chosen, fair representative examples can meaningfully steer even complex language models toward equitable outcomes.
Conclusion and Outlook
Fair Wasserstein Coresets (FWC) introduce a principled approach to distilling large datasets while maintaining fairness. By uniting optimal transport theory with demographic parity, FWC yields coresets that are small, representative, and fair.
Key takeaways:
- Effective: Achieves state-of-the-art fairness–utility tradeoffs in downstream tasks
- Scalable: Operates efficiently on large datasets via the MM algorithm
- Versatile: Extends naturally to tasks like data augmentation and LLM fairness correction
- Insightful: Bridges fairness and clustering, generalizing classic k-means
As data-driven models play an ever-larger role in decisions that affect lives, methods like FWC are vital to ensure that efficiency doesn’t come at the cost of equity. The paper opens avenues for extending the approach to other fairness measures (like equalized odds), exploring privacy–fairness tradeoffs, and leveraging fair coresets for tasks such as neural network pruning and bias-resistant training.
Fairness is more than an add-on to machine learning—it’s a foundational design principle. FWC shows that with the right mathematical framing, fairness can be built in from the start.