有两个问题:第一,这些收集器不具有相同的返回类型.即使您使用DoubleStream或LongStream,您也会更接近,但您仍然不会得到完全相同的返回类型.因此,我们需要做一些额外的工作来使它们返回相同的东西.例如,这可以是一个选项:
private double collectValues(AggregatorType agg, List<Long> values) {
DoubleStream stream = values.stream().mapToDouble(x -> x + 0.0d);
return switch (agg) {
case AVG -> stream.average().orElseThrow();
case SUM -> stream.sum();
case MAX -> stream.max().orElseThrow();
case MIN -> stream.min().orElseThrow();
default -> throw new IllegalArgumentException();
};
}
第二个问题是,我们不能轻易地为同一个流的元素使用不同的收集器.因此,您需要做一些不同的操作--例如,您可以分两步完成.最初,将所有B.值收集到Aggegator.Type:
public Map<AggregatorType, Double> aggregate(Map<A, B> fields) {
Map<AggregatorType, List<Long>> valuesByType = fields.entrySet()
.stream()
.collect(Collectors.groupingBy(
entry -> entry.getKey().type(),
Collectors.mapping(
entry -> entry.getValue().value(),
Collectors.toList())
));
// return valuesByType.stream()...
}
然后,使用第一个代码片段中的函数收集每个List:
return valuesByType.entrySet()
.stream()
.collect(Collectors.toMap(
entry -> entry.getKey(),
entry -> collectValues(entry.getKey(), entry.getValue())
));
以下是完整的故事:
@Test
void test() {
//given
Map<A, B> values = Map.of(
new A(1L, AggregatorType.AVG), new B(10L),
new A(2L, AggregatorType.AVG), new B(20L),
new A(3L, AggregatorType.SUM), new B(30L),
new A(4L, AggregatorType.SUM), new B(40L),
new A(5L, AggregatorType.MAX), new B(50L),
new A(6L, AggregatorType.MAX), new B(60L),
new A(7L, AggregatorType.MIN), new B(70L),
new A(8L, AggregatorType.MIN), new B(80L)
);
// when
Map<AggregatorType, Double> result = aggregate(values);
//then
assertThat(result).isEqualTo(Map.of(
AggregatorType.AVG, 15d,
AggregatorType.SUM, 70d,
AggregatorType.MAX, 60d,
AggregatorType.MIN, 70d
));
}
public Map<AggregatorType, Double> aggregate(Map<A, B> fields) {
Map<AggregatorType, List<Long>> valuesByType = fields.entrySet()
.stream()
.collect(Collectors.groupingBy(
entry -> entry.getKey().type(),
Collectors.mapping(
entry -> entry.getValue().value(),
Collectors.toList())
));
return valuesByType.entrySet()
.stream()
.collect(Collectors.toMap(
entry -> entry.getKey(),
entry -> collectValues(entry.getKey(), entry.getValue())
));
}
private double collectValues(AggregatorType aggregator, List<Long> values) {
DoubleStream stream = values.stream().mapToDouble(x -> x + 0.0d);
return switch (aggregator) {
case AVG -> stream.average().orElseThrow();
case SUM -> stream.sum();
case MAX -> stream.max().orElseThrow();
case MIN -> stream.min().orElseThrow();
default -> throw new IllegalArgumentException();
};
}
如果你想了解我是如何做到这一点的,请一步一步地阅读这篇文章:https://medium.com/javarevisited/polymorphic-stream-collector-in-java-44f9008bf043?sk=b92590ad1c65a3731746404dbd53b0f7