pyspark常用语法(含pandas对比)

1.排名函数

dense_rank():相同数具有相同的排名,始终具有连续的排名值

import pyspark.sql.functions as F
from pyspark.sql.window import Window
data = [(1, 'John'),
        (1, 'Mike'),
        (1, 'Emma'),
        (4, 'Sarah')]
df = spark.createDataFrame(data, ['id', 'name'])
window = Window.orderBy(col('id'))
df = df.withColumn("frame_id", F.dense_rank().over(window))
df.show()

补充一个其他的常用的:

rank():相同数具有相同的排名,下一个跳过去


row_number():相同数具有不同的排名,下一个接着

2.pandas中的map函数,pyspark不支持map ,when...otherwise

pyspark df一列分数,如果大于零值变为‘good’, 小于变为‘no

pandas:
df['score'] = df['score'].map(lambda x : 'good' if x> 0 else 'no')
pyspark
from pyspark.sql.functions import when
df = df.withColumn('score', when(df.score > 0, 'good').otherwise('no'))

3.groupby

分组求简单的最大最小平均值等:

根据"obj_frame_id"分组,求'rel_pos_x'平均值生成新列为"rel_pos_x"
data_add = data.groupBy("obj_frame_id").agg(F.avg('rel_pos_x').alias("rel_pos_x")))

另一种写法: 

window_spec = Window.partitionBy('follow_id')
df = df.withColumn('follow_start_time', F.min('ts').over(window_spec))
df = df.withColumn('follow_end_time', F.max('ts').over(window_spec))
df = df.withColumn('follow_count_time', F.count('ts').over(window_spec))

 分组求复杂的写pandas_udf的

df = df.groupby(['obj_id']).applyInPandas(group_udf, schema=df.schema)
不能有空值

注意:这种方法返回值里不能有空值,会报错而且也不会具体告诉你是什么错

4.用字典去替换df一列的值

比如:我有一个字典{‘a’: 111, b:222}, df有一列值为‘a',‘b’,我想把它替换成111,222

from pyspark.sql.types import StructType, StructField
from pyspark.sql.types import StringType, IntegerType
import pyspark.sql.functions as F
from pyspark.sql.functions import col,udf
replace_udf = udf(lambda x: algo_dict.get(x, x), StringType())
data_scene = data_scene.withColumn("code_name", replace_udf(data_scene["algorithm_id"]))
pandas:
data_scene['code_name']=data_scene['algorithm_id'].map(algo_dict)

 5.排序

pandas:
df = df.sort_values('header timestamp')
pyspark:
df = df.sort('header timestamp')

 6.增加一列常数或用另一列赋值

注:pyspark不支持把一列list赋给一列df

pandas:
df['oritentation pitch'] = 0
pyspark:
from pyspark.sql.functions import lit
df = df.withColumn('oritentation pitch', lit(0))
pandas:
df['bag_timestamp'] = df['header_timestamp']
pyspark:
新建一列bag_timestamp,其值是已有列header_timestamp
df = df.withColumn('bag_timestamp', df['header_timestamp'])

7.去重drop_duplicates()

虽然两个写法一样,但是pyspark没有keep这个参数

pandas里:

  • keep: 'first', 'last', False,默认为first

            决定保留的数据行。

            first:保留第一个出现的重复数据

            last:保留最后一个出现的重复数据

            False:删除所有的重复行

    pandas:
    df = df.drop_duplicates(subset='列名',keep='last')
    pyspark:
    df = df.drop_duplicates(subset='列名')
    

    8.拼接两列 F.concat

    举例:我的'object_id'是1,'class_label_pred'是car,新增的'obj_id_class'值为'1_car'

    pyspark:
    obj_table = obj_table.withColumn('obj_id_class', F.concat(F.col('object_id'), lit('_'), F.col('class_label_pred').cast('string')))
    或者:
    data=data.withColumn('obj_id_class',concat_ws('_',"obj_id","obj_class"))
    pandas:
    obj_table['obj_id_class'] = obj_table["object_id"].map(str) + '_' + 
                           obj_table["class_label_pred"].map(str)

    9.常见的过滤

    找df 'position_x'列绝对值大于10的

    pyspark:
    obj_table = obj_table[(F.abs(df.position_x)<= 10)]
    pandas:
    obj_table = obj_table[(df.position_x.abs()<= 10)]
    

     10.取行limit

    取df最后一行作为一个新的df

    df2 = df.orderBy(F.col('ts').desc()).limit(1)

    11.pandas中的diff() 转 pyspark

    from pyspark.sql.functions import lit, lag, lead, col,when
    data = [(1, 10), (2, 20), (3, 15), (4, 25), (5, 30)]
    df = spark.createDataFrame(data, ['id', 'value'])
    windowSpec = Window.orderBy('id')
    df = df.withColumn('diff', col('value') - lag('value').over(windowSpec))

    lag相当于pandas中的shift()

    lead相当于pandas中的shift(-1)

    window_spec = Window.orderBy('ts')
    data = data.withColumn('last_front_count', lag(col('front_count')).over(window_spec))

    12.填空值fillna()

    pandas:
    df["yaw_flag"] = df["yaw_flag"].fillna(0)
    pyspark:
    df = df.fillna(0, subset=["yaw_flag"])

    再简单介绍几个其他的

     1.使用matplotlib画图时,pyspark需要用df[['col']].collect()

    import matplotlib.pyplot as plt
    import matplotlib
    matplotlib.use('agg')
    y1.plot(data_result_i[['time']].collect(), data_result_i[['ego_velocity_x']].collect(), color='dodgerblue', label='ego_spd',
                linewidth=3)

    2.遇到的错误 

    初步解决办法是关闭spark重启或将自定义udf放到函数里面去,我也不太理解这个问题

     附之前的pyspark:

    pandas、pyspark、spark相互转换,语法对比(超详细)

    python spark 求dataframe一列的max,min,median

    python spark 纵向合并多个Datafame