poly.rb 17.8 KB
Newer Older
1
# poly.rb -- polynomial-related stuff; poly.scm --> poly.rb
2 3

# Translator: Michael Scholz <mi-scholz@users.sourceforge.net>
4
# Created: 05/04/09 23:55:07
5
# Changed: 17/11/30 22:57:04
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52

# class Complex
#  to_f
#  to_f_or_c
#  
# class Poly < Vec
#  inspect
#  to_poly
#  reduce
#  +(other)
#  *(other)
#  /(other)
#  derivative
#  resultant(other)
#  discriminant
#  gcd(other)
#  roots
#  eval(x)
#
# class Float
#  +(other)
#  *(other)
#  /(other)
#
# class String
#  to_poly
#
# class Array
#  to_poly
#
# class Vct
#  to_poly
#
# Poly(obj)
# make_poly(len, init, &body)
# poly?(obj)
# poly(*vals)
# poly_reduce(obj)
# poly_add(obj1, obj2)
# poly_multiply(obj1, obj2)
# poly_div(obj1, obj2)
# poly_derivative(obj)
# poly_gcd(obj1, obj2)
# poly_roots(obj)

require "clm"
require "mix"
53
include Math
54 55

class Complex
56 57 58 59
  # XXX: attr_writer :real, :imag
  #      Doesn't work any longer.
  #      Complex objects are now frozen objects.
  #      (Thu Nov 30 21:29:10 CET 2017)
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
  with_silence do
    def to_f
      self.real.to_f
    end
  end

  def to_f_or_c
    self.imag.zero? ? self.to_f : self
  end
end

class Poly < Vec
  Poly_roots_epsilon = 1.0e-6

  def inspect
    @name = "poly"
    super
  end
  
  def to_poly
    self
  end
  
  def reduce
    if self.last.zero?
      i = self.length - 1
      while self[i].zero? and i > 0
        i -= 1
      end
      self[0, i + 1]
    else
      self
    end
  end
  # [1, 2, 3].to_poly.reduce             ==> poly(1.0, 2.0, 3.0)
  # poly(1, 2, 3, 0, 0, 0).reduce        ==> poly(1.0, 2.0, 3.0)
  # vct(0, 0, 0, 0, 1, 0).to_poly.reduce ==> poly(0.0, 0.0, 0.0, 0.0, 1.0)
  
  def poly_add(other)
    assert_type((array?(other) or vct?(other) or number?(other)),
                other, 0, "a poly, a vct an array, or a number")
    if number?(other)
      v = self.dup
      v[0] += other
      v
    else
      if self.length > other.length
        self.add(other)
      else
        Poly(other).add(self)
      end
    end
  end
  alias + poly_add
  # poly(0.1, 0.2, 0.3) + poly(0, 1, 2, 3, 4) ==> poly(0.1, 1.2, 2.3, 3.0, 4.0)
  # poly(0.1, 0.2, 0.3) + 0.5                 ==> poly(0.6, 0.2, 0.3)
  # 0.5 + poly(0.1, 0.2, 0.3)                 ==> poly(0.6, 0.2, 0.3)

  def poly_multiply(other)
    assert_type((array?(other) or vct?(other) or number?(other)),
                other, 0, "a poly, a vct, an array, or a number")
    if number?(other)
      Poly(self.scale(Float(other)))
    else
      len = self.length + other.length
      m = Poly.new(len, 0.0)
      self.each_with_index do |val1, i|
        other.each_with_index do |val2, j|
