1. Check two schemas are equal
2. Check the number of rows are equal
3. Check there is no unequal rows
trait DataFrameSuitBase extends FlatSpec with Matchers {
def equalDataFrames(expected: DataFrame, result: DataFrame) {
//Check the equality of two schemas
expected.schema.toString() shouldBe result.schema.toString
val expectedRDD = zipWithIndex(expected.rdd)
val resultRDD = zipWithIndex(result.rdd)
//Check the number of rows in two dfs
expectedRDD.count() shouldBe resultRDD.count()
//the number of unequal rows should be zero
val unequal = expectedRDD
.cogroup(resultRDD)
.filter{
case (idx, (r1, r2)) =>
!(r1.isEmpty || r2.isEmpty) && (!r1.head.equals(r2.head))
}.collect()
unequal shouldBe List()
}
private def zipWithIndex[T](input: RDD[T]): RDD[(Int, T)] = {
val counts = input
.mapPartitions{itr => Iterator(itr.size)}
.collect()
val countSums = counts.scanLeft(0)(_ + _)
.zipWithIndex.map{case (x, y) => (y, x)}
.toMap
input.mapPartitionsWithIndex{case (idx, itr) =>
itr.zipWithIndex
.map{
case (y, i) => (i + countSums(idx), y)
}
}
}
/** Set the nullable to either true or false for all fields in a schema**/
def setNullableFields( df: DataFrame, nullable: Boolean) : DataFrame = {
val schema = df.schema
val newSchema = StructType(
schema.map {
case StructField( c, t, _, m) ⇒ StructField( c, t, nullable, m)
})
df.sqlContext.createDataFrame( df.rdd, newSchema )
}
}
Reference:
No comments:
Post a Comment