diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/datasource/init/ScriptUtils.java b/spring-jdbc/src/main/java/org/springframework/jdbc/datasource/init/ScriptUtils.java index fa87e08f9d5..271eff02109 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/datasource/init/ScriptUtils.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/datasource/init/ScriptUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -414,12 +414,18 @@ public abstract class ScriptUtils { } /** - * Does the provided SQL script contain the specified delimiter? - * @param script the SQL script - * @param delim the string delimiting each statement - typically a ';' character + * Determine if the provided SQL script contains the specified delimiter. + *
This method is intended to be used to find the string delimiting each + * SQL statement — for example, a ';' character. + *
Any occurrence of the delimiter within the script will be ignored if it + * is enclosed within single quotes ({@code '}) or double quotes ({@code "}) + * or if it is escaped with a backslash ({@code \}). + * @param script the SQL script to search within + * @param delimiter the delimiter to search for */ - public static boolean containsSqlScriptDelimiters(String script, String delim) { - boolean inLiteral = false; + public static boolean containsSqlScriptDelimiters(String script, String delimiter) { + boolean inSingleQuote = false; + boolean inDoubleQuote = false; boolean inEscape = false; for (int i = 0; i < script.length(); i++) { @@ -433,11 +439,16 @@ public abstract class ScriptUtils { inEscape = true; continue; } - if (c == '\'') { - inLiteral = !inLiteral; + if (!inDoubleQuote && (c == '\'')) { + inSingleQuote = !inSingleQuote; } - if (!inLiteral && script.startsWith(delim, i)) { - return true; + else if (!inSingleQuote && (c == '"')) { + inDoubleQuote = !inDoubleQuote; + } + if (!inSingleQuote && !inDoubleQuote) { + if (script.startsWith(delimiter, i)) { + return true; + } } } diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/datasource/init/ScriptUtilsUnitTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/datasource/init/ScriptUtilsUnitTests.java index f57adece236..5260c46f3e5 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/datasource/init/ScriptUtilsUnitTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/datasource/init/ScriptUtilsUnitTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,8 @@ import java.util.ArrayList; import java.util.List; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.support.EncodedResource; @@ -165,17 +167,25 @@ public class ScriptUtilsUnitTests { assertThat(statements).containsExactly(statement1, statement2); } - @Test - public void containsDelimiters() { - assertThat(containsSqlScriptDelimiters("select 1\n select ';'", ";")).isFalse(); - assertThat(containsSqlScriptDelimiters("select 1; select 2", ";")).isTrue(); - assertThat(containsSqlScriptDelimiters("select 1; select '\\n\n';", "\n")).isFalse(); - assertThat(containsSqlScriptDelimiters("select 1\n select 2", "\n")).isTrue(); - assertThat(containsSqlScriptDelimiters("select 1\n select 2", "\n\n")).isFalse(); - assertThat(containsSqlScriptDelimiters("select 1\n\n select 2", "\n\n")).isTrue(); - // MySQL style escapes '\\' - assertThat(containsSqlScriptDelimiters("insert into users(first_name, last_name)\nvalues('a\\\\', 'b;')", ";")).isFalse(); - assertThat(containsSqlScriptDelimiters("insert into users(first_name, last_name)\nvalues('Charles', 'd\\'Artagnan'); select 1;", ";")).isTrue(); + @ParameterizedTest + @CsvSource(delimiter = '#', value = { + // semicolon + "'select 1\n select '';''' # ; # false", + "'select 1\n select \";\"' # ; # false", + "'select 1; select 2' # ; # true", + // newline + "'select 1; select ''\n''' # '\n' # false", + "'select 1; select \"\n\"' # '\n' # false", + "'select 1\n select 2' # '\n' # true", + // double newline + "'select 1\n select 2' # '\n\n' # false", + "'select 1\n\n select 2' # '\n\n' # true", + // semicolon with MySQL style escapes '\\' + "'insert into users(first, last)\nvalues(''a\\\\'', ''b;'')' # ; # false", + "'insert into users(first, last)\nvalues(''Charles'', ''d\\''Artagnan''); select 1' # ; # true" + }) + public void containsDelimiter(String script, String delimiter, boolean expected) { + assertThat(containsSqlScriptDelimiters(script, delimiter)).isEqualTo(expected); } private String readScript(String path) throws Exception { diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptUtils.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptUtils.java index d914ffb4e44..6cd2b6a4900 100644 --- a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptUtils.java +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -432,12 +432,18 @@ public abstract class ScriptUtils { } /** - * Does the provided SQL script contain the specified delimiter? - * @param script the SQL script - * @param delim the string delimiting each statement - typically a ';' character + * Determine if the provided SQL script contains the specified delimiter. + *
This method is intended to be used to find the string delimiting each + * SQL statement — for example, a ';' character. + *
Any occurrence of the delimiter within the script will be ignored if it + * is enclosed within single quotes ({@code '}) or double quotes ({@code "}) + * or if it is escaped with a backslash ({@code \}). + * @param script the SQL script to search within + * @param delimiter the delimiter to search for */ - public static boolean containsSqlScriptDelimiters(String script, String delim) { - boolean inLiteral = false; + public static boolean containsSqlScriptDelimiters(String script, String delimiter) { + boolean inSingleQuote = false; + boolean inDoubleQuote = false; boolean inEscape = false; for (int i = 0; i < script.length(); i++) { @@ -451,11 +457,16 @@ public abstract class ScriptUtils { inEscape = true; continue; } - if (c == '\'') { - inLiteral = !inLiteral; + if (!inDoubleQuote && (c == '\'')) { + inSingleQuote = !inSingleQuote; } - if (!inLiteral && script.startsWith(delim, i)) { - return true; + else if (!inSingleQuote && (c == '"')) { + inDoubleQuote = !inDoubleQuote; + } + if (!inSingleQuote && !inDoubleQuote) { + if (script.startsWith(delimiter, i)) { + return true; + } } } diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/ScriptUtilsUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/ScriptUtilsUnitTests.java index 79c14178d44..a1712f6ff92 100644 --- a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/ScriptUtilsUnitTests.java +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/ScriptUtilsUnitTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,12 +21,15 @@ import java.util.List; import org.assertj.core.util.Strings; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.core.io.support.EncodedResource; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.r2dbc.connection.init.ScriptUtils.containsSqlScriptDelimiters; /** * Unit tests for {@link ScriptUtils}. @@ -184,30 +187,25 @@ public class ScriptUtilsUnitTests { assertThat(statements).hasSize(2).containsSequence(statement1, statement2); } - @Test - public void containsDelimiters() { - assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1\n select ';'", - ";")).isFalse(); - assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1; select 2", - ";")).isTrue(); - - assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1; select '\\n\n';", - "\n")).isFalse(); - assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1\n select 2", - "\n")).isTrue(); - - assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1\n select 2", - "\n\n")).isFalse(); - assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1\n\n select 2", - "\n\n")).isTrue(); - - // MySQL style escapes '\\' - assertThat(ScriptUtils.containsSqlScriptDelimiters( - "insert into users(first_name, last_name)\nvalues('a\\\\', 'b;')", - ";")).isFalse(); - assertThat(ScriptUtils.containsSqlScriptDelimiters( - "insert into users(first_name, last_name)\nvalues('Charles', 'd\\'Artagnan'); select 1;", - ";")).isTrue(); + @ParameterizedTest + @CsvSource(delimiter = '#', value = { + // semicolon + "'select 1\n select '';''' # ; # false", + "'select 1\n select \";\"' # ; # false", + "'select 1; select 2' # ; # true", + // newline + "'select 1; select ''\n''' # '\n' # false", + "'select 1; select \"\n\"' # '\n' # false", + "'select 1\n select 2' # '\n' # true", + // double newline + "'select 1\n select 2' # '\n\n' # false", + "'select 1\n\n select 2' # '\n\n' # true", + // semicolon with MySQL style escapes '\\' + "'insert into users(first, last)\nvalues(''a\\\\'', ''b;'')' # ; # false", + "'insert into users(first, last)\nvalues(''Charles'', ''d\\''Artagnan''); select 1' # ; # true" + }) + public void containsDelimiter(String script, String delimiter, boolean expected) { + assertThat(containsSqlScriptDelimiters(script, delimiter)).isEqualTo(expected); } private String readScript(String path) {