128
          m[i + j] = m[i + j] + val1 * val2
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
        end
      end
      m
    end
  end
  alias * poly_multiply
  # poly(1, 1) * poly(-1, 1)        ==> poly(-1.0, 0.0, 1.0, 0.0)
  # poly(-5, 1) * poly(3, 7, 2)     ==> poly(-15.0, -32.0, -3.0, 2.0, 0.0)
  # poly(-30, -4, 2) * poly(0.5, 1) ==> poly(-15.0, -32.0, -3.0, 2.0, 0.0)
  # poly(-30, -4, 2) * 0.5          ==> poly(-15.0, -2.0, 1.0)
  # 2.0 * poly(-30, -4, 2)          ==> poly(-60.0, -8.0, 4.0)

  def poly_div(other)
    assert_type((array?(other) or vct?(other) or number?(other)),
                other, 0, "a poly, a vct, an array, or a number")
    if number?(other)
      [self * (1.0 / other), poly(0.0)]
    else
      if other.length > self.length
        [poly(0.0), other.to_poly]
      else
        r = self.dup
        q = Poly.new(self.length, 0.0)
        n = self.length - 1
        nv = other.length - 1
        (n - nv).downto(0) do |i|
          q[i] = r[nv + i] / other[nv]
156 157 158 159 160 161
          (nv + i - 1).downto(i) do |j|
            r[j] = r[j] - q[i] * other[j - i]
          end
        end
        nv.upto(n) do |i|
          r[i] = 0.0
162 163 164 165 166 167
        end
        [q, r]
      end
    end
  end
  alias / poly_div
168 169 170 171 172 173 174 175 176 177 178 179
  # poly(-1.0, 0.0, 1.0) / poly(1.0, 1.0)
  #   ==> [poly(-1.0, 1.0, 0.0),       poly(0.0, 0.0, 0.0)]
  # poly(-15, -32, -3, 2) / poly(-5, 1)
  #   ==> [poly(3.0, 7.0, 2.0, 0.0),   poly(0.0, 0.0, 0.0, 0.0)]
  # poly(-15, -32, -3, 2) / poly(3, 1)
  #   ==> [poly(-5.0, -9.0, 2.0, 0.0), poly(0.0, 0.0, 0.0, 0.0)]
  # poly(-15, -32, -3, 2) / poly(0.5, 1)
  #   ==> [poly(-30.0, -4.0, 2.0, 0.0), poly(0.0, 0.0, 0.0, 0.0)]
  # poly(-15, -32, -3, 2) / poly(3, 7, 2)
  #   ==> [poly(-5.0, 1.0, 0.0, 0.0),  poly(0.0, 0.0, 0.0, 0.0)]
  # poly(-15, -32, -3, 2) / 2.0
  #   ==> [poly(-7.5, -16.0, -1.5, 1.0), poly(0.0)]
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197

  def derivative
    len = self.length - 1
    pl = Poly.new(len, 0.0)
    j = len
    (len - 1).downto(0) do |i|
      pl[i] = self[j] * j
      j -= 1
    end
    pl
  end
  # poly(0.5, 1.0, 2.0, 4.0).derivative ==> poly(1.0, 4.0, 12.0)

  def resultant(other)
    m = self.length
    m1 = m - 1
    n = other.length
    n1 = n - 1
198 199 200 201
    d = n1 + m1
    mat = Array.new(d) do
      Vct.new(d, 0.0)
    end
202 203
    n1.times do |i|
      m.times do |j|
204
        mat[i][i + j] = self[m1 - j]
205 206 207 208
      end
    end
    m1.times do |i|
      n.times do |j|
209
        mat[i + n1][i + j] = other[n1 - j]
210 211
      end
    end
212
    determinant(mat)
213 214 215 216 217 218 219 220 221 222 223
  end
  # poly(-1, 0, 1).resultant([1, -2, 1]) ==> 0.0
  # poly(-1, 0, 2).resultant([1, -2, 1]) ==> 1.0
  # poly(-1, 0, 1).resultant([1, 1])     ==> 0.0
  # poly(-1, 0, 1).resultant([2, 1])     ==> 3.0

  def discriminant
    self.resultant(self.derivative)
  end
  # poly(-1, 0, 1).discriminant ==> -4.0
  # poly(1, -2, 1).discriminant ==>  0.0
