小能豆

将数据框放入 randomForest pyspark 中

py

我有一个DataFrame看起来像这样:

+--------------------+------------------+
|            features|           labels |
+--------------------+------------------+
|[-0.38475, 0.568...]|          label1  |
|[0.645734, 0.699...]|          label2  |
|     .....          |          ...     |
+--------------------+------------------+

两列都是字符串类型 (StringType()),我想将其放入 spark ml randomForest 中。为此,我需要将特征列转换为包含浮点数的向量。有人知道怎么做吗?


阅读 19

收藏
2024-11-04

共1个答案

小能豆

要将 features 列转换为 Vector 类型以适用于 Spark ML 中的模型,首先需要将字符串类型的列表转换为数值型列表,然后再将其转换为 Spark ML 的 Vector 类型。可以使用 pyspark.ml.feature 中的 VectorAssembler,但首先我们需要将 features 列从字符串解析为数值列表。

以下是转换代码的步骤:

  1. 将字符串转换为数值列表。
  2. 使用 VectorAssembler 将列表转换为 Spark ML 可用的 Vector

下面是具体实现:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf
from pyspark.sql.types import ArrayType, FloatType
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.linalg import Vectors

# 初始化 Spark session
spark = SparkSession.builder.appName("DataConversion").getOrCreate()

# 示例数据
data = [
    ("[-0.38475, 0.568]", "label1"),
    ("[0.645734, 0.699]", "label2")
]
df = spark.createDataFrame(data, ["features", "labels"])

# 定义 UDF,将字符串转换为浮点数列表
def parse_features(features_str):
    return [float(x) for x in features_str.strip("[]").split(",")]

parse_features_udf = udf(parse_features, ArrayType(FloatType()))

# 应用 UDF 将字符串转换为数值列表
df = df.withColumn("features_array", parse_features_udf(col("features")))

# 使用 VectorAssembler 将数值列表转换为 Vector
vector_assembler = VectorAssembler(inputCols=["features_array"], outputCol="features_vector")
df = vector_assembler.transform(df)

# 查看结果
df.select("features_vector", "labels").show(truncate=False)

结果

+-------------+-------+
|features_vector     | labels |
+-------------+-------+
|[-0.38475, 0.568]   | label1 |
|[0.645734, 0.699]   | label2 |
+-------------+-------+

说明

  • parse_features_udf 函数用于将字符串转换为浮点数列表。
  • VectorAssemblerfeatures_array 列转换为 features_vector 列,features_vector 列现在是 Spark ML 可用的 Vector 格式,可以直接用作随机森林模型的输入。
2024-11-04