diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/NamedParameterUtils.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/NamedParameterUtils.java index 9a1b9ea33..8524a3e34 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/NamedParameterUtils.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/NamedParameterUtils.java @@ -26,6 +26,7 @@ import java.util.Set; import java.util.TreeMap; import org.springframework.dao.InvalidDataAccessApiUsageException; +import org.springframework.lang.Nullable; import org.springframework.r2dbc.core.PreparedOperation; import org.springframework.r2dbc.core.binding.BindMarker; import org.springframework.r2dbc.core.binding.BindMarkers; @@ -436,6 +437,7 @@ abstract class NamedParameterUtils { return param; } + @Nullable List getMarker(String name) { return this.references.get(name); } @@ -499,7 +501,7 @@ abstract class NamedParameterUtils { @SuppressWarnings("unchecked") public void bind(org.springframework.r2dbc.core.binding.BindTarget target, String identifier, Object value) { - List bindMarkers = getBindMarkers(identifier); + List> bindMarkers = getBindMarkers(identifier); if (bindMarkers == null) { @@ -507,28 +509,30 @@ abstract class NamedParameterUtils { return; } - if (value instanceof Collection) { - Collection collection = (Collection) value; + for (List outer : bindMarkers) { + if (value instanceof Collection) { + Collection collection = (Collection) value; - Iterator iterator = collection.iterator(); - Iterator markers = bindMarkers.iterator(); + Iterator iterator = collection.iterator(); + Iterator markers = outer.iterator(); - while (iterator.hasNext()) { + while (iterator.hasNext()) { - Object valueToBind = iterator.next(); + Object valueToBind = iterator.next(); - if (valueToBind instanceof Object[]) { - Object[] objects = (Object[]) valueToBind; - for (Object object : objects) { - bind(target, markers, object); + if (valueToBind instanceof Object[]) { + Object[] objects = (Object[]) valueToBind; + for (Object object : objects) { + bind(target, markers, object); + } + } else { + bind(target, markers, valueToBind); } - } else { - bind(target, markers, valueToBind); } - } - } else { - for (BindMarker bindMarker : bindMarkers) { - bindMarker.bind(target, value); + } else { + for (BindMarker bindMarker : outer) { + bindMarker.bind(target, value); + } } } } @@ -547,7 +551,7 @@ abstract class NamedParameterUtils { public void bindNull(org.springframework.r2dbc.core.binding.BindTarget target, String identifier, Class valueType) { - List bindMarkers = getBindMarkers(identifier); + List> bindMarkers = getBindMarkers(identifier); if (bindMarkers == null) { @@ -555,12 +559,15 @@ abstract class NamedParameterUtils { return; } - for (BindMarker bindMarker : bindMarkers) { - bindMarker.bindNull(target, valueType); + for (List outer : bindMarkers) { + for (BindMarker bindMarker : outer) { + bindMarker.bindNull(target, valueType); + } } } - List getBindMarkers(String identifier) { + @Nullable + List> getBindMarkers(String identifier) { List parameters = this.parameters.getMarker(identifier); @@ -568,10 +575,9 @@ abstract class NamedParameterUtils { return null; } - List markers = new ArrayList<>(); - + List> markers = new ArrayList<>(); for (NamedParameters.NamedParameter parameter : parameters) { - markers.addAll(parameter.placeholders); + markers.add(new ArrayList<>(parameter.placeholders)); } return markers; @@ -582,7 +588,6 @@ abstract class NamedParameterUtils { return this.expandedSql; } - @Override public void bindTo(BindTarget target) { diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/NamedParameterUtilsTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/NamedParameterUtilsTests.java new file mode 100644 index 000000000..1a165bd6a --- /dev/null +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/NamedParameterUtilsTests.java @@ -0,0 +1,121 @@ +/* + * Copyright 2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import static org.assertj.core.api.Assertions.*; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import org.springframework.r2dbc.core.Parameter; +import org.springframework.r2dbc.core.PreparedOperation; +import org.springframework.r2dbc.core.binding.BindMarkersFactory; +import org.springframework.r2dbc.core.binding.BindTarget; + +/** + * Unit tests for {@link NamedParameterUtils}. + * + * @author Mark Paluch + */ +class NamedParameterUtilsTests { + + @Test // GH-1306 + void inCollectionSameParameterNameShouldBindAllAnonymousParameters() { + + ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement("select :names AND :names"); + PreparedOperation operation = NamedParameterUtils.substituteNamedParameters(parsedSql, + BindMarkersFactory.anonymous("?"), + new MapBindParameterSource(Collections.singletonMap("names", Parameter.from(Arrays.asList("1", "2", "3"))))); + + List bindings = new ArrayList<>(); + + operation.bindTo(new BindingCaptor(bindings)); + + assertThat(operation.get()).isEqualTo("select ?, ?, ? AND ?, ?, ?"); + assertThat(bindings).contains("0: 1", "1: 2", "2: 3", "3: 1", "4: 2", "5: 3"); + } + + @Test // GH-1306 + void complexInCollectionSameParameterNameShouldBindAllAnonymousParameters() { + + Map parameterMap = new HashMap<>(); + parameterMap.put("names", Parameter.from(Arrays.asList("1", "2", "3"))); + parameterMap.put("hello", Parameter.from("world")); + + ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement("select :names AND :hello OR :names"); + PreparedOperation operation = NamedParameterUtils.substituteNamedParameters(parsedSql, + BindMarkersFactory.anonymous("?"), new MapBindParameterSource(parameterMap)); + + List bindings = new ArrayList<>(); + + operation.bindTo(new BindingCaptor(bindings)); + + assertThat(operation.get()).isEqualTo("select ?, ?, ? AND ? OR ?, ?, ?"); + assertThat(bindings).contains("0: 1", "1: 2", "2: 3", "3: world", "4: 1", "5: 2", "6: 3"); + } + + @Test // GH-1306 + void inCollectionSameParameterNameShouldBindAllNamedParameters() { + + ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement("select :names AND :names"); + PreparedOperation operation = NamedParameterUtils.substituteNamedParameters(parsedSql, + BindMarkersFactory.indexed("$", 1), + new MapBindParameterSource(Collections.singletonMap("names", Parameter.from(Arrays.asList("1", "2", "3"))))); + + List bindings = new ArrayList<>(); + + operation.bindTo(new BindingCaptor(bindings)); + + assertThat(operation.get()).isEqualTo("select $1, $2, $3 AND $1, $2, $3"); + assertThat(bindings).containsOnly("0: 1", "1: 2", "2: 3"); + } + + static class BindingCaptor implements BindTarget { + + private final List bindings; + + BindingCaptor(List bindings) { + this.bindings = bindings; + } + + @Override + public void bind(String identifier, Object value) { + bindings.add(identifier + ": " + value); + } + + @Override + public void bind(int index, Object value) { + bindings.add(index + ": " + value); + } + + @Override + public void bindNull(String identifier, Class type) { + + } + + @Override + public void bindNull(int index, Class type) { + + } + + } +}