224 225 226 227 228 229 230 231
  # (poly(-1, 1) * poly(-1, 1) * poly(3, 1)).reduce.discriminant
  #   ==> 0.0
  # (poly(-1, 1) * poly(-1, 1) * poly(3, 1) * poly(2, 1)).reduce.discriminant
  #   ==> 0.0
  # (poly(1, 1) * poly(-1, 1) * poly(3, 1) * poly(2, 1)).reduce.discriminant
  #   ==> 2304.0
  # (poly(1, 1) * poly(-1, 1) * poly(3, 1) * poly(3, 1)).reduce.discriminant
  #   ==> 0.0
232 233
  
  def gcd(other)
234 235
    assert_type((array?(other) or vct?(other)), other, 0,
                "a poly, a vct or an array")
236 237 238
    if self.length < other.length
      poly(0.0)
    else
239 240 241
      qr = self.poly_div(other).map do |m|
        m.reduce
      end
242 243 244 245 246 247 248 249 250 251 252
      if qr[1].length == 1
        if qr[1][0].zero?
          Poly(other)
        else
          poly(0.0)
        end
      else
        qr[0].gcd(qr[1])
      end
    end
  end
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
  # (poly(2, 1) * poly(-3, 1)).reduce.gcd(poly(2, 1))
  #   ==> poly(2.0, 1.0)
  # (poly(2, 1) * poly(-3, 1)).reduce.gcd(poly(3, 1))
  #   ==> poly(0.0)
  # (poly(2, 1) * poly(-3, 1)).reduce.gcd(poly(-3, 1))
  #   ==> poly(-3.0, 1.0)
  # (poly(8, 1) * poly(2, 1) * poly(-3, 1)).reduce.gcd(poly(-3, 1))
  #   ==> poly(-3.0, 1.0)
  # (poly(8, 1) * poly(2, 1) *
  #  poly(-3, 1)).reduce.gcd((poly(8, 1) * poly(-3, 1)).reduce)
  #   ==> poly(-24.0, 5.0, 1.0)
  # poly(-1, 0, 1).gcd(poly(2, -2, -1, 1))
  #   ==> poly(0.0)
  # poly(2, -2, -1, 1).gcd(poly(-1, 0, 1))
  #   ==> poly(1.0, -1.0)
  # poly(2, -2, -1, 1).gcd(poly(-2.5, 1))
  #   ==> poly(0.0)
270 271 272

  def roots
    rts = poly()
273 274
    deg = self.length - 1
    if deg.zero?
275 276 277 278 279 280
      rts
    else
      if self[0].zero?
        if deg == 1
          poly(0.0)
        else
281 282 283
          Poly.new(deg) do |i|
            self[i + 1]
          end.roots.unshift(0.0)
284 285 286 287 288 289 290 291
        end
      else
        if deg == 1
          linear_root(self[1], self[0])
        else
          if deg == 2
            quadratic_root(self[2], self[1], self[0])
          else
292 293
            if deg == 3 and
               (rts = cubic_root(self[3], self[2], self[1], self[0]))
294 295
              rts
            else
296 297 298
              if deg == 4 and
                 (rts = quartic_root(self[4], self[3],
                                     self[2], self[1], self[0]))
299 300 301
                rts
              else
                ones = 0
302 303 304 305 306
                1.upto(deg) do |i|
                  if self[i].nonzero?
                    ones += 1
                  end
                end
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
                if ones == 1
                  nth_root(self[deg], self[0], deg)
                else
                  if ones == 2 and deg.even? and self[deg / 2].nonzero?
                    n = deg / 2
                    poly(self[0], self[deg / 2], self[deg]).roots.each do |qr|
                      rts.push(*nth_root(1.0, -qr, n.to_f))
                    end
                    rts
                  else
                    if deg > 3 and
                        ones == 3 and
                        (deg % 3).zero? and
                        self[deg / 3].nonzero? and
                        self[(deg * 2) / 3].nonzero?
                      n = deg / 3
323 324 325 326
                      poly(self[0],
                           self[deg / 3],
                           self[(deg * 2) / 3],
                           self[deg]).roots.each do |qr|
