// Copyright 2000-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.intellij.spring.model.highlighting.dom;

import com.intellij.codeInsight.daemon.impl.quickfix.ExtendsListFix;
import com.intellij.codeInspection.LocalQuickFix;
import com.intellij.openapi.util.text.StringUtil;
import com.intellij.psi.PsiClass;
import com.intellij.psi.PsiType;
import com.intellij.psi.util.PsiTypesUtil;
import com.intellij.spring.SpringApiBundle;
import com.intellij.spring.model.SpringBeanPointer;
import com.intellij.spring.model.xml.RequiredBeanType;
import com.intellij.util.SmartList;
import com.intellij.util.xml.DomElement;
import com.intellij.util.xml.DomJavaUtil;
import com.intellij.util.xml.GenericAttributeValue;
import com.intellij.util.xml.highlighting.DomElementAnnotationHolder;

import java.util.ArrayList;
import java.util.List;

/**
 * @author Yann C&eacute;bron
 * @see RequiredBeanType
 */
public final class RequiredBeanTypeChecker {

  private RequiredBeanTypeChecker() {
  }

  public static void check(DomElement element, DomElementAnnotationHolder holder) {
    if (!(element instanceof GenericAttributeValue)) {
      return;
    }

    final RequiredBeanType requiredBeanType = element.getAnnotation(RequiredBeanType.class);
    if (requiredBeanType == null) {
      return;
    }

    Object value = ((GenericAttributeValue)element).getValue();
    if (value == null) {
      return;
    }

    final String[] requiredClasses = requiredBeanType.value();
    short notFoundRequiredClasses = 0;
    final List<PsiClass> foundRequiredClasses = new SmartList<>();
    for (String requiredClassName : requiredClasses) {
      final PsiClass requiredClass = DomJavaUtil.findClass(requiredClassName, element);
      if (requiredClass == null) {
        notFoundRequiredClasses++;
        continue;
      }
      foundRequiredClasses.add(requiredClass);
    }

    // stop if we cannot resolve any of base class(es)
    final boolean isOneRequiredClass = requiredClasses.length == 1;
    if (notFoundRequiredClasses == requiredClasses.length) {
      final String message = isOneRequiredClass
                             ? SpringApiBundle.message("bean.base.class.not.found", requiredClasses[0])
                             : SpringApiBundle.message("bean.base.classes.not.found", StringUtil.join(requiredClasses, ","));
      holder.createProblem(element, message);
      return;
    }

    if (value instanceof SpringBeanPointer) {
      SpringBeanPointer<?>  springBeanPointer = (SpringBeanPointer)value;
      checkSpringBeanPointer(element, holder, requiredClasses, foundRequiredClasses, isOneRequiredClass, springBeanPointer, false);
    }
    else if (value instanceof List) {
      @SuppressWarnings("unchecked")
      List<SpringBeanPointer<?>> pointers = (List<SpringBeanPointer<?>>)value;
      for (SpringBeanPointer pointer : pointers) {
        checkSpringBeanPointer(element, holder, requiredClasses, foundRequiredClasses, isOneRequiredClass, pointer, true);
      }
    }
    else {
      throw new IllegalArgumentException("must (List)SpringBeanPointer: " + element);
    }
  }

  private static void checkSpringBeanPointer(DomElement element,
                                             DomElementAnnotationHolder holder,
                                             String[] requiredClasses,
                                             List<PsiClass> foundRequiredClasses,
                                             boolean oneRequiredClass,
                                             SpringBeanPointer<?>  springBeanPointer,
                                             boolean multiple) {

    PsiType[] psiTypes = springBeanPointer.getEffectiveBeanTypes();
    for (PsiClass requiredClass : foundRequiredClasses) {
      for (PsiType psiType : psiTypes) {
        if (PsiTypesUtil.getClassType(requiredClass).isAssignableFrom(psiType)) {
          return;
        }
      }
    }

    final String classText = oneRequiredClass ? requiredClasses[0] : StringUtil.join(requiredClasses, ",");
    final String message;
    if (!multiple) {
      message = oneRequiredClass
                ? SpringApiBundle.message("bean.must.be.of.type", classText)
                : SpringApiBundle.message("bean.must.be.one.of.these.types", classText);
    }
    else {
      final String beanName = springBeanPointer.getName();
      message = oneRequiredClass
                ? SpringApiBundle.message("bean.name.must.be.of.type", beanName, classText)
                : SpringApiBundle.message("bean.name.must.be.one.of.these.types", beanName, classText);
    }

    List<LocalQuickFix> quickfixes = new ArrayList<>(foundRequiredClasses.size());
    if (psiTypes.length > 0) {
      final PsiClass psiClass = PsiTypesUtil.getPsiClass(psiTypes[0]);
      if (psiClass != null) {
        for (PsiClass foundRequiredClass : foundRequiredClasses) {
          quickfixes.add(new ExtendsListFix(psiClass, foundRequiredClass, true));
        }
      }
    }
    holder.createProblem(element, message, quickfixes.toArray(LocalQuickFix.EMPTY_ARRAY));
  }
}
