我正在使用 pyspark,我有一个数据框对象df,这就是输出的df.printSchema()样子
df
df.printSchema()
root |-- M_MRN: string (nullable = true) |-- measurements: array (nullable = true) | |-- element: struct (containsNull = true) | | |-- Observation_ID: string (nullable = true) | | |-- Observation_Name: string (nullable = true) | | |-- Observation_Result: string (nullable = true)
我想过滤掉“测量”中 Observation_ID 不为“5”或“10”的所有数组。因此,目前当我运行时,df.select('measurements').take(2)我得到
df.select('measurements').take(2)
[Row(measurements=[Row(Observation_ID='5', Observation_Name='ABC', Observation_Result='108/72'), Row(Observation_ID='11', Observation_Name='ABC', Observation_Result='70'), Row(Observation_ID='10', Observation_Name='ABC', Observation_Result='73.029'), Row(Observation_ID='14', Observation_Name='XYZ', Observation_Result='23.1')]), Row(measurements=[Row(Observation_ID='2', Observation_Name='ZZZ', Observation_Result='3/4'), Row(Observation_ID='5', Observation_Name='ABC', Observation_Result='7')])]
我希望在完成上述过滤并运行后df.select('measurements').take(2)得到
[Row(measurements=[Row(Observation_ID='5', Observation_Name='ABC', Observation_Result='108/72'), Row(Observation_ID='10', Observation_Name='ABC', Observation_Result='73.029')]), Row(measurements=[Row(Observation_ID='5', Observation_Name='ABC', Observation_Result='7')])]
有没有办法在 pyspark 中做到这一点?非常感谢您的帮助!
从 Spark 2.4 开始,您可以使用高阶函数FILTER从数组中过滤出元素。因此,如果要删除Observation_ID非 5 或 10 的元素,可以按如下方式操作:
FILTER
Observation_ID
from pyspark.sql.functions import expr df.withColumn('measurements', expr("FILTER(measurements, x -> x.Observation_ID = 5 OR x.Observation_ID = 10)"))