diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80e2c1986d75801e823ae5b43da36de10f6ac664..27705520505a28439e8164dc7f29bf589dc0286d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -457,6 +457,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val VARIABLE_SUBSTITUTE_ENABLED = + SQLConfigBuilder("spark.sql.variable.substitute") + .doc("This enables substitution using syntax like ${var} ${system:var} and ${env:var}.") + .booleanConf + .createWithDefault(true) + + val VARIABLE_SUBSTITUTE_DEPTH = + SQLConfigBuilder("spark.sql.variable.substitute.depth") + .doc("The maximum replacements the substitution engine will do.") + .intConf + .createWithDefault(40) + // TODO: This is still WIP and shouldn't be turned on without extensive test coverage val COLUMNAR_AGGREGATE_MAP_ENABLED = SQLConfigBuilder("spark.sql.codegen.aggregate.map.enabled") .internal() @@ -615,6 +627,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def columnarAggregateMapEnabled: Boolean = getConf(COLUMNAR_AGGREGATE_MAP_ENABLED) + def variableSubstituteEnabled: Boolean = getConf(VARIABLE_SUBSTITUTE_ENABLED) + + def variableSubstituteDepth: Int = getConf(VARIABLE_SUBSTITUTE_DEPTH) + override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala new file mode 100644 index 0000000000000000000000000000000000000000..0982f1d687161359f06ea11438b90dd9076a9eaf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util.regex.Pattern + +import org.apache.spark.sql.AnalysisException + +/** + * A helper class that enables substitution using syntax like + * `${var}`, `${system:var}` and `${env:var}`. + * + * Variable substitution is controlled by [[SQLConf.variableSubstituteEnabled]]. + */ +class VariableSubstitution(conf: SQLConf) { + + private val pattern = Pattern.compile("\\$\\{[^\\}\\$ ]+\\}") + + /** + * Given a query, does variable substitution and return the result. + */ + def substitute(input: String): String = { + // Note that this function is mostly copied from Hive's SystemVariables, so the style is + // very Java/Hive like. + if (input eq null) { + return null + } + + if (!conf.variableSubstituteEnabled) { + return input + } + + var eval = input + val depth = conf.variableSubstituteDepth + val builder = new StringBuilder + val m = pattern.matcher("") + + var s = 0 + while (s <= depth) { + m.reset(eval) + builder.setLength(0) + + var prev = 0 + var found = false + while (m.find(prev)) { + val group = m.group() + var substitute = substituteVariable(group.substring(2, group.length - 1)) + if (substitute.isEmpty) { + substitute = group + } else { + found = true + } + builder.append(eval.substring(prev, m.start())).append(substitute) + prev = m.end() + } + + if (!found) { + return eval + } + + builder.append(eval.substring(prev)) + eval = builder.toString + s += 1 + } + + if (s > depth) { + throw new AnalysisException( + "Variable substitution depth is deeper than " + depth + " for input " + input) + } else { + return eval + } + } + + /** + * Given a variable, replaces with the substitute value (default to ""). + */ + private def substituteVariable(variable: String): String = { + var value: String = null + + if (variable.startsWith("system:")) { + value = System.getProperty(variable.substring("system:".length())) + } + + if (value == null && variable.startsWith("env:")) { + value = System.getenv(variable.substring("env:".length())) + } + + if (value == null && conf != null && variable.startsWith("hiveconf:")) { + value = conf.getConfString(variable.substring("hiveconf:".length()), "") + } + + if (value == null && conf != null && variable.startsWith("sparkconf:")) { + value = conf.getConfString(variable.substring("sparkconf:".length()), "") + } + + if (value == null && conf != null && variable.startsWith("spark:")) { + value = conf.getConfString(variable.substring("spark:".length()), "") + } + + if (value == null && conf != null) { + value = conf.getConfString(variable, "") + } + + value + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..deac95918bba548eb4467a291f955a29a39fe88c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException + +class VariableSubstitutionSuite extends SparkFunSuite { + + private lazy val conf = new SQLConf + private lazy val sub = new VariableSubstitution(conf) + + test("system property") { + System.setProperty("varSubSuite.var", "abcd") + assert(sub.substitute("${system:varSubSuite.var}") == "abcd") + } + + test("environmental variables") { + assert(sub.substitute("${env:SPARK_TESTING}") == "1") + } + + test("Spark configuration variable") { + conf.setConfString("some-random-string-abcd", "1234abcd") + assert(sub.substitute("${hiveconf:some-random-string-abcd}") == "1234abcd") + assert(sub.substitute("${sparkconf:some-random-string-abcd}") == "1234abcd") + assert(sub.substitute("${spark:some-random-string-abcd}") == "1234abcd") + assert(sub.substitute("${some-random-string-abcd}") == "1234abcd") + } + + test("multiple substitutes") { + val q = "select ${bar} ${foo} ${doo} this is great" + conf.setConfString("bar", "1") + conf.setConfString("foo", "2") + conf.setConfString("doo", "3") + assert(sub.substitute(q) == "select 1 2 3 this is great") + } + + test("test nested substitutes") { + val q = "select ${bar} ${foo} this is great" + conf.setConfString("bar", "1") + conf.setConfString("foo", "${bar}") + assert(sub.substitute(q) == "select 1 1 this is great") + } + + test("depth limit") { + val q = "select ${bar} ${foo} ${doo}" + conf.setConfString(SQLConf.VARIABLE_SUBSTITUTE_DEPTH.key, "2") + + // This should be OK since it is not nested. + conf.setConfString("bar", "1") + conf.setConfString("foo", "2") + conf.setConfString("doo", "3") + assert(sub.substitute(q) == "select 1 2 3") + + // This should not be OK since it is nested in 3 levels. + conf.setConfString("bar", "1") + conf.setConfString("foo", "${bar}") + conf.setConfString("doo", "${foo}") + intercept[AnalysisException] { + sub.substitute(q) + } + } +}