327 328 329 330 331 332 333 334 335 336 337 338
                        rts.push(*nth_root(1.0, -qr, n.to_f))
                      end
                      rts
                    else
                      q = self.dup
                      pp = self.derivative
                      qp = pp.dup
                      n = deg
                      x = Complex(1.3, 0.314159)
                      v = q.eval(x)
                      m = v.abs * v.abs
                      20.times do # until c_g?
339 340 341
                        if (dx = v / qp.eval(x)).abs <= Poly_roots_epsilon
                          break
                        end
342
                        20.times do
343 344 345
                          if dx.abs <= Poly_roots_epsilon
                            break
                          end
346 347 348 349 350 351 352 353
                          y = x - dx
                          v1 = q.eval(y)
                          if (m1 = v1.abs * v1.abs) < m
                            x = y
                            v = v1
                            m = m1
                            break
                          else
354
                            dx /= 4.0
355 356 357 358 359 360 361 362 363 364 365 366
                          end
                        end
                      end
                      x = x - self.eval(x) / pp.eval(x)
                      x = x - self.eval(x) / pp.eval(x)
                      if x.imag < Poly_roots_epsilon
                        q = q.poly_div(poly(-x.real, 1.0))
                        n -= 1
                      else
                        q = q.poly_div(poly(x.abs, 0.0, 1.0))
                        n -= 2
                      end
367 368 369 370 371 372
                      rts = if n > 0
                              q.car.reduce.roots
                            else
                              poly()
                            end
                      rts << x.to_f_or_c
373 374 375 376 377 378 379 380 381 382 383 384 385 386
                      rts
                    end
                  end
                end
              end
            end
          end
        end
      end
    end
  end
  
  def eval(x)
    sum = self.last
387 388 389
    self.reverse[1..-1].each do |val|
      sum = sum * x + val
    end
390 391 392 393
    sum
  end

  private
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
  def submatrix(mx, row, col)
    nmx = Array.new(mx.length - 1) do
      Vct.new(mx.length - 1, 0.0)
    end
    ni = 0
    mx.length.times do |i|
      if i != row
        nj = 0
        mx.length.times do |j|
          if j != col
            nmx[ni][nj] = mx[i][j]
            nj += 1
          end
        end
        ni += 1
      end
    end
    nmx
  end

  def determinant(mx)
    if mx.length == 1
      mx[0][0]
    else
      if mx.length == 2
        mx[0][0] * mx[1][1] - mx[0][1] * mx[1][0]
      else
        if mx.length == 3
          ((mx[0][0] * mx[1][1] * mx[2][2] +
            mx[0][1] * mx[1][2] * mx[2][0] +
            mx[0][2] * mx[1][0] * mx[2][1]) -
           (mx[0][0] * mx[1][2] * mx[2][1] +
            mx[0][1] * mx[1][0] * mx[2][2] +
            mx[0][2] * mx[1][1] * mx[2][0]))
        else
          sum = 0.0
          sign = 1
          mx.length.times do |i|
            mult = mx[0][i]
            if mult != 0.0
              sum = sum + sign * mult * determinant(submatrix(mx, 0, i))
            end
            sign = -sign
          end
          sum
        end
      end
    end
  end

444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
  # ax + b
  def linear_root(a, b)
    poly(-b / a)
  end

  # ax^2 + bx + c
  def quadratic_root(a, b, c)
    d = sqrt(b * b - 4.0 * a * c)
    poly((-b + d) / (2.0 * a), (-b - d) / (2.0 * a))
  end

  # ax^3 + bx^2 + cx + d
  def cubic_root(a, b, c, d)
    # Abramowitz & Stegun 3.8.2
    a0 = d / a
    a1 = c / a
    a2 = b / a
    q = (a1 / 3) - ((a2 * a2) / 9)
    r = ((a1 * a2 - 3 * a0) / 6) - ((a2 * a2 * a2) / 27)
    sq3r2 = sqrt(q * q * q + r * r)
    r1 = (r + sq3r2) ** (1 / 3.0)
    r2 = (r - sq3r2) ** (1 / 3.0)
    incr = (TWO_PI * Complex::I) / 3
    pl = poly(a0, a1, a2, 1)
    sqrt3 = sqrt(-3)
    3.times do |i|
      3.times do |j|
        s1 = r1 * exp(i * incr)
        s2 = r2 * exp(j * incr)
        z1 = simplify_complex((s1 + s2) - (a2 / 3))
        if pl.eval(z1).abs < Poly_roots_epsilon
