1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
| import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions._
object SparkSkewnessExample extends App {
val conf = new SparkConf() .setMaster("local[*]") .setAppName("SparkSkewnessExample")
val spark = SparkSession .builder() .config(conf) .getOrCreate()
import spark.implicits._
val df1 = Seq( ("a", "12"), ("a", "31"), ("a", "24"), ("a", "0"), ("a", "24"), ("b", "45"), ("c", "24") ).toDF("id", "value") df1.show(10,false)
val df2 = Seq( ("a", "45"), ("b", "575"), ("c", "54") ).toDF("id", "value") df2.show(10,false)
def eliminateSkewness(leftDf: DataFrame, leftCol: String, rightDf: DataFrame) = { val df1 = leftDf .withColumn(leftCol, concat( leftDf.col(leftCol), lit("_"), lit(floor(rand(123456) * 10))))
val df2 = rightDf .withColumn("saltCol", explode( array((0 to 10).map(lit(_)): _ *) ))
(df1, df2) }
val (df3, df4) = eliminateSkewness(df1, "id", df2)
df3.show(100, false) df4.show(100, false)
df3.join( df4, df3.col("id") <=> concat(df4.col("id"), lit("_"), df4.col("saltCol")) ).drop("saltCol") .show(100,false) }
|