Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

The Case Statement and the when Function in Spark

Things on this page are fragmentary and immature notes/thoughts of the author. Please read with your own judgement!

Tips and Traps

  1. Watch out for NaNs ..., behave might not what you expect ...

  2. None can be used for otherwise and yield null in DataFrame.

Column alias and postional columns can be used in group by in Spark SQL!!!

Notice the function when behaves like if-else.

import pandas as pd
import findspark

findspark.init("/opt/spark-3.0.1-bin-hadoop3.2/")

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import *
from pyspark.sql.types import StructType

spark = SparkSession.builder.appName("Case/When").enableHiveSupport().getOrCreate()
df_p = pd.DataFrame({"age": [None, 30, 19], "name": ["Michael", "Andy", "Justin"]})
df = spark.createDataFrame(df_p)
df.show()
+----+-------+
| age|   name|
+----+-------+
| NaN|Michael|
|30.0|   Andy|
|19.0| Justin|
+----+-------+

df.schema
StructType(List(StructField(age,DoubleType,true),StructField(name,StringType,true)))
df.filter(col("age").isNull()).show()
+---+----+
|age|name|
+---+----+
+---+----+

df.createOrReplaceTempView("df")
spark.sql(
    """
    select 
        case 
            when age > 20 then 1
            else 0
        end as age_group,
        row_number() over (partition by case when age > 20 then 1 else 0 end order by age) as nima
    from 
        df
    """
).show()
+---------+----+
|age_group|nima|
+---------+----+
|        1|   1|
|        1|   2|
|        0|   1|
+---------+----+

spark.sql(
    """
    select
        age_group,
        percentile(age, 0.5) over (partition by age_group) as nima
    from (
        select 
            *,
            case 
                when age > 20 then 1
                else 0
            end as age_group
        from 
            df
        ) A
    """
).show()
+---------+----+
|age_group|nima|
+---------+----+
|        1| NaN|
|        1| NaN|
|        0|19.0|
+---------+----+

spark.sql(
    """
    select 
        case 
            when age > 20 then 1
            else 0
        end as age_group,
        row_number() over (partition by age_group order by age) as nima
    from 
        df
    """
).show()
---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
/opt/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:

/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:

Py4JJavaError: An error occurred while calling o25.sql.
: org.apache.spark.sql.AnalysisException: cannot resolve '`age_group`' given input columns: [df.age, df.name]; line 7 pos 40;
'Project [CASE WHEN (age#4 > cast(20 as double)) THEN 1 ELSE 0 END AS age_group#131, row_number() windowspecdefinition('age_group, age#4 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS nima#132]
+- SubqueryAlias `df`
   +- LogicalRDD [age#4, name#5], false

	at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:111)
	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:108)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:280)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:280)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:279)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode.org$apache$spark$sql$catalyst$trees$TreeNode$$mapChild$2(TreeNode.scala:297)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4$$anonfun$apply$13.apply(TreeNode.scala:356)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
	at scala.collection.AbstractTraversable.map(Traversable.scala:104)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:356)
	at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:186)
	at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:326)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:328)
	at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:186)
	at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:326)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:328)
	at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:186)
	at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:326)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$transformExpressionsUp$1.apply(QueryPlan.scala:93)
	at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$transformExpressionsUp$1.apply(QueryPlan.scala:93)
	at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$1.apply(QueryPlan.scala:105)
	at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$1.apply(QueryPlan.scala:105)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpression$1(QueryPlan.scala:104)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$1(QueryPlan.scala:116)
	at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$1$2.apply(QueryPlan.scala:121)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.immutable.List.foreach(List.scala:392)
	at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
	at scala.collection.immutable.List.map(List.scala:296)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$1(QueryPlan.scala:121)
	at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2.apply(QueryPlan.scala:126)
	at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:186)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.mapExpressions(QueryPlan.scala:126)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsUp(QueryPlan.scala:93)
	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:108)
	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:86)
	at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:126)
	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.checkAnalysis(CheckAnalysis.scala:86)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:95)
	at org.apache.spark.sql.catalyst.analysis.Analyzer$$anonfun$executeAndCheck$1.apply(Analyzer.scala:108)
	at org.apache.spark.sql.catalyst.analysis.Analyzer$$anonfun$executeAndCheck$1.apply(Analyzer.scala:105)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:201)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:105)
	at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:58)
	at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:56)
	at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:48)
	at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:78)
	at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:642)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)


