Binomial Expansion in Java

The challenge

Write a program that can do some algebra. Write a function expand that takes in an expression with a single, one character variable, and expands it. The expression is in the form (ax+b)^n where a and b are integers which may be positive or negative, x is any single character variable, and n is a natural number. If a = 1, no coefficient will be placed in front of the variable. If a = -1, a “-” will be placed in front of the variable.

The expanded form should be returned as a string in the form ax^b+cx^d+ex^f... where ac, and e are the coefficients of the term, x is the original one character variable that was passed in the original expression and bd, and f, are the powers that x is being raised to in each term and are in decreasing order. If the coefficient of a term is zero, the term should not be included. If the coefficient of a term is one, the coefficient should not be included. If the coefficient of a term is -1, only the “-” should be included. If the power of the term is 0, only the coefficient should be included. If the power of the term is 1, the caret and power should be excluded.

Examples

Solution.expand("(x+1)^2");      // returns "x^2+2x+1"
Solution.expand("(p-1)^3");      // returns "p^3-3p^2+3p-1"
Solution.expand("(2f+4)^6");     // returns "64f^6+768f^5+3840f^4+10240f^3+15360f^2+12288f+4096"
Solution.expand("(-2a-4)^0");    // returns "1"
Solution.expand("(-12t+43)^2");  // returns "144t^2-1032t+1849"
Solution.expand("(r+0)^203");    // returns "r^203"
Solution.expand("(-x-1)^2");     // returns "x^2+2x+1"

The solution in Java code

Option 1:

import java.util.regex.*;
import java.util.*;
import java.lang.*;

public class Solution {
    public static String expand(String expr) {
        Pattern pattern = Pattern.compile("\\((-?\\d*)(.)([-+]\\d+)\\)\\^(\\d+)");
        Matcher matcher = pattern.matcher(expr);
        matcher.find();
      
        final String _a = matcher.group(1);
        final int a = _a.isEmpty() ? 1 : _a.equals("-") ? -1 : Integer.parseInt(_a);
        final String x = matcher.group(2);
        final int b = Integer.parseInt(matcher.group(3).replace("+", ""));
        final int n = Integer.parseInt(matcher.group(4).replace("+", ""));
        double f = Math.pow((double)a, n);
      
        if (n == 0) return "1";
        if (a == 0) return String.format("%d", (int)Math.pow((double)b, n));
        if (b == 0) return String.format("%d%s%s", (int)f, x, (n > 1) ? String.format("^%d", n) : "");
      
        final StringBuilder result = new StringBuilder();
        for (int i = 0; i <= n; i++) {
            if (f > 0 && i > 0) result.append("+");
            if (f < 0) result.append("-");
            if (i > 0 || f * f > 1) result.append((long)Math.abs(f));
            if (i < n) result.append(x);
            if (i < n - 1) result.append(String.format("^%d", n - i));
            f = f * (n - i) * b / (double)a / (double)(i + 1);
        }
      
        return result.toString();
    }
}

Option 2:

import java.util.*;
import java.util.regex.*;
import java.util.stream.*;
import java.math.BigInteger;

public class Solution {
    
    final static private Pattern P_POLY  = Pattern.compile("\\((-?\\d*)(\\w+)\\+?(-?\\d+)\\)\\^(\\d+)"),
                                 CLEANER = Pattern.compile("\\b1(?=[a-z])|\\^1\\b|\\+$|\\+(?=-)");
    
    public static String expand(String expr) {
        Matcher m = P_POLY.matcher(expr);
        m.find();
        int i = 1;
        String as = m.group(i++),
               xs = m.group(i++),
               bs = m.group(i++),
               es = m.group(i++);
        
        if ("0".equals(es)) return "1";
        
        int        e = Integer.parseInt(es);
        BigInteger a = new BigInteger( "-".equals(as) ? "-1" : as.isEmpty() ? "1" : as),
                   b = new BigInteger(bs);
        
        List<BigInteger> poly = new ArrayList<>(Arrays.asList(a,b)),
                         tmp  = null;
        
        for (i=0 ; i<e-1 ; i++) {
            poly.add(BigInteger.ZERO);
            tmp = poly.stream().collect(Collectors.toList());
            int s = poly.size();
            for (int j=0 ; j<s ; j++) { tmp.set(j, a.multiply(poly.get(j))
                                                    .add( b.multiply(poly.get((s+j-1) % s))) ); }
            poly = tmp;
        }
        
        StringBuilder sb = new StringBuilder();
        i = -1;
        for (BigInteger coef: poly) { i++;
            if (coef.equals(BigInteger.ZERO)) continue;
            if (i==e) sb.append(""+coef);
            else      sb.append(String.format("%d%s^%d", coef, xs, e-i));
            sb.append("+");
        }
        return CLEANER.matcher(sb.toString())
                      .replaceAll("");
    }
}

