JPA分页踩坑指南
1.原生sql查询返回vo类包含主键id,无法自动映射,需要用到投影ResultTransformer,所以我定义了一个投影工具类
JpaCommonService
2.异步调用原生查询方法的时候,需要用
NativeQuery<?> query = entityManager.createNativeQuery(sql).unwrap(NativeQuery.class);
,不能用NativeQueryImpl<?> query = entityManager.createNativeQuery(sql).unwrap(NativeQueryImpl.class);
3.用jpql查询entityManager.createQuery()分页查询,如分页会涉及到子查询是不支持的,只能用原生sql, 不支持select count(1) from (select * from User) t
以下是分页返回vo的具体实现
分页返回vo
下面是用到的一些类和方法
分页参数
package com.example.springbootjpadruid.domain.common;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class BasePage {
private int page = 1;
private int size = 10;
/**
* 是否需要分页
*/
private boolean needPage = false;
public int getPage() {
if (page < 1) {
return 0;
} else {
return page - 1;
}
}
public int getOffset() {
return getPage() * size;
}
}
分页返回结果
package com.example.springbootjpadruid.domain.common;
import cn.hutool.core.bean.BeanUtil;
import lombok.Data;
import org.springframework.data.domain.Page;
import java.util.List;
@Data
public class PageVO<T> {
private int page = 1;
private int size = 10;
private long total;
private int totalPage;
private List<T> list;
public static <S, T> PageVO<T> of(Page<S> page, Class<T> clazz) {
PageVO<T> pageVO = new PageVO<>();
pageVO.setPage(page.getNumber() + 1);
pageVO.setSize(page.getSize());
pageVO.setTotal(page.getTotalElements());
List<S> content = page.getContent();
pageVO.setList(BeanUtil.copyToList(content, clazz));
pageVO.setTotalPage(page.getTotalPages());
return pageVO;
}
public static <T> PageVO<T> of(Page<T> page) {
PageVO<T> pageVO = new PageVO<>();
pageVO.setPage(page.getNumber() + 1);
pageVO.setSize(page.getSize());
pageVO.setTotal(page.getTotalElements());
pageVO.setList(page.getContent());
pageVO.setTotalPage(page.getTotalPages());
return pageVO;
}
}
原生sql查询通用方法
package com.example.springbootjpadruid.service;
import com.alibaba.fastjson.JSON;
import com.example.springbootjpadruid.domain.common.BasePage;
import org.hibernate.query.NativeQuery;
import org.hibernate.transform.ResultTransformer;
import org.springframework.beans.SimpleTypeConverter;
import org.springframework.stereotype.Service;
import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
import javax.persistence.Query;
import java.beans.Introspector;
import java.beans.PropertyDescriptor;
import java.lang.reflect.Method;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
/**
* @author Jonny
* @description 原生sql查询,自动将下划线转驼峰
*/
@Service
public class JpaCommonService {
private static final SimpleTypeConverter CONVERTER = new SimpleTypeConverter();
@PersistenceContext
private EntityManager entityManager;
/**
* 查询数据集合
*
* @param sql 查询sql,参数用:name格式
* @param params 查询参数map格式,key对应参数中的:name
* @param clazz 实体类型,为空则直接转换为map格式
* @return list
*/
@SuppressWarnings("unchecked")
public <P extends BasePage> List<?> queryList(String sql, P params, Class<?> clazz) {
String jsonParam = JSON.toJSONString(params);
Map<String, Object> mapParams = JSON.parseObject(jsonParam, Map.class);
// 提取 SQL 中的参数并过滤无效参数
Set<String> sqlParams = extractSqlParams(sql);
Map<String, Object> filteredParams = filterParams(sqlParams, mapParams);
return queryList(sql, mapParams, filteredParams, clazz);
}
/**
* 总记录数
*
* @param sql 查询sql,参数用:name格式
* @param params 查询参数map格式,key对应参数中的:name
* @return list
*/
@SuppressWarnings("unchecked")
public <P extends BasePage> long count(String sql, P params) {
String jsonParam = JSON.toJSONString(params);
Map<String, Object> mapParams = JSON.parseObject(jsonParam, Map.class);
// 提取 SQL 中的参数并过滤无效参数
Set<String> sqlParams = extractSqlParams(sql);
Map<String, Object> filteredParams = filterParams(sqlParams, mapParams);
return count(sql, mapParams, filteredParams, Long.class);
}
/**
* 查询数据集合
*
* @param sql 查询sql,参数用:name格式
* @param params 查询参数map格式,key对应参数中的:name
* @param filteredParams sql中使用的参数
* @param clazz 实体类型,为空则直接转换为map格式
* @return list
*/
@SuppressWarnings("all")
public List<?> queryList(String sql, Map<String, Object> params, Map<String, Object> extractSqlParams, Class<?> clazz) {
try {
NativeQuery<?> query = entityManager.createNativeQuery(sql).unwrap(NativeQuery.class);
if (Objects.nonNull(extractSqlParams)) {
extractSqlParams.forEach((k, v) -> query.setParameter(k, v));
}
needPage(params, query);
if (Objects.isNull(clazz)) {
query.unwrap(NativeQuery.class).setResultTransformer(new AliasToEntityMapResultTransformer());
return query.getResultList();
} else {
query.unwrap(NativeQuery.class).setResultTransformer(new AliasToBeanResultTransformer(clazz));
return query.getResultList();
}
} catch (Exception e) {
throw new RuntimeException("查询或转换数据时出错", e);
}
}
@SuppressWarnings("all")
public long count(String countSql, Map<String, Object> params, Map<String, Object> filteredParams, Class<?> clazz) {
long total;
try {
Query countQuery = entityManager.createNativeQuery(countSql);
if (Objects.nonNull(filteredParams)) {
filteredParams.forEach(countQuery::setParameter);
}
total = Long.parseLong(countQuery.getSingleResult().toString());
} catch (Exception e) {
throw new RuntimeException("查询记录数出错", e);
}
return total;
}
/**
* 判断是否需要分页
* 当需要分页时,需要设置setNeedPage(true),并设置offset和size
*
* @param params 查询参数
* @param query 查询对象
*/
private void needPage(Map<String, Object> params, Query query) {
if (Boolean.TRUE.equals(params.get("needPage"))) {
int offset = (int) params.get("offset");
int size = (int) params.get("size");
query.setFirstResult(offset);
query.setMaxResults(size);
}
}
/**
* 从 SQL 中提取所有命名参数
*
* @param sql SQL 查询
* @return 参数名集合
*/
private static Set<String> extractSqlParams(String sql) {
Pattern pattern = Pattern.compile(":(\\w+)");
Matcher matcher = pattern.matcher(sql);
Set<String> params = new HashSet<>();
while (matcher.find()) {
params.add(matcher.group(1));
}
return params;
}
/**
* 根据 SQL 中的参数过滤多余的 Map 参数
*
* @param sqlParams SQL 中使用的参数
* @param params 原始参数
* @return 过滤后的参数
*/
private static Map<String, Object> filterParams(Set<String> sqlParams, Map<String, Object> params) {
if (params == null || params.isEmpty()) {
return Collections.emptyMap();
}
return params.entrySet().stream().filter(entry -> sqlParams.contains(entry.getKey())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
private static class AliasToEntityMapResultTransformer implements ResultTransformer {
@Override
public Object transformTuple(Object[] tuple, String[] aliases) {
Map<String, Object> result = new HashMap<>();
for (int i = 0; i < aliases.length; i++) {
if (aliases[i] != null) {
result.put(aliases[i].toLowerCase(), tuple[i]);
}
}
return result;
}
@Override
public List transformList(List collection) {
return collection;
}
}
private static class AliasToBeanResultTransformer implements ResultTransformer {
private final Class<?> resultClass;
public AliasToBeanResultTransformer(Class<?> resultClass) {
this.resultClass = resultClass;
}
/**
* 转换结果集 自定将下划线转驼峰
*
* @param tuple 结果集
* @param aliases 字段名
* @return 实体对象
*/
@Override
public Object transformTuple(Object[] tuple, String[] aliases) {
try {
Object result = resultClass.getDeclaredConstructor().newInstance();
PropertyDescriptor[] props = Introspector.getBeanInfo(resultClass).getPropertyDescriptors();
Map<String, Method> writeMethodMap = Arrays.stream(props).filter(p -> Objects.nonNull(p.getWriteMethod())).collect(Collectors.toMap(p -> p.getName().toLowerCase(), PropertyDescriptor::getWriteMethod));
for (int i = 0; i < aliases.length; i++) {
if (aliases[i] == null) {
continue;
}
String fieldName = aliases[i].toLowerCase().replace("_", "");
Method writeMethod = writeMethodMap.get(fieldName);
if (writeMethod != null) {
Object value = CONVERTER.convertIfNecessary(tuple[i], writeMethod.getParameterTypes()[0]);
writeMethod.invoke(result, value);
}
}
return result;
} catch (Exception e) {
throw new RuntimeException("实体映射错误: " + resultClass.getName(), e);
}
}
@Override
public List transformList(List collection) {
return collection;
}
}
}
使用方法
package com.example.springbootjpadruid.service.impl;
import com.example.springbootjpadruid.domain.common.PageVO;
import com.example.springbootjpadruid.domain.entity.primary.User;
import com.example.springbootjpadruid.domain.query.UserQuery;
import com.example.springbootjpadruid.domain.vo.UserVO;
import com.example.springbootjpadruid.repository.primary.UserRepository;
import com.example.springbootjpadruid.service.JpaCommonService;
import com.example.springbootjpadruid.service.UserService;
import com.github.wenhao.jpa.PredicateBuilder;
import com.github.wenhao.jpa.Specifications;
import org.hibernate.query.NativeQuery;
import org.hibernate.transform.ResultTransformer;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.PageRequest;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import javax.annotation.Resource;
import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
import javax.persistence.Query;
import java.util.List;
@Service
public class UserServiceImpl implements UserService {
@Resource
private UserRepository repository;
@PersistenceContext
private EntityManager entityManager;
@Resource
private JpaCommonService jpaCommonService;
@Override
public String init(List<User> users) {
repository.saveAll(users);
return "ok";
}
@Override
public User getUser() {
PredicateBuilder<User> where = Specifications.and();
where.eq("username", "zhangsan");
where.eq("age", 20);
return repository.findOne(where.build()).orElse(null);
}
/**
* 分页1
*
* @param param
* @return
*/
@Override
public PageVO<UserVO> listPage(UserQuery param) {
PageRequest pageable = PageRequest.of(param.getPage(), param.getSize());
Page<User> page = repository.findAll(pageable);
return PageVO.of(page, UserVO.class);
}
/**
* 分页2
*
* @param param 查询参数
* @return 分页数据
*/
@SuppressWarnings("all")
@Override
public PageVO<UserVO> nativeListPage(UserQuery param) {
StringBuffer sql = new StringBuffer("select id as id, username as username, age as age, address as address from sys_user where 1 = 1");
buildWhere(param, sql);
StringBuffer countSql = new StringBuffer("select count(1) from sys_user where 1 = 1");
buildWhere(param, countSql);
PageRequest pageable = PageRequest.of(param.getPage(), param.getSize());
NativeQuery query = entityManager.createNativeQuery(sql.toString()).unwrap(NativeQuery.class);
query.setFirstResult(param.getOffset());
query.setMaxResults(param.getSize());
// 主键id无法直接映射,需要手动设置
// query.setResultTransformer(Transformers.aliasToBean(UserVO.class));
// 由于返回数据包含主键id,无法直接映射,只能借助投影手动映射
query.setResultTransformer(new ResultTransformer() {
@Override
public Object transformTuple(Object[] objects, String[] strings) {
UserVO userVO = new UserVO();
userVO.setId(Long.parseLong(objects[0].toString()));
userVO.setUsername(objects[1].toString());
userVO.setAge(Integer.parseInt(objects[2].toString()));
userVO.setAddress(objects[3].toString());
return userVO;
}
@Override
public List transformList(List list) {
return list;
}
});
buildParam(param, query);
Query countQuery = entityManager.createNativeQuery(countSql.toString());
buildParam(param, countQuery);
long total = Long.parseLong(countQuery.getSingleResult().toString());
PageImpl<UserVO> pageImpl = new PageImpl<>(query.getResultList(), pageable, total);
return PageVO.of(pageImpl);
}
/**
* 分页3
* 原生sql自定转驼峰,带下划线不管大小自动转标准驼峰,全大写会转小写然后转标准驼峰
* 解决分页方法2的原生sql如果没用as别名,无法自动转驼峰的问题
*
* @param param
* @return
*/
@SuppressWarnings("all")
@Override
public PageVO<UserVO> listPageByJpaUtil(UserQuery param) {
// StringBuffer sql = new StringBuffer("select id, username, age, address, user_role as USER_ROLE from sys_user where 1 = 1");
StringBuffer sql = new StringBuffer("select * from sys_user where 1 = 1");
buildWhere(param, sql);
StringBuffer countSql = new StringBuffer("select count(1) from sys_user where 1 = 1");
buildWhere(param, countSql);
PageRequest pageable = PageRequest.of(param.getPage(), param.getSize());
param.setNeedPage(Boolean.TRUE);
List<UserVO> userList = (List<UserVO>) jpaCommonService.queryList(sql.toString(), param, UserVO.class);
long total = jpaCommonService.count(countSql.toString(), param);
PageImpl<UserVO> pageImpl = new PageImpl<>(userList, pageable, total);
return PageVO.of(pageImpl);
}
private void buildParam(UserQuery param, Query query) {
if (StringUtils.hasText(param.getUsername())) {
query.setParameter("username", param.getUsername());
}
}
private void buildWhere(UserQuery param, StringBuffer sql) {
if (StringUtils.hasText(param.getUsername())) {
sql.append(" and username = :username");
}
}
}