During handling of the above exception, another exception occurred:

AnalysisException                         Traceback (most recent call last)
<ipython-input-32-d47e32b0bffd> in <module>
      8     from
      9         df
---> 10     """).show()

/opt/spark/python/pyspark/sql/session.py in sql(self, sqlQuery)
    765         [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
    766         """
--> 767         return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped)
    768 
    769     @since(2.0)

/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

/opt/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     67                                              e.java_exception.getStackTrace()))
     68             if s.startswith('org.apache.spark.sql.AnalysisException: '):
---> 69                 raise AnalysisException(s.split(': ', 1)[1], stackTrace)
     70             if s.startswith('org.apache.spark.sql.catalyst.analysis'):
     71                 raise AnalysisException(s.split(': ', 1)[1], stackTrace)

AnalysisException: "cannot resolve '`age_group`' given input columns: [df.age, df.name]; line 7 pos 40;\n'Project [CASE WHEN (age#4 > cast(20 as double)) THEN 1 ELSE 0 END AS age_group#131, row_number() windowspecdefinition('age_group, age#4 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS nima#132]\n+- SubqueryAlias `df`\n   +- LogicalRDD [age#4, name#5], false\n"
spark.sql(
    """
    select 
        case 
            when age > 20 then 1
            else 0
        end as age_group,
        count(*) as n
    from 
        df
    group by
        age_group
    """
).show()
+---------+---+
|age_group|  n|
+---------+---+
|        1|  2|
|        0|  1|
+---------+---+

spark.sql(
    """
    select 
        case 
            when age > 20 then 1
            else 0
        end as age_group,
        count(*) as n
    from 
        df
    group by
        1
    """
).show()
+---------+---+
|age_group|  n|
+---------+---+
|        1|  2|
|        0|  1|
+---------+---+

df.withColumn("null_gt", when(col("age") >= 0, 1).otherwise(None)).show()
+----+-------+-------+
| age|   name|null_gt|
+----+-------+-------+
| NaN|Michael|      1|
|30.0|   Andy|      1|
|19.0| Justin|      1|
+----+-------+-------+

df.withColumn("null_gt", when(col("age") < 20, 1).otherwise(None)).show()
+----+-------+-------+
| age|   name|null_gt|
+----+-------+-------+
| NaN|Michael|   null|
|30.0|   Andy|   null|
|19.0| Justin|      1|
+----+-------+-------+

df.withColumn("null_gt", when(col("age") >= 20, 1).otherwise(None)).show()
+----+-------+-------+
| age|   name|null_gt|
+----+-------+-------+
| NaN|Michael|      1|
|30.0|   Andy|      1|
|19.0| Justin|   null|
+----+-------+-------+

df.withColumn("null_lt",
    when($"age" <= 1000, 1).otherwise(null)
).show
+----+-------+-------+
| age|   name|null_lt|
+----+-------+-------+
|null|Michael|   null|
|  30|   Andy|      1|
|  19| Justin|      1|
+----+-------+-------+

df.select(when($"age".isNull, 0).when($"age" > 20 , 100).otherwise(10).alias("age")).show
+---+
|age|
+---+
|  0|
|100|
| 10|
+---+

df.select(when($"age".isNull, 0).alias("age")).show
+----+
| age|
+----+
|   0|
|null|
|null|
+----+

val df = Range(0, 10).toDF
df.show
+-----+
|value|
+-----+
|    0|
|    1|
|    2|
|    3|
|    4|
|    5|
|    6|
|    7|
|    8|
|    9|
+-----+

null

Notice the function when behaves like if-else.

df.withColumn("group",
    when($"value" <= 3, 0)
    .when($"value" <= 100, 1)
).show
+-----+-----+
|value|group|
+-----+-----+
|    0|    0|
|    1|    0|
|    2|    0|
|    3|    0|
|    4|    1|
|    5|    1|
|    6|    1|
|    7|    1|
|    8|    1|
|    9|    1|
+-----+-----+