在 Spark 3.4.1 代码中显式处理架构

问题描述 投票:0回答:1

我曾经对 Spark 代码应用一些补丁来显式处理更具体的数据类型和结构。带补丁的旧 Spark 代码:

private def serializerFor(inputObject: Expression, typeToken: TypeToken[_], schema: DataType): Expression = {

def toCatalystArray(input: Expression, elementType: TypeToken[_], elemType: DataType): Expression = {
  val (dataType, nullable) = inferDataType(elementType)
  if (ScalaReflection.isNativeType(dataType)) {
    val cls = input.dataType.asInstanceOf[ObjectType].cls
    if (cls.isArray && cls.getComponentType.isPrimitive) {
      createSerializerForPrimitiveArray(input, dataType)
    } else {
      createSerializerForGenericArray(input, dataType, nullable = nullable)
    }
  } else {
    createSerializerForMapObjects(input, ObjectType(elementType.getRawType),
      serializerFor(_, elementType, elemType))
  }
}

if (!inputObject.dataType.isInstanceOf[ObjectType]) {
  inputObject
} else {
  typeToken.getRawType match {
    case c if c == classOf[String] => createSerializerForString(inputObject)

    case c if c == classOf[java.time.Instant] => createSerializerForJavaInstant(inputObject)

    case c if c == classOf[java.sql.Timestamp] => createSerializerForSqlTimestamp(inputObject)

    case c if c == classOf[java.time.LocalDate] => createSerializerForJavaLocalDate(inputObject)

    case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject)

    case c if c == classOf[java.math.BigDecimal] =>
      createSerializerForJavaBigDecimal(inputObject, schema.asInstanceOf[DecimalType])

    case c if c == classOf[java.lang.Boolean] => createSerializerForBoolean(inputObject)
    case c if c == classOf[java.lang.Byte] => createSerializerForByte(inputObject)
    case c if c == classOf[java.lang.Short] => createSerializerForShort(inputObject)
    case c if c == classOf[java.lang.Integer] => createSerializerForInteger(inputObject)
    case c if c == classOf[java.lang.Long] => createSerializerForLong(inputObject)
    case c if c == classOf[java.lang.Float] => createSerializerForFloat(inputObject)
    case c if c == classOf[java.lang.Double] => createSerializerForDouble(inputObject)

    case _ if typeToken.isArray =>
      val baseType = schema match {
        case s: BinaryType => ByteType
        case _ => schema.asInstanceOf[ArrayType].elementType
      }
      toCatalystArray(inputObject, typeToken.getComponentType, baseType)

    case _ if ttIsAssignableFrom(listType, typeToken) =>
      toCatalystArray(inputObject, elementType(typeToken), schema.asInstanceOf[ArrayType].elementType)

    case _ if ttIsAssignableFrom(mapType, typeToken) =>
      val (keyType, valueType) = mapKeyValueType(typeToken)

      createSerializerForMap(
        inputObject,
        MapElementInformation(
          inferExternalType(keyType.getRawType),
          nullable = true,
          serializerFor(_, keyType, schema.asInstanceOf[MapType].keyType)),
        MapElementInformation(
          inferExternalType(valueType.getRawType),
          nullable = true,
          serializerFor(_, valueType, schema.asInstanceOf[MapType].valueType))
      )

    case other if other.isEnum =>
      createSerializerForString(
        Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false))

    case other =>
      val properties = getJavaBeanReadableAndWritableProperties(other)
      val propNames = properties.map { f => (f.getName, f) }.toMap
      // reorder the props w.r.t to given struct type.
      val orderedProps = schema.asInstanceOf[StructType].
        fields.map { f => (propNames.get(f.name).get, f) }
      val fields = orderedProps.map { case (p, f) =>
        val fieldName = p.getName
        val fieldType = typeToken.method(p.getReadMethod).getReturnType
        val fieldValue = Invoke(
          inputObject,
          p.getReadMethod.getName,
          inferExternalType(fieldType.getRawType))
        (fieldName, serializerFor(fieldValue, fieldType, f.dataType))
      }
      createSerializerForObject(inputObject, fields)
  }
}}}

这里我添加了一些额外的代码(如下所示)来处理涉及模式的数据结构,例如 StructType 或 ArrayType。

//in case of Array
val baseType = schema match {
            case s: BinaryType => ByteType
            case _ => schema.asInstanceOf[ArrayType].elementType
          }
