@ -16,7 +16,6 @@
package org.springframework.data.mongodb.core ;
package org.springframework.data.mongodb.core ;
import java.util.ArrayList ;
import java.util.ArrayList ;
import java.util.Collections ;
import java.util.List ;
import java.util.List ;
import java.util.Optional ;
import java.util.Optional ;
import java.util.stream.Collectors ;
import java.util.stream.Collectors ;
@ -25,8 +24,12 @@ import org.bson.Document;
import org.bson.conversions.Bson ;
import org.bson.conversions.Bson ;
import org.springframework.context.ApplicationEventPublisher ;
import org.springframework.context.ApplicationEventPublisher ;
import org.springframework.dao.DataIntegrityViolationException ;
import org.springframework.dao.DataIntegrityViolationException ;
import org.springframework.data.mapping.PersistentEntity ;
import org.springframework.data.mapping.callback.EntityCallbacks ;
import org.springframework.data.mapping.callback.EntityCallbacks ;
import org.springframework.data.mongodb.BulkOperationException ;
import org.springframework.data.mongodb.BulkOperationException ;
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext ;
import org.springframework.data.mongodb.core.aggregation.AggregationUpdate ;
import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext ;
import org.springframework.data.mongodb.core.convert.QueryMapper ;
import org.springframework.data.mongodb.core.convert.QueryMapper ;
import org.springframework.data.mongodb.core.convert.UpdateMapper ;
import org.springframework.data.mongodb.core.convert.UpdateMapper ;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity ;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity ;
@ -133,12 +136,12 @@ class DefaultBulkOperations implements BulkOperations {
@Override
@Override
@SuppressWarnings ( "unchecked" )
@SuppressWarnings ( "unchecked" )
public BulkOperations updateOne ( Query query , Update update ) {
public BulkOperations updateOne ( Query query , UpdateDefinition update ) {
Assert . notNull ( query , "Query must not be null" ) ;
Assert . notNull ( query , "Query must not be null" ) ;
Assert . notNull ( update , "Update must not be null" ) ;
Assert . notNull ( update , "Update must not be null" ) ;
return updateOne ( Collections . singletonList ( Pair . of ( query , update ) ) ) ;
return update ( query , update , false , false ) ;
}
}
@Override
@Override
@ -155,12 +158,14 @@ class DefaultBulkOperations implements BulkOperations {
@Override
@Override
@SuppressWarnings ( "unchecked" )
@SuppressWarnings ( "unchecked" )
public BulkOperations updateMulti ( Query query , Update update ) {
public BulkOperations updateMulti ( Query query , UpdateDefinition update ) {
Assert . notNull ( query , "Query must not be null" ) ;
Assert . notNull ( query , "Query must not be null" ) ;
Assert . notNull ( update , "Update must not be null" ) ;
Assert . notNull ( update , "Update must not be null" ) ;
return updateMulti ( Collections . singletonList ( Pair . of ( query , update ) ) ) ;
update ( query , update , false , true ) ;
return this ;
}
}
@Override
@Override
@ -176,7 +181,7 @@ class DefaultBulkOperations implements BulkOperations {
}
}
@Override
@Override
public BulkOperations upsert ( Query query , Update update ) {
public BulkOperations upsert ( Query query , UpdateDefinition update ) {
return update ( query , update , true , true ) ;
return update ( query , update , true , true ) ;
}
}
@ -294,7 +299,7 @@ class DefaultBulkOperations implements BulkOperations {
maybeInvokeBeforeSaveCallback ( it . getSource ( ) , target ) ;
maybeInvokeBeforeSaveCallback ( it . getSource ( ) , target ) ;
}
}
return mapWriteModel ( it . getModel ( ) ) ;
return mapWriteModel ( it . getSource ( ) , it . get Model ( ) ) ;
}
}
/ * *
/ * *
@ -306,7 +311,7 @@ class DefaultBulkOperations implements BulkOperations {
* @param multi whether to issue a multi - update .
* @param multi whether to issue a multi - update .
* @return the { @link BulkOperations } with the update registered .
* @return the { @link BulkOperations } with the update registered .
* /
* /
private BulkOperations update ( Query query , Update update , boolean upsert , boolean multi ) {
private BulkOperations update ( Query query , UpdateDefinition update , boolean upsert , boolean multi ) {
Assert . notNull ( query , "Query must not be null" ) ;
Assert . notNull ( query , "Query must not be null" ) ;
Assert . notNull ( update , "Update must not be null" ) ;
Assert . notNull ( update , "Update must not be null" ) ;
@ -322,11 +327,16 @@ class DefaultBulkOperations implements BulkOperations {
return this ;
return this ;
}
}
private WriteModel < Document > mapWriteModel ( WriteModel < Document > writeModel ) {
private WriteModel < Document > mapWriteModel ( Object source , WriteModel < Document > writeModel ) {
if ( writeModel instanceof UpdateOneModel ) {
if ( writeModel instanceof UpdateOneModel ) {
UpdateOneModel < Document > model = ( UpdateOneModel < Document > ) writeModel ;
UpdateOneModel < Document > model = ( UpdateOneModel < Document > ) writeModel ;
if ( source instanceof AggregationUpdate aggregationUpdate ) {
List < Document > pipeline = mapUpdatePipeline ( aggregationUpdate ) ;
return new UpdateOneModel < > ( getMappedQuery ( model . getFilter ( ) ) , pipeline , model . getOptions ( ) ) ;
}
return new UpdateOneModel < > ( getMappedQuery ( model . getFilter ( ) ) , getMappedUpdate ( model . getUpdate ( ) ) ,
return new UpdateOneModel < > ( getMappedQuery ( model . getFilter ( ) ) , getMappedUpdate ( model . getUpdate ( ) ) ,
model . getOptions ( ) ) ;
model . getOptions ( ) ) ;
@ -335,6 +345,11 @@ class DefaultBulkOperations implements BulkOperations {
if ( writeModel instanceof UpdateManyModel ) {
if ( writeModel instanceof UpdateManyModel ) {
UpdateManyModel < Document > model = ( UpdateManyModel < Document > ) writeModel ;
UpdateManyModel < Document > model = ( UpdateManyModel < Document > ) writeModel ;
if ( source instanceof AggregationUpdate aggregationUpdate ) {
List < Document > pipeline = mapUpdatePipeline ( aggregationUpdate ) ;
return new UpdateManyModel < > ( getMappedQuery ( model . getFilter ( ) ) , pipeline , model . getOptions ( ) ) ;
}
return new UpdateManyModel < > ( getMappedQuery ( model . getFilter ( ) ) , getMappedUpdate ( model . getUpdate ( ) ) ,
return new UpdateManyModel < > ( getMappedQuery ( model . getFilter ( ) ) , getMappedUpdate ( model . getUpdate ( ) ) ,
model . getOptions ( ) ) ;
model . getOptions ( ) ) ;
@ -357,6 +372,19 @@ class DefaultBulkOperations implements BulkOperations {
return writeModel ;
return writeModel ;
}
}
private List < Document > mapUpdatePipeline ( AggregationUpdate source ) {
Class < ? > type = bulkOperationContext . getEntity ( ) . isPresent ( )
? bulkOperationContext . getEntity ( ) . map ( PersistentEntity : : getType ) . get ( )
: Object . class ;
AggregationOperationContext context = new RelaxedTypeBasedAggregationOperationContext ( type ,
bulkOperationContext . getUpdateMapper ( ) . getMappingContext ( ) , bulkOperationContext . getQueryMapper ( ) ) ;
List < Document > pipeline = new AggregationUtil ( bulkOperationContext . getQueryMapper ( ) ,
bulkOperationContext . getQueryMapper ( ) . getMappingContext ( ) ) . createPipeline ( source ,
context ) ;
return pipeline ;
}
private Bson getMappedUpdate ( Bson update ) {
private Bson getMappedUpdate ( Bson update ) {
return bulkOperationContext . getUpdateMapper ( ) . getMappedObject ( update , bulkOperationContext . getEntity ( ) ) ;
return bulkOperationContext . getUpdateMapper ( ) . getMappedObject ( update , bulkOperationContext . getEntity ( ) ) ;
}
}