This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new 482b03cbd126 Revert "[SPARK-54760][SQL] DelegatingCatalogExtension as
session catalog supports both V1 and V2 functions"
482b03cbd126 is described below
commit 482b03cbd1269d53ccf5ff46430d920f66d38df2
Author: Wenchen Fan <[email protected]>
AuthorDate: Sat Dec 27 13:55:36 2025 +0800
Revert "[SPARK-54760][SQL] DelegatingCatalogExtension as session catalog
supports both V1 and V2 functions"
This reverts commit 51042d66d8e88bdd1ee1a150b775d681e45d69d4.
---
.../sql/catalyst/analysis/FunctionResolution.scala | 23 ++-
.../identifier-clause-legacy.sql.out | 8 +-
.../analyzer-results/identifier-clause.sql.out | 8 +-
.../results/identifier-clause-legacy.sql.out | 8 +-
.../sql-tests/results/identifier-clause.sql.out | 8 +-
.../DataSourceV2DataFrameSessionCatalogSuite.scala | 8 +-
.../sql/connector/DataSourceV2FunctionSuite.scala | 182 ++++++++++-----------
.../DataSourceV2SQLSessionCatalogSuite.scala | 9 +-
.../connector/SupportsCatalogOptionsSuite.scala | 9 +-
.../sql/connector/TestV2SessionCatalogBase.scala | 60 +------
10 files changed, 126 insertions(+), 197 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
index 8d6e2931a73b..800126e0030e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
@@ -26,16 +26,17 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog.{
CatalogManager,
+ CatalogV2Util,
+ FunctionCatalog,
+ Identifier,
LookupCatalog
}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.functions.{
AggregateFunction => V2AggregateFunction,
- ScalarFunction,
- UnboundFunction
+ ScalarFunction
}
import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors}
-import org.apache.spark.sql.internal.connector.V1Function
import org.apache.spark.sql.types._
class FunctionResolution(
@@ -51,14 +52,10 @@ class FunctionResolution(
resolveBuiltinOrTempFunction(u.nameParts, u.arguments, u).getOrElse {
val CatalogAndIdentifier(catalog, ident) =
relationResolution.expandIdentifier(u.nameParts)
- catalog.asFunctionCatalog.loadFunction(ident) match {
- case V1Function(_) =>
- // this triggers the second time v1 function resolution but should
be cheap
- // (no RPC to external catalog), since the metadata has been
already cached
- // in FunctionRegistry during the above `catalog.loadFunction`
call.
- resolveV1Function(ident.asFunctionIdentifier, u.arguments, u)
- case unboundV2Func =>
- resolveV2Function(unboundV2Func, u.arguments, u)
+ if (CatalogV2Util.isSessionCatalog(catalog)) {
+ resolveV1Function(ident.asFunctionIdentifier, u.arguments, u)
+ } else {
+ resolveV2Function(catalog.asFunctionCatalog, ident, u.arguments, u)
}
}
}
@@ -275,9 +272,11 @@ class FunctionResolution(
}
private def resolveV2Function(
- unbound: UnboundFunction,
+ catalog: FunctionCatalog,
+ ident: Identifier,
arguments: Seq[Expression],
u: UnresolvedFunction): Expression = {
+ val unbound = catalog.loadFunction(ident)
val inputType = StructType(arguments.zipWithIndex.map {
case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable)
})
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out
index 95639c72a0ad..94fff8f58697 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out
@@ -972,11 +972,11 @@ VALUES(IDENTIFIER('a.b.c.d')())
-- !query analysis
org.apache.spark.sql.AnalysisException
{
- "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
- "sqlState" : "42K05",
+ "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
+ "sqlState" : "42601",
"messageParameters" : {
- "namespace" : "`a`.`b`.`c`",
- "sessionCatalog" : "spark_catalog"
+ "identifier" : "`a`.`b`.`c`.`d`",
+ "limit" : "2"
},
"queryContext" : [ {
"objectType" : "",
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
index e3150b199658..e6a406072c48 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
@@ -972,11 +972,11 @@ VALUES(IDENTIFIER('a.b.c.d')())
-- !query analysis
org.apache.spark.sql.AnalysisException
{
- "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
- "sqlState" : "42K05",
+ "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
+ "sqlState" : "42601",
"messageParameters" : {
- "namespace" : "`a`.`b`.`c`",
- "sessionCatalog" : "spark_catalog"
+ "identifier" : "`a`.`b`.`c`.`d`",
+ "limit" : "2"
},
"queryContext" : [ {
"objectType" : "",
diff --git
a/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out
b/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out
index 13a4b43fd058..6a99be057010 100644
---
a/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out
+++
b/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out
@@ -1112,11 +1112,11 @@ struct<>
-- !query output
org.apache.spark.sql.AnalysisException
{
- "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
- "sqlState" : "42K05",
+ "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
+ "sqlState" : "42601",
"messageParameters" : {
- "namespace" : "`a`.`b`.`c`",
- "sessionCatalog" : "spark_catalog"
+ "identifier" : "`a`.`b`.`c`.`d`",
+ "limit" : "2"
},
"queryContext" : [ {
"objectType" : "",
diff --git
a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
index beeb3b13fe1e..0c0473791201 100644
--- a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
@@ -1112,11 +1112,11 @@ struct<>
-- !query output
org.apache.spark.sql.AnalysisException
{
- "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
- "sqlState" : "42K05",
+ "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
+ "sqlState" : "42601",
"messageParameters" : {
- "namespace" : "`a`.`b`.`c`",
- "sessionCatalog" : "spark_catalog"
+ "identifier" : "`a`.`b`.`c`.`d`",
+ "limit" : "2"
},
"queryContext" : [ {
"objectType" : "",
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
index bc6ceeb24593..8959b285b028 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
@@ -168,10 +168,6 @@ private [connector] trait SessionCatalogTest[T <: Table,
Catalog <: TestV2Sessio
spark.sessionState.catalogManager.catalog(name)
}
- protected def sessionCatalog: Catalog = {
- catalog(SESSION_CATALOG_NAME).asInstanceOf[Catalog]
- }
-
protected val v2Format: String =
classOf[FakeV2ProviderWithCustomSchema].getName
protected val catalogClassName: String =
classOf[InMemoryTableSessionCatalog].getName
@@ -182,9 +178,7 @@ private [connector] trait SessionCatalogTest[T <: Table,
Catalog <: TestV2Sessio
override def afterEach(): Unit = {
super.afterEach()
- sessionCatalog.checkUsage()
- sessionCatalog.clearTables()
- sessionCatalog.clearFunctions()
+ catalog(SESSION_CATALOG_NAME).asInstanceOf[Catalog].clearTables()
spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
index 366528e46ff2..c6f2da686fe9 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -702,127 +702,127 @@ class DataSourceV2FunctionSuite extends
DatasourceV2SQLBase {
comparePlans(df1.queryExecution.optimizedPlan,
df2.queryExecution.optimizedPlan)
checkAnswer(df1, Row(3) :: Nil)
}
-}
-case object StrLenDefault extends ScalarFunction[Int] {
- override def inputTypes(): Array[DataType] = Array(StringType)
- override def resultType(): DataType = IntegerType
- override def name(): String = "strlen_default"
+ private case object StrLenDefault extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_default"
- override def produceResult(input: InternalRow): Int = {
- val s = input.getString(0)
- s.length
+ override def produceResult(input: InternalRow): Int = {
+ val s = input.getString(0)
+ s.length
+ }
}
-}
-case object StrLenMagic extends ScalarFunction[Int] {
- override def inputTypes(): Array[DataType] = Array(StringType)
- override def resultType(): DataType = IntegerType
- override def name(): String = "strlen_magic"
+ case object StrLenMagic extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_magic"
- def invoke(input: UTF8String): Int = {
- input.toString.length
+ def invoke(input: UTF8String): Int = {
+ input.toString.length
+ }
}
-}
-case object StrLenBadMagic extends ScalarFunction[Int] {
- override def inputTypes(): Array[DataType] = Array(StringType)
- override def resultType(): DataType = IntegerType
- override def name(): String = "strlen_bad_magic"
+ case object StrLenBadMagic extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_bad_magic"
- def invoke(input: String): Int = {
- input.length
+ def invoke(input: String): Int = {
+ input.length
+ }
}
-}
-case object StrLenBadMagicWithDefault extends ScalarFunction[Int] {
- override def inputTypes(): Array[DataType] = Array(StringType)
- override def resultType(): DataType = IntegerType
- override def name(): String = "strlen_bad_magic"
+ case object StrLenBadMagicWithDefault extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_bad_magic"
+
+ def invoke(input: String): Int = {
+ input.length
+ }
- def invoke(input: String): Int = {
- input.length
+ override def produceResult(input: InternalRow): Int = {
+ val s = input.getString(0)
+ s.length
+ }
}
- override def produceResult(input: InternalRow): Int = {
- val s = input.getString(0)
- s.length
+ private case object StrLenNoImpl extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_noimpl"
}
-}
-case object StrLenNoImpl extends ScalarFunction[Int] {
- override def inputTypes(): Array[DataType] = Array(StringType)
- override def resultType(): DataType = IntegerType
- override def name(): String = "strlen_noimpl"
-}
+ // input type doesn't match arguments accepted by `UnboundFunction.bind`
+ private case object StrLenBadInputTypes extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType, IntegerType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_bad_input_types"
+ }
-// input type doesn't match arguments accepted by `UnboundFunction.bind`
-case object StrLenBadInputTypes extends ScalarFunction[Int] {
- override def inputTypes(): Array[DataType] = Array(StringType, IntegerType)
- override def resultType(): DataType = IntegerType
- override def name(): String = "strlen_bad_input_types"
-}
+ private case object BadBoundFunction extends BoundFunction {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "bad_bound_func"
+ }
-case object BadBoundFunction extends BoundFunction {
- override def inputTypes(): Array[DataType] = Array(StringType)
- override def resultType(): DataType = IntegerType
- override def name(): String = "bad_bound_func"
-}
+ object UnboundDecimalAverage extends UnboundFunction {
+ override def name(): String = "decimal_avg"
-object UnboundDecimalAverage extends UnboundFunction {
- override def name(): String = "decimal_avg"
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.fields.length > 1) {
+ throw new UnsupportedOperationException("Too many arguments")
+ }
- override def bind(inputType: StructType): BoundFunction = {
- if (inputType.fields.length > 1) {
- throw new UnsupportedOperationException("Too many arguments")
+ // put interval type here for testing purpose
+ inputType.fields(0).dataType match {
+ case _: NumericType | _: DayTimeIntervalType => DecimalAverage
+ case dataType =>
+ throw new UnsupportedOperationException(s"Unsupported input type:
$dataType")
+ }
}
- // put interval type here for testing purpose
- inputType.fields(0).dataType match {
- case _: NumericType | _: DayTimeIntervalType => DecimalAverage
- case dataType =>
- throw new UnsupportedOperationException(s"Unsupported input type:
$dataType")
- }
+ override def description(): String =
+ "decimal_avg: produces an average using decimal division"
}
- override def description(): String =
- "decimal_avg: produces an average using decimal division"
-}
-
-object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
- override def name(): String = "decimal_avg"
- override def inputTypes(): Array[DataType] =
Array(DecimalType.SYSTEM_DEFAULT)
- override def resultType(): DataType = DecimalType.SYSTEM_DEFAULT
+ object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
+ override def name(): String = "decimal_avg"
+ override def inputTypes(): Array[DataType] =
Array(DecimalType.SYSTEM_DEFAULT)
+ override def resultType(): DataType = DecimalType.SYSTEM_DEFAULT
- override def newAggregationState(): (Decimal, Int) = (Decimal.ZERO, 0)
+ override def newAggregationState(): (Decimal, Int) = (Decimal.ZERO, 0)
- override def update(state: (Decimal, Int), input: InternalRow): (Decimal,
Int) = {
- if (input.isNullAt(0)) {
- state
- } else {
- val l = input.getDecimal(0, DecimalType.SYSTEM_DEFAULT.precision,
- DecimalType.SYSTEM_DEFAULT.scale)
- state match {
- case (_, d) if d == 0 =>
- (l, 1)
- case (total, count) =>
- (total + l, count + 1)
+ override def update(state: (Decimal, Int), input: InternalRow): (Decimal,
Int) = {
+ if (input.isNullAt(0)) {
+ state
+ } else {
+ val l = input.getDecimal(0, DecimalType.SYSTEM_DEFAULT.precision,
+ DecimalType.SYSTEM_DEFAULT.scale)
+ state match {
+ case (_, d) if d == 0 =>
+ (l, 1)
+ case (total, count) =>
+ (total + l, count + 1)
+ }
}
}
- }
- override def merge(leftState: (Decimal, Int), rightState: (Decimal, Int)):
(Decimal, Int) = {
- (leftState._1 + rightState._1, leftState._2 + rightState._2)
- }
+ override def merge(leftState: (Decimal, Int), rightState: (Decimal, Int)):
(Decimal, Int) = {
+ (leftState._1 + rightState._1, leftState._2 + rightState._2)
+ }
- override def produceResult(state: (Decimal, Int)): Decimal = state._1 /
Decimal(state._2)
-}
+ override def produceResult(state: (Decimal, Int)): Decimal = state._1 /
Decimal(state._2)
+ }
-object NoImplAverage extends UnboundFunction {
- override def name(): String = "no_impl_avg"
- override def description(): String = name()
+ object NoImplAverage extends UnboundFunction {
+ override def name(): String = "no_impl_avg"
+ override def description(): String = name()
- override def bind(inputType: StructType): BoundFunction = {
- throw SparkUnsupportedOperationException()
+ override def bind(inputType: StructType): BoundFunction = {
+ throw SparkUnsupportedOperationException()
+ }
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
index dcc49b252fdb..7463eb34d17f 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.connector
-import org.apache.spark.sql.{DataFrame, Row, SaveMode}
+import org.apache.spark.sql.{DataFrame, SaveMode}
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable,
Table, TableCatalog}
class DataSourceV2SQLSessionCatalogSuite
@@ -79,11 +79,4 @@ class DataSourceV2SQLSessionCatalogSuite
assert(getTableMetadata("default.t").columns().map(_.name()) ===
Seq("c2", "c1"))
}
}
-
- test("SPARK-54760: DelegatingCatalogExtension supports both V1 and V2
functions") {
- sessionCatalog.createFunction(Identifier.of(Array("ns"), "strlen"),
StrLen(StrLenDefault))
- checkAnswer(
- sql("SELECT char_length('Hello') as v1, ns.strlen('Spark') as v2"),
- Row(5, 5))
- }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
index ef4128c29722..6b5bd982ee5a 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
@@ -52,10 +52,6 @@ class SupportsCatalogOptionsSuite extends QueryTest with
SharedSparkSession with
spark.sessionState.catalogManager.catalog(name).asInstanceOf[TableCatalog]
}
- protected def sessionCatalog: InMemoryTableSessionCatalog = {
- catalog(SESSION_CATALOG_NAME).asInstanceOf[InMemoryTableSessionCatalog]
- }
-
private implicit def stringToIdentifier(value: String): Identifier = {
Identifier.of(Array.empty, value)
}
@@ -69,8 +65,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with
SharedSparkSession with
override def afterEach(): Unit = {
super.afterEach()
- Try(sessionCatalog.checkUsage())
- Try(sessionCatalog.clearTables())
+
Try(catalog(SESSION_CATALOG_NAME).asInstanceOf[InMemoryTableSessionCatalog].clearTables())
catalog(catalogName).listTables(Array.empty).foreach(
catalog(catalogName).dropTable(_))
spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
@@ -151,7 +146,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with
SharedSparkSession with
val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name",
"t1")
dfw.save()
- val table = sessionCatalog.loadTable(Identifier.of(Array("default"), "t1"))
+ val table =
catalog(SESSION_CATALOG_NAME).loadTable(Identifier.of(Array("default"), "t1"))
assert(table.partitioning().isEmpty, "Partitioning should be empty")
assert(table.columns() sameElements
Array(Column.create("id", LongType)), "Schema did not match")
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
index 6a82dca9cafc..2254abef3fcb 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
@@ -21,32 +21,21 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean
import scala.jdk.CollectionConverters._
-import scala.util.{Failure, Success, Try}
-import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column,
DelegatingCatalogExtension, Identifier, Table, TableCatalog}
-import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.types.StructType
/**
* A V2SessionCatalog implementation that can be extended to generate
arbitrary `Table` definitions
* for testing DDL as well as write operations (through df.write.saveAsTable,
df.write.insertInto
- * and SQL), also supports v2 function operations.
+ * and SQL).
*/
private[connector] trait TestV2SessionCatalogBase[T <: Table] extends
DelegatingCatalogExtension {
protected val tables: java.util.Map[Identifier, T] = new
ConcurrentHashMap[Identifier, T]()
- protected val functions: java.util.Map[Identifier, UnboundFunction] =
- new ConcurrentHashMap[Identifier, UnboundFunction]()
private val tableCreated: AtomicBoolean = new AtomicBoolean(false)
- private val funcCreated: AtomicBoolean = new AtomicBoolean(false)
-
- def checkUsage(): Unit = {
- assert(tableCreated.get || funcCreated.get,
- "Either tables or functions are not created, maybe didn't use the
session catalog code path?")
- }
private def addTable(ident: Identifier, table: T): Unit = {
tableCreated.set(true)
@@ -107,54 +96,13 @@ private[connector] trait TestV2SessionCatalogBase[T <:
Table] extends Delegating
}
def clearTables(): Unit = {
+ assert(
+ tableCreated.get,
+ "Tables are not created, maybe didn't use the session catalog code
path?")
tables.keySet().asScala.foreach(super.dropTable)
tables.clear()
tableCreated.set(false)
}
-
- override def listFunctions(namespace: Array[String]): Array[Identifier] = {
- (Try(listFunctions0(namespace)), Try(super.listFunctions(namespace)))
match {
- case (Success(v2), Success(v1)) => v2 ++ v1
- case (Success(v2), Failure(_)) => v2
- case (Failure(_), Success(v1)) => v1
- case (Failure(_), Failure(_)) =>
- throw new NoSuchNamespaceException(namespace)
- }
- }
-
- private def listFunctions0(namespace: Array[String]): Array[Identifier] = {
- if (namespace.isEmpty || namespaceExists(namespace)) {
-
functions.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray
- } else {
- throw new NoSuchNamespaceException(namespace)
- }
- }
-
- override def loadFunction(ident: Identifier): UnboundFunction = {
- Option(functions.get(ident)) match {
- case Some(func) => func
- case _ =>
- super.loadFunction(ident)
- }
- }
-
- override def functionExists(ident: Identifier): Boolean = {
- functions.containsKey(ident) || super.functionExists(ident)
- }
-
- def createFunction(ident: Identifier, fn: UnboundFunction): UnboundFunction
= {
- funcCreated.set(true)
- functions.put(ident, fn)
- }
-
- def dropFunction(ident: Identifier): Unit = {
- functions.remove(ident)
- }
-
- def clearFunctions(): Unit = {
- functions.clear()
- funcCreated.set(false)
- }
}
object TestV2SessionCatalogBase {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]