Ben Chuanlong Du's Blog

It is never too late to learn.

The Case Statement and the when Function in Spark

Tips and Traps

  1. Watch out for NaNs ..., behave might not what you expect ...

  2. None can be used for otherwise and yield null in DataFrame.

Column alias and postional columns can be used in group by in Spark SQL!!!

Notice the function when behaves like if-else.

In [1]:
import pandas as pd
import findspark

findspark.init("/opt/spark-3.0.1-bin-hadoop3.2/")

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import *
from pyspark.sql.types import StructType

spark = SparkSession.builder.appName("Case/When").enableHiveSupport().getOrCreate()
In [2]:
df_p = pd.DataFrame({"age": [None, 30, 19], "name": ["Michael", "Andy", "Justin"]})
In [3]:
df = spark.createDataFrame(df_p)
df.show()
+----+-------+
| age|   name|
+----+-------+
| NaN|Michael|
|30.0|   Andy|
|19.0| Justin|
+----+-------+

In [15]:
df.schema
Out[15]:
StructType(List(StructField(age,DoubleType,true),StructField(name,StringType,true)))
In [12]:
df.filter(col("age").isNull()).show()
+---+----+
|age|name|
+---+----+
+---+----+

In [18]:
df.createOrReplaceTempView("df")
In [31]:
spark.sql(
    """
    select 
        case 
            when age > 20 then 1
            else 0
        end as age_group,
        row_number() over (partition by case when age > 20 then 1 else 0 end order by age) as nima
    from 
        df
    """
).show()
+---------+----+
|age_group|nima|
+---------+----+
|        1|   1|
|        1|   2|
|        0|   1|
+---------+----+

In [34]:
spark.sql(
    """
    select
        age_group,
        percentile(age, 0.5) over (partition by age_group) as nima
    from (
        select 
            *,
            case 
                when age > 20 then 1
                else 0
            end as age_group
        from 
            df
        ) A
    """
).show()
+---------+----+
|age_group|nima|
+---------+----+
|        1| NaN|
|        1| NaN|
|        0|19.0|
+---------+----+

In [32]:
spark.sql(
    """
    select 
        case 
            when age > 20 then 1
            else 0
        end as age_group,
        row_number() over (partition by age_group order by age) as nima
    from 
        df
    """
).show()
---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
/opt/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:

/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:

Py4JJavaError: An error occurred while calling o25.sql.
: org.apache.spark.sql.AnalysisException: cannot resolve '`age_group`' given input columns: [df.age, df.name]; line 7 pos 40;
'Project [CASE WHEN (age#4 > cast(20 as double)) THEN 1 ELSE 0 END AS age_group#131, row_number() windowspecdefinition('age_group, age#4 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS nima#132]
+- SubqueryAlias `df`
   +- LogicalRDD [age#4, name#5], false

	at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:111)
	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:108)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:280)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:280)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:279)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode.org$apache$spark$sql$catalyst$trees$TreeNode$$mapChild$2(TreeNode.scala:297)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4$$anonfun$apply$13.apply(TreeNode.scala:356)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
	at scala.collection.AbstractTraversable.map(Traversable.scala:104)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:356)
	at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:186)
	at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:326)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:328)
	at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:186)
	at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:326)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:328)
	at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:186)
	at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:326)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:277)
	at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$transformExpressionsUp$1.apply(QueryPlan.scala:93)
	at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$transformExpressionsUp$1.apply(QueryPlan.scala:93)
	at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$1.apply(QueryPlan.scala:105)
	at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$1.apply(QueryPlan.scala:105)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpression$1(QueryPlan.scala:104)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$1(QueryPlan.scala:116)
	at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$1$2.apply(QueryPlan.scala:121)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.immutable.List.foreach(List.scala:392)
	at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
	at scala.collection.immutable.List.map(List.scala:296)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$1(QueryPlan.scala:121)
	at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2.apply(QueryPlan.scala:126)
	at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:186)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.mapExpressions(QueryPlan.scala:126)
	at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsUp(QueryPlan.scala:93)
	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:108)
	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:86)
	at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:126)
	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.checkAnalysis(CheckAnalysis.scala:86)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:95)
	at org.apache.spark.sql.catalyst.analysis.Analyzer$$anonfun$executeAndCheck$1.apply(Analyzer.scala:108)
	at org.apache.spark.sql.catalyst.analysis.Analyzer$$anonfun$executeAndCheck$1.apply(Analyzer.scala:105)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:201)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:105)
	at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:58)
	at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:56)
	at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:48)
	at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:78)
	at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:642)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)


