我有一个具有以下架构的 pyspark df:
root
|-- array_bytes: binary (nullable = true)
我希望能够将其转换为图像数组。我可以使用以下代码在 Pandas 中完成此操作:
df_pandas = df.toPandas()
def bytes_to_array(byte_data):
arr = np.frombuffer(byte_data, dtype=np.uint8)
return arr.reshape((224, 224, 3))
df_pandas['image_array'] = df_pandas['array_bytes'].apply(bytes_to_array)
我似乎找不到在 PySpark 中执行此操作的方法。这是我尝试过的:
def convert_binary_to_array(binary_data: bytes) -> np.ndarray:
arr = np.frombuffer(binary_data, dtype=np.uint8)
return arr.reshape((224, 224, 3))
def convert_binary_in_df(df, binary_column: str = 'binary_data'):
def convert_binary_udf(byte_data):
return convert_binary_to_array(byte_data).tolist()
# register and apply udf
convert_binary_spark_udf = udf(convert_binary_udf, ArrayType(ArrayType(IntegerType())))
df_output = df.withColumn("image_data", convert_binary_spark_udf(binary_column))
return df_output
df_converted = convert_binary_in_df(df, binary_column='array_bytes')
但是,
image_data
最终会充满空值。我对 PySpark 的了解不是很强,无法找出问题所在。预先感谢您的帮助。
我整理了一个工作示例来拍摄一些图像,将其转换为数据帧,然后从数据帧重建图像,看看看起来是否相同
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, IntegerType
import numpy as np
import matplotlib.pyplot as plt
# an image with a central red square and blue background
image = np.zeros((224, 224, 3), dtype=np.uint8)
image[:, :] = [0, 0, 255]
image[72:152, 72:152] = [255, 0, 0]
plt.imshow(image)
plt.title("Original Image")
plt.axis('off')
plt.show()
def convert_binary_to_array(binary_data: bytes) -> np.ndarray:
arr = np.frombuffer(binary_data, dtype=np.uint8)
return arr.reshape((224, 224, 3))
def convert_binary_in_df(df, binary_column: str = 'binary_data'):
def convert_binary_udf(byte_data):
return convert_binary_to_array(byte_data).tolist()
convert_binary_spark_udf = udf(convert_binary_udf, ArrayType(ArrayType(ArrayType(IntegerType()))))
df_output = df.withColumn("image_data", convert_binary_spark_udf(binary_column))
return df_output
spark = SparkSession.builder.appName("ImageArrayConversion").getOrCreate()
binary_image_data = image.tobytes()
df = spark.createDataFrame([(binary_image_data,)], ["array_bytes"])
df_converted = convert_binary_in_df(df, binary_column='array_bytes')
df_converted.show(truncate=False)
image_data_from_df = df_converted.collect()[0]['image_data']
# Reconstruct the image back
recreated_image = np.array(image_data_from_df, dtype=np.uint8)
plt.imshow(recreated_image)
plt.title("Recreated Image from DF")
plt.axis('off')
plt.show()
我必须在代码中进行的唯一修复是将 UDF 的返回类型更改为
ArrayType(ArrayType(ArrayType(IntegerType()))
,因为图像中有三个通道。