Shuffle Optimizations
When Adaptive Query is turned Off....
import org.apache.spark.sql.{SparkSession, functions => F}
import io.delta.tables._
import org.apache.spark.sql.functions._
object Main {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("CDC Example")
.master("local[*]")
.config("spark.sql.warehouse.dir", "/home/kaustubh/Documents/")
.config("spark.sql.catalogImplementation", "hive")
.enableHiveSupport()
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
.config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")
//.config("spark.sql.shuffle.partitions","4")
.config("spark.sql.adaptive.enabled","false")
.getOrCreate()
import spark.implicits._
val sales = Seq(
("2024-01", "Electronics", 1000),
("2024-01", "Clothing", 500),
("2024-02", "Electronics", 1500),
("2024-02", "Clothing", 700),
("2024-01", "Electronics", 800)
).toDF("month", "category", "amount")
println("=== UNOPTIMIZED: Multiple shuffles ===")
// BAD: Multiple operations causing separate shuffles
val unoptimized = sales
.groupBy("category").agg(sum("amount").as("total"))
.groupBy("category").agg(max("total").as("max_total")) // Another shuffle!
unoptimized.explain()
unoptimized.show()
System.in.read()
spark.stop()
}
}Key Concepts:
- Combine Aggregations - Instead of multiple
groupByoperations (each causing a shuffle), combine all aggregations into a single operation - Adjust Partition Count - The default 200 shuffle partitions is often too high for small datasets. Reduce it to 4-10 for better performance
- Broadcast Small Tables - Use
broadcast()for small lookup tables to avoid expensive shuffle joins - Strategic Repartitioning - Pre-partition data by key when you'll perform multiple operations on the same grouping key
What Causes Shuffles:
groupBy,join,repartition,distinct,sortBy- These move data across the network between executors
Quick Wins:
- Start by checking
spark.sql.shuffle.partitions- adjust based on your data size - Look for multiple
groupByoperations that can be combined - Identify small tables (<10MB) that can be broadcast
explain() calls so you can see the execution plans and understand where shuffles occur. Try running it and comparing the plans between optimized and unoptimized versions!Key Differences:
- Repartition: Full shuffle (expensive), can increase/decrease partitions, redistributes data evenly
- Coalesce: Minimal shuffle (cheap), only decreases partitions, may be uneven
The Examples Cover:
- Basic repartition - increasing partitions for more parallelism
- Basic coalesce - decreasing partitions efficiently
- Repartition by column - grouping data by key (perfect before groupBy!)
- Real-world scenarios - when to use each
- Performance comparison - why coalesce is better for reducing partitions
- Multiple columns - partitioning by category AND country
- Small files problem - avoiding hundreds of tiny output files
Teaching Tips:
Each example shows the partition distribution so students can see how data is organized. The code prints which records end up in which partition, making it concrete rather than abstract.
The most important concept for students: repartition by column before groupBy eliminates an extra shuffle and can dramatically improve performance!
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
object RepartitionCoalesceGuide {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("Repartition vs Coalesce")
.master("local[*]")
.getOrCreate()
import spark.implicits._
// Sample dataset: employee records
val employees = Seq(
(1, "Alice", "Engineering", 80000),
(2, "Bob", "Sales", 60000),
(3, "Charlie", "Engineering", 85000),
(4, "Diana", "Sales", 65000),
(5, "Eve", "HR", 70000),
(6, "Frank", "Engineering", 90000),
(7, "Grace", "Sales", 62000),
(8, "Henry", "HR", 72000)
).toDF("id", "name", "department", "salary")
println("=== INITIAL STATE ===")
println(s"Initial partitions: ${employees.rdd.getNumPartitions}")
employees.rdd.glom().collect().zipWithIndex.foreach { case (arr, idx) =>
println(s"Partition $idx: ${arr.length} records")
}
// ============================================
// EXAMPLE 1: REPARTITION (INCREASE PARTITIONS)
// ============================================
println("\n=== EXAMPLE 1: Repartition - Increase Partitions ===")
// Use case: Increase parallelism for heavy processing
val df1 = employees.repartition(4)
println(s"After repartition(4): ${df1.rdd.getNumPartitions} partitions")
df1.rdd.glom().collect().zipWithIndex.foreach { case (arr, idx) =>
println(s"Partition $idx: ${arr.length} records")
}
// NOTE: Full shuffle - data redistributed evenly across all partitions
// ============================================
// EXAMPLE 2: COALESCE (DECREASE PARTITIONS)
// ============================================
println("\n=== EXAMPLE 2: Coalesce - Decrease Partitions ===")
// Use case: Reduce partitions before writing to disk
val df2 = employees.repartition(8).coalesce(2)
println(s"After coalesce(2): ${df2.rdd.getNumPartitions} partitions")
df2.rdd.glom().collect().zipWithIndex.foreach { case (arr, idx) =>
println(s"Partition $idx: ${arr.length} records")
}
// NOTE: No full shuffle - combines existing partitions
// ============================================
// EXAMPLE 3: REPARTITION BY COLUMN
// ============================================
println("\n=== EXAMPLE 3: Repartition by Column ===")
// Use case: Group data by key for subsequent operations
val df3 = employees.repartition(3, col("department"))
println(s"After repartition by department: ${df3.rdd.getNumPartitions} partitions")
df3.rdd.glom().collect().zipWithIndex.foreach { case (arr, idx) =>
val depts = arr.map(_.getString(2)).distinct.mkString(", ")
println(s"Partition $idx: ${arr.length} records - Departments: $depts")
}
// NOTE: Same departments go to same partition - great for groupBy!
// ============================================
// EXAMPLE 4: WHEN TO USE EACH
// ============================================
println("\n=== EXAMPLE 4: Practical Use Cases ===")
// Scenario A: Processing large file, need more parallelism
println("\nScenario A: Large file processing")
val largeData = employees.repartition(8) // Increase for parallel processing
println(s"Increased to ${largeData.rdd.getNumPartitions} for heavy computation")
// Scenario B: Writing results to files (avoid many small files)
println("\nScenario B: Writing output")
val forWriting = largeData.coalesce(1) // Reduce to fewer output files
println(s"Coalesced to ${forWriting.rdd.getNumPartitions} before writing")
// forWriting.write.csv("output/employees") // Would create 1 file
// Scenario C: GroupBy optimization
println("\nScenario C: GroupBy optimization")
val prePartitioned = employees.repartition(3, col("department"))
val aggregated = prePartitioned
.groupBy("department")
.agg(
avg("salary").as("avg_salary"),
count("*").as("count")
)
println("Repartitioned by department before groupBy - no additional shuffle!")
aggregated.show()
// ============================================
// EXAMPLE 5: PERFORMANCE COMPARISON
// ============================================
println("\n=== EXAMPLE 5: Performance Comparison ===")
val manyPartitions = employees.repartition(100)
println(s"\nMany partitions (100):")
println(s"Partitions: ${manyPartitions.rdd.getNumPartitions}")
// Reducing partitions - WRONG WAY
val wrongWay = manyPartitions.repartition(2)
println(s"\nWrong: repartition(2) - Full shuffle!")
println(s"Partitions: ${wrongWay.rdd.getNumPartitions}")
// Reducing partitions - RIGHT WAY
val rightWay = manyPartitions.coalesce(2)
println(s"\nRight: coalesce(2) - Minimal shuffle!")
println(s"Partitions: ${rightWay.rdd.getNumPartitions}")
// ============================================
// EXAMPLE 6: REPARTITION WITH MULTIPLE COLUMNS
// ============================================
println("\n=== EXAMPLE 6: Repartition by Multiple Columns ===")
val salesData = Seq(
("2024-01", "Electronics", "USA", 1000),
("2024-01", "Electronics", "UK", 800),
("2024-02", "Clothing", "USA", 500),
("2024-02", "Clothing", "UK", 600)
).toDF("month", "category", "country", "sales")
val multiPartitioned = salesData.repartition(4, col("category"), col("country"))
println(s"Repartitioned by category AND country")
multiPartitioned.rdd.glom().collect().zipWithIndex.foreach { case (arr, idx) =>
val info = arr.map(r => s"${r.getString(1)}-${r.getString(2)}").distinct.mkString(", ")
println(s"Partition $idx: $info")
}
// ============================================
// EXAMPLE 7: AVOIDING SMALL FILES PROBLEM
// ============================================
println("\n=== EXAMPLE 7: Small Files Problem ===")
// BAD: Writing with too many partitions
val badWrite = employees.repartition(50)
println(s"BAD: ${badWrite.rdd.getNumPartitions} partitions = 50 small files!")
// GOOD: Coalesce before writing
val goodWrite = badWrite.coalesce(2)
println(s"GOOD: ${goodWrite.rdd.getNumPartitions} partitions = 2 reasonably sized files")
spark.stop()
}
}
/*
=== QUICK REFERENCE GUIDE ===
REPARTITION:
- Full shuffle (expensive)
- Can increase OR decrease partitions
- Redistributes data evenly
- Use when: Need more parallelism, or need to partition by column
COALESCE:
- Minimal data movement (cheap)
- Can ONLY decrease partitions
- May create uneven partitions
- Use when: Reducing partitions before write, or combining after filter
COMMON PATTERNS:
1. Reading → Processing:
df.repartition(100) // Increase for parallel processing
2. Processing → Writing:
df.coalesce(10) // Reduce to avoid small files
3. Before GroupBy:
df.repartition(col("key")) // Avoid shuffle in groupBy
4. After Filter:
df.filter(...).coalesce(5) // Reduce empty partitions
RULES OF THUMB:
- Partition size: 100-200MB ideal
- For 10GB data: ~50-100 partitions
- Before writing: 1 partition = 1 output file
- After filtering heavily: use coalesce to remove empty partitions
*/
I've created a comprehensive guide with 10 powerful methods to optimize Spark without using repartition! Here's the key insight for students:
The Best Shuffle is No Shuffle!
Top Strategies:
- Broadcast Join - For small tables (<10MB), send to all nodes instead of shuffling
- Bucketing - Pre-partition tables at write time, get free optimization forever
- Partition Pruning - Read only the data you need with smart partitioning
- Combine Operations - One groupBy with many aggregations instead of multiple groupBys
- Filter Early - Reduce data before expensive operations
The Power of Bucketing: This is often overlooked by beginners! You pay the shuffle cost ONCE when writing, then ALL future queries on that table are shuffle-free. It's like organizing your closet once and finding clothes instantly forever after.
Real Performance Impact:
- Broadcasting a 5MB table: Saves 100% shuffle on that table
- Filtering early: Can reduce shuffle data by 50-90%
- Combining aggregations: Turns 3 shuffles into 1
Each example shows the "bad way" vs "good way" with explain() plans so students can see the actual difference in execution. The comments explain WHY each approach works and when to use it.
Golden Rule for Students: Always ask "Can I avoid this shuffle?" before thinking "How should I repartition?"
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel
object OptimizeWithoutRepartition {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("Optimize Without Repartition")
.master("local[*]")
.getOrCreate()
import spark.implicits._
// Sample datasets
val employees = Seq(
(1, "Alice", "Engineering", 80000),
(2, "Bob", "Sales", 60000),
(3, "Charlie", "Engineering", 85000),
(4, "Diana", "Sales", 65000),
(5, "Eve", "HR", 70000),
(6, "Frank", "Engineering", 90000)
).toDF("id", "name", "department", "salary")
val departments = Seq(
("Engineering", "Tech"),
("Sales", "Business"),
("HR", "Support")
).toDF("department", "type")
// ============================================
// METHOD 1: BROADCAST JOIN (Avoid Shuffle Join)
// ============================================
println("=== METHOD 1: Broadcast Join ===")
// BAD: Regular join causes shuffle on both sides
println("\nBad - Regular join (shuffle on both sides):")
val regularJoin = employees.join(departments, "department")
regularJoin.explain()
// GOOD: Broadcast small table (no shuffle for small table)
println("\nGood - Broadcast join (no shuffle for small table):")
val broadcastJoin = employees.join(broadcast(departments), "department")
broadcastJoin.explain()
broadcastJoin.show()
println("✓ Saved: One full shuffle eliminated!")
// ============================================
// METHOD 2: BUCKET TABLES (Pre-partition at Write Time)
// ============================================
println("\n=== METHOD 2: Bucketing (Pre-partitioned Tables) ===")
// Write data with bucketing
employees.write
.mode("overwrite")
.bucketBy(4, "department")
.sortBy("department")
.saveAsTable("employees_bucketed")
departments.write
.mode("overwrite")
.bucketBy(4, "department")
.sortBy("department")
.saveAsTable("departments_bucketed")
// Read bucketed tables
val empBucketed = spark.table("employees_bucketed")
val deptBucketed = spark.table("departments_bucketed")
// Join without shuffle!
println("\nJoin on bucketed tables (NO SHUFFLE!):")
val bucketJoin = empBucketed.join(deptBucketed, "department")
bucketJoin.explain()
println("✓ Saved: Complete shuffle eliminated at query time!")
println("✓ Cost: One-time shuffle at write time")
// ============================================
// METHOD 3: PARTITION PRUNING (Read Less Data)
// ============================================
println("\n=== METHOD 3: Partition Pruning ===")
// Write partitioned data
employees.write
.mode("overwrite")
.partitionBy("department")
.parquet("output/employees_partitioned")
// Read with filter - only reads relevant partitions
println("\nReading with partition filter (reads only needed data):")
val prunedRead = spark.read.parquet("output/employees_partitioned")
.filter(col("department") === "Engineering")
prunedRead.explain()
prunedRead.show()
println("✓ Saved: Reads only Engineering partition, not entire dataset")
// ============================================
// METHOD 4: COMBINE OPERATIONS (Single Shuffle)
// ============================================
println("\n=== METHOD 4: Combine Operations ===")
// BAD: Multiple shuffles
println("\nBad - Multiple groupBy operations:")
val bad = employees
.groupBy("department").agg(sum("salary").as("total"))
.groupBy("department").agg(max("total").as("max_total"))
bad.explain()
// GOOD: Single aggregation
println("\nGood - Combined aggregations:")
val good = employees
.groupBy("department")
.agg(
sum("salary").as("total_salary"),
avg("salary").as("avg_salary"),
max("salary").as("max_salary"),
min("salary").as("min_salary"),
count("*").as("emp_count")
)
good.explain()
good.show()
println("✓ Saved: Multiple shuffles → Single shuffle")
// ============================================
// METHOD 5: FILTER EARLY (Process Less Data)
// ============================================
println("\n=== METHOD 5: Filter Early (Predicate Pushdown) ===")
// BAD: Filter after expensive operation
println("\nBad - Filter after aggregation:")
val filterLate = employees
.groupBy("department").agg(sum("salary").as("total"))
.filter(col("total") > 100000)
filterLate.explain()
// GOOD: Filter before expensive operation
println("\nGood - Filter before aggregation:")
val filterEarly = employees
.filter(col("salary") > 70000) // Reduce data early!
.groupBy("department").agg(sum("salary").as("total"))
filterEarly.explain()
filterEarly.show()
println("✓ Saved: Less data to shuffle and aggregate")
// Example 2: Filtering Before Join - BAD vs GOOD
// =============================================== def joinBadApproach(): Unit = { println("\n=== BAD: Join large datasets then filter ===") val orders = (1 to 10000).map(i => (i, s"user${i % 100}", i * 10, "2024-01-01")).toDF("order_id", "user_id", "amount", "date") val users = (1 to 100).map(i => (s"user$i", s"User Name $i", if (i % 10 == 0) "premium" else "basic")).toDF("user_id", "name", "tier") // BAD: Join ALL data, then filter - shuffles everything val result = orders .join(users, "user_id") // BIG SHUFFLE .filter($"tier" === "premium" && $"amount" > 5000) println(s"Result count: ${result.count()}") println("Shuffled all orders and users before filtering") } def joinGoodApproach(): Unit = { println("\n=== GOOD: Filter before join ===") val orders = (1 to 10000).map(i => (i, s"user${i % 100}", i * 10, "2024-01-01")).toDF("order_id", "user_id", "amount", "date") val users = (1 to 100).map(i => (s"user$i", s"User Name $i", if (i % 10 == 0) "premium" else "basic")).toDF("user_id", "name", "tier") // GOOD: Filter first, then join - much smaller shuffle val premiumUsers = users.filter($"tier" === "premium") val highValueOrders = orders.filter($"amount" > 5000) val result = highValueOrders.join(premiumUsers, "user_id") println(s"Result count: ${result.count()}") println("Filtered data before join - much smaller shuffle!")
}
println("✓ Saved: No data movement between partitions")
// ============================================
// METHOD 6: CACHE WISELY (Avoid Recomputation)
// ============================================
println("\n=== METHOD 6: Strategic Caching ===")
// Scenario: Using same filtered data multiple times
val filtered = employees.filter(col("salary") > 65000)
// BAD: Recomputes filter twice
println("\nBad - Without caching (recomputes twice):")
val result1 = filtered.groupBy("department").count()
val result2 = filtered.groupBy("department").agg(avg("salary"))
// Each query recomputes the filter!
// GOOD: Cache intermediate result
println("\nGood - With caching:")
val filteredCached = employees.filter(col("salary") > 65000).cache()
filteredCached.count() // Materialize cache
val cached1 = filteredCached.groupBy("department").count()
val cached2 = filteredCached.groupBy("department").agg(avg("salary"))
cached1.show()
cached2.show()
println("✓ Saved: Recomputation eliminated, filter runs once")
filteredCached.unpersist() // Clean up
// ============================================
// METHOD 7: ADJUST SPARK CONFIGS (Smart Defaults)
// ============================================
println("\n=== METHOD 7: Tune Spark Configuration ===")
// For small datasets
spark.conf.set("spark.sql.shuffle.partitions", "4")
println("\nSmall dataset - Reduced shuffle partitions to 4")
val smallResult = employees.groupBy("department").count()
smallResult.explain()
smallResult.show()
println("✓ Saved: Less overhead with fewer partitions")
// Auto broadcast threshold
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10485760") // 10MB
println("\nAuto-broadcast threshold set to 10MB")
println("✓ Small tables auto-broadcast without explicit broadcast()")
// ============================================
// METHOD 9: SALTING (Handle Skewed Data)
// ============================================
println("\n=== METHOD 9: Salting for Skewed Keys ===")
// Simulate skewed data (most employees in Engineering)
val skewedData = (1 to 100).map(i =>
(i, s"Emp$i", if (i <= 80) "Engineering" else "Sales", 70000)
).toDF("id", "name", "department", "salary")
**What's happening:**
- Creates 100 employee records
- **80 employees** → Engineering department
- **20 employees** → Sales department
- This is **skewed data** (80/20 split, not balanced)
**Why it's a problem:**
```
When you groupBy("department"):
Partition 0: Engineering [80 records] ████████████████████
Partition 1: Sales [20 records] █████
❌ One executor does 80% of work (slow!)
❌ Other executor sits idle (wasted resources)
❌ The job is only as fast as the slowest partitionThe Solution: Salting // Add salt to distribute skewed key
val salted = skewedData.withColumn("salt", (rand() * 4).cast("int"))
.withColumn("dept_salted", concat(col("department"), lit("_"), col("salt")))
// explanation
.withColumn("salt", (rand() * 4).cast("int"))
// - `rand()` → Random number between 0.0 and 1.0
// - `rand() * 4` → Random number between 0.0 and 4.0
// - `.cast("int")` → Converts to integer: 0, 1, 2, or 3
**Result:**
id | name | department | salary | salt
----|-------|--------------|--------|-----
1 | Emp1 | Engineering | 70000 | 2
2 | Emp2 | Engineering | 70000 | 0
3 | Emp3 | Engineering | 70000 | 3
.withColumn("dept_salted", concat(col("department"), lit("_"), col("salt")))
- `concat()` → Combines strings
- `col("department")` → "Engineering" or "Sales"
- `lit("_")` → Literal underscore "_"
- `col("salt")` → The random number (0-3)
**Result:**
id | department | salt | dept_salted -
---|--------------|------|------------------
1 | Engineering | 2 | Engineering_2
2 | Engineering | 0 | Engineering_0
3 | Engineering | 3 | Engineering_3
80 | Engineering | 1 | Engineering_1
81 | Sales | 2 | Sales_2
```
## How It Fixes the Skew
**Before salting:**
```
groupBy("department")
Partition 0: Engineering [80] ████████████████████
Partition 1: Sales [20] █████
**After salting:**
groupBy("dept_salted")
Partition 0: Engineering_0 [~20] █████
Partition 1: Engineering_1 [~20] █████
Partition 2: Engineering_2 [~20] █████
Partition 3: Engineering_3 [~20] █████
Partition 4: Sales_0 [~5] ██
Partition 5: Sales_1 [~5] ██
Partition 6: Sales_2 [~5] ██
Partition 7: Sales_3 [~5] ██
✅ Work distributed across multiple partitions!
✅ All executors busy!
✅ Faster processing!
val saltedAgg = salted
.groupBy("dept_salted").agg(sum("salary").as("total"))
.withColumn("department", split(col("dept_salted"), "_")(0))
.groupBy("department").agg(sum("total").as("final_total"))
Why Two GroupBys Are Necessary with Salting
The Problem Salting Solves:
WITHOUT salting:
groupBy("department") → One partition processes 80 records (slow!)
→ Other partition processes 20 records (fast, sits idle)
WITH salting (Stage 1):
groupBy("dept_salted") → Engineering split into 4 groups (~20 each)
→ All partitions process in parallel (fast!)Why You Need Stage 2:
After Stage 1, you have:
dept_salted | total
------------------|----------
Engineering_0 | 1,400,000
Engineering_1 | 1,400,000
Engineering_2 | 1,400,000
Engineering_3 | 1,400,000
Sales_0 | 350,000
You need to combine these back into "Engineering" - that requires a second groupBy!
You CANNOT avoid the second groupBy with salting - here's why:
The Math:
NO SALTING (1 shuffle):
800 records on 1 partition = 60 seconds
200 records on 1 partition = 15 seconds
Total: 60 seconds (bottlenecked!)
WITH SALTING (2 shuffles):
Stage 1: 200 records × 4 partitions in parallel = 15 seconds
Stage 2: 8 aggregated rows to shuffle = 0.1 seconds
Total: 15.1 seconds (4x faster!)Key Insight:
The second shuffle is virtually FREE because:
- Stage 1 reduces 1000 rows → 8 aggregated rows
- Stage 2 only shuffles those 8 rows (negligible cost!)
- You gain massive parallelism in Stage 1
When You CAN Avoid the Second GroupBy:
Only if you don't need to combine results:
// If you want stats per SALTED key (don't need combined total)
val result = salted
.groupBy("dept_salted") // Only 1 groupBy!
.agg(max("salary"), avg("salary"))
// Result: Engineering_0, Engineering_1, etc. (not combined)Verdict:
Two small balanced shuffles > One huge skewed shuffle
The second groupBy is the FEATURE that makes salting work, not a performance problem! The artifact I created shows detailed performance comparisons to prove this point to your students.
println("\nSalted aggregation (distributes skewed data):")
saltedAgg.show()
println("✓ Saved: Parallel processing of skewed keys")
// ============================================
// METHOD 10: AVOID WIDE TRANSFORMATIONS
// ============================================
println("\n=== METHOD 10: Use Narrow Transformations ===")
// Narrow transformations (no shuffle): map, filter, flatMap, mapPartitions
val narrow = employees
.filter(col("salary") > 70000) // Narrow
.withColumn("bonus", col("salary") * 0.1) // Narrow
.select("name", "department", "bonus") // Narrow
println("\nNarrow transformations (no shuffle needed):")
narrow.explain()
narrow.show()
println("✓ All operations processed in parallel without shuffle")
spark.stop()
}
}
// Example 3: Window Functions - BAD vs GOOD // ========================================== def windowBadApproach(): Unit = { println("\n=== BAD: Multiple window operations ===") val sales = Seq( ("Electronics", "Laptop", 1200, "2024-01-01"), ("Electronics", "Phone", 800, "2024-01-02"), ("Electronics", "Tablet", 500, "2024-01-03"), ("Books", "Novel", 20, "2024-01-01"), ("Books", "Textbook", 100, "2024-01-02"), ("Books", "Magazine", 5, "2024-01-03") ).toDF("category", "product", "revenue", "date") // BAD: Each window operation causes separate shuffle val window1 = Window.partitionBy("category").orderBy($"revenue".desc) val window2 = Window.partitionBy("category") val result = sales .withColumn("rank", row_number().over(window1)) // SHUFFLE 1 .withColumn("total", sum("revenue").over(window2)) // SHUFFLE 2 .withColumn("avg", avg("revenue").over(window2)) // Uses cached shuffle println("Result:") result.show() } def windowGoodApproach(): Unit = { println("\n=== GOOD: Combined window operations ===") val sales = Seq( ("Electronics", "Laptop", 1200, "2024-01-01"), ("Electronics", "Phone", 800, "2024-01-02"), ("Electronics", "Tablet", 500, "2024-01-03"), ("Books", "Novel", 20, "2024-01-01"), ("Books", "Textbook", 100, "2024-01-02"), ("Books", "Magazine", 5, "2024-01-03") ).toDF("category", "product", "revenue", "date") // GOOD: Use same window spec for multiple operations val window = Window.partitionBy("category").orderBy($"revenue".desc) val result = sales .withColumn("rank", row_number().over(window)) .withColumn("total", sum("revenue").over(window)) .withColumn("avg", avg("revenue").over(window)) println("Result:") result.show() println("Single shuffle for all window operations!") }
/*
=== OPTIMIZATION STRATEGIES SUMMARY ===
WITHOUT REPARTITION, YOU CAN:
1. BROADCAST JOIN
- For tables < 10MB
- Eliminates shuffle on small table
- Use: broadcast(df)
2. BUCKETING
- Pre-partition at write time
- Zero shuffle for future joins/groupBy
- One-time cost, repeated benefit
3. PARTITION PRUNING
- Read only needed data
- Use partitionBy() when writing
- Filter on partition column when reading
4. COMBINE OPERATIONS
- Single groupBy with multiple aggs
- Avoids multiple shuffles
- Much more efficient
5. FILTER EARLY
- Reduce data before shuffle
- Predicate pushdown
- Less data = less shuffle cost
6. CACHE STRATEGICALLY
- Cache after shuffle if reused
- Avoids recomputation
- Use for iterative algorithms
7. TUNE CONFIGS
- spark.sql.shuffle.partitions
- spark.sql.autoBroadcastJoinThreshold
- Adjust for data size
8. MAPPARTITIONS
- Process within partitions
- No data movement
- Great for heavy init operations
9. SALTING
- Handle skewed data
- Distribute hot keys
- Two-stage aggregation
10. NARROW TRANSFORMATIONS
- map, filter, select
- No shuffle required
- Process in parallel
GOLDEN RULE:
"The best shuffle is the one you avoid!"
PRIORITY ORDER:
1st: Avoid shuffle (broadcast, filter early, narrow ops)
2nd: Reduce shuffle cost (bucketing, combine ops)
3rd: Optimize shuffle (proper partitions, caching)
4th: Only then consider repartition
*/
No comments:
Post a Comment