Skip to content

Commit

Permalink
[SPARK-50104][CONNECT] Support SparkSession.executeCommand in Connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR adds support for `SparkSession.executeCommand(..)` to the Scala Spark Connect client.

### Why are the changes needed?
This reduces friction between the classic and connect implementations the Scala SQL interface.

### Does this PR introduce _any_ user-facing change?
Yes. It `SparkSession.executeCommand(..)` works on the Connect scala client now.

### How was this patch tested?
I have added a test case to `SparkSessionE2ESuite`.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #49774 from hvanhovell/SPARK-50104.

Authored-by: Herman van Hovell <herman@databricks.com>
Signed-off-by: Herman van Hovell <herman@databricks.com>
  • Loading branch information
hvanhovell committed Feb 4, 2025
1 parent e0a7db2 commit ebcec7c
Show file tree
Hide file tree
Showing 14 changed files with 300 additions and 152 deletions.
5 changes: 0 additions & 5 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5386,11 +5386,6 @@
"Invoking SparkSession 'baseRelationToDataFrame'. This is server side developer API"
]
},
"SESSION_EXECUTE_COMMAND" : {
"message" : [
"Invoking SparkSession 'executeCommand'."
]
},
"SESSION_EXPERIMENTAL_METHODS" : {
"message" : [
"Access to SparkSession Experimental (methods). This is server side developer API"
Expand Down
206 changes: 106 additions & 100 deletions python/pyspark/sql/connect/proto/commands_pb2.py

Large diffs are not rendered by default.

57 changes: 57 additions & 0 deletions python/pyspark/sql/connect/proto/commands_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class Command(google.protobuf.message.Message):
REMOVE_CACHED_REMOTE_RELATION_COMMAND_FIELD_NUMBER: builtins.int
MERGE_INTO_TABLE_COMMAND_FIELD_NUMBER: builtins.int
ML_COMMAND_FIELD_NUMBER: builtins.int
EXECUTE_EXTERNAL_COMMAND_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
def register_function(
Expand Down Expand Up @@ -150,6 +151,8 @@ class Command(google.protobuf.message.Message):
@property
def ml_command(self) -> pyspark.sql.connect.proto.ml_pb2.MlCommand: ...
@property
def execute_external_command(self) -> global___ExecuteExternalCommand: ...
@property
def extension(self) -> google.protobuf.any_pb2.Any:
"""This field is used to mark extensions to the protocol. When plugins generate arbitrary
Commands they can add them here. During the planning the correct resolution is done.
Expand Down Expand Up @@ -179,6 +182,7 @@ class Command(google.protobuf.message.Message):
| None = ...,
merge_into_table_command: global___MergeIntoTableCommand | None = ...,
ml_command: pyspark.sql.connect.proto.ml_pb2.MlCommand | None = ...,
execute_external_command: global___ExecuteExternalCommand | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(
Expand All @@ -192,6 +196,8 @@ class Command(google.protobuf.message.Message):
b"create_dataframe_view",
"create_resource_profile_command",
b"create_resource_profile_command",
"execute_external_command",
b"execute_external_command",
"extension",
b"extension",
"get_resources_command",
Expand Down Expand Up @@ -235,6 +241,8 @@ class Command(google.protobuf.message.Message):
b"create_dataframe_view",
"create_resource_profile_command",
b"create_resource_profile_command",
"execute_external_command",
b"execute_external_command",
"extension",
b"extension",
"get_resources_command",
Expand Down Expand Up @@ -288,6 +296,7 @@ class Command(google.protobuf.message.Message):
"remove_cached_remote_relation_command",
"merge_into_table_command",
"ml_command",
"execute_external_command",
"extension",
]
| None
Expand Down Expand Up @@ -2339,3 +2348,51 @@ class MergeIntoTableCommand(google.protobuf.message.Message):
) -> None: ...

global___MergeIntoTableCommand = MergeIntoTableCommand

class ExecuteExternalCommand(google.protobuf.message.Message):
"""Execute an arbitrary string command inside an external execution engine"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

class OptionsEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.str
value: builtins.str
def __init__(
self,
*,
key: builtins.str = ...,
value: builtins.str = ...,
) -> None: ...
def ClearField(
self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
) -> None: ...

RUNNER_FIELD_NUMBER: builtins.int
COMMAND_FIELD_NUMBER: builtins.int
OPTIONS_FIELD_NUMBER: builtins.int
runner: builtins.str
"""(Required) The class name of the runner that implements `ExternalCommandRunner`"""
command: builtins.str
"""(Required) The target command to be executed."""
@property
def options(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]:
"""(Optional) The options for the runner."""
def __init__(
self,
*,
runner: builtins.str = ...,
command: builtins.str = ...,
options: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"command", b"command", "options", b"options", "runner", b"runner"
],
) -> None: ...

global___ExecuteExternalCommand = ExecuteExternalCommand
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.sources

import scala.jdk.CollectionConverters._

import org.apache.spark.sql.connector.ExternalCommandRunner
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class FakeCommandRunner extends ExternalCommandRunner {

override def executeCommand(command: String, options: CaseInsensitiveStringMap): Array[String] = {
System.setProperty("command", command)
options.keySet().iterator().asScala.toSeq.sorted.toArray
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.util.{Failure, Success}
import org.scalatest.concurrent.Eventually._

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession}
import org.apache.spark.util.SparkThreadUtils.awaitResult

Expand Down Expand Up @@ -440,4 +441,11 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession {
session.stop()
}

test("executeCommand") {
val df = spark.executeCommand(
"org.apache.spark.sql.sources.FakeCommandRunner",
"command",
Map("one" -> "1", "two" -> "2"))
assert(df.as(StringEncoder).collect().toSet == Set("one", "two"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ class UnsupportedFeaturesSuite extends ConnectFunSuite {
_.baseRelationToDataFrame(new BaseRelation)
}

testUnsupportedFeature("SparkSession.executeCommand", "SESSION_EXECUTE_COMMAND") {
_.executeCommand("ds", "exec", Map.empty)
}

testUnsupportedFeature("Dataset.queryExecution", "DATASET_QUERY_EXECUTION") {
_.range(1).queryExecution
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ object SparkConnectServerUtils {
.filter { e: String =>
val fileName = e.substring(e.lastIndexOf(File.separatorChar) + 1)
fileName.endsWith(".jar") &&
(fileName.startsWith("scalatest") || fileName.startsWith("scalactic"))
(fileName.startsWith("scalatest") || fileName.startsWith("scalactic") ||
(fileName.startsWith("spark-catalyst") && fileName.endsWith("-tests")))
}
.map(e => Paths.get(e).toUri)
spark.client.artifactManager.addArtifacts(jars.toImmutableArraySeq)
Expand Down
14 changes: 14 additions & 0 deletions sql/connect/common/src/main/protobuf/spark/connect/commands.proto
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ message Command {
RemoveCachedRemoteRelationCommand remove_cached_remote_relation_command = 15;
MergeIntoTableCommand merge_into_table_command = 16;
MlCommand ml_command = 17;
ExecuteExternalCommand execute_external_command = 18;

// This field is used to mark extensions to the protocol. When plugins generate arbitrary
// Commands they can add them here. During the planning the correct resolution is done.
google.protobuf.Any extension = 999;
Expand Down Expand Up @@ -535,3 +537,15 @@ message MergeIntoTableCommand {
// (Required) Whether to enable schema evolution.
bool with_schema_evolution = 7;
}

// Execute an arbitrary string command inside an external execution engine
message ExecuteExternalCommand {
// (Required) The class name of the runner that implements `ExternalCommandRunner`
string runner = 1;

// (Required) The target command to be executed.
string command = 2;

// (Optional) The options for the runner.
map<string, string> options = 3;
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ private[sql] object ConnectClientUnsupportedErrors {
def queryExecution(): SparkUnsupportedOperationException =
unsupportedFeatureException("DATASET_QUERY_EXECUTION")

def executeCommand(): SparkUnsupportedOperationException =
unsupportedFeatureException("SESSION_EXECUTE_COMMAND")

def baseRelationToDataFrame(): SparkUnsupportedOperationException =
unsupportedFeatureException("SESSION_BASE_RELATION_TO_DATAFRAME")

Expand All @@ -54,9 +51,6 @@ private[sql] object ConnectClientUnsupportedErrors {
def sparkContext(): SparkUnsupportedOperationException =
unsupportedFeatureException("SESSION_SPARK_CONTEXT")

def sqlContext(): SparkUnsupportedOperationException =
unsupportedFeatureException("SESSION_SQL_CONTEXT")

def registerUdaf(): SparkUnsupportedOperationException =
unsupportedFeatureException("REGISTER_UDAF")
}
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,14 @@ class SparkSession private[sql] (
override def executeCommand(
runner: String,
command: String,
options: Map[String, String]): DataFrame =
throw ConnectClientUnsupportedErrors.executeCommand()
options: Map[String, String]): DataFrame = {
executeCommandWithDataFrameReturn(newCommand { builder =>
builder.getExecuteExternalCommandBuilder
.setRunner(runner)
.setCommand(command)
.putAllOptions(options.asJava)
})
}

/** @inheritdoc */
def sql(sqlText: String, args: Array[_]): DataFrame = {
Expand Down Expand Up @@ -237,18 +243,21 @@ class SparkSession private[sql] (
sql(query, Array.empty)
}

private def sql(sqlCommand: proto.SqlCommand): DataFrame = newDataFrame { builder =>
private def sql(sqlCommand: proto.SqlCommand): DataFrame = {
// Send the SQL once to the server and then check the output.
val cmd = newCommand(b => b.setSqlCommand(sqlCommand))
val plan = proto.Plan.newBuilder().setCommand(cmd)
executeCommandWithDataFrameReturn(newCommand(_.setSqlCommand(sqlCommand)))
}

private def executeCommandWithDataFrameReturn(command: proto.Command): DataFrame = {
val plan = proto.Plan.newBuilder().setCommand(command)
val responseIter = client.execute(plan.build())

try {
val response = responseIter
.find(_.hasSqlCommandResult)
.getOrElse(throw new RuntimeException("SQLCommandResult must be present"))
// Update the builder with the values from the result.
builder.mergeFrom(response.getSqlCommandResult.getRelation)
newDataFrame(_.mergeFrom(response.getSqlCommandResult.getRelation))
} finally {
// consume the rest of the iterator
responseIter.foreach(_ => ())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import org.apache.spark.sql.{Column, Encoders, ForeachWriter, Observation, Row}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LazyExpression, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex, UnresolvedRelation, UnresolvedStar, UnresolvedStarWithColumns, UnresolvedStarWithColumnsRenames, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTranspose}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{StringEncoder, UnboundRowEncoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
Expand All @@ -60,6 +60,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.classic.{Catalog, Dataset, MergeIntoWriter, RelationalGroupedDataset, SparkSession, TypedAggUtils, UserDefinedFunctionUtils}
import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidCommandInput, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
import org.apache.spark.sql.connect.ml.MLHandler
Expand All @@ -70,7 +71,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, TypedAggregateExpression}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.command.CreateViewCommand
import org.apache.spark.sql.execution.command.{CreateViewCommand, ExternalCommandExecutor}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation}
import org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource
Expand All @@ -81,7 +82,7 @@ import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.util.{ArrowUtils, CaseInsensitiveStringMap}
import org.apache.spark.storage.CacheId
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -2506,6 +2507,8 @@ class SparkConnectPlanner(
handleMergeIntoTableCommand(command.getMergeIntoTableCommand)
case proto.Command.CommandTypeCase.ML_COMMAND =>
handleMlCommand(command.getMlCommand, responseObserver)
case proto.Command.CommandTypeCase.EXECUTE_EXTERNAL_COMMAND =>
handleExecuteExternalCommand(command.getExecuteExternalCommand, responseObserver)

case _ => throw new UnsupportedOperationException(s"$command not supported.")
}
Expand All @@ -2525,6 +2528,40 @@ class SparkConnectPlanner(
.build())
}

private def handleExecuteExternalCommand(
command: proto.ExecuteExternalCommand,
responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = {
if (command.getRunner.isEmpty) {
throw InvalidPlanInput("runner cannot be empty in executeExternalCommand")
}
if (command.getCommand.isEmpty) {
throw InvalidPlanInput("command cannot be empty in executeExternalCommand")
}
val executor = ExternalCommandExecutor(
session,
command.getRunner,
command.getCommand,
command.getOptionsMap.asScala.toMap)
val result = executor.execute()
executeHolder.eventsManager.postFinished(Some(result.size))

// Return a SQLCommandResult that contains a LocalRelation.
val arrowData = ArrowSerializer.serialize(
result.iterator,
StringEncoder,
ArrowUtils.rootAllocator,
session.sessionState.conf.sessionLocalTimeZone)
val sqlCommandResult = SqlCommandResult.newBuilder()
sqlCommandResult.getRelationBuilder.getLocalRelationBuilder.setData(arrowData)
responseObserver.onNext(
ExecutePlanResponse
.newBuilder()
.setSessionId(sessionId)
.setServerSideSessionId(sessionHolder.serverSessionId)
.setSqlCommandResult(sqlCommandResult)
.build())
}

private def handleSqlCommand(
command: SqlCommand,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.classic.SparkSession.applyAndLoadExtensions
import org.apache.spark.sql.connector.ExternalCommandRunner
import org.apache.spark.sql.errors.{QueryCompilationErrors, SqlScriptingErrors}
import org.apache.spark.sql.errors.SqlScriptingErrors
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.ExternalCommandExecutor
import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal._
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
Expand Down Expand Up @@ -583,15 +582,7 @@ class SparkSession private(
*/
@Unstable
def executeCommand(runner: String, command: String, options: Map[String, String]): DataFrame = {
DataSource.lookupDataSource(runner, sessionState.conf) match {
case source if classOf[ExternalCommandRunner].isAssignableFrom(source) =>
Dataset.ofRows(self, ExternalCommandExecutor(
source.getDeclaredConstructor().newInstance()
.asInstanceOf[ExternalCommandRunner], command, options))

case _ =>
throw QueryCompilationErrors.commandExecutionInRunnerUnsupportedError(runner)
}
Dataset.ofRows(self, ExternalCommandExecutor(this, runner, command, options))
}

/** @inheritdoc */
Expand Down
Loading

0 comments on commit ebcec7c

Please sign in to comment.