During handling of the above exception, another exception occurred:

AnalysisException                         Traceback (most recent call last)
<ipython-input-32-d47e32b0bffd> in <module>
      8     from
      9         df
---> 10     """).show()

/opt/spark/python/pyspark/sql/session.py in sql(self, sqlQuery)
    765         [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
    766         """
--> 767         return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped)
    768 
    769     @since(2.0)

/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

/opt/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     67                                              e.java_exception.getStackTrace()))
     68             if s.startswith('org.apache.spark.sql.AnalysisException: '):
---> 69                 raise AnalysisException(s.split(': ', 1)[1], stackTrace)
     70             if s.startswith('org.apache.spark.sql.catalyst.analysis'):
     71                 raise AnalysisException(s.split(': ', 1)[1], stackTrace)

AnalysisException: "cannot resolve '`age_group`' given input columns: [df.age, df.name]; line 7 pos 40;\n'Project [CASE WHEN (age#4 > cast(20 as double)) THEN 1 ELSE 0 END AS age_group#131, row_number() windowspecdefinition('age_group, age#4 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS nima#132]\n+- SubqueryAlias `df`\n   +- LogicalRDD [age#4, name#5], false\n"
In [24]:
spark.sql(
    """
    select 
        case 
            when age > 20 then 1
            else 0
        end as age_group,
        count(*) as n
    from 
        df
    group by
        age_group
    """
).show()
+---------+---+
|age_group|  n|
+---------+---+
|        1|  2|
|        0|  1|
+---------+---+

In [21]:
spark.sql(
    """
    select 
        case 
            when age > 20 then 1
            else 0
        end as age_group,
        count(*) as n
    from 
        df
    group by
        1
    """
).show()
+---------+---+
|age_group|  n|
+---------+---+
|        1|  2|
|        0|  1|
+---------+---+

In [4]:
df.withColumn("null_gt", when(col("age") >= 0, 1).otherwise(None)).show()
+----+-------+-------+
| age|   name|null_gt|
+----+-------+-------+
| NaN|Michael|      1|
|30.0|   Andy|      1|
|19.0| Justin|      1|
+----+-------+-------+

In [7]:
df.withColumn("null_gt", when(col("age") < 20, 1).otherwise(None)).show()
+----+-------+-------+
| age|   name|null_gt|
+----+-------+-------+
| NaN|Michael|   null|
|30.0|   Andy|   null|
|19.0| Justin|      1|
+----+-------+-------+

In [6]:
df.withColumn("null_gt", when(col("age") >= 20, 1).otherwise(None)).show()
+----+-------+-------+
| age|   name|null_gt|
+----+-------+-------+
| NaN|Michael|      1|
|30.0|   Andy|      1|
|19.0| Justin|   null|
+----+-------+-------+

In [14]:
df.withColumn("null_lt",
    when($"age" <= 1000, 1).otherwise(null)
).show
+----+-------+-------+
| age|   name|null_lt|
+----+-------+-------+
|null|Michael|   null|
|  30|   Andy|      1|
|  19| Justin|      1|
+----+-------+-------+

In [4]:
df.select(when($"age".isNull, 0).when($"age" > 20 , 100).otherwise(10).alias("age")).show
+---+
|age|
+---+
|  0|
|100|
| 10|
+---+

In [5]:
df.select(when($"age".isNull, 0).alias("age")).show
+----+
| age|
+----+
|   0|
|null|
|null|
+----+

In [6]:
val df = Range(0, 10).toDF
df.show
+-----+
|value|
+-----+
|    0|
|    1|
|    2|
|    3|
|    4|
|    5|
|    6|
|    7|
|    8|
|    9|
+-----+

Out[6]:
null

Notice the function when behaves like if-else.

In [7]:
df.withColumn("group",
    when($"value" <= 3, 0)
    .when($"value" <= 100, 1)
).show
+-----+-----+
|value|group|
+-----+-----+
|    0|    0|
|    1|    0|
|    2|    0|
|    3|    0|
|    4|    1|
|    5|    1|
|    6|    1|
|    7|    1|
|    8|    1|
|    9|    1|
+-----+-----+

Comments