Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/18994#discussion_r134144063
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TableConstraints.scala
---
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.catalog
+
+import java.util.UUID
+
+import org.json4s._
+import org.json4s.JsonAST.JValue
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.Resolver
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.SchemaUtils
+
+/**
+ * A container class to hold all the constraints defined on a table. Scope
of the
+ * constraint names are at the table level.
+ */
+case class TableConstraints(
+ primaryKey: Option[PrimaryKey] = None,
+ foreignKeys: Seq[ForeignKey] = Seq.empty) {
+
+ /**
+ * Adds the given constraint to the existing table constraints, after
verifying the
+ * constraint name is not a duplicate.
+ */
+ def addConstraint(constraint: TableConstraint, resolver: Resolver):
TableConstraints = {
+ if ((primaryKey.exists(pk => resolver(pk.constraintName,
constraint.constraintName))
+ || foreignKeys.exists(fk => resolver(fk.constraintName,
constraint.constraintName)))) {
+ throw new AnalysisException(
+ s"Failed to add constraint, duplicate constraint name
'${constraint.constraintName}'")
+ }
+ constraint match {
+ case pk: PrimaryKey =>
+ if (primaryKey.nonEmpty) {
+ throw new AnalysisException(
+ s"Primary key '${primaryKey.get.constraintName}' already
exists.")
+ }
+ this.copy(primaryKey = Option(pk))
+ case fk: ForeignKey => this.copy(foreignKeys = foreignKeys :+ fk)
+ }
+ }
+}
+
+object TableConstraints {
+ /**
+ * Returns a [[TableConstraints]] containing [[PrimaryKey]] or
[[ForeignKey]]
+ */
+ def apply(tableConstraint: TableConstraint): TableConstraints = {
+ tableConstraint match {
+ case pk: PrimaryKey => TableConstraints(primaryKey = Option(pk))
+ case fk: ForeignKey => TableConstraints(foreignKeys = Seq(fk))
+ }
+ }
+
+ /**
+ * Converts constraints represented in Json strings to
[[TableConstraints]].
+ */
+ def fromJson(pkJson: Option[String], fksJson: Seq[String]):
TableConstraints = {
+ val pk = pkJson.map(pk => PrimaryKey.fromJson(parse(pk)))
+ val fks = fksJson.map(fk => ForeignKey.fromJson(parse(fk)))
+ TableConstraints(pk, fks)
+ }
+}
+
+/**
+ * Common type representing a table constraint.
+ */
+sealed trait TableConstraint {
+ val constraintName : String
+ val keyColumnNames : Seq[String]
+}
+
+object TableConstraint {
+ private[TableConstraint] val curId = new
java.util.concurrent.atomic.AtomicLong(0L)
+ private[TableConstraint] val jvmId = UUID.randomUUID()
+
+ /**
+ * Generates unique constraint name to use when adding table constraints,
+ * if user does not specify a name. The `curId` field is unique within a
given JVM,
+ * while the `jvmId` is used to uniquely identify JVMs.
+ */
+ def generateConstraintName(constraintType: String = "constraint"):
String = {
+ s"${constraintType}_${jvmId}_${curId.getAndIncrement()}"
+ }
+
+ def parseColumn(json: JValue): String = json match {
+ case JString(name) => name
+ case _ => json.toString
+ }
+
+ object JSortedObject {
+ def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value
match {
+ case JObject(seq) => Some(seq.toList.sortBy(_._1))
+ case _ => None
+ }
+ }
+
+ /**
+ * Returns [[StructField]] for the given column name if it exists in the
given schema.
+ */
+ def findColumnByName(
+ schema: StructType, name: String, resolver: Resolver): StructField = {
+ schema.fields.collectFirst {
+ case field if resolver(field.name, name) => field
+ }.getOrElse(throw new AnalysisException(
+ s"Invalid column reference '$name', table data schema is
'${schema}'"))
+ }
+
+ /**
+ * Verify the user input constraint information, and add the missing
information
+ * like the unspecified reference columns that defaults to reference
table's primary key.
+ */
+ def verifyAndBuildConstraint(
+ inputConstraint: TableConstraint,
+ table: CatalogTable,
+ catalog: SessionCatalog,
+ resolver: Resolver): TableConstraint = {
+ SchemaUtils.checkColumnNameDuplication(
+ inputConstraint.keyColumnNames, "in the constraint key definition",
resolver)
+ // check if the column names are valid non-partition columns.
+ val keyColFields = inputConstraint.keyColumnNames
+ .map(findColumnByName(table.dataSchema, _, resolver))
+ // Constraints are only supported for basic sql types, throw error for
any other data types.
+ keyColFields.map(_.dataType).foreach {
+ case ByteType | ShortType | IntegerType | LongType | FloatType |
--- End diff --
BinaryType?
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]