//in case other
val propNames = properties.map { f => (f.getName, f) }.toMap
          // reorder the props w.r.t to given struct type.
          val orderedProps = schema.asInstanceOf[StructType].
            fields.map { f => (propNames.get(f.name).get, f) }
          val fields = orderedProps.map { case (p, f) =>

但是现在,在 Spark 3.4.1 中,代码流程发生了变化,并且很难应用这些补丁。有人可以帮我编辑这个代码吗? 火花 3.4.1:

private def serializerFor(enc: AgnosticEncoder[_], input: Expression): Expression = enc match {
case _ if isNativeEncoder(enc) => input
case BoxedBooleanEncoder => createSerializerForBoolean(input)
case BoxedByteEncoder => createSerializerForByte(input)
case BoxedShortEncoder => createSerializerForShort(input)
case BoxedIntEncoder => createSerializerForInteger(input)
case BoxedLongEncoder => createSerializerForLong(input)
case BoxedFloatEncoder => createSerializerForFloat(input)
case BoxedDoubleEncoder => createSerializerForDouble(input)
case JavaEnumEncoder(_) => createSerializerForJavaEnum(input)
case ScalaEnumEncoder(_, _) => createSerializerForScalaEnum(input)
case StringEncoder => createSerializerForString(input)
case ScalaDecimalEncoder(dt) => createSerializerForBigDecimal(input, dt)
case JavaDecimalEncoder(dt, false) => createSerializerForBigDecimal(input, dt)
case JavaDecimalEncoder(dt, true) => createSerializerForAnyDecimal(input, dt)
case ScalaBigIntEncoder => createSerializerForBigInteger(input)
case JavaBigIntEncoder => createSerializerForBigInteger(input)
case DayTimeIntervalEncoder => createSerializerForJavaDuration(input)
case YearMonthIntervalEncoder => createSerializerForJavaPeriod(input)
case DateEncoder(true) | LocalDateEncoder(true) => createSerializerForAnyDate(input)
case DateEncoder(false) => createSerializerForSqlDate(input)
case LocalDateEncoder(false) => createSerializerForJavaLocalDate(input)
case TimestampEncoder(true) | InstantEncoder(true) => createSerializerForAnyTimestamp(input)
case TimestampEncoder(false) => createSerializerForSqlTimestamp(input)
case InstantEncoder(false) => createSerializerForJavaInstant(input)
case LocalDateTimeEncoder => createSerializerForLocalDateTime(input)
case UDTEncoder(udt, udtClass) => createSerializerForUserDefinedType(input, udt, udtClass)
case OptionEncoder(valueEnc) =>
  serializerFor(valueEnc, UnwrapOption(externalDataTypeFor(valueEnc), input))

case ArrayEncoder(elementEncoder, containsNull) =>
  if (elementEncoder.isPrimitive) {
    createSerializerForPrimitiveArray(input, elementEncoder.dataType)
  } else {
    serializerForArray(elementEncoder, containsNull, input, lenientSerialization = false)
  }

case IterableEncoder(ctag, elementEncoder, containsNull, lenientSerialization) =>
  val getter = if (classOf[scala.collection.Set[_]].isAssignableFrom(ctag.runtimeClass)) {
    // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array.
    // Note that the property of `Set` is only kept when manipulating the data as domain object.
    Invoke(input, "toSeq", ObjectType(classOf[scala.collection.Seq[_]]))
  } else {
    input
  }
  serializerForArray(elementEncoder, containsNull, getter, lenientSerialization)

case MapEncoder(_, keyEncoder, valueEncoder, valueContainsNull) =>
  createSerializerForMap(
    input,
    MapElementInformation(
      ObjectType(classOf[AnyRef]),
      nullable = keyEncoder.nullable,
      validateAndSerializeElement(keyEncoder, keyEncoder.nullable)),
    MapElementInformation(
      ObjectType(classOf[AnyRef]),
      nullable = valueContainsNull,
      validateAndSerializeElement(valueEncoder, valueContainsNull))
  )

case ProductEncoder(_, fields) =>
  val serializedFields = fields.map { field =>
    // SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul
    // is necessary here. Because for a nullable nested inputObject with struct data
    // type, e.g. StructType(IntegerType, StringType), it will return nullable=true
    // for IntegerType without KnownNotNull. And that's what we do not expect to.
    val getter = Invoke(
      KnownNotNull(input),
      field.name,
      externalDataTypeFor(field.enc),
      returnNullable = field.nullable)
    field.name -> serializerFor(field.enc, getter)
  }
  createSerializerForObject(input, serializedFields)

case RowEncoder(fields) =>
  val serializedFields = fields.zipWithIndex.map { case (field, index) =>
    val fieldValue = serializerFor(
      field.enc,
      ValidateExternalType(
        GetExternalRowField(input, index, field.name),
        field.enc.dataType,
        lenientExternalDataTypeFor(field.enc)))

    val convertedField = if (field.nullable) {
      exprs.If(
        Invoke(input, "isNullAt", BooleanType, exprs.Literal(index) :: Nil),
        // Because we strip UDTs, `field.dataType` can be different from `fieldValue.dataType`.
        // We should use `fieldValue.dataType` here.
        exprs.Literal.create(null, fieldValue.dataType),
        fieldValue
      )
    } else {
      AssertNotNull(fieldValue)
    }
    field.name -> convertedField
  }
  createSerializerForObject(input, serializedFields)

case JavaBeanEncoder(_, fields) =>
  val serializedFields = fields.map { f =>
    val fieldValue = Invoke(
      KnownNotNull(input),
      f.readMethod.get,
      externalDataTypeFor(f.enc),
      propagateNull = f.nullable,
      returnNullable = f.nullable)
    f.name -> serializerFor(f.enc, fieldValue)
  }
  createSerializerForObject(input, serializedFields)
}

正如我们在这里看到的,他们使用了编码器。我尝试应用这些补丁,但得到了一些 classCastExceptions。 预先感谢。

scala apache-spark
1个回答
0
投票

您可能想看看构建/转换AgnosticEncoders。 Spark 4 预览版 2 几乎强制将此作为扩展编码器逻辑的主要方式,但您将获得针对您的类型以及经典 Spark 的 Spark Connect 的使用。

如果 Spark #48477 被接受,您还可以考虑扩展/自定义 Sparkutils Frameless 分叉编码,以进行更多的未来验证,但使用现有的 AgnosticEncoder 似乎符合您的用例(如代码示例中所示)。

© www.soinside.com 2019 - 2024. All rights reserved.