475 476 477
          z2 = simplify_complex((-0.5 * (s1 + s2)) +
                                (a2 / -3) +
                                ((s1 - s2) * 0.5 * sqrt3))
478
          if pl.eval(z2).abs < Poly_roots_epsilon
479 480 481
            z3 = simplify_complex((-0.5 * (s1 + s2)) +
                                  (a2 / -3) +
                                  ((s1 - s2) * -0.5 * sqrt3))
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
            if pl.eval(z3).abs < Poly_roots_epsilon
              return poly(z1, z2, z3)
            end
          end
        end
      end
    end
    false
  end

  # ax^4 + bx^3 + cx^2 + dx + e
  def quartic_root(a, b, c, d, e)
    # Weisstein, "Encyclopedia of Mathematics"
    a0 = e / a
    a1 = d / a
    a2 = c / a
    a3 = b / a
    if yroot = poly((4 * a2 * a0) + -(a1 * a1) + -(a3 * a3 * a0),
                    (a1 * a3) - (4 * a0),
                    -a2,
                    1).roots
      yroot.each do |y1|
        r = sqrt((0.25 * a3 * a3) + (-a2 + y1))
        dd = if r.zero?
506 507 508
              sqrt((0.75 * a3 * a3) +
                   (-2 * a2) +
                   (2 * sqrt(y1 * y1 - 4 * a0)))
509 510 511 512 513
            else
              sqrt((0.75 * a3 * a3) + (-2 * a2) + (-(r * r)) +
                   (0.25 * ((4 * a3 * a2) + (-8 * a1) + (-(a3 * a3 * a3)))) / r)
            end
        ee = if r.zero?
514 515 516
              sqrt((0.75 * a3 * a3) +
                   (-2 * a2) +
                   (-2 * sqrt((y1 * y1) - (4 * a0))))
517 518
            else
              sqrt((0.75 * a3 * a3) + (-2 * a2) + (-(r * r)) +
519 520
                   (-0.25 *
                    ((4 * a3 * a2) + (-8 * a1) + (-(a3 * a3 * a3)))) / r)
521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538
            end
        z1 = (-0.25 * a3) + ( 0.5 * r) + ( 0.5 * dd)
        z2 = (-0.25 * a3) + ( 0.5 * r) + (-0.5 * dd)
        z3 = (-0.25 * a3) + (-0.5 * r) + ( 0.5 * ee)
        z4 = (-0.25 * a3) + (-0.5 * r) + (-0.5 * ee)
        if poly(e, d, c, b, a).eval(z1).abs < Poly_roots_epsilon
          return poly(z1, z2, z3, z4)
        end
      end
    end
    false
  end
  
  # ax^n + b
  def nth_root(a, b, deg)
    n = (-b / a) ** (1.0 / deg)
    incr = (TWO_PI * Complex::I) / deg
    rts = poly()
539 540 541
    deg.to_i.times do |i|
      rts.unshift(simplify_complex(exp(i * incr) * n))
    end
542 543 544 545 546 547 548 549
    rts
  end

  Poly_roots_epsilon2 = 1.0e-6
  def simplify_complex(a)
    if a.imag.abs < Poly_roots_epsilon2
      (a.real.abs < Poly_roots_epsilon2) ? 0.0 : a.real.to_f
    else
550
      if a.real.abs < Poly_roots_epsilon2
