Things on this page are fragmentary and immature notes/thoughts of the author. Please read with your own judgement!
Tips and Traps¶
Watch out for NaNs ..., behave might not what you expect ...
None can be used for otherwise and yield null in DataFrame.
Column alias and postional columns can be used in group by in Spark SQL!!!
Notice the function when behaves like if-else.
import pandas as pd
import findspark
findspark.init("/opt/spark-3.0.1-bin-hadoop3.2/")
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import *
from pyspark.sql.types import StructType
spark = SparkSession.builder.appName("Case/When").enableHiveSupport().getOrCreate()df_p = pd.DataFrame({"age": [None, 30, 19], "name": ["Michael", "Andy", "Justin"]})df = spark.createDataFrame(df_p)
df.show()+----+-------+
| age| name|
+----+-------+
| NaN|Michael|
|30.0| Andy|
|19.0| Justin|
+----+-------+
df.schemaStructType(List(StructField(age,DoubleType,true),StructField(name,StringType,true)))df.filter(col("age").isNull()).show()+---+----+
|age|name|
+---+----+
+---+----+
df.createOrReplaceTempView("df")spark.sql(
"""
select
case
when age > 20 then 1
else 0
end as age_group,
row_number() over (partition by case when age > 20 then 1 else 0 end order by age) as nima
from
df
"""
).show()+---------+----+
|age_group|nima|
+---------+----+
| 1| 1|
| 1| 2|
| 0| 1|
+---------+----+
spark.sql(
"""
select
age_group,
percentile(age, 0.5) over (partition by age_group) as nima
from (
select
*,
case
when age > 20 then 1
else 0
end as age_group
from
df
) A
"""
).show()+---------+----+
|age_group|nima|
+---------+----+
| 1| NaN|
| 1| NaN|
| 0|19.0|
+---------+----+
spark.sql(
"""
select
case
when age > 20 then 1
else 0
end as age_group,
row_number() over (partition by age_group order by age) as nima
from
df
"""
).show()---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
/opt/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
62 try:
---> 63 return f(*a, **kw)
64 except py4j.protocol.Py4JJavaError as e:
/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
327 "An error occurred while calling {0}{1}{2}.\n".
--> 328 format(target_id, ".", name), value)
329 else:
Py4JJavaError: An error occurred while calling o25.sql.
: org.apache.spark.sql.AnalysisException: cannot resolve '`age_group`' given input columns: [df.age, df.name]; line 7 pos 40;
'Project [CASE WHEN (age#4 > cast(20 as double)) THEN 1 ELSE 0 END AS age_group#131, row_number() windowspecdefinition('age_group, age#4 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS nima#132]
+- SubqueryAlias `df`
+- LogicalRDD [age#4, name#5], false
at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:111)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:108)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:280)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:280)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:279)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
at org.apache.spark.sql.catalyst.trees.TreeNode.org$apache$spark$sql$catalyst$trees$TreeNode$$mapChild$2(TreeNode.scala:297)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4$$anonfun$apply$13.apply(TreeNode.scala:356)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at scala.collection.AbstractTraversable.map(Traversable.scala:104)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:356)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:186)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:326)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:277)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:328)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:186)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:326)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:277)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:277)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:328)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:186)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:326)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:277)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$transformExpressionsUp$1.apply(QueryPlan.scala:93)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$transformExpressionsUp$1.apply(QueryPlan.scala:93)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$1.apply(QueryPlan.scala:105)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$1.apply(QueryPlan.scala:105)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69)
at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpression$1(QueryPlan.scala:104)
at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$1(QueryPlan.scala:116)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$1$2.apply(QueryPlan.scala:121)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.immutable.List.foreach(List.scala:392)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at scala.collection.immutable.List.map(List.scala:296)
at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$1(QueryPlan.scala:121)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2.apply(QueryPlan.scala:126)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:186)
at org.apache.spark.sql.catalyst.plans.QueryPlan.mapExpressions(QueryPlan.scala:126)
at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsUp(QueryPlan.scala:93)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:108)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:86)
at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:126)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.checkAnalysis(CheckAnalysis.scala:86)
at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:95)
at org.apache.spark.sql.catalyst.analysis.Analyzer$$anonfun$executeAndCheck$1.apply(Analyzer.scala:108)
at org.apache.spark.sql.catalyst.analysis.Analyzer$$anonfun$executeAndCheck$1.apply(Analyzer.scala:105)
at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:201)
at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:105)
at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:58)
at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:56)
at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:48)
at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:78)
at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:642)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
During handling of the above exception, another exception occurred:
AnalysisException Traceback (most recent call last)
<ipython-input-32-d47e32b0bffd> in <module>
8 from
9 df
---> 10 """).show()
/opt/spark/python/pyspark/sql/session.py in sql(self, sqlQuery)
765 [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
766 """
--> 767 return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped)
768
769 @since(2.0)
/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
1255 answer = self.gateway_client.send_command(command)
1256 return_value = get_return_value(
-> 1257 answer, self.gateway_client, self.target_id, self.name)
1258
1259 for temp_arg in temp_args:
/opt/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
67 e.java_exception.getStackTrace()))
68 if s.startswith('org.apache.spark.sql.AnalysisException: '):
---> 69 raise AnalysisException(s.split(': ', 1)[1], stackTrace)
70 if s.startswith('org.apache.spark.sql.catalyst.analysis'):
71 raise AnalysisException(s.split(': ', 1)[1], stackTrace)
AnalysisException: "cannot resolve '`age_group`' given input columns: [df.age, df.name]; line 7 pos 40;\n'Project [CASE WHEN (age#4 > cast(20 as double)) THEN 1 ELSE 0 END AS age_group#131, row_number() windowspecdefinition('age_group, age#4 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS nima#132]\n+- SubqueryAlias `df`\n +- LogicalRDD [age#4, name#5], false\n"spark.sql(
"""
select
case
when age > 20 then 1
else 0
end as age_group,
count(*) as n
from
df
group by
age_group
"""
).show()+---------+---+
|age_group| n|
+---------+---+
| 1| 2|
| 0| 1|
+---------+---+
spark.sql(
"""
select
case
when age > 20 then 1
else 0
end as age_group,
count(*) as n
from
df
group by
1
"""
).show()+---------+---+
|age_group| n|
+---------+---+
| 1| 2|
| 0| 1|
+---------+---+
df.withColumn("null_gt", when(col("age") >= 0, 1).otherwise(None)).show()+----+-------+-------+
| age| name|null_gt|
+----+-------+-------+
| NaN|Michael| 1|
|30.0| Andy| 1|
|19.0| Justin| 1|
+----+-------+-------+
df.withColumn("null_gt", when(col("age") < 20, 1).otherwise(None)).show()+----+-------+-------+
| age| name|null_gt|
+----+-------+-------+
| NaN|Michael| null|
|30.0| Andy| null|
|19.0| Justin| 1|
+----+-------+-------+
df.withColumn("null_gt", when(col("age") >= 20, 1).otherwise(None)).show()+----+-------+-------+
| age| name|null_gt|
+----+-------+-------+
| NaN|Michael| 1|
|30.0| Andy| 1|
|19.0| Justin| null|
+----+-------+-------+
df.withColumn("null_lt",
when($"age" <= 1000, 1).otherwise(null)
).show+----+-------+-------+
| age| name|null_lt|
+----+-------+-------+
|null|Michael| null|
| 30| Andy| 1|
| 19| Justin| 1|
+----+-------+-------+
df.select(when($"age".isNull, 0).when($"age" > 20 , 100).otherwise(10).alias("age")).show+---+
|age|
+---+
| 0|
|100|
| 10|
+---+
df.select(when($"age".isNull, 0).alias("age")).show+----+
| age|
+----+
| 0|
|null|
|null|
+----+
val df = Range(0, 10).toDF
df.show+-----+
|value|
+-----+
| 0|
| 1|
| 2|
| 3|
| 4|
| 5|
| 6|
| 7|
| 8|
| 9|
+-----+
nullNotice the function when behaves like if-else.
df.withColumn("group",
when($"value" <= 3, 0)
.when($"value" <= 100, 1)
).show+-----+-----+
|value|group|
+-----+-----+
| 0| 0|
| 1| 0|
| 2| 0|
| 3| 0|
| 4| 1|
| 5| 1|
| 6| 1|
| 7| 1|
| 8| 1|
| 9| 1|
+-----+-----+