Spark SQL Recursive DataFrame – Pyspark and Scala

  • Post author:
  • Post last modified:June 7, 2021
  • Post category:Apache Spark
  • Reading time:6 mins read

Identifying top level hierarchy of one column from another column is one of the import feature that many relational databases such as Teradata, Oracle, Snowflake, etc support. The relational databases use recursive query to identify the hierarchies of data, such as an organizational structure, employee-manager, bill-of-materials, and document hierarchy. Relational databases such as Teradata, Snowflake supports recursive queries in the form of recursive WITH clause or recursive views. But, Spark SQL does not support recursive CTE or recursive views. In this article, we will check Spark SQL recursive DataFrame using Pyspark and Scala.

Spark SQL Recursive DataFrame

Latest Spark with GraphX component allows you to identify the hierarchies of data. GraphX is a new component in a Spark for graphs and graph-parallel computation. But, preference of using GraphX or DataFrame based approach is as per project requirement. Many developers prefer the Graph approach as GraphX is Spark API for graph and graph-parallel computation. Graph algorithms are iterative in nature and properties of vertices depends upon the properties of its directly or indirectly connected vertices and it is faster compared to Database Approach.

Before jumping into implementation, let us check the recursive query in relational database.

Relational Database Recursive Query Example

Consider following Teradata recursive query example.

WITH RECURSIVE temp_table (employee_number) AS
( SELECT root.employee_number
FROM employee root
WHERE root.manager_employee_number = 801
UNION ALL 
SELECT indirect.employee_number
FROM temp_table direct, employee indirect
WHERE direct.employee_number = indirect.manager_employee_number
)
SELECT * FROM temp_table ORDER BY employee_number;

You can notice WITH clause is using RECURSIVE keyword.

Spark SQL does not support these types of CTE. In most of hierarchical data, depth is unknown, you could identify the top level hierarchy of one column from another column using WHILE loop and recursively joining DataFrame.

Pyspark Recursive DataFrame to Identify Hierarchies of Data

Following Pyspark Code uses the WHILE loop and recursive join to identify the hierarchies of data. It is an alternative approach of Teradata or Oracle recursive query in Pyspark. Note that, it is not an efficient solution, but, does its job.

from pyspark.sql.functions import *

# Employee DF
schema = 'EMPLOYEE_NUMBER int, MANAGER_EMPLOYEE_NUMBER int'
employees = spark.createDataFrame(
[[801,None], 
[1016,801], 
[1003,801], 
[1019,801], 
[1010,1003], 
[1004,1003], 
[1001,1003],
[1012,1004], 
[1002,1004], 
[1015,1004], 
[1008,1019], 
[1006,1019], 
[1014,1019],
[1011,1019]], schema=schema)

# initial DataFrame
empDF = employees \
  .withColumnRenamed('EMPLOYEE_NUMBER', 'level_0') \
  .withColumnRenamed('MANAGER_EMPLOYEE_NUMBER', 'level_1')

i = 1

# Loop Through if you dont know recusrsive depth
while True:
  this_level = 'level_{}'.format(i)
  next_level = 'level_{}'.format(i+1)
  emp_level = employees \
    .withColumnRenamed('EMPLOYEE_NUMBER', this_level) \
    .withColumnRenamed('MANAGER_EMPLOYEE_NUMBER', next_level)
  empDF = empDF.join(emp_level, on=this_level, how='left')
  
  # Check if DF is empty. Break loop if empty, Otherwise continue with next level
  if empDF.where(col(next_level).isNotNull()).count() == 0:
    break
  else:
    i += 1

# Sort columns and show
empDF.sort('level_0').select('level_0').show()

Scala Recursive DataFrame to Identify Hierarchies of Data

And following code is the Scala equivalent of the above Pysaprk code.

import org.apache.spark.sql.functions
import scala.util.control.Breaks._

// Employee DF
val employees = spark.createDataFrame(Seq( 
(801,None),
(1016,Some(801)),
(1003,Some(801)),
(1019,Some(801)),
(1010,Some(1003)),
(1004,Some(1003)),
(1001,Some(1003)),
(1012,Some(1004)),
(1002,Some(1004)),
(1015,Some(1004)),
(1008,Some(1019)),
(1006,Some(1019)),
(1014,Some(1019)),
(1011,Some(1019))
)).toDF("EMPLOYEE_NUMBER", "MANAGER_EMPLOYEE_NUMBER")

// initial DataFrame 
var empDF = employees.withColumnRenamed("EMPLOYEE_NUMBER", "level_0").withColumnRenamed("MANAGER_EMPLOYEE_NUMBER", "level_1")

var i = 1

// Loop Through if you dont know recusrsive depth
while(true)
{
val  this_level = "level_" + i
val next_level = "level_" + (i+1)
val level_i = employees.withColumnRenamed("EMPLOYEE_NUMBER", this_level).withColumnRenamed("MANAGER_EMPLOYEE_NUMBER", next_level)

empDF = empDF.join(level_i, Seq(this_level), "left")

// Check if DF is empty. Break loop if empty, Otherwise continue with next level
if (empDF.filter(next_level+" is not null").count() == 0)
{
    break
}
else
{
    i = i+ 1
}	
}

// Sort columns and show
empDF.sort("level_0").select("level_0").show()

Related Articles,

Hope this helps 🙂