如何在 PySpark 中将二进制图像数据转换为 RGB 数组?

问题描述 投票:0回答:1

我有一个具有以下架构的 pyspark df:

root
 |-- array_bytes: binary (nullable = true)

我希望能够将其转换为图像数组。我可以使用以下代码在 Pandas 中完成此操作:

df_pandas = df.toPandas()
def bytes_to_array(byte_data):
    arr = np.frombuffer(byte_data, dtype=np.uint8)
    return arr.reshape((224, 224, 3)) 

df_pandas['image_array'] = df_pandas['array_bytes'].apply(bytes_to_array)

我似乎找不到在 PySpark 中执行此操作的方法。这是我尝试过的:

def convert_binary_to_array(binary_data: bytes) -> np.ndarray:
    arr = np.frombuffer(binary_data, dtype=np.uint8)
    return arr.reshape((224, 224, 3))


def convert_binary_in_df(df, binary_column: str = 'binary_data'):
    def convert_binary_udf(byte_data):
        return convert_binary_to_array(byte_data).tolist()
    
    # register and apply udf
    convert_binary_spark_udf = udf(convert_binary_udf, ArrayType(ArrayType(IntegerType())))
    df_output = df.withColumn("image_data", convert_binary_spark_udf(binary_column))
    return df_output

df_converted = convert_binary_in_df(df, binary_column='array_bytes')

但是,

image_data
最终会充满空值。我对 PySpark 的了解不是很强,无法找出问题所在。预先感谢您的帮助。

python image pyspark binary byte
1个回答
0
投票

我整理了一个工作示例来拍摄一些图像,将其转换为数据帧,然后从数据帧重建图像,看看看起来是否相同

import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, IntegerType
import numpy as np
import matplotlib.pyplot as plt


# an image with a central red square and blue background
image = np.zeros((224, 224, 3), dtype=np.uint8)
image[:, :] = [0, 0, 255]
image[72:152, 72:152] = [255, 0, 0]


plt.imshow(image)
plt.title("Original Image")
plt.axis('off')
plt.show()



def convert_binary_to_array(binary_data: bytes) -> np.ndarray:
    arr = np.frombuffer(binary_data, dtype=np.uint8)
    return arr.reshape((224, 224, 3))

def convert_binary_in_df(df, binary_column: str = 'binary_data'):
    def convert_binary_udf(byte_data):
        return convert_binary_to_array(byte_data).tolist()
    
    convert_binary_spark_udf = udf(convert_binary_udf, ArrayType(ArrayType(ArrayType(IntegerType()))))
    df_output = df.withColumn("image_data", convert_binary_spark_udf(binary_column))
    return df_output

spark = SparkSession.builder.appName("ImageArrayConversion").getOrCreate()

binary_image_data = image.tobytes()

df = spark.createDataFrame([(binary_image_data,)], ["array_bytes"])

df_converted = convert_binary_in_df(df, binary_column='array_bytes')

df_converted.show(truncate=False)


image_data_from_df = df_converted.collect()[0]['image_data']


# Reconstruct the image back

recreated_image = np.array(image_data_from_df, dtype=np.uint8)
plt.imshow(recreated_image)
plt.title("Recreated Image from DF")
plt.axis('off')
plt.show()


我必须在代码中进行的唯一修复是将 UDF 的返回类型更改为

ArrayType(ArrayType(ArrayType(IntegerType()))
,因为图像中有三个通道。

Output

© www.soinside.com 2019 - 2024. All rights reserved.