551 552 553
        # XXX: a.real = 0.0
        #      Doesn't work any longer (see above, class Complex).
        a = Complex(0.0, a.imag)
554
      end
555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624
      a
    end
  end
end

class Float
  unless defined? 0.0.poly_plus
    alias fp_plus +
    def poly_plus(other)
      case other
      when Poly
        other[0] += self
        other
      else
        self.fp_plus(other)
      end
    end
    alias + poly_plus
  end

  unless defined? 0.0.poly_times
    alias fp_times *
    def poly_times(other)
      case other
      when Poly
        Poly(other.scale(self))
      else
        self.fp_times(other)
      end
    end
    alias * poly_times
  end

  unless defined? 0.0.poly_div
    alias fp_div /
    def poly_div(other)
      case other
      when Poly
        [poly(0.0), other]
      else
        self.fp_div(other)
      end
    end
    alias / poly_div
  end
end

class String
  def to_poly
    if self.scan(/^poly\([-+,.)\d\s]+/).null?
      poly()
    else
      eval(self)
    end
  end
end

class Array
  def to_poly
    poly(*self)
  end
end

class Vct
  def to_poly
    poly(*self.to_a)
  end
end

def Poly(obj)
625 626 627
  if obj.nil?
    obj = []
  end
628
  assert_type(obj.respond_to?(:to_poly), obj, 0,
629
              "an object containing method 'to_poly'")
630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651
  obj.to_poly
end

def make_poly(len, init = 0.0, &body)
  Poly.new(len, init, &body)
end

def poly?(obj)
  obj.instance_of?(Poly)
end

def poly(*vals)
  Poly.new(vals.length) do |i|
    if integer?(val = vals[i])
      Float(val)
    else
      val
    end
  end
end

def poly_reduce(obj)
652 653
  assert_type(obj.respond_to?(:to_poly), obj, 0,
              "an object containing method 'to_poly'")
654 655 656 657 658
  Poly(obj).reduce
end

def poly_add(obj1, obj2)
  if number?(obj1)
659 660
    assert_type(obj2.respond_to?(:to_poly), obj2, 1,
                "an object containing method 'to_poly'")
661 662
    Float(obj1) + Poly(obj2)
  else
663 664
    assert_type(obj1.respond_to?(:to_poly), obj1, 0,
                "an object containing method 'to_poly'")
665 666 667 668 669 670
    Poly(obj1) + obj2
  end
end

def poly_multiply(obj1, obj2)
  if number?(obj1)
671 672
    assert_type(obj2.respond_to?(:to_poly), obj2, 1,
                "an object containing method 'to_poly'")
673 674
    Float(obj1) * Poly(obj2)
  else
675 676
    assert_type(obj1.respond_to?(:to_poly), obj1, 0,
                "an object containing method 'to_poly'")
677 678 679 680 681 682
    Poly(obj1) * obj2
  end
end

def poly_div(obj1, obj2)
  if number?(obj1)
683 684
    assert_type(obj2.respond_to?(:to_poly), obj2, 1,
                "an object containing method 'to_poly'")
685 686
    Float(obj1) / Poly(obj2)
  else
687 688
    assert_type(obj1.respond_to?(:to_poly), obj1, 0,
                "an object containing method 'to_poly'")
689 690 691 692 693
    Poly(obj1) / obj2
  end
end

def poly_derivative(obj)
694 695
  assert_type(obj.respond_to?(:to_poly), obj, 0,
              "an object containing method 'to_poly'")
696 697 698 699
  Poly(obj).derivative
end

def poly_gcd(obj1, obj2)
700 701
  assert_type(obj.respond_to?(:to_poly), obj, 0,
              "an object containing method 'to_poly'")
702 703 704 705
  Poly(obj1).gcd(obj2)
end

def poly_roots(obj)
706 707
  assert_type(obj.respond_to?(:to_poly), obj, 0,
              "an object containing method 'to_poly'")
708 709 710 711
  Poly(obj).roots
end

# poly.rb ends here