From 6f64cfd1e5305566312bfa2125774aa4e3edfd75 Mon Sep 17 00:00:00 2001 From: Juergen Hoeller Date: Tue, 18 Oct 2022 23:04:23 +0200 Subject: [PATCH] Test square brackets with index/key expressions See gh-27925 --- .../core/namedparam/NamedParameterUtils.java | 27 ++++++++----- .../namedparam/NamedParameterUtilsTests.java | 27 ++++++++++++- .../r2dbc/core/NamedParameterUtils.java | 4 +- .../core/NamedParameterUtilsUnitTests.java | 40 +++++-------------- 4 files changed, 54 insertions(+), 44 deletions(-) diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java index 1325cdedf5e..8fb2bf9d90b 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -78,9 +78,9 @@ public abstract class NamedParameterUtils { * Parse the SQL statement and locate any placeholders or named parameters. * Named parameters are substituted for a JDBC placeholder. * @param sql the SQL statement - * @return the parsed statement, represented as ParsedSql instance + * @return the parsed statement, represented as {@link ParsedSql} instance */ - public static ParsedSql parseSqlStatement(final String sql) { + public static ParsedSql parseSqlStatement(String sql) { Assert.notNull(sql, "SQL must not be null"); Set namedParameters = new HashSet<>(); @@ -122,17 +122,20 @@ public abstract class NamedParameterUtils { while (statement[j] != '}') { j++; if (j >= statement.length) { - throw new InvalidDataAccessApiUsageException("Non-terminated named parameter declaration " + - "at position " + i + " in statement: " + sql); + throw new InvalidDataAccessApiUsageException( + "Non-terminated named parameter declaration at position " + i + + " in statement: " + sql); } if (statement[j] == ':' || statement[j] == '{') { - throw new InvalidDataAccessApiUsageException("Parameter name contains invalid character '" + - statement[j] + "' at position " + i + " in statement: " + sql); + throw new InvalidDataAccessApiUsageException( + "Parameter name contains invalid character '" + statement[j] + + "' at position " + i + " in statement: " + sql); } } if (j - i > 2) { parameter = sql.substring(i + 2, j); - namedParameterCount = addNewNamedParameter(namedParameters, namedParameterCount, parameter); + namedParameterCount = addNewNamedParameter( + namedParameters, namedParameterCount, parameter); totalParameterCount = addNamedParameter( parameterList, totalParameterCount, escapes, i, j + 1, parameter); } @@ -144,7 +147,8 @@ public abstract class NamedParameterUtils { } if (j - i > 1) { parameter = sql.substring(i + 1, j); - namedParameterCount = addNewNamedParameter(namedParameters, namedParameterCount, parameter); + namedParameterCount = addNewNamedParameter( + namedParameters, namedParameterCount, parameter); totalParameterCount = addNamedParameter( parameterList, totalParameterCount, escapes, i, j, parameter); } @@ -185,8 +189,8 @@ public abstract class NamedParameterUtils { return parsedSql; } - private static int addNamedParameter( - List parameterList, int totalParameterCount, int escapes, int i, int j, String parameter) { + private static int addNamedParameter(List parameterList, + int totalParameterCount, int escapes, int i, int j, String parameter) { parameterList.add(new ParameterHolder(parameter, i - escapes, j - escapes)); totalParameterCount++; @@ -271,6 +275,7 @@ public abstract class NamedParameterUtils { if (paramNames.isEmpty()) { return originalSql; } + StringBuilder actualSql = new StringBuilder(originalSql.length()); int lastIndex = 0; for (int i = 0; i < paramNames.size(); i++) { diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterUtilsTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterUtilsTests.java index 7a1df2928a5..ff408b412b6 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterUtilsTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterUtilsTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -319,4 +319,29 @@ public class NamedParameterUtilsTests { assertThat(psql2.getParameterNames().get(0)).isEqualTo("xxx"); } + @Test // gh-27925 + void namedParamMapReference() { + String sql = "insert into foos (id) values (:headers[id])"; + ParsedSql psql = NamedParameterUtils.parseSqlStatement(sql); + assertThat(psql.getNamedParameterCount()).isEqualTo(1); + assertThat(psql.getParameterNames()).containsExactly("headers[id]"); + + class Foo { + final Map headers = new HashMap<>(); + public Foo() { + this.headers.put("id", 1); + } + public Map getHeaders() { + return this.headers; + } + } + + Foo foo = new Foo(); + Object[] params = NamedParameterUtils.buildValueArray(psql, + new BeanPropertySqlParameterSource(foo), null); + + assertThat(params[0]).isInstanceOf(SqlParameterValue.class); + assertThat(((SqlParameterValue) params[0]).getValue()).isEqualTo(foo.getHeaders().get("id")); + } + } diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/NamedParameterUtils.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/NamedParameterUtils.java index 293d8451885..e89255b42c0 100644 --- a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/NamedParameterUtils.java +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/NamedParameterUtils.java @@ -37,8 +37,7 @@ import org.springframework.util.Assert; /** * Helper methods for named parameter parsing. * - *

Only intended for internal use within Spring's R2DBC - * framework. + *

Only intended for internal use within Spring's R2DBC framework. * *

References to the same parameter name are substituted with * the same bind marker placeholder if a {@link BindMarkersFactory} uses @@ -293,7 +292,6 @@ abstract class NamedParameterUtils { if (paramSource.hasValue(paramName)) { Object value = paramSource.getValue(paramName); if (value instanceof Collection) { - Iterator entryIter = ((Collection) value).iterator(); int k = 0; int counter = 0; diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/NamedParameterUtilsUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/NamedParameterUtilsUnitTests.java index f332b222e8b..c26ba5acb6c 100644 --- a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/NamedParameterUtilsUnitTests.java +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/NamedParameterUtilsUnitTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,6 +43,7 @@ public class NamedParameterUtilsUnitTests { private final BindMarkersFactory BIND_MARKERS = BindMarkersFactory.indexed("$", 1); + @Test public void shouldParseSql() { String sql = "xxx :a yyyy :b :c :a zzzzz"; @@ -145,7 +146,6 @@ public class NamedParameterUtilsUnitTests { String sql = "select 'first name' from artists where info->'stat'->'albums' = ?? :album and '[\"1\",\"2\",\"3\"]'::jsonb ?? '4'"; ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql); - assertThat(parsedSql.getTotalParameterCount()).isEqualTo(1); assertThat(expand(parsedSql)).isEqualTo(expectedSql); } @@ -156,7 +156,6 @@ public class NamedParameterUtilsUnitTests { String sql = "select '[\"3\", \"11\"]'::jsonb ?| '{1,3,11,12,17}'::text[]"; ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql); - assertThat(parsedSql.getTotalParameterCount()).isEqualTo(0); assertThat(expand(parsedSql)).isEqualTo(expectedSql); } @@ -177,7 +176,6 @@ public class NamedParameterUtilsUnitTests { String sql = "select '0\\:0' as a, foo from bar where baz < DATE(:p1 23\\:59\\:59) and baz = :p2"; ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql); - assertThat(parsedSql.getParameterNames()).containsExactly("p1", "p2"); assertThat(expand(parsedSql)).isEqualTo(expectedSql); } @@ -198,7 +196,6 @@ public class NamedParameterUtilsUnitTests { String sql = "select foo from bar where baz = b:{}z"; ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql); - assertThat(parsedSql.getParameterNames()).isEmpty(); assertThat(expand(parsedSql)).isEqualTo(expectedSql); @@ -225,13 +222,11 @@ public class NamedParameterUtilsUnitTests { String expectedSql = "xxx & yyyy"; ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(expectedSql); - assertThat(expand(parsedSql)).isEqualTo(expectedSql); } @Test public void substituteNamedParametersWithLogicalAnd() { - String expectedSql = "xxx & yyyy"; assertThat(expand(expectedSql)).isEqualTo(expectedSql); @@ -249,7 +244,6 @@ public class NamedParameterUtilsUnitTests { String sql = "SELECT ':foo'':doo', :xxx FROM DUAL"; ParsedSql psql = NamedParameterUtils.parseSqlStatement(sql); - assertThat(psql.getTotalParameterCount()).isEqualTo(1); assertThat(psql.getParameterNames()).containsExactly("xxx"); } @@ -259,7 +253,6 @@ public class NamedParameterUtilsUnitTests { String sql = "SELECT /*:doo*/':foo', :xxx FROM DUAL"; ParsedSql psql = NamedParameterUtils.parseSqlStatement(sql); - assertThat(psql.getTotalParameterCount()).isEqualTo(1); assertThat(psql.getParameterNames()).containsExactly("xxx"); } @@ -269,18 +262,23 @@ public class NamedParameterUtilsUnitTests { String sql2 = "SELECT ':foo'/*:doo*/, :xxx FROM DUAL"; ParsedSql psql2 = NamedParameterUtils.parseSqlStatement(sql2); - assertThat(psql2.getTotalParameterCount()).isEqualTo(1); assertThat(psql2.getParameterNames()).containsExactly("xxx"); } + @Test // gh-27925 + void namedParamMapReference() { + String sql = "insert into foos (id) values (:headers[id])"; + ParsedSql psql = NamedParameterUtils.parseSqlStatement(sql); + assertThat(psql.getNamedParameterCount()).isEqualTo(1); + assertThat(psql.getParameterNames()).containsExactly("headers[id]"); + } + @Test public void shouldAllowParsingMultipleUseOfParameter() { - String sql = "SELECT * FROM person where name = :id or lastname = :id"; ParsedSql parsed = NamedParameterUtils.parseSqlStatement(sql); - assertThat(parsed.getTotalParameterCount()).isEqualTo(2); assertThat(parsed.getNamedParameterCount()).isEqualTo(1); assertThat(parsed.getParameterNames()).containsExactly("id", "id"); @@ -300,23 +298,19 @@ public class NamedParameterUtilsUnitTests { "SELECT * FROM person where name = $0 or lastname = $0"); operation.bindTo(new BindTarget() { - @Override public void bind(String identifier, Object value) { throw new UnsupportedOperationException(); } - @Override public void bind(int index, Object value) { assertThat(index).isEqualTo(0); assertThat(value).isEqualTo("foo"); } - @Override public void bindNull(String identifier, Class type) { throw new UnsupportedOperationException(); } - @Override public void bindNull(int index, Class type) { throw new UnsupportedOperationException(); @@ -340,25 +334,20 @@ public class NamedParameterUtilsUnitTests { "SELECT * FROM person where name IN ($0, $1, $2) or lastname IN ($0, $1, $2)"); operation.bindTo(new BindTarget() { - @Override public void bind(String identifier, Object value) { throw new UnsupportedOperationException(); } - @Override public void bind(int index, Object value) { assertThat(index).isIn(0, 1, 2); assertThat(value).isIn("foo", "bar", "baz"); - bindings.add(index, value); } - @Override public void bindNull(String identifier, Class type) { throw new UnsupportedOperationException(); } - @Override public void bindNull(int index, Class type) { throw new UnsupportedOperationException(); @@ -386,22 +375,18 @@ public class NamedParameterUtilsUnitTests { Map bindValues = new LinkedHashMap<>(); operation.bindTo(new BindTarget() { - @Override public void bind(String identifier, Object value) { throw new UnsupportedOperationException(); } - @Override public void bind(int index, Object value) { bindValues.put(index, value); } - @Override public void bindNull(String identifier, Class type) { throw new UnsupportedOperationException(); } - @Override public void bindNull(int index, Class type) { throw new UnsupportedOperationException(); @@ -425,22 +410,18 @@ public class NamedParameterUtilsUnitTests { "SELECT * FROM person where name = $0 or lastname = $0"); operation.bindTo(new BindTarget() { - @Override public void bind(String identifier, Object value) { throw new UnsupportedOperationException(); } - @Override public void bind(int index, Object value) { throw new UnsupportedOperationException(); } - @Override public void bindNull(String identifier, Class type) { throw new UnsupportedOperationException(); } - @Override public void bindNull(int index, Class type) { assertThat(index).isEqualTo(0); @@ -449,6 +430,7 @@ public class NamedParameterUtilsUnitTests { }); } + private String expand(ParsedSql sql) { return NamedParameterUtils.substituteNamedParameters(sql, BIND_MARKERS, new MapBindParameterSource()).toQuery();