Option 3:

public class Solution {

  public static String expand(String expr) {
    String[] vals = extractValues(expr).split(":");
    if(Integer.parseInt(vals[0]) == 0) return "1";
    if(Integer.parseInt(vals[0]) == 1) return refine(vals[1] + "+" + vals[2]);

    return higherOrder(vals[1], Integer.parseInt(vals[2]), Integer.parseInt(vals[0]));
  }

  private static String extractValues(String expr){
    String format = "\\((-?[a-z0-9]+)\\+?(-?[a-z0-9]+)\\)\\^([0-9]+)";
    return   expr.replaceAll(format,"$3") + ":"
           + expr.replaceAll(format,"$1") + ":"
           + expr.replaceAll(format,"$2");
  }

  private static String refine(String expr){
    expr = expr.replaceAll("-1([a-z])", "-$1");
    return expr.replace("+-", "-");
  }

  private static String higherOrder(String s1, int s2, int pow){
    String res = java.util.stream.IntStream
                  .rangeClosed(0, pow)
                  .mapToObj(k -> binoThm(pow, k, s1, s2))
                  .filter(s -> s.length() > 0)
                  .collect(java.util.stream.Collectors.joining("+"));    
    return refine(res);
  }

  private static String binoThm(int n, int r, String s1, int n2){
    long nCr = nCr(n, r);
    int  n1  = getNum(s1);
         s1  = s1.replaceAll("[0-9\\-]","");

    long coeff  = (long)(nCr * Math.pow(n1, n - r) * Math.pow(n2, r));
    if(coeff == 0) return "";
    
    String var = (n - r > 0 ?( s1 + ((n - r > 1) ? ("^" + (n - r)) : "")) : "");
    return ((coeff == 1 && n - r > 0)? "" : coeff) + "" + var;
  }

  private static int getNum(String s){
    try{
      String val = s.replaceAll("[a-z]","");
      if(val.equals("-")) return -1;
      return (Integer.parseInt(val));
    }catch(Exception e){
    }
    return 1;
  }

  private static long nCr(int n, int r){
    long p = 1, k = 1;
    if (n - r < r) 
      r = n - r;

    if (r != 0) {
      while (r > 0) {
        p *= n;
        k *= r;
        long m = gcd(p, k);
        p /= m;
        k /= m;
        n--;
        r--;
      }
    }
    return p;
  }

  private static long gcd(long a, long b) { 
    return b == 0 ? a : gcd(b, a%b); 
  }

}

Test cases to validate our solution

import org.junit.Test;
import static org.junit.Assert.assertEquals;

public class ExampleTest {

	@Test
	public void testBPositive() {        
		assertEquals("1", Solution.expand("(x+1)^0"));
		assertEquals("x+1", Solution.expand("(x+1)^1"));
		assertEquals("x^2+2x+1", Solution.expand("(x+1)^2"));
    }
	
	@Test
	public void testBNegative() {        
		assertEquals("1", Solution.expand("(x-1)^0"));
		assertEquals("x-1", Solution.expand("(x-1)^1"));
		assertEquals("x^2-2x+1", Solution.expand("(x-1)^2"));
	}
	
	@Test
	public void testAPositive() {        
		assertEquals("625m^4+1500m^3+1350m^2+540m+81", Solution.expand("(5m+3)^4"));
		assertEquals("8x^3-36x^2+54x-27", Solution.expand("(2x-3)^3"));
		assertEquals("1", Solution.expand("(7x-7)^0"));
	}
	
	@Test
	public void testANegative() {        
		assertEquals("625m^4-1500m^3+1350m^2-540m+81", Solution.expand("(-5m+3)^4"));
		assertEquals("-8k^3-36k^2-54k-27", Solution.expand("(-2k-3)^3"));
		assertEquals("1", Solution.expand("(-7x-7)